# XAI-Guided CoT Pipeline Test


In [1]:
# Imports
import os, json, time
from datetime import datetime
import pandas as pd
import numpy as np
from xgboost import XGBClassifier
from dotenv import load_dotenv

from scripts.pipeline import Pipeline
from scripts.configs import Dataset, Model
from scripts.preprocess import (
    preprocess_titanic, preprocess_mushroom, 
    preprocess_diabetes, preprocess_loan
)

load_dotenv()
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
WANDB_PROJECT_NAME = os.getenv("WANDB_PROJECT_NAME")
PROJECT_NAME = os.getenv("PROJECT_NAME")
BUCKET_NAME = os.getenv("BUCKET_NAME")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Dataset Configurations
datasets = {
    "titanic": Dataset(
        name="titanic", path="data/datasets/titanic.csv",
        config_file_path="data/dataset_config/titanic_config.json",
        shap_vals_path="data/shap_values/titanic_shap.csv",
        preprocess_fn=preprocess_titanic, target_col="Survived",
        labels={0: "Did not survive", 1: "Survived"}
    ),
    "diabetes": Dataset(
        name="diabetes", path="data/datasets/diabetes.csv",
        config_file_path="data/dataset_config/diabetes_config.json",
        shap_vals_path="data/shap_values/diabetes_shap.csv",
        preprocess_fn=preprocess_diabetes, target_col="Outcome",
        labels={0: "No Diabetes", 1: "Has Diabetes"}
    ),
    "loan": Dataset(
        name="loan", path="data/datasets/loan_approval.csv",
        config_file_path="data/dataset_config/loan_config.json",
        shap_vals_path="data/shap_values/loan_shap.csv",
        preprocess_fn=preprocess_loan, target_col="loan_status",
        labels={0: "Rejected", 1: "Approved"}
    ),
    "mushroom": Dataset(
        name="mushroom", path="data/datasets/mushrooms.csv",
        config_file_path="data/dataset_config/mushroom_config.json",
        shap_vals_path="data/shap_values/mushroom_shap.csv",
        preprocess_fn=preprocess_mushroom, target_col="class",
        labels={0: "Edible", 1: "Poisonous"}
    )
}

# [TO DO] Update for the masked datasets
# Masked versions 
datasets_masked = {
    "titanic": Dataset(
        name="titanic_masked", path="data/datasets/titanic_masked.csv",
        config_file_path="data/dataset_config/titanic_masked_config.json",
        shap_vals_path="data/shap_values/titanic_masked_shap.csv",
        preprocess_fn=preprocess_titanic, target_col="Survived",
        labels={0: "0", 1: "1"}
    ),
    "diabetes": Dataset(
        name="diabetes_masked", path="data/datasets/diabetes_masked.csv",
        config_file_path="data/dataset_config/diabetes_masked_config.json",
        shap_vals_path="data/shap_values/diabetes_masked_shap.csv",
        preprocess_fn=preprocess_diabetes, target_col="Outcome",
        labels={0: "0", 1: "1"}
    ),
    "loan": Dataset(
        name="loan_masked", path="data/datasets/loan_masked.csv",
        config_file_path="data/dataset_config/loan_masked_config.json",
        shap_vals_path="data/shap_values/loan_masked_shap.csv",
        preprocess_fn=preprocess_loan, target_col="loan_status",
        labels={0: "0", 1: "1"}
    ),
    "mushroom": Dataset(
        name="mushroom_masked", path="data/datasets/mushroom_masked.csv",
        config_file_path="data/dataset_config/mushroom_masked_config.json",
        shap_vals_path="data/shap_values/mushroom_masked_shap.csv",
        preprocess_fn=preprocess_mushroom, target_col="class",
        labels={0: "0", 1: "1"}
    )
}

print(f"Datasets: {list(datasets.keys())}")
print(f"Masked datasets: {list(datasets_masked.keys())}")

Datasets: ['titanic', 'diabetes', 'loan', 'mushroom']
Masked datasets: ['titanic', 'diabetes', 'loan', 'mushroom']


In [3]:
# Model Configurations
reasoning_gen_model = Model(
    provider="together", name="deepseek-ai/DeepSeek-R1",
    temperature=0.6, max_tokens=4096
)
objective_judge_model = Model(
    provider="anthropic", name="claude-haiku-4-5",
    temperature=0.6, max_tokens=4096
)
cot_model = Model(
    provider="google", name="gemini-2.5-flash",
    temperature=0.0, max_tokens=4096
)

TUNE_CONFIG_FILE = "data/tune_config/xgb.json"
print("Models configured.")

Models configured.


In [None]:
# Helper Functions
def run_pipeline(
    dataset: Dataset,
    output_dir: str,
    masked: bool = False,
    baseline: bool = True,
    objective_judge: bool = True,
    cot_ablation: bool = True
) -> dict:
    
    os.makedirs(output_dir, exist_ok=True)
    suffix = "masked" if masked else "unmasked"
    filename = f"{dataset.name}_{suffix}_{int(time.time())}.json"
    output_path = os.path.join(output_dir, filename)
    
    print(f"\n{'='*60}")
    print(f"Running: {dataset.name} ({suffix})")
    print(f"{'='*60}")
    
    start_time = time.time()
    
    pipeline = Pipeline(
        dataset=dataset,
        explanable_model=XGBClassifier(),
        tune_config_file=TUNE_CONFIG_FILE,
        reasoning_gen_model=reasoning_gen_model,
        objective_judge_model=objective_judge_model,
        cot_model=cot_model
    )
    
    pipeline.run(
        baseline=baseline,
        objective_judge=objective_judge,
        cot_ablation=cot_ablation,
    )
    
    elapsed = time.time() - start_time
    
    results = {
        "dataset": dataset.name,
        "masked": masked,
        "elapsed_seconds": elapsed,
        "metrics": {k: v for k, v in pipeline.results.items() if k != "reasoning"}
    }
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Completed in {elapsed:.2f}s. Saved to {output_path}")
    return results


In [6]:

MASKED_DATASETS_TO_RUN = [
    # "titanic",
    # "diabetes",
    # "loan",
    # "mushroom",
]
UNMASKED_DATASETS_TO_RUN = [
    "titanic",
    # "diabetes",
    # "loan",
    # "mushroom",
]

# === Run Experiments ===
all_results = []


for name in MASKED_DATASETS_TO_RUN:
    result = run_pipeline(
        datasets_masked[name], output_dir="data/results/masked", masked=True
    )
    all_results.append(result)

for name in UNMASKED_DATASETS_TO_RUN:
    result = run_pipeline(
        datasets[name], output_dir="data/results/unmasked", masked=False
    )
    all_results.append(result)

print(f"\nCompleted {len(all_results)} experiments.")



Running: titanic (unmasked)
[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
Create sweep with ID: hz2727hs
Sweep URL: https://wandb.ai/hl3925-columbia-university/6998final/sweeps/hz2727hs
[XAI-MODEL] Completed hyperparameter tuning.
[XAI-MODEL] Trained model with best hyperparameters.
[XAI-MODEL] Logged explanation data to data/dataset_config/titanic_config.json
[XAI-MODEL] Explanation process completed.
[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
[GCS CLIENT] File data/batches/titanic_zero-shot_baseline_batches.jsonl uploaded to batch_inputs/gemini/titanic_zero-shot_baseline_batches.jsonl
[ZERO-SHOT] Submitted Job: projects/372383421945/locations/us-east4/batchPredictionJobs/1202038063409135616
[ZERO-SHOT] Output base dir: gs://6998final-bucket/batch_outputs/gemini
[ZERO-SHOT] projects/372383421945/locations/us-east4/batchPredictionJobs/1202038063409135616 state: JobState.JOB_STATE_RUNNING
[ZERO-SHOT] projects/372383421945/locations/us-east4/batchPredictionJobs/12

Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 18.6k/18.6k [00:00<00:00, 29.0kB/s]


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 50.2k/50.2k [00:00<00:00, 37.8MB/s]


[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/titanic_reasoning_predictions.jsonl
[PIPELINE] Reasoning generation completed.
[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
[OBJECTIVE JUDGE] Submitted batch with id: msgbatch_01R3dNpdMdcgrQdQD2MC7ANf
[OBJECTIVE JUDGE] Batch msgbatch_01R3dNpdMdcgrQdQD2MC7ANf is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01R3dNpdMdcgrQdQD2MC7ANf is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01R3dNpdMdcgrQdQD2MC7ANf has completed processing.
[OBJECTIVE JUDGE] Batch result types: {'succeeded': 6, 'errored': 0, 'expired': 0}
[OBJECTIVE JUDGE] Saved evaluations to data/batch_outputs/titanic_objective_judge_evaluations.jsonl
[PIPELINE] Objective judge evaluation completed.
[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
[GCS CLIENT] File data/batches/titanic_icl_batches.jsonl uploaded to batch_inputs/gemini/titanic_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/372383421945/locations/us-east4