# Pipeline For CoT Validation

This pipeline will hold all :) - unfinished

In [None]:
import os
import yaml
import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.reasoning.formatting import apply_instruct_format
from src.reasoning.inference import run_inference
from src.reasoning.certainty_calc import capture_confidence_tokens
from src.utils.memory_management import check_memory
from src.utils.data_handler import batch_generate, tokens_generate

## Load Config, Dataset & Model

In [None]:
with open("config/config.yaml", "r") as f:
    config = yaml.safe_load(f)

model_str = 'deepseek_r1_qwendistill_14'
ds_str = 'big_bench'
ds_config = config['datasets'][ds_str]
model_path = config['models'][model_str]['path']

# Decide on device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
check_memory(device=device)

# Load dataset and model
dataset = load_dataset(ds_config['source'], ds_config['subset']).shuffle(config['seed'])
df = pd.DataFrame(dataset['train'][:5])  # small slice for now ! CHANGE

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)

## Set Up Prompts

In [None]:
formatted_prompts = apply_instruct_format(
    inputs=df[ds_config['input_column']].tolist(),
    config=config,
    model=model_str,
    is_thinking=True,
    is_answer_format=True
)
df['formatted_prompt'] = formatted_prompts

# Batch + tokenize
batches = batch_generate(df, ds_config['input_column'], ds_config['target_column'], batch_size=2)
tokenized_data = tokens_generate(batches, tokenizer, device=device)

## Run Inference

In [None]:
results = run_inference(model, tokenized_data, tokenizer, time_tracking=True, max_new_tokens=200)

For each generated output, do chain-of-thought probability capture:

In [None]:
cot_analysis = []
for output_dict in results:
    # For now: single example at a time! -will iterate inside them as well
    gen_tokens = output_dict['response']  # This is the decoded string. 
    # eed the token IDs - re-encoding them:
    gen_token_ids = tokenizer.encode(gen_tokens, return_tensors='pt').to(device)

    # Build a single input_ids tensor from the original
    # 'tokenized_prompts' is already on device. It's a dict with input_ids, attention_mask, etc.
    input_ids = output_dict["tokenized_prompts"]["input_ids"]  # shape [batch_size, seq_len]
    
    # For simplicity, pass them as shape [1, seq_len] each (will have multiple samples in batch -> loop them)
    single_input_ids = input_ids[0].unsqueeze(0)
    single_gen_tokens = gen_token_ids[0].unsqueeze(0)

    # knowing correct answer is "Sam" / can read it from output_dict["target"] if it's a single token label
    correct_answer_str = "Sam"

    chain_results = capture_confidence_tokens(
        model=model,
        tokenizer=tokenizer,
        input_ids=single_input_ids,
        generated_tokens=single_gen_tokens,
        correct_answer_str=correct_answer_str,
        threshold=0.90,
        topk=1000,
        device=device
    )
    cot_analysis.append(chain_results)

## Examine chain_results

In [None]:
for i, analysis in enumerate(cot_analysis):
    print(f"\nExample {i} final answer: {analysis['final_answer']}")
    print(f"Step at which prob >= 90%: {analysis['step_90_confidence']}")
    # analysis['step_details'] is a list with token-by-token info