# XGBoost Grid Search

This notebook performs a basic experiment for hyperparameter tuning for XGBoost models using grid search.

In [None]:
# Import libraries
import pandas as pd
from xgboost_grid_search import train_best_model

In [None]:

files = {'Dataset_1/2visit_CN_MCI.csv','Dataset_1/2visit_MCI_AD.csv', 'Dataset_1/3visit_CN_MCI.csv','Dataset_1/3visit_MCI_AD.csv', 'Dataset_1/4visit_CN_MCI.csv','Dataset_1/4visit_MCI_AD.csv'
         'Dataset_1/5visit_CN_MCI.csv','Dataset_1/5visit_MCI_AD.csv'}

param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.1, 0.2],
    'subsample': [0.6, 0.8, 1.0],
    'colsample_bytree': [0.6, 0.8, 1.0]
}

In [None]:
# Batch grid search across datasets and progression types, saving artifacts and reports
import os
from datetime import datetime
from xgboost_grid_search import train_best_model, create_delta_features

# Datasets and their progression type
files = [
    ("Dataset_1/2visit_CN_MCI.csv", "MCI"),
    ("Dataset_1/2visit_MCI_AD.csv", "AD"),
    ("Dataset_1/3visit_CN_MCI.csv", "MCI"),
    ("Dataset_1/3visit_MCI_AD.csv", "AD"),
    ("Dataset_1/4visit_CN_MCI.csv", "MCI"),
    ("Dataset_1/4visit_MCI_AD.csv", "AD"),
    ("Dataset_1/5visit_CN_MCI.csv", "MCI"),
    ("Dataset_1/5visit_MCI_AD.csv", "AD"),
]

param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.1, 0.2],
    'subsample': [0.6, 0.8, 1.0],
    'colsample_bytree': [0.6, 0.8, 1.0]
}

results_dir = "grid_results"
os.makedirs(results_dir, exist_ok=True)

for path, prog in files:
    try:
        df = pd.read_csv(path)
        base = os.path.splitext(os.path.basename(path))[0]
        csv_out = os.path.join(results_dir, f"{base}_{prog}_cv_scores.csv")
        model_base = f"{base}"
        print(f"\n=== Running grid search for {base} ({prog}) ===")
        model, cols = train_best_model(
            df,
            progression_type=prog,
            param_grid=param_grid,
            csv_path=csv_out,
            save_dir="saved_models",
            model_base_name=model_base,
            save_artifacts=True,
        )
    except Exception as e:
        print(f"Error processing {path}: {e}")