In [19]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from constants import numeric_features, categorical_features
from part2.shared import load_processed_data
from part3.Mixture import SimpleMixtureOfExperts
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from part2.shared import load_train_with_validation_data

In [20]:

df = load_processed_data()
df = df[df["Target"].isin(["Graduate", "Dropout"])]
df["y"] = (df["Target"] == "Graduate").astype(int)
df = df.drop(columns=["Target", "Target encoded"], errors='ignore')

In [21]:
X = df.drop(columns=["y"])
y = df["y"].values

In [22]:
num_features = [c for c in numeric_features if c in X.columns]
cat_features = [c for c in categorical_features if c in X.columns]

In [23]:
num_pipeline = Pipeline([
    ("imputer", SimpleImputer(strategy="mean")),
    ("scaler", StandardScaler())
])

In [24]:
cat_pipeline = Pipeline([
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("encoder", OneHotEncoder(handle_unknown="ignore"))
])

In [25]:
full_pipeline = ColumnTransformer([
    ("num", num_pipeline, num_features),
    ("cat", cat_pipeline, cat_features)
])

In [26]:
X_train_raw, X_val_raw, X_test_raw, y_train, y_val, y_test = load_train_with_validation_data(X, y)
X_train = full_pipeline.fit_transform(X_train_raw)
X_test = full_pipeline.transform(X_test_raw)
X_val = full_pipeline.transform(X_val_raw)

In [27]:
experts = [LogisticRegression(max_iter=1000, C=0.1, penalty="l1", solver="liblinear"), RandomForestClassifier(max_depth=20, min_samples_split=5, n_estimators=100)]
moe = SimpleMixtureOfExperts(experts=experts)

moe.fit(X_train, y_train)
y_pred = moe.predict(X_test)

print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.93      0.86      0.90       284
           1       0.91      0.96      0.94       443

    accuracy                           0.92       727
   macro avg       0.92      0.91      0.92       727
weighted avg       0.92      0.92      0.92       727

