# Explanability-Driven In-context Learning

# Importing the required modules

In [1]:
# modules used for data handling
import pandas as pd
import numpy as np
import json
import random

# custom preprocessing modules
from scripts.preprocess import preprocess_titanic

# modules used for model handling
from xgboost import XGBClassifier

# modules used for genari pipeline
from scripts.pipeline import Pipeline
from scripts.configs import Dataset, Model

# modules used for env variables
import os
from dotenv import load_dotenv
load_dotenv()

WANDB_API_KEY = os.getenv("WANDB_API_KEY")
WANDB_PROJECT_NAME = os.getenv("WANDB_PROJECT_NAME")
PROJECT_NAME = os.getenv("PROJECT_NAME")
BUCKET_NAME = os.getenv("BUCKET_NAME")



## Default Configs

In [2]:
explanable_model = XGBClassifier()
tune_config_file = "data/tune_config/xgb.json"
reasoning_gen_model = Model(
    provider="together",
    name = "deepseek-ai/DeepSeek-R1",
    temperature=0.6,
    max_tokens=4096
)
objective_judge_model = Model(
    provider="anthropic",
    name="claude-haiku-4-5",
    temperature=0.6,
    max_tokens=4096
)
cot_model = Model(
    provider="google",
    name="gemini-2.5-flash",
    temperature=0.0,
    max_tokens=4096
)

## Test Pipeline

In [3]:
dataset = Dataset(
    name="titanic",
    path="data/datasets/titanic_small.csv",
    config_file_path="data/dataset_config/titanic_config.json",
    shap_vals_path="data/shap_values/titanic_shap.csv",
    preprocess_fn=preprocess_titanic,
    target_col="Survived",
    labels={0: "Did not survive", 1: "Survived"}
)

In [4]:
llm = Pipeline(
    dataset=dataset,
    explanable_model=explanable_model,
    tune_config_file=tune_config_file,
    reasoning_gen_model=reasoning_gen_model,
    objective_judge_model=objective_judge_model,
    cot_model=cot_model
)

In [5]:
llm.run(baseline=True, objective_judge=True, cot_ablation=True)

[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
Create sweep with ID: bngoq9v8
Sweep URL: https://wandb.ai/gauravpendharkar/xai-guided-cot/sweeps/bngoq9v8
[XAI-MODEL] Completed hyperparameter tuning.
[XAI-MODEL] Trained model with best hyperparameters.
[XAI-MODEL] Logged explanation data to data/dataset_config/titanic_config.json
[XAI-MODEL] Explanation process completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[GCS CLIENT] File data/batches/titanic_zero-shot_baseline_batches.jsonl uploaded to batch_inputs/gemini/titanic_zero-shot_baseline_batches.jsonl
[ZERO-SHOT] Submitted Job: projects/54181826632/locations/us-east4/batchPredictionJobs/6150843143005667328
[ZERO-SHOT] Output base dir: gs://xai_guided_cot_bucket/batch_outputs/gemini
[ZERO-SHOT] projects/54181826632/locations/us-east4/batchPredictionJobs/6150843143005667328 state: JobState.JOB_STATE_QUEUED
[ZERO-SHOT] projects/54181826632/locations/us-east4/batchPredictionJobs/6150843143005667328 state: JobState.J

Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 18.6k/18.6k [00:00<00:00, 34.5kB/s]


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 52.0k/52.0k [00:00<00:00, 26.9MB/s]


[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/titanic_reasoning_predictions.jsonl
[PIPELINE] Reasoning generation completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[OBJECTIVE JUDGE] Submitted batch with id: msgbatch_019CHTtjiLir3HpUDR3oUE8M
[OBJECTIVE JUDGE] Batch msgbatch_019CHTtjiLir3HpUDR3oUE8M is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_019CHTtjiLir3HpUDR3oUE8M is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_019CHTtjiLir3HpUDR3oUE8M has completed processing.
[OBJECTIVE JUDGE] Batch result types: {'succeeded': 7, 'errored': 0, 'expired': 0}
[OBJECTIVE JUDGE] Saved evaluations to data/batch_outputs/titanic_objective_judge_evaluations.jsonl
[PIPELINE] Objective judge evaluation completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[GCS CLIENT] File data/batches/titanic_icl_batches.jsonl uploaded to batch_inputs/gemini/titanic_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/54181826632/locations/us-east4/ba

In [6]:
llm.results

{'zero_shot_baseline': {'xgboost': {'accuracy': 0.8484848484848485,
   'macro_f1_score': 0.8331648129423661},
  'zero-shot-prompting': {'macro_f1_score': 0.7954545454545454,
   'accuracy': 0.8181818181818182}},
 'zero_shot_cot_ablation': {'xgboost': {'accuracy': 0.8484848484848485,
   'macro_f1_score': 0.8331648129423661},
  'zero-shot-cot': {'macro_f1_score': 0.5909090909090909,
   'accuracy': 0.6363636363636364}},
 'reasoning': {60: 'The model correctly predicts that the passenger did not survive (predicted label 0 matches ground truth 0.0). The SHAP values reveal that the strongest negative contributions come from Sex (0.0, male) at -1.49 and Pclass (3.0, third class) at -0.996, aligning with historical trends where males and lower-class passengers had lower survival rates. Age (21.0) further reduces survival probability with a SHAP of -0.312, as young adult males were less prioritized during evacuation. Although having no siblings/spouses (SibSp=0.0) contributes positively (SHAP +0