In [2]:
# module
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
import numpy as np

In [3]:
# Util Functions
from enum import Enum

# 파일을 읽어와 pandas의 DataFrame으로 반환해줌
def get_df_from(path: str, sep: str = ','):
    extension = path.split('.')[-1]
    try:
        if extension == "csv" or extension == "txt":
            df = pd.read_csv(path, sep=sep, on_bad_lines='skip')
            if df.empty:
                print("W: DataFrame is empty.")
                return None
            return df
        else:
            print("E: File extension is not supported.")
            return None
    except FileNotFoundError:
        print(f"E: File not found. Check the path: {path}")
        return None
    except pd.errors.ParserError as e:
        print(f"E: Error parsing CSV file: {e}")
        return None
    except Exception as e:
        print(f"E: An unexpected error occured: {e}")
        return None

# 리스트의 평균을 반환.(기본값 0)
def calculate_mean(lst):
    if lst: # 빈 리스트가 아닌 경우에만 평균 계산
        return np.mean(lst)
    return 0 # 빈 리스트인 경우 NaN 반환

# 리스트의 합을 반환.(기본값 0)
def calculate_sum(lst):
    if lst:
        return np.sum(lst)
    return 0


class ChartShape(Enum):
    BAR = 1
    HISTOGRAM = 2
    LINE = 3
    HEATMAP = 4

# Chart 뽑는 함수
def get_chart_from(df: pd.DataFrame, x_col: str, y_col: str, shape: ChartShape = ChartShape.BAR, real_time: bool = False):
    plt.figure(figsize=(10, 6))
    if shape == ChartShape.BAR:
        plt.bar(df[x_col], df[y_col])
    elif shape == ChartShape.HISTOGRAM:
        plt.hist(df[x_col], bins=10)
    elif shape == ChartShape.LINE:
        plt.plot(df[x_col], df[y_col])
    else:
        plt.plot(df[x_col], df[y_col])

    if real_time:
        plt.gca().xaxis.set_major_locator(mdates.MonthLocator()) # 월 단위로 큰 눈금 표시
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) # 날짜 형식 지정
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    plt.grid(True)
    plt.tight_layout()
    plt.xticks(df[x_col], rotation=45, ha='right')
    plt.show()

def get_chart_from_series(sr: pd.Series, x_col: str, y_col: str, shape: ChartShape = ChartShape.BAR):
    if shape == ChartShape.BAR:
        sr.plot.bar()
    if shape == ChartShape.LINE:
        sr.plot.line()
    if shape == ChartShape.HISTOGRAM:
        sr.plot.hist()
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    plt.show()

def get_integer(question: str):
    try:
        id = int(input(question))
        return id
    except:
        print(f"E: invalid integer")
        return 0

In [4]:
# Function

folder_path = '../public/kmrd-small/'

file_paths = {
    'castings': folder_path + 'castings.csv',
    'countries': folder_path + 'countries.csv',
    'genres': folder_path + 'genres.csv',
    'movies': folder_path + 'movies.txt',
    'peoples': folder_path + 'peoples.txt',
    'rates': folder_path + 'rates.csv',
}

def get_df_strict_from(path: str, sep: str = ','):
    df = get_df_from(path, sep)
    for _ in range(3):
        if df is not None:
            return df
        new_path = input('오류가 발생했습니다. 정확한 경로를 확인해주세요: ')
        new_sep = input('구분자를 입력해주세요: ')
        df = get_df_from(new_path, new_sep)
    exit()

In [5]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
import xgboost as xgb

### One-Hot encoding

- 범주형 데이터를 숫자형 데이;터로 변환하는 방법 중 하나
- 일반적으로 단어를 벡터로 표현할 때 사용
- Ex) "나는" -> [1.0, 0.2, 0.3]

### 랜덤 포레스트 회귀 모델



In [6]:
rates_df = get_df_strict_from(file_paths["rates"])
movies_df = get_df_strict_from(file_paths["movies"], sep='\t')
genres_df = get_df_strict_from(file_paths["genres"])

rates_df = rates_df.drop('time', axis=1)

movies_df['title_eng'] = movies_df['title_eng'].str.replace(r',\s*\d{4}', '', regex=True)

genres_df = genres_df.drop_duplicates()

In [7]:
movie_genre = genres_df.groupby('movie')['genre'].apply(list)
movie_genre = pd.DataFrame(movie_genre)
movie_genre.columns = ['genres']
print(movie_genre.head())

                   genres
movie                    
10001       [드라마, 멜로/로맨스]
10002           [SF, 코미디]
10003           [SF, 코미디]
10004  [서부, SF, 판타지, 코미디]
10005   [판타지, 모험, SF, 액션]


In [8]:
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
genre_encoded = mlb.fit_transform(movie_genre['genres'])
genre_one_hot = pd.DataFrame(genre_encoded, columns=mlb.classes_).set_index(movie_genre.index)

genres_encoded = pd.concat([movie_genre, genre_one_hot], axis=1).drop('genres', axis=1)

print(genres_encoded.head())

       SF  가족  공포  느와르  다큐멘터리  드라마  멜로/로맨스  모험  뮤지컬  미스터리  범죄  서부  서사  스릴러  \
movie                                                                        
10001   0   0   0    0      0    1       1   0    0     0   0   0   0    0   
10002   1   0   0    0      0    0       0   0    0     0   0   0   0    0   
10003   1   0   0    0      0    0       0   0    0     0   0   0   0    0   
10004   1   0   0    0      0    0       0   0    0     0   0   1   0    0   
10005   1   0   0    0      0    0       0   1    0     0   0   0   0    0   

       애니메이션  액션  에로  전쟁  코미디  판타지  
movie                               
10001      0   0   0   0    0    0  
10002      0   0   0   0    1    0  
10003      0   0   0   0    1    0  
10004      0   0   0   0    1    1  
10005      0   1   0   0    0    1  


In [9]:
user_movie_genre = pd.merge(rates_df, genres_encoded, how='left', on='movie')
user_movie_genre = pd.get_dummies(user_movie_genre, columns=['movie'])
print(user_movie_genre.head())

   user  rate   SF   가족   공포  느와르  다큐멘터리  드라마  멜로/로맨스   모험  ...  movie_10978  \
0     0     7  1.0  0.0  0.0  0.0    0.0  0.0     0.0  0.0  ...        False   
1     0     7  1.0  0.0  0.0  0.0    0.0  0.0     0.0  0.0  ...        False   
2     0     9  1.0  1.0  0.0  0.0    0.0  0.0     0.0  1.0  ...        False   
3     0     9  0.0  0.0  0.0  0.0    0.0  1.0     0.0  0.0  ...        False   
4     0     7  0.0  0.0  0.0  0.0    0.0  1.0     0.0  0.0  ...        False   

   movie_10979  movie_10980  movie_10981  movie_10982  movie_10983  \
0        False        False        False        False        False   
1        False        False        False        False        False   
2        False        False        False        False        False   
3        False        False        False        False        False   
4        False        False        False        False        False   

   movie_10985  movie_10988  movie_10994  movie_10998  
0        False        False        False  

In [10]:

X = user_movie_genre.drop('rate', axis=1)
y = user_movie_genre['rate']


In [11]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [12]:
# X = user_movie_genre_encoded.drop('rate', axis=1)
# y = user_movie_genre_encoded['rate']

# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [14]:
# n_estimators: number of Decision Tree
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)

y_pred_rf = rf_model.predict(X_test)

mse_rf = mean_squared_error(y_test, y_pred_rf)
rmse_rf = np.sqrt(mse_rf)
mae_rf = mean_absolute_error(y_test, y_pred_rf)

In [24]:
print(mse_rf, rmse_rf, mae_rf)
X_test_cp = X_test.copy()
y_test_cp = y_test.copy()

5.018067787810454 2.240104414488408 1.2579123576125413


87069                                                       10
44148                                                       10
92462                                                       10
62922                                                        9
89565                                                       10
                                   ...                        
10627                                                        1
24184                                                        9
107702                                                      10
31326                                                       10
rate_pred    [4.51, 10.0, 10.0, 9.65, 9.93, 7.39, 9.9, 4.49...
Name: rate, Length: 28143, dtype: object
