# Explanability-Driven In-context Learning

# Importing the required modules

In [None]:
# modules used for data handling
import pandas as pd
import numpy as np
import json
import random

# custom preprocessing modules
from scripts.preprocess import preprocess_titanic, preprocess_mushroom, preprocess_diabetes, preprocess_loan

# modules used for model handling
from xgboost import XGBClassifier

# modules used for genari pipeline
from scripts.pipeline import Pipeline
from scripts.configs import Dataset, Model

# modules used for env variables
import os
from dotenv import load_dotenv
load_dotenv()

WANDB_API_KEY = os.getenv("WANDB_API_KEY")
WANDB_PROJECT_NAME = os.getenv("WANDB_PROJECT_NAME")
PROJECT_NAME = os.getenv("PROJECT_NAME")
BUCKET_NAME = os.getenv("BUCKET_NAME")

  from .autonotebook import tqdm as notebook_tqdm


## Default Configs

In [2]:
explanable_model = XGBClassifier()
tune_config_file = "data/tune_config/xgb.json"
reasoning_gen_model = Model(
    provider="together",
    name = "deepseek-ai/DeepSeek-R1",
    temperature=0.6,
    max_tokens=4096
)
objective_judge_model = Model(
    provider="anthropic",
    name="claude-haiku-4-5",
    temperature=0.6,
    max_tokens=4096
)
cot_model = Model(
    provider="google",
    name="gemini-2.5-flash",
    temperature=0.0,
    max_tokens=4096
)

## Test Pipeline

In [6]:
# dataset = Dataset(
#     name="titanic",
#     path="data/datasets/titanic_small.csv",
#     config_file_path="data/dataset_config/titanic_config.json",
#     shap_vals_path="data/shap_values/titanic_shap.csv",
#     preprocess_fn=preprocess_titanic,
#     target_col="Survived",
#     labels={0: "Did not survive", 1: "Survived"}
# )
dataset = Dataset(
    name="mushroom",
    path="data/datasets/mushrooms.csv",
    config_file_path="data/dataset_config/mushroom_config.json",
    shap_vals_path="data/shap_values/mushroom_shap.csv",
    preprocess_fn=preprocess_mushroom,
    target_col="class",
    labels={0: "Edible", 1: "Poisonous"}
)

In [7]:
llm = Pipeline(
    dataset=dataset,
    explanable_model=explanable_model,
    tune_config_file=tune_config_file,
    reasoning_gen_model=reasoning_gen_model,
    objective_judge_model=objective_judge_model,
    cot_model=cot_model
)

In [8]:
llm.run(baseline=True, objective_judge=True, cot_ablation=True)

[Mushroom] Dropped 2480 rows due to NaNs (kept 5644 rows).
Create sweep with ID: pt4fhcq4
Sweep URL: https://wandb.ai/hl3925-columbia-university/6998final/sweeps/pt4fhcq4
[XAI-MODEL] Completed hyperparameter tuning.
[XAI-MODEL] Trained model with best hyperparameters.
[XAI-MODEL] Logged explanation data to data/dataset_config/mushroom_config.json
[XAI-MODEL] Explanation process completed.
[Mushroom] Dropped 2480 rows due to NaNs (kept 5644 rows).
[GCS CLIENT] File data/batches/mushroom_zero-shot_baseline_batches.jsonl uploaded to batch_inputs/gemini/mushroom_zero-shot_baseline_batches.jsonl
[ZERO-SHOT] Submitted Job: projects/372383421945/locations/us-east4/batchPredictionJobs/274138210496413696
[ZERO-SHOT] Output base dir: gs://6998final-bucket/batch_outputs/gemini
[ZERO-SHOT] projects/372383421945/locations/us-east4/batchPredictionJobs/274138210496413696 state: JobState.JOB_STATE_QUEUED
[ZERO-SHOT] projects/372383421945/locations/us-east4/batchPredictionJobs/274138210496413696 state:

Uploading file mushroom_reasoning_batches.jsonl: 100%|██████████| 208k/208k [00:00<00:00, 414kB/s] 


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file mushroom_reasoning_predictions.jsonl: 100%|██████████| 316k/316k [00:00<00:00, 13.3MB/s]


[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/mushroom_reasoning_predictions.jsonl
[PIPELINE] Reasoning generation completed.
[Mushroom] Dropped 2480 rows due to NaNs (kept 5644 rows).
[OBJECTIVE JUDGE] Submitted batch with id: msgbatch_01NYv2WgZVk2yPpf1bSwxw4b
[OBJECTIVE JUDGE] Batch msgbatch_01NYv2WgZVk2yPpf1bSwxw4b is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01NYv2WgZVk2yPpf1bSwxw4b is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01NYv2WgZVk2yPpf1bSwxw4b has completed processing.
[OBJECTIVE JUDGE] Batch result types: {'succeeded': 53, 'errored': 0, 'expired': 0}
[OBJECTIVE JUDGE] Saved evaluations to data/batch_outputs/mushroom_objective_judge_evaluations.jsonl
[PIPELINE] Objective judge evaluation completed.
[Mushroom] Dropped 2480 rows due to NaNs (kept 5644 rows).
[GCS CLIENT] File data/batches/mushroom_icl_batches.jsonl uploaded to batch_inputs/gemini/mushroom_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/372383421945/locatio

In [1]:
llm.results

NameError: name 'llm' is not defined