# Model Training
This notebook trains the XGBoost (Baseline) and MLP (Hybrid) models used for Pokemon Type Prediction.

In [None]:
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.multioutput import MultiOutputClassifier
from xgboost import XGBClassifier
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.preprocessing import MultiLabelBinarizer

from pokemon_predictor import config
from pokemon_predictor.tabular import load_data
from pokemon_predictor.losses import FocalLoss

## Train XGBoost

In [None]:
def train_xgboost():
    print("Training XGBoost Model...")
    X_train, X_test, y_train, y_test, classes = load_data('rgb', split_data=True)
    
    mlb = MultiLabelBinarizer()
    mlb.classes_ = classes
    
    model = MultiOutputClassifier(XGBClassifier(n_estimators=100, max_depth=5, learning_rate=0.1, colsample_bytree=0.75, n_jobs=-1, random_state=42))
    model.fit(X_train, y_train)
    
    score = model.score(X_test, y_test)
    print(f"XGBoost Test Accuracy (Subset): {score:.4f}")
    
    out_path = config.MODELS_DIR / "xgboost_model.pkl"
    joblib.dump(model, out_path)
    print(f"Saved XGBoost model to {out_path}")
    
    y_labels = pd.read_csv(config.PROCESSED_DATA_DIR / "y_labels.csv")
    y_list = []
    for _, row in y_labels.iterrows():
        types = [row['type1']]
        if pd.notna(row['type2']):
            types.append(row['type2'])
        y_list.append(types)
    mlb_full = MultiLabelBinarizer()
    mlb_full.fit(y_list)
    joblib.dump(mlb_full, config.MODELS_DIR / "mlb.pkl")
    print(f"Saved MLB to {config.MODELS_DIR / 'mlb.pkl'}")

train_xgboost()

## Train MLP

In [None]:
def train_mlp():
    print("\nTraining MLP Model (Hybrid)...")
    X_train, X_test, y_train, y_test, classes = load_data('hybrid', split_data=True)
    
    model = Sequential([
        Input(shape=(X_train.shape[1],)),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Dense(256, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(len(classes), activation='sigmoid')
    ])
    
    model.compile(optimizer=Adam(0.001), loss=FocalLoss(), metrics=['accuracy', 'binary_accuracy'])
    
    early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=50,
        batch_size=32,
        callbacks=[early_stop],
        verbose=1
    )
    
    out_path = config.MODELS_DIR / "mlp_model_optimized.h5"
    model.save(out_path)
    print(f"Saved MLP model to {out_path}")
    
    print("Finding best threshold...")
    y_pred = model.predict(X_test)
    best_thresh = 0.5
    best_f1 = 0.0
    
    from sklearn.metrics import f1_score
    for thresh in np.arange(0.1, 0.9, 0.05):
        y_pred_bin = (y_pred > thresh).astype(int)
        f1 = f1_score(y_test, y_pred_bin, average='micro')
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh
            
    print(f"Best Threshold: {best_thresh:.2f} (F1: {best_f1:.4f})")
    joblib.dump(best_thresh, config.MODELS_DIR / "best_threshold.pkl")

train_mlp()