# Explanability-Driven In-context Learning

# Importing the required modules

In [None]:
# modules used for data handling
import pandas as pd
import numpy as np
import json
import random

# modules used for modeling
from xgboost import XGBClassifier

# user-defined modules

from scripts.configs import Dataset, Model
from scripts.preprocess import preprocess_titanic
from scripts.postprocess import (parse_reasoning_llm_results, parse_baseline_llm_results, summarize_baseline_results)
from scripts.prompt_generator import (zero_shot_prompt_generator, 
                                      reasoning_generator_prompt,
                                      objective_judge_prompt_generator)

from scripts.explanable_tree_model import ExplainableModel
from scripts.zero_shot_baseline import ZeroShotBaseline
from scripts.diverse_examples import get_diverse_examples
from scripts.reason_generation import ReasonGenerator
from scripts.objective_judge import ObjectiveJudge

# modules used for env variables
import os
from dotenv import load_dotenv
load_dotenv()

WANDB_API_KEY = os.getenv("WANDB_API_KEY")
WANDB_PROJECT_NAME = os.getenv("WANDB_PROJECT_NAME")
PROJECT_NAME = os.getenv("PROJECT_NAME")
BUCKET_NAME = os.getenv("BUCKET_NAME")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
titanic_dataset = Dataset(
    name="titanic",
    path="data/datasets/titanic.csv",
    config_file_path="data/dataset_config/titanic_config.json",
    preprocess_fn=preprocess_titanic,
    shap_vals_path="data/shap_values/titanic_shap.csv",
    target_col="Survived"
)

# Tree Model Performance (XGBoost)

In [3]:
clf = XGBClassifier()
xmodel = ExplainableModel(
    dataset=titanic_dataset,
    estimator=clf
)
xmodel.explain(params_grid_file="data/tune_config/xgb.json")

[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
Create sweep with ID: yc32vi2s
Sweep URL: https://wandb.ai/gauravpendharkar/xai-guided-cot/sweeps/yc32vi2s


[34m[1mwandb[0m: Agent Starting Run: 1jdbwahq with config:
[34m[1mwandb[0m: 	learning_rate: 0.27854961660395194
[34m[1mwandb[0m: 	max_depth: 6
[34m[1mwandb[0m: 	min_child_weight: 1
[34m[1mwandb[0m: 	n_estimators: 339
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 0.08415814901838715
[34m[1mwandb[0m: 	subsample: 0.6877590380220275
[34m[1mwandb[0m: Currently logged in as: [33mmitugaurav15[0m ([33mgauravpendharkar[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [anthropic, google.genai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


0,1
accuracy,▁
accuracy_std,▁
f1_macro,▁
f1_macro_std,▁

0,1
accuracy,0.79801
accuracy_std,0.03593
f1_macro,0.7885
f1_macro_std,0.03672


[34m[1mwandb[0m: Agent Starting Run: 9wwfoqxj with config:
[34m[1mwandb[0m: 	learning_rate: 0.1139874026806906
[34m[1mwandb[0m: 	max_depth: 3
[34m[1mwandb[0m: 	min_child_weight: 3
[34m[1mwandb[0m: 	n_estimators: 400
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 0.9535245332694242
[34m[1mwandb[0m: 	subsample: 0.7282078377467783


0,1
accuracy,▁
accuracy_std,▁
f1_macro,▁
f1_macro_std,▁

0,1
accuracy,0.81731
accuracy_std,0.03909
f1_macro,0.80587
f1_macro_std,0.03998


[34m[1mwandb[0m: Agent Starting Run: b3lg0oto with config:
[34m[1mwandb[0m: 	learning_rate: 0.008197120293719072
[34m[1mwandb[0m: 	max_depth: 6
[34m[1mwandb[0m: 	min_child_weight: 3
[34m[1mwandb[0m: 	n_estimators: 471
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 0.2404745471506243
[34m[1mwandb[0m: 	subsample: 0.8750482603977314


0,1
accuracy,▁
accuracy_std,▁
f1_macro,▁
f1_macro_std,▁

0,1
accuracy,0.80677
accuracy_std,0.0468
f1_macro,0.79149
f1_macro_std,0.05384


[34m[1mwandb[0m: Agent Starting Run: itl3kizj with config:
[34m[1mwandb[0m: 	learning_rate: 0.0971233997861034
[34m[1mwandb[0m: 	max_depth: 4
[34m[1mwandb[0m: 	min_child_weight: 5
[34m[1mwandb[0m: 	n_estimators: 478
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 1.187004446147595
[34m[1mwandb[0m: 	subsample: 0.7698499456753943


0,1
accuracy,▁
accuracy_std,▁
f1_macro,▁
f1_macro_std,▁

0,1
accuracy,0.81554
accuracy_std,0.03343
f1_macro,0.80343
f1_macro_std,0.03509


[34m[1mwandb[0m: Agent Starting Run: w4juistf with config:
[34m[1mwandb[0m: 	learning_rate: 0.06860036760191021
[34m[1mwandb[0m: 	max_depth: 3
[34m[1mwandb[0m: 	min_child_weight: 5
[34m[1mwandb[0m: 	n_estimators: 363
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 1.8351968143525192
[34m[1mwandb[0m: 	subsample: 0.72028470097465


0,1
accuracy,▁
accuracy_std,▁
f1_macro,▁
f1_macro_std,▁

0,1
accuracy,0.81379
accuracy_std,0.04102
f1_macro,0.80039
f1_macro_std,0.04393


Completed hyperparameter tuning.


[34m[1mwandb[0m: Sorting runs by -summary_metrics.f1_macro


Trained model with best hyperparameters.
Logged explanation data to data/dataset_config/titanic_config.json
Explanation process completed.


# Zero Shot Baseline

In [3]:
model = Model(name="gemini-2.5-flash", temperature=0.0, max_tokens= 512)
baseline = ZeroShotBaseline(dataset=titanic_dataset, 
                            model=model, 
                            prompt_gen_fn=zero_shot_prompt_generator)
baseline.create_batch_prompts()
len(baseline.batches), baseline.batches[1]
for b in baseline.batches[:1]:
    print("KEY:", b["key"])
    print("PROMPT:\n", b["request"]["contents"][0]["parts"][0]["text"][:800], "\n---\n")
baseline.save_batches_as_jsonl()
print("Saved batch file to:", baseline.output_file)
baseline.upload_batches_to_gcs()
print("GCS URI:", baseline.gcp_uri)
baseline.submit_batch_inference_job()
baseline.download_job_outputs_from_gcs()

[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).
KEY: baseline_unmasked_batch-0
PROMPT:
 
            You are a classifier for the tabular dataset 'titanic'.
            Each example has features and a target label called 'Survived'.
            Given the following feature values for one example, predict the label.
            Return EXACTLY one of the following labels (just the value, no extra words):
            1, 0

            Here are the feature values:
            Pclass: 1.0
Sex: 1.0
Age: 24.0
SibSp: 0.0
Parch: 0.0
Fare: 69.3
Embarked: 1.0

            Question: What is the predicted value of 'Survived' for this example?

            Note: Answer with exactly one of the allowed label values, nothing else.
     
---

Saved batch file to: data/batches/titanic_baseline_batches.jsonl
File data/batches/titanic_baseline_batches.jsonl uploaded to batch_inputs/gemini/titanic_baseline_batches.jsonl
GCS URI: gs://6998final-bucket/batch_inputs/gemini/titanic_baseline_batches.jsonl

In [8]:
results_path = f"data/batch_outputs/titanic_baseline_predictions.jsonl"
unmasked_df, masked_df = parse_baseline_llm_results(
    results_jsonl_path=results_path, config_file_path="data/dataset_config/titanic_config.json"
)
with pd.option_context("display.max_rows", 40, "display.width", None):
    print("=== Unmasked Results ===")
    display(unmasked_df)

    print("\n=== Masked Results ===")
    display(masked_df)


# Get summary statistics
summary = summarize_baseline_results(unmasked_df, masked_df)
print("\n=== Summary ===")
for name, stats in summary.items():
    print(f"\n{name.upper()}:")
    for k, v in stats.items():
        print(f"  {k}: {v:.3f}" if isinstance(v, float) else f"  {k}: {v}")

=== Unmasked Results ===


Unnamed: 0,batch_id,test_idx,prediction,ground_truth,correct,finish_reason,completed,raw_output
0,0,641,1.0,1,True,STOP,1,1
1,1,496,1.0,1,True,STOP,1,1
2,2,262,1.0,0,False,STOP,1,1
3,3,311,1.0,1,True,STOP,1,1
4,4,551,0.0,0,True,STOP,1,0
...,...,...,...,...,...,...,...,...
138,138,362,0.0,0,True,STOP,1,0
139,139,56,1.0,0,False,STOP,1,1
140,140,137,1.0,1,True,STOP,1,1
141,141,651,1.0,1,True,STOP,1,1



=== Masked Results ===


Unnamed: 0,batch_id,test_idx,prediction,ground_truth,correct,finish_reason,completed,raw_output
0,0,641,1.0,1,True,STOP,1,1
1,1,496,,1,,MAX_TOKENS,0,
2,2,262,0.0,0,True,STOP,1,0
3,3,311,,1,,MAX_TOKENS,0,
4,4,551,0.0,0,True,STOP,1,0
...,...,...,...,...,...,...,...,...
138,138,362,0.0,0,True,STOP,1,0
139,139,56,,0,,MAX_TOKENS,0,
140,140,137,0.0,1,False,STOP,1,0
141,141,651,,1,,MAX_TOKENS,0,



=== Summary ===

UNMASKED:
  total: 143
  completed: 141
  correct: 102
  accuracy: 0.713
  accuracy_of_completed: 0.723

MASKED:
  total: 143
  completed: 122
  correct: 86
  accuracy: 0.601
  accuracy_of_completed: 0.705


# Reason Generation

In [3]:
reasoning_model = Model(
    name="deepseek-ai/DeepSeek-R1",
    temperature=0.6,
    max_tokens=4096
)

rg = ReasonGenerator(
    dataset=titanic_dataset,
    model=reasoning_model,
    prompt_gen_fn=reasoning_generator_prompt
)

rg.create_batch_prompts()

Found best number of clusters: k=3 with silhouette score: 0.2747174239109675
Chosen 3 diverse examples.
[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).


In [5]:
rg.save_batches_as_jsonl()

In [6]:
rg.submit_batches()

Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 6.40k/6.40k [00:00<00:00, 10.2kB/s]


Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.IN_PROGRESS
Current Status: BatchJobStatus.COMPLETED
Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 20.4k/20.4k [00:00<00:00, 5.88MB/s]


In [None]:
results_jsonl_path = f"data/batch_outputs/{titanic_dataset.name}_reasoning_predictions.jsonl"
results = []
with open(results_jsonl_path, 'r') as f:
    for line in f:
        results.append(json.loads(line.strip()))

In [None]:
parse_reasoning_llm_results(results_jsonl_path)

{862: "The model correctly predicts survival (1) for this passenger, matching the ground truth label (1.0). The prediction is primarily driven by two dominant features: **Sex** (SHAP +2.37) and **Pclass** (SHAP +1.97). The high positive SHAP value for Sex=1.0 (likely indicating female, as females had higher survival rates) strongly increases survival probability, consistent with the dataset's top feature importance (0.44). Similarly, Pclass=1.0 (first class) contributes significantly to survival due to prioritized evacuation. Supporting features like **Fare** (SHAP +0.21, moderate cost aligning with first-class status) and **SibSp=0.0** (SHAP +0.10, no siblings/spouses competing for resources) provide additional positive contributions. Although **Age=48.0** (SHAP +0.05) slightly favors survival (possibly due to adulthood and evacuation priority), and **Parch=0.0** (SHAP -0.02, no children/parents) has a negligible negative effect, their impacts are dwarfed by the advantages of Sex and 

# LLM as Judge

In [3]:
results_jsonl_path = f"data/batch_outputs/{titanic_dataset.name}_reasoning_predictions.jsonl"
reasoning = parse_reasoning_llm_results(results_jsonl_path)
objective_judge_model = Model(
    name="claude-haiku-4-5",
    temperature=0.6,
    max_tokens=4096
)

In [4]:
judge = ObjectiveJudge(
    dataset=titanic_dataset,
    model=objective_judge_model,
    prompt_gen_fn=objective_judge_prompt_generator
)

judge.create_batch_prompts(reasoning=reasoning)

[Titanic] Dropped 179 rows due to NaNs (kept 712 rows).


In [5]:
print(judge.batches[0]["params"]["messages"][0]["content"])


ROLE:
You are an expert, objective judge for the tabular dataset 'titanic'.
Your role is to assess the faithfulness and quality of the model's reasoning against the provided data.
Each example has features and a target label called 'Survived'.

--- INPUT DATA ---
Here are the **Feature Values** for this specific example:
Survived: 1.0
Pclass: 1.0
Sex: 1.0
Age: 48.0
SibSp: 0.0
Parch: 0.0
Fare: 25.9292
Embarked: 0.0

Here are the **SHAP Values** (the source of truth for feature contribution):
Pclass: 1.9708366
Sex: 2.3702703
Age: 0.051405907
SibSp: 0.10275364
Parch: -0.019484863
Fare: 0.21362345
Embarked: 0.028390752

Here are the overall **Feature Importances** (Global context):
Pclass: 0.21643278002738953
Sex: 0.4413147568702698
Age: 0.07450001686811447
SibSp: 0.07695005089044571
Parch: 0.054212767630815506
Fare: 0.06814616173505783
Embarked: 0.0684434026479721

The model predicted: 1
The ground truth label is: 1.0

Here is the **Model's Reasoning** for its prediction:
The model corre

In [6]:
judge.submit_batch()

Submitted batch with id: msgbatch_01A6bvfAv4UFBPjSTjaRWWRq
Batch msgbatch_01A6bvfAv4UFBPjSTjaRWWRq is still processing...
Batch msgbatch_01A6bvfAv4UFBPjSTjaRWWRq has completed processing.
Batch result types: {'succeeded': 3, 'errored': 0, 'expired': 0}
Saved evaluations to data/batch_outputs/titanic_objective_judge_evaluations.jsonl


In [4]:
from scripts.postprocess import parse_objective_judge_results

results_jsonl_path = f"data/batch_outputs/{titanic_dataset.name}_objective_judge_evaluations.jsonl"

parse_objective_judge_results(results_jsonl_path)

{862: {'faithfulness': 4.75, 'consistency': 4.75, 'coherence': 4.75},
 147: {'faithfulness': 4.75, 'consistency': 4.75, 'coherence': 4.75},
 302: {'faithfulness': 4.75, 'consistency': 4.75, 'coherence': 4.75}}

# CoT-based Classification