# XAI-Guided CoT Pipeline: Tutorial


- XAI-Guided-CoT is automated pipeline for Chain-of-Thought (CoT) Prompting. In CoT, there two key types, first one in which no examples are given and second one in which few examples. We focus on the second type. In few-shot CoT, every example includes some data or information, the expected output and a reasoning. This reasoning explains how to get the expected output given the information.

- In context of a tabular binary classification task, the data refers to a row of data and expected output is the class label. For human, it is very time-consuming to write this reasoning and at times, the patterns are so complex that a human cannot articulate the reasoning in words. Therefore, we design a pipeline where this reasoning can be generated automatically by an LLM from explainability attributes such as feature importances and SHAP values.

- The figure below depicts the end-to-end generative-AI workflow.
<div style="text-align: center;">
    <img src="images/genai_workflow.png" alt="Alt text" width="500">
</div>

- In the workflow, a tree-based explainable model is trained and tuned. All the data splits and predictions are recorded for future evaluation. After training and tuning, the explainability attributes are computed: feature importances and SHAP. The SHAP CSV file is clustered to get diverse decision-making examples. The reasoning is generated for these diverse examples and they used as the reasoning component in Chain-of-Thought (CoT) prompting.

<b><u>Note:</u></b> The `ObjectiveJudge` component is optional. However, `ExplanableModel`, `ReasonGenerator` and `ICLClassifier` are mandatory components.

Let us understand how the pipeline and its individual components can be used!

## 1. Environmental Setup

### 1.1 Create a virtual environment

First, create a virtual environment and install the required modules.

<b><u>Note:</u></b> You should run the two commands in a terminal to create and activate a virtual environment and then choose it as the kernel for jupyter notebook.

In [30]:
!python3 --version

Python 3.13.7


In [1]:
# !python3 -m venv venv
# !source venv/bin/activate
!pip3 install -r requirements.txt

Collecting git+https://github.com/toon-format/toon-python.git (from -r requirements.txt (line 16))
  Cloning https://github.com/toon-format/toon-python.git to /private/var/folders/29/jhvlmlq53ln75t56n34jqkrc0000gn/T/pip-req-build-a6jv368b
  Running command git clone --filter=blob:none --quiet https://github.com/toon-format/toon-python.git /private/var/folders/29/jhvlmlq53ln75t56n34jqkrc0000gn/T/pip-req-build-a6jv368b
  Resolved https://github.com/toon-format/toon-python.git to commit 6b26984a01279defdcf40ab9d09bc418fe8133ac
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
!pip3 install git+https://github.com/toon-format/toon-python.git

Collecting git+https://github.com/toon-format/toon-python.git
  Cloning https://github.com/toon-format/toon-python.git to /private/var/folders/29/jhvlmlq53ln75t56n34jqkrc0000gn/T/pip-req-build-yjz8xmk_
  Running command git clone --filter=blob:none --quiet https://github.com/toon-format/toon-python.git /private/var/folders/29/jhvlmlq53ln75t56n34jqkrc0000gn/T/pip-req-build-yjz8xmk_
  Resolved https://github.com/toon-format/toon-python.git to commit 6b26984a01279defdcf40ab9d09bc418fe8133ac
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25h
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


### 1.2 Setup the directory structure for data folder

In [3]:
!chmod +x setup.sh
!./setup.sh

Setup root: /Users/gauravpendharkar/xai-guided-cot
Created directories:
  - /Users/gauravpendharkar/xai-guided-cot/data
  - /Users/gauravpendharkar/xai-guided-cot/data/datasets
  - /Users/gauravpendharkar/xai-guided-cot/data/dataset_config
  - /Users/gauravpendharkar/xai-guided-cot/data/shap_values
  - /Users/gauravpendharkar/xai-guided-cot/data/batch_outputs
  - /Users/gauravpendharkar/xai-guided-cot/data/tune_config
  - /Users/gauravpendharkar/xai-guided-cot/data/batches


This creates the directories expected by the pipeline.

### 1.3 Getting access to the LLMs

- Now, we use three model providers in the pipeline as follows:

1. Reason Generation: Together AI (https://www.together.ai/)
2. Objective Judge: Anthropic API (https://platform.claude.com/)
3. Classification: Google Vertex AI API (https://cloud.google.com/)

- You shall go to the websites, create an account, and get an API key from Together AI and Anthropic API. 
- However, in case of Google Cloud, we use the `gcloud` CLI for authentication and not the API key approach. (Install Google Cloud SDK)

In [None]:
!gcloud auth application-default login

- Your default browser will be opened where you would be asked to login into your Google Account. After doing that you will be redirected to https://docs.cloud.google.com/sdk/auth_success (automatically).

- For the other two providers, copy your API keys and paste it inside the `.env` (created for you by the `setup.sh`) under environment variables with names `TOGETHER_API_KEY` and `CLAUDE_API_KEY`.

### 1.4 Login into wandb

Weights and Biases is used for the hyperparameter tuning of the tree-based explainable model.

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmitugaurav15[0m ([33mgauravpendharkar[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

You can also use the API key approach. For getting the API key, you can go to https://wandb.ai/site, login with your preferred approach, and create an account. After creating your account, you will be redirected to landing page to "track your first run". Inside the yellow box, you can copy your API key.

After getting your API key, paste it into the `.env` under the `WANDB_API_KEY` variable. This repository also needs `WANDB_PROJECT_NAME` which is included in the `.env`. Note the `WANDB_PROJECT_NAME` is not a secret but defines the environment for wandb and hence is placed inside the `.env`.

### 1.5 Setting up GCP

(Ensure you are authenticated with GCP)

For the code in this repository to work as expected, you need to setup GCP as follows:

1. **Install Google Cloud SDK (gcloud CLI)**
   - If you haven't already, install the Google Cloud SDK from https://cloud.google.com/sdk/docs/install
   - Verify installation by running: `gcloud --version`

2. **Create a GCP Project**
   - Go to https://console.cloud.google.com/
   - Click on the project dropdown at the top and select "New Project"
   - Enter a project name and note the Project ID (you'll need this later)
   - Click "Create"

3. **Enable Billing**
   - Ensure billing is enabled on your account (required for Vertex AI and Cloud Storage usage)
   - Go to Billing in the GCP Console and link a billing account to your project

4. **Enable Required APIs**
   - Enable the following APIs in your project:
     - **Vertex AI API**: Required for batch inference jobs with Gemini models
     - **Cloud Storage API**: Required for storing batch input files
   - You can enable them via:
     - GCP Console: Navigate to "APIs & Services" > "Library" and search for each API
     - Or via command line:
     ```bash
       gcloud services enable aiplatform.googleapis.com

       gcloud services enable storage-component.googleapis.com
    ```

5. **Create a GCS Bucket**
   - Go to Cloud Storage in the GCP Console (https://console.cloud.google.com/storage)
   - Click "Create Bucket"
   - Choose a globally unique bucket name
   - Select a location/region (e.g., `us-east4`, `us-central1`) - note this location as you'll need it
   - Choose storage class and access control settings (defaults are usually fine)
   - Click "Create"

6. **Set Environment Variables**
   - After setup, note the following values:
     - **Project ID**: Found in the GCP Console project dropdown
     - **Bucket Name**: The name you gave your GCS bucket
     - **Location**: The region/location where your bucket was created (e.g., `us-east4`)
   - Add these to your `.env` file as `PROJECT_ID`, `BUCKET_NAME`, and `LOCATION`.

<b><u>Note:</u></b> You must use the same `LOCATION` for the GCP bucket as the location which you use to submit your batch inference. This is a constraint for this code to work properly but ideally you can use different locations.
     

Now, you have all the necessary setup to run the code in this repository! 

### 1.6 Optional Unit Tests

Optionally, to test if everything is working fine you run all the unit tests from `tests/` with the following command inside your terminal:

```bash
pytest -v

```
<b><u>Note:</u></b>
- `pytest` must be executed inside the terminal
- All tests are expected to pass successfully if all the setup is correct.

### 1.7 Ignore warnings

In [None]:
# suppress warnings
import warnings 
warnings.filterwarnings('ignore')

## 2. Different Supported Configurations

This pipeline uses three types of configuration classes to define how the pipeline operates: `Dataset`, `Model`, and `COT`.

### 2.1 Dataset Configuration (`Dataset`)

The `Dataset` configuration defines everything about your input data:

- <b>name</b>: the name of dataset
- <b>path</b>: the path of the original dataset CSV file. It is recommended to store all datasets inside `data/datasets` but any valid path works fine.
- <b>config_file_path</b>: the path to the dataset configuration file. All the data splits, tree-based model predictions, and feature importances get logged in this file. It is fine if it does not exist but the directory in which it is located must exist.
- <b>shap_vals_path</b>: the path to the CSV file to store the SHAP values for the training dataset. It is fine if it does not exist but the directory in which it is located must exist.
- <b>preprocess_fn</b>: the preprocessing function for the dataset. It is required to ensure the dataset gets compatible with a tree-based model for training. If your dataset does not need any preprocessing, then pass an empty function that returns input dataframe itself.
- <b>target_col</b>: the target column of the dataset.
- <b>labels</b>: the mapping for the labels of the datasets. The length must be 2.

<b>Note</b>: 

Pydantic is used to validate the dataset file extension, its existence, presence of a preprocessing function, and the length of the labels being two (binary classification)

In [None]:
# module used for dataset 
# configuration
from scripts.configs import Dataset
from scripts.preprocess import preprocess_titanic

dataset = Dataset(
    name="titanic", # name of the dataset
    path="data/datasets/titanic.csv", # path to the csv file 
    config_file_path="data/dataset_config/titanic_config.json", # path to the dataset configuration file
    shap_vals_path="data/shap_values/titanic_shap.csv", # path to the SHAP values file
    preprocess_fn=preprocess_titanic, # preprocessing function
    target_col="Survived", # target column
    labels={0: "Did not survive", 1: "Survived"} # labels for the target variable
)

### 2.2 Model Configuration (`Model`)

The `Model` configuration specifies which language model to use and how it should generate responses:

- <b>provider</b>: the model provider. It must be `google` or `anthropic` or `together`)
- <b>name</b>: the model name. This must be same as the string identifier inside the respective API.
- <b>temperature</b>: the temperature of the model. This controls the determinism of the output. A lower temperature (closer to 0) is preferred for the use-cases in this code repository.
- <b>max_tokens</b>: the maximum number of tokens. This decides how many tokens the chosen model can generate. For Gemini models, it is worth looking at the difference between thinking budget and maximum tokens.

<b>Note</b>: 

The system validates that your provider and model combination is supported (refer: `scripts/constants`), and that `temperature >= 0` and `max_tokens > 0`.

In [None]:
# module used for 
# model configuration
from scripts.configs import Model

model = Model(
    provider="anthropic", # name of the model provider
    name="claude-haiku-4-5", # name of the model 
    temperature=0.7, # temperature for the model
    max_tokens=1000 # maximum number of tokens the LLM can generate
)

### 2.3 Chain-of-Thought Configuration (`COT`)

The `COT` configuration controls the reasoning process used by the language model:

- **Examples**: `num_examples_per_agent` - How many example cases each reasoning agent sees
- **Reasoning Steps**: A dictionary mapping step numbers to reasoning templates or descriptions
- **Thinking Budget**: Maximum number of reasoning steps allowed (prevents infinite loops)

**Example**: You might set `num_examples_per_agent=5` to show 5 examples, define reasoning steps like `{1: "Analyze features", 2: "Consider patterns"}`, and set `thinking_budget=10` to limit reasoning to 10 steps.

**Note**: The thinking budget must be ≥ 0 to prevent negative values

In [None]:
# module used for 
# chain-of-thought configuration
from scripts.configs import COT

cot = COT(
    num_examples_per_agent=5, # number of examples per agent
    reasoning={}, # no reasoning (zero-shot cot)
    thinking_budget=100 # thinking budget   
)

## 3. System Configuration for the tutorial

For this pipeline, the system is defined by seven configurations as follows:

1. <b>a dataset</b>. Let us use the `titanic_small.csv` already available in the repository under `data/datasets`.
2. <b>a tree-based explainable model</b>. Let us choose `XGBClassifier`.
3. <b>a reasoning generation model</b>. Let us use the model `deepseek-ai/DeepSeek-R1` from the provider `together` with a temperature of 0.6 and maximum token limit for generation as 4096.
4. <b>an objective judge model</b>. Let us use `claude-haiku-4-5` by `anthropic` with a temperature of 0.6 and maximum token limit for generation as 4096.
5. <b>a chain-of-thought model</b>. Let us use `gemini-2.5-flash` by `google` with temperature of 0.0 and maximum token limit for generation as 4096.
6. <b>a chain-of-thought configuration</b>. Fixed inside `scripts/pipeline.py`
7. <b>a hyperparameter tuning parameter grid</b>. Let us use the sample grid `xgb.json` provided inside `data/tune_config`.

<b><u>Note:</u></b> Different model providers have different syntax for getting model inferences and hence there are few limitations as follows:

- The explainable must be an explainable tree-based model from the `SUPPORTED_EXPLAINABLE_MODELS` described inside `scripts/constants.py`.
- The reasoning model can only be `deepseek-ai/DeepSeek-R1` from Together AI API.
- The objective judge can be either `claude-sonnet-4-5` or `claude-haiku-4-5` from Anthropic API.
- The chain-of-thought model either `gemini-2.5-flash` or `gemini-2.5-pro` from Google Vertex AI API.

In [8]:
from xgboost import XGBClassifier

# explainable model
explanable_model = XGBClassifier()

# hyperparameter tuning 
# parameter grid
tune_config_file = "data/tune_config/xgb.json"

# dataset config
dataset = Dataset(
    name="titanic", 
    path="data/datasets/titanic_small.csv", 
    config_file_path="data/dataset_config/titanic_small_config.json", 
    shap_vals_path="data/shap_values/titanic_small_shap.csv", 
    preprocess_fn=preprocess_titanic, 
    target_col="Survived", 
    labels={0: "Did not survive", 1: "Survived"}
)

# reasoning generation 
# model config
reasoning_gen_model = Model(
    provider="together",
    name = "deepseek-ai/DeepSeek-R1",
    temperature=0.6,
    max_tokens=4096
)

# objective judge 
# model config
objective_judge_model = Model(
    provider="anthropic",
    name="claude-haiku-4-5",
    temperature=0.6,
    max_tokens=4096
)

# chain-of-thought 
# model config
cot_model = Model(
    provider="google",
    name="gemini-2.5-flash",
    temperature=0.0,
    max_tokens=4096
)

## 4. What is a Batch Inference? Why use it over a standard inference approach?

Batch Inference is when multiple requests are packaged into a single request.

In the context of this pipeline, the aim is to evaluate the performance of how XAI attributes can influence the Chain of Thought prompting. The focus is the classification performance and not the inference time. <b>Therefore, there is no immediate need for outputs.</b> 

Moreover, there are few other reasons as follows:

- Every test sample needs to be processed in the isolation of the others since the reasoning for the previous test sample can serve as context for the current one. (if all or more than one test samples are put in one single prompt)

- Every test sample gets one request for itself which implies a larger context window which is an asset for Chain of Thought prompting.

- The cost effectiveness of batch inferences due to enhanced throughput the model providers can achieve by parallelizing the inference across different accelerators.

## 5. The Pipeline

In [9]:
# reduce verbosity
# of logs
import os
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_QUIET"] = "true"
os.environ["WANDB_CONSOLE"] = "off" 

To use the pipeline, simply create an object of the `Pipeline` class and pass the system configurations into it.

### Initialize object

In [None]:
# module used for pipeline
from scripts.pipeline import Pipeline

pipeline = Pipeline(
    dataset=dataset, # dataset config
    explanable_model=explanable_model, # explainable model config
    tune_config_file=tune_config_file, # hyperparameter tuning config file path
    reasoning_gen_model=reasoning_gen_model, # reasoning generation model config
    objective_judge_model=objective_judge_model, # objective judge model config
    cot_model=cot_model # chain-of-thought model config
)

The `Pipeline.run()` has four boolean arguments that can be used to get additional results:

- <b>baseline</b>: if set to true, computes the zero-shot prompting baseline.
- <b>objective_judge</b>: if set to true, evaluates quality of the output generated by `ReasonGenerator`.
- <b>cot_ablation</b>: if set to true, computes the zero-shot CoT performance.
- <b>masked</b>: if set to true, bypasses the training and tuning of the explainable tree-based model because its performance is independent of the semantics of the dataset metadata.

Under default setting, the pipeline performs five steps:

1. Trains and tunes the tree-based explainable model (`ExplanableModel`)
2. Extracts explanability attributes: feature importances and SHAP values.
3. Generates natural language reasoning from numerical explainability attributes. (`ReasonGenerator`)
4. This reasoning is passed onto the CoT model as a part of the examples to serve as references for the binary classification task.
5. Computes standard evaluations (`accuracy` and `macro_f1_score`) for the tree-based model and XAI-Guided-CoT.

### Without LLM-as-a-Judge

In [11]:
pipeline.run()

[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
Create sweep with ID: i6nxqdvb
Sweep URL: https://wandb.ai/gauravpendharkar/xai-guided-cot/sweeps/i6nxqdvb


[34m[1mwandb[0m: Agent Starting Run: a71pcg40 with config:
[34m[1mwandb[0m: 	learning_rate: 0.015021334658267554
[34m[1mwandb[0m: 	max_depth: 6
[34m[1mwandb[0m: 	min_child_weight: 5
[34m[1mwandb[0m: 	n_estimators: 284
[34m[1mwandb[0m: 	random_state: 42
[34m[1mwandb[0m: 	reg_lambda: 0.01986677845807551
[34m[1mwandb[0m: 	subsample: 0.982515323119008


[XAI-MODEL] Completed hyperparameter tuning.
[XAI-MODEL] Trained model with best hyperparameters.
[XAI-MODEL] Logged explanation data to data/dataset_config/titanic_small_config.json
[XAI-MODEL] Explanation process completed.
[DIVERSE EXAMPLES] Found best number of clusters: k=10 with silhouette score: 0.47964845807183637
[DIVERSE EXAMPLES] Chosen 10 diverse examples.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).


Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 26.4k/26.4k [00:00<00:00, 47.2kB/s]


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 64.2k/64.2k [00:00<00:00, 19.5MB/s]


[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/titanic_reasoning_predictions.jsonl
[PIPELINE] Reasoning generation completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[GCS CLIENT] File data/batches/titanic_icl_batches.jsonl uploaded to batch_inputs/gemini/titanic_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/54181826632/locations/us-east4/batchPredictionJobs/2588022497999847424
[ICL CLASSIFIER] Output base dir: gs://xai_guided_cot_bucket/batch_outputs/gemini/titanic_cot_1765888259/
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2588022497999847424 state: JobState.JOB_STATE_QUEUED
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2588022497999847424 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2588022497999847424 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJo

### With LLM-as-a-Judge 

In [12]:
pipeline.run(objective_judge=True)

[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
Create sweep with ID: mwsbp9ts
Sweep URL: https://wandb.ai/gauravpendharkar/xai-guided-cot/sweeps/mwsbp9ts
[XAI-MODEL] Completed hyperparameter tuning.
[XAI-MODEL] Trained model with best hyperparameters.
[XAI-MODEL] Logged explanation data to data/dataset_config/titanic_small_config.json
[XAI-MODEL] Explanation process completed.
[DIVERSE EXAMPLES] Found best number of clusters: k=10 with silhouette score: 0.32333119714604575
[DIVERSE EXAMPLES] Chosen 10 diverse examples.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).


Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 26.6k/26.6k [00:00<00:00, 49.1kB/s]


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 74.3k/74.3k [00:00<00:00, 16.9MB/s]


[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/titanic_reasoning_predictions.jsonl
[PIPELINE] Reasoning generation completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[OBJECTIVE JUDGE] Submitted batch with id: msgbatch_01DuSTJojjmihTjefqjF3UNi
[OBJECTIVE JUDGE] Batch msgbatch_01DuSTJojjmihTjefqjF3UNi is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01DuSTJojjmihTjefqjF3UNi has completed processing.
[OBJECTIVE JUDGE] Batch result types: {'succeeded': 10, 'errored': 0, 'expired': 0}
[OBJECTIVE JUDGE] Saved evaluations to data/batch_outputs/titanic_objective_judge_evaluations.jsonl
[PIPELINE] Objective judge evaluation completed.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[GCS CLIENT] File data/batches/titanic_icl_batches.jsonl uploaded to batch_inputs/gemini/titanic_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/54181826632/locations/us-east4/batchPredictionJobs/3946983685558894592
[ICL CLASSIFIER] Output base dir: gs://xai

- In few cases, batch inferences can fail, expire, or cancel. In such cases, it must be important to know how to use the individual components. With reference to this pipeline, the `ReasonGenerator` and `ICLClassifier` are mandatory components. In either case, errors will be raised.

- Considering the fact that pydantic is used to validate the range of hyperparameters and the model names, it is unlikely that the batch will fail because of invalid input. It might expire or cancel due to issues from the model provider but that is very rare.

## 6. Individual Components

While using the individual components, there are three steps:

1. <b>Initialization of the component</b>

This step is just creating an object of the respective class with the correct system configuration described in previous sections.

2. <b>Running the inference</b>

This step includes many sub-steps as follows:

- First the all the batches for test examples are created and structured as a JSONL file. JSONL file is a file in which every line is a JSON document.

- This JSONL is uploaded to the model provider's cloud storage (e.g. Google Cloud Bucket for Vertex AI Batch Inference)

- The batch inference job is submitted to the model provider. This might take as much as 24 hours to get into execution. However, based off batch inferences done during debugging and testing, the batch inference starts running within couple of seconds.

- We synchronously monitor it and automatically download the JSONL result file to local storage. This file contains the output for each request inside the input JSONL file whether the request succeeds or fails. However, the order of the output does not match with that of input file and hence, a meaningful `request_id` is pivotal to mapping the results for futher evaluation.

3. <b>Parsing the result</b>

This step just parses the output text generated by the LLM. However, the catch is the error handling because the the output of the LLM is non-deterministic and it may not follow the given output format always.

In [14]:
# print the results
# properly
from pprint import pprint

### 6.1 ReasonGenerator

This component converts numerical explainability attributes into natural language reasoning.

#### 6.1.1 Initialization

In [None]:
# module used for reasoning generation
from scripts.reason_generation import ReasonGenerator
from scripts.prompt_generator import reasoning_prompt_generator

reason_generator = ReasonGenerator(
    dataset=dataset, # dataset config
    model=reasoning_gen_model, # reasoning generation model config
    prompt_gen_fn=reasoning_prompt_generator # reasoning prompt generator function
)

#### 6.1.2 Running the inference

In [15]:
reason_generator.create_batch_prompts() # create batch prompts
reason_generator.save_batches_as_jsonl() # save batches as jsonl (locally)
reason_generator.submit_batches() # submit batches (to the model provider)

[DIVERSE EXAMPLES] Found best number of clusters: k=10 with silhouette score: 0.32333119714604575
[DIVERSE EXAMPLES] Chosen 10 diverse examples.
[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).


Uploading file titanic_reasoning_batches.jsonl: 100%|██████████| 26.6k/26.6k [00:00<00:00, 50.1kB/s]


[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.IN_PROGRESS
[REASON GENERATION] Current Status: BatchJobStatus.COMPLETED
[REASON GENERATION] Batch completed successfully.


Downloading file titanic_reasoning_predictions.jsonl: 100%|██████████| 67.9k/67.9k [00:00<00:00, 7.02MB/s]

[REASON GENERATION] Batch outputs downloaded to data/batch_outputs/titanic_reasoning_predictions.jsonl





As explained above, the SHAP values CSV file is clustered first and there seem to be only 10 unique decision-making patterns. The dataset is preprocessed to ensure that the row indices for the diverse examples point towards the correct data points. After that, the `titanic_reasoning_batches.jsonl` file uploaded to the cloud storage on Together AI. The batch inference job is submitted and instantly starts processing. Finally, the batch completes successfully and the `titanic_reasoning_predictions.jsonl` file is downloaded from Together AI's cloud storage and saved locally to `data/batch_outputs/titanic_reasoning_predictions.jsonl`

#### 6.1.3 Parsing the result

In [None]:
# module used for parsing the results
from scripts.postprocess import parse_reasoning_llm_results

reasoning = parse_reasoning_llm_results(
    results_jsonl_path=reason_generator.destination_file_name
)

In [27]:
pprint(dict(list(reasoning.items())[0:1]))

{104: 'The model correctly predicted that the passenger did not survive (0), '
      'matching the ground truth label (0.0). This outcome is primarily driven '
      'by the strong negative contributions of the two most important '
      "features: Sex (SHAP = -1.15) and Pclass (SHAP = -0.99). The passenger's "
      'male gender (Sex=0.0) and third-class ticket (Pclass=3.0) substantially '
      'decrease survival probability, as these are the highest-impact features '
      'in the model (overall importances of 0.44 and 0.19 respectively). While '
      "the passenger's age (30.5 years) provides a modest positive "
      'contribution (SHAP=0.14), it is outweighed by the dominant negative '
      'factors. Other features like Fare (8.05, SHAP=-0.02) and SibSp/Parch '
      '(both 0.0, SHAP≈0) have negligible impacts. The cumulative effect of '
      'these SHAP values—especially the large negative contributions from Sex '
      'and Pclass—clearly pushes the prediction below the deci

In the reasoning output, the dictionary is organized in such a way that the key is the index of the row in the original dataset after preprocessing and the value is the natural language reasoning generated by `ReasonGenerator`.

### 6.2 ObjectiveJudge

This component evaluates the natural language reasoning generated by the `ReasonGenerator`. We use the LLM-as-a-Judge approach for evaluation since there are no manually annotated reasoning outputs to compare the generations and compute metrics like ROUGE.

#### 6.2.1 Initialization

In [None]:
# module used for objective judge
from scripts.objective_judge import ObjectiveJudge
from scripts.prompt_generator import objective_judge_prompt_generator

judge = ObjectiveJudge(
    dataset=dataset, # dataset config
    model=objective_judge_model, # objective judge model config
    prompt_gen_fn=objective_judge_prompt_generator # objective judge prompt generator function
)

#### 6.2.2 Running the inference

In [19]:
judge.create_batch_prompts(reasoning=reasoning) # reasoning for the examples (generated by ReasonGenerator)
judge.submit_batch() # submit the batch inference job

[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[OBJECTIVE JUDGE] Submitted batch with id: msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo is still processing...
[OBJECTIVE JUDGE] Batch msgbatch_01Jf6sMCTJVo3KwKiaSWQ7Jo has completed processing.
[OBJECTIVE JUDGE] Batch result types: {'succeeded': 10, 'errored': 0, 'expired': 0}
[OBJECTIVE JUDGE] Saved evaluations to data/batch_outputs/titanic_objective_judge_evaluations.jsonl


The dataset is preprocessed again so that reasoning output from `ReasonGenerator` can be mapped to the correct rows in the dataset. The batch inference job gets submitted to Anthropic API and it instantly starts processing. All 10 requests are successful and evaluations downloaded from Anthropic's cloud storage to `data/batch_outputs/titanic_objective_judge_evaluations.jsonl` on the local system.

#### 6.2.3 Parsing the result

In [None]:
# module used for parsing the results
from scripts.postprocess import parse_objective_judge_results

reasoning_eval = parse_objective_judge_results(
    results_jsonl_path=judge.destination_file_name # path to the jsonl file for objective judge evaluations
)

In [21]:
pprint(reasoning_eval)

{3: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 31: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 41: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 77: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 91: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 104: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 109: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 126: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 172: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75},
 176: {'coherence': 4.75, 'consistency': 4.75, 'faithfulness': 4.75}}


Here, the output is a dictionary with the key as the row index of the dataset after preprocessing and the value is a dictionary with the metrics as key and the score as the value.

### 6.3 ICLClassifier

This component predicts a class label for a given test sample by using the given reasoning examples. The natural language reasoning used in this component is generated by the  `ReasonGenerator`.

#### 6.3.1 Initialization

In [None]:
# module used for ICL classification
from scripts.icl_classification import ICLClassifier
from scripts.prompt_generator import cot_prompt_generator

cot_config = COT(
    num_examples_per_agent=10, # number of examples every agent sees
    reasoning=reasoning, # reasoning for the examples (generated by ReasonGenerator)
    thinking_budget=1000 # maximum number of tokens for internal thinking
)

icl_classifier = ICLClassifier(
    dataset=dataset, # dataset config
    model=cot_model, # chain-of-thought model config
    cot=cot_config, # chain-of-thought configuration
    prompt_gen_fn=cot_prompt_generator # chain-of-thought prompt generator function
)

#### 6.3.2 Running the inference

In [23]:
icl_classifier.create_batch_prompts() # create batch prompts
icl_classifier.save_batches_as_jsonl() # save batches as jsonl (locally)
icl_classifier.upload_batches_to_gcs() # upload batches to gcp bucket
icl_classifier.submit_batch_inference_job() # submit batch inference job (vertex ai)
icl_classifier.download_job_outputs_from_gcs() # download job outputs from gcp bucket

[Titanic] Dropped 39 rows due to NaNs (kept 161 rows).
[GCS CLIENT] File data/batches/titanic_icl_batches.jsonl uploaded to batch_inputs/gemini/titanic_icl_batches.jsonl
[ICL CLASSIFIER] Submitted Job: projects/54181826632/locations/us-east4/batchPredictionJobs/2433774210762407936
[ICL CLASSIFIER] Output base dir: gs://xai_guided_cot_bucket/batch_outputs/gemini/titanic_cot_1765890383/
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2433774210762407936 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2433774210762407936 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2433774210762407936 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/2433774210762407936 state: JobState.JOB_STATE_RUNNING
[ICL CLASSIFIER] projects/54181826632/locations/us-east4/batchPredictionJobs/24337742107

The dataset is preprocessed to map the natural language reasoning with correct rows in the dataset. The locally saved batch inputs JSONL file `data/batches/titanic_icl_batches.jsonl` is uploaded to the GCP bucket at `batch_inputs/gemini/titanic_icl_batches.jsonl`. The job is submitted to Vertex AI and it starts running instantly. After the job succeeds, the output is downloaded automatically from the GCP bucket to the local system at `data/batch_outputs/titanic_cot_predictions.jsonl`.

#### 6.3.3 Parsing the result

In [24]:
# module used for parsing the results
from scripts.postprocess import parse_cot_llm_results

predictions = parse_cot_llm_results(
    results_jsonl_path=icl_classifier.destination_file_name
)

[POSTPROCESS] 0 requests were interrupted due to token limit and are ignored for evaluation.


In [28]:
pprint(dict(list(predictions.items())[-5:]))

{15: 1, 62: 1, 117: 0, 118: 1, 120: 0}


The output is a dictionary with the key as the row index of the dataset after preprocessing and the value is the predicted class label.

This tutorial is about how to use the pipeline components shown inside the GenAI workflow. The `ZeroShotBaseline` is only used to establish a baseline performance to compare it with `XAI-Guided-CoT`. If you want to use this component, then the experimental results are right place. From functionality standpoint, the `ZeroShotBaseline` uses the same LLM configuration as that of `ICLClassifier`. The only difference is in the prompt and `thinkingBudget` hyperparameter. The two files could be combined into a single class but are kept as separate for better interpretability.

## Troubleshooting Guide

This guide addresses common issues you might encounter while using the code in this repository. Issues are organized by category for easier navigation.

### 1. Installation and Dependencies

**Problem: Dependency conflicts or installation errors**

- **Solution**: Always use a virtual environment to isolate your project dependencies:

```bash
  python3 -m venv venv
  source venv/bin/activate  # On Windows: venv\Scripts\activate
  pip install -r requirements.txt
```

  - **Verification**: Run `pytest -v` from the `tests/` directory to verify all dependencies are correctly installed.

**Problem: Missing `toon-python` package**

- **Solution**: Install it separately:h
  pip install git+https://github.com/toon-format/toon-python.git

---

### 2. Environment Variables and Authentication

**Problem: Missing or incorrect environment variables**

- **Symptoms**: `KeyError` when accessing environment variables, `None` values, or authentication failures
- **Solution**: 
  - Verify your `.env` file contains all required variables:
    - `PROJECT_ID`, `BUCKET_NAME`, `LOCATION` (GCP)
    - `TOGETHER_API_KEY`, `CLAUDE_API_KEY` (LLM APIs)
    - `WANDB_API_KEY`, `WANDB_PROJECT_NAME` (Weights & Biases)
  - Ensure no extra spaces or quotes around values
  - Restart your Python kernel/terminal after modifying `.env`

**Problem: GCP authentication failures**

- **Symptoms**: `google.auth.exceptions.DefaultCredentialsError` or upload failures
- **Solution**:
  - Re-run: `gcloud auth application-default login`
  - Verify billing is enabled on your GCP project
  - Ensure Vertex AI API and Cloud Storage API are enabled
  - Check that `PROJECT_ID` matches your GCP project

**Problem: API key authentication failures**

- **Symptoms**: 401/403 errors from Together AI or Anthropic
- **Solution**:
  - Verify API keys are correct and not expired
  - Check API quotas/limits haven't been exceeded
  - Ensure keys are in `.env` without quotes

---

### 3. Configuration Errors

**Problem: Invalid model/provider combination**

- **Symptoms**: `ValueError: Invalid model name: X for provider: Y`
- **Solution**: 
  - Check `scripts/constants.py` for supported combinations:
    - Reasoning model: `deepseek-ai/DeepSeek-R1` (Together AI only)
    - Objective judge: `claude-sonnet-4-5` or `claude-haiku-4-5` (Anthropic only)
    - CoT model: `gemini-2.5-flash` or `gemini-2.5-pro` (Google only)
  - To use other models, update `VALID_PROVIDERS` and `VALID_MODELS` in `scripts/constants.py`

**Problem: Dataset configuration errors**

- **Symptoms**: `FileNotFoundError`, `TypeError`, or `ValueError` during Dataset initialization
- **Solution**:
  - Verify CSV file exists and path is correct (relative or absolute)
  - Ensure `preprocess_fn` is a callable function (not a string)
  - Verify `labels` dictionary has exactly 2 entries: `{0: "Class 0", 1: "Class 1"}`
  - Check that `target_col` exists in your dataset
  - Refer to `scripts/configs.py` for validation rules

**Problem: Invalid hyperparameters**

- **Symptoms**: `ValueError: Temperature must be >= 0` or `max_tokens must be > 0`
- **Solution**:
  - Set `temperature >= 0` (typically 0.0-1.0)
  - Set `max_tokens > 0` (check model-specific limits)
  - For Gemini models: remember `maxOutputTokens = thinkingBudget + externalTokens`

---

### 4. GCP and Cloud Storage Issues

**Problem: Bucket upload failures**

- **Symptoms**: `AssertionError: ICL Classifier is a mandatory component. Kindly debug the issue`
- **Solution**:
  - Verify bucket name exists and is correct
  - Check GCP permissions (Storage Admin role)
  - Ensure bucket `LOCATION` matches environment variable
  - Verify `gcloud` authentication: `gcloud auth application-default login`

**Problem: Vertex AI batch job submission failures**

- **Symptoms**: `AssertionError: Error submitting batch job to Vertex AI`
- **Solution**:
  - Ensure Vertex AI API is enabled: `gcloud services enable aiplatform.googleapis.com`
  - Verify `PROJECT_ID` and `LOCATION` are correct
  - Ensure `LOCATION` matches your bucket region
  - Check billing is enabled

**Problem: Missing predictions.jsonl file**

- **Symptoms**: `ValueError: No predictions.jsonl file found in GCS location`
- **Solution**:
  - Wait for batch job to complete (check job status in GCP Console)
  - Verify batch job succeeded (not failed/cancelled)
  - Check GCS bucket path matches expected output directory

---

### 5. Data and Preprocessing Issues

**Problem: SHAP values CSV missing 'idx' column**

- **Symptoms**: `KeyError: SHAP values CSV must contain an 'idx' column`
- **Solution**:
  - Re-run the `ExplainableModel` step to regenerate SHAP values
  - Ensure preprocessing function maintains consistent row indices
  - Verify SHAP CSV structure matches expected format

**Problem: Empty reasoning outputs**

- **Symptoms**: `AssertionError: Reasoning outputs are required`
- **Solution**:
  - Check raw LLM outputs in `data/batch_outputs/titanic_reasoning_predictions.jsonl`
  - Verify LLM followed expected output format
  - May need to adjust prompt or model temperature
  - Re-run pipeline or re-submit batch using `data/batches/titanic_reasoning_batches.jsonl`

**Problem: Preprocessing errors**

- **Symptoms**: Errors during dataset loading or preprocessing
- **Solution**:
  - Ensure preprocessing function handles missing values correctly
  - Verify all required columns exist in dataset
  - Check data types are compatible with tree-based models
  - Ensure target column has binary values (0 and 1)

---

### 6. Batch Inference Issues

**Problem: Batch jobs failing, expiring, or cancelling**

- **Symptoms**: Job status shows `FAILED`, `EXPIRED`, or `CANCELLED`
- **Solution**:
  - Check GCP Console (for Vertex AI) or provider dashboard for detailed error messages
  - Verify input JSONL format is valid
  - Check for token limit issues (increase `max_tokens`)
  - Ensure API quotas/limits aren't exceeded
  - Retry the batch submission

**Problem: Token limit interruptions**

- **Symptoms**: `[POSTPROCESS] X requests were interrupted due to token limit`
- **Solution**:
  - Increase `max_tokens` in model configuration
  - Reduce `num_examples_per_agent` in COT config
  - Simplify prompts or reduce reasoning complexity
  - For Gemini: adjust `thinkingBudget` relative to `max_tokens`

**Problem: LLM output format mismatches**

- **Symptoms**: Parsing errors, missing predictions, or empty reasoning outputs
- **Solution**:
  - Check raw outputs in JSONL files (`data/batch_outputs/`)
  - Verify LLM follows expected format (check prompt templates)
  - Adjust prompts to be more explicit about output format
  - Consider using lower temperature for more deterministic outputs
  - **Note**: DeepSeek-R1 has been observed to violate output formats more frequently than Gemini/Anthropic models

---

### 7. Wandb and Hyperparameter Tuning Issues

**Problem: Wandb sweep creation failures**

- **Symptoms**: Errors when creating wandb sweep
- **Solution**:
  - Verify `WANDB_API_KEY` and `WANDB_PROJECT_NAME` are set in `.env`
  - Run `wandb.login()` if not already authenticated
  - Check sweep configuration JSON file is valid
  - Ensure wandb project exists or can be created

**Problem: Hyperparameter tuning errors**

- **Symptoms**: Errors during model training in sweep
- **Solution**:
  - Verify hyperparameter ranges in config file are valid
  - Check dataset is properly preprocessed
  - Ensure sufficient training data
  - Review wandb logs for specific error messages

**Problem: Wandb config parsing issues (older versions)**

- **Symptoms**: `best_run.config` returns text instead of dict
- **Solution**: 
  - Update to latest `wandb` version: `pip install --upgrade wandb`
  - If using older version, parse manually: `json.loads(best_run.config)`

---

### 8. General Debugging Tips

**Check intermediate outputs:**
- Inspect JSONL files in `data/batches/` and `data/batch_outputs/`
- Verify batch files are properly formatted
- Check for malformed JSON or missing fields

**Verify component initialization:**
- Test individual components separately before running full pipeline
- Use `masked=True` in `pipeline.run()` to skip model training if debugging other components
- Check component attributes after initialization

**Monitor batch job status:**
- Check GCP Console for Vertex AI batch job status
- Monitor Together AI dashboard for batch job progress
- Check Anthropic API for batch processing status

**Common fixes:**
- Restart Python kernel after changing `.env` file
- Re-authenticate with `gcloud auth application-default login`
- Clear cached files and re-run preprocessing
- Verify all directory paths exist before running pipeline
- Run unit tests (`pytest -v`) after customizing code

---

### 9. Getting Additional Help

If you encounter issues not covered here:

1. **Check the code**: Review `scripts/configs.py` and `scripts/constants.py` for validation rules
2. **Run unit tests**: Execute `pytest -v` to verify basic functionality
3. **Inspect logs**: Check console output and error messages for specific details
4. **Review intermediate files**: Examine JSONL files in `data/batches/` and `data/batch_outputs/`
5. **GitHub Issues**: Feel free to raise an issue on GitHub with:
   - Error message and traceback
   - Your configuration (without sensitive keys)
   - Steps to reproduce
   - Relevant log files or intermediate outputs