In [1]:
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import pandas as pd
import numpy as np
from dataset import TMDBDataset

In [2]:
overview_cast_df = TMDBDataset(
    root = "./tmp",
    node_feature_method = "counter",
    node_feature_params = {'min_df': 0.1},
    node_feature_column_source = "keywords",
    add_additional_node_features = True,
    edge_weight_column_source = "cast",
    jaccard_distance_threshold = 0,
    graph_type = "homogenous"
)
overview_cast_df.y = np.log(overview_cast_df.y)

In [4]:
keywords_cast_df = TMDBDataset(
    root = "./tmp",
    node_feature_method = "counter",
    node_feature_params = {'min_df': 0.015},
    node_feature_column_source = "keywords",
    add_additional_node_features = True,
    edge_weight_column_source = "cast",
    jaccard_distance_threshold = 0,
    graph_type = "homogenous"
)
keywords_cast_df.y = np.log(keywords_cast_df.y)

In [46]:
df = TMDBDataset(
    root="./tmp",
    node_feature_method="counter",
    node_feature_params=0.015,
    node_feature_column_source="keywords",
    add_additional_node_features=True,
    edge_weight_column_source="cast",
    jaccard_distance_threshold=0,
    graph_type="heterogeneous",
)

In [47]:
datasets = {
    "overview": overview_cast_df,
    "keywords": keywords_cast_df,
}

In [58]:
import itertools

In [70]:
def train_baseline_models(dataset): 
    results = []
    X_train, X_test, y_train, y_test = train_test_split(dataset.x.numpy(), dataset.y.numpy().ravel(), test_size=0.2, random_state=42)
    models = {
        "LinearRegression": LinearRegression,
        "RandomForestRegressor": RandomForestRegressor,
        "XGBRegressor": XGBRegressor,
        "MLPRegressor": MLPRegressor
    }

    params = {
        "LinearRegression": {"fit_intercept": [True, False]},
        "RandomForestRegressor": {"max_depth": [3, 5, 8, 10], "max_features": [0.5, 0.7, 0.9]},
        "XGBRegressor": {"max_depth": [3, 5, 8, 10]},
        "MLPRegressor": {"hidden_layer_sizes": [(100, 50), (100, 50, 25), (100,)]}
    }

    for name, model_class in models.items():
        model_params = params[name]
        keys, values = zip(*model_params.items())
        combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
        for combination in combinations:
            model = model_class(**combination)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            mse_train = mean_squared_error(y_train, model.predict(X_train))
            mse_test = mean_squared_error(y_test, y_pred)
            results.append({"model": name, "mse_train": mse_train, "mse_test": mse_test})
    return pd.DataFrame(results)

In [72]:
overview_baseline = train_baseline_models(datasets["overview"])



In [74]:
overview_baseline.sort_values(by="mse_test").head(10)

Unnamed: 0,model,mse_train,mse_test
11,RandomForestRegressor,0.393127,1.104251
8,RandomForestRegressor,0.565075,1.108089
10,RandomForestRegressor,0.530861,1.114208
13,RandomForestRegressor,0.363508,1.115641
12,RandomForestRegressor,0.372565,1.125007
9,RandomForestRegressor,0.526836,1.127313
7,RandomForestRegressor,0.814792,1.132227
5,RandomForestRegressor,0.838574,1.134234
6,RandomForestRegressor,0.823087,1.140114
14,XGBRegressor,0.480009,1.141661


In [75]:
keywords_baseline = train_baseline_models(datasets["keywords"])



In [76]:
keywords_baseline.sort_values(by="mse_test").head(10)

Unnamed: 0,model,mse_train,mse_test
14,XGBRegressor,0.48965,1.06616
11,RandomForestRegressor,0.472122,1.087648
13,RandomForestRegressor,0.442478,1.100171
8,RandomForestRegressor,0.613914,1.102038
12,RandomForestRegressor,0.454443,1.104614
10,RandomForestRegressor,0.575502,1.105507
0,LinearRegression,0.886038,1.106514
9,RandomForestRegressor,0.591085,1.109919
5,RandomForestRegressor,0.829228,1.120952
6,RandomForestRegressor,0.819412,1.135546
