In [None]:
# trainer.py

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.base import clone
import config
import shap_analysis


def train_models(models, X_train, X_test, y_train, y_test):
    
    results = []
    kf = KFold(n_splits=config.N_SPLITS,
               shuffle=True,
               random_state=config.RANDOM_STATE)
    
    for model_name, base_model in models.items():
        
        print(f"\nModel: {model_name}")
        
        for i, target in enumerate(config.TARGET_COLUMNS):
            
            cv_scores = []
            
            for train_idx, val_idx in kf.split(X_train):
                
                model = clone(base_model)
                model.fit(X_train[train_idx], y_train[train_idx, i])
                
                preds = model.predict(X_train[val_idx])
                cv_scores.append(r2_score(
                    y_train[val_idx, i], preds
                ))
            
            final_model = clone(base_model)
            final_model.fit(X_train, y_train[:, i])
            
            test_preds = final_model.predict(X_test)
            
            r2 = r2_score(y_test[:, i], test_preds)
            rmse = np.sqrt(mean_squared_error(y_test[:, i], test_preds))
            mae = mean_absolute_error(y_test[:, i], test_preds)
            
            results.append([
                model_name, target,
                np.mean(cv_scores),
                r2, rmse, mae
            ])
            
            # Run SHAP
            shap_analysis.run_shap(
                final_model,
                model_name,
                target,
                X_train,
                X_test,
                config.FEATURE_COLUMNS
            )
    
    df_results = pd.DataFrame(results, columns=[
        "Model","Target",
        "CV_R2","Test_R2","Test_RMSE","Test_MAE"
    ])
    
    df_results.to_csv("results/final_results.csv", index=False)
    
    return df_results