In [1]:

# ============================================================================
# EXPERIMENT 1: Zero-Shot Chain-of-Thought Baseline
# ============================================================================


In [1]:
# --- CELL 1: Install Dependencies ---
!pip install transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0 torch>=2.0.0 tqdm matplotlib -q

In [2]:
# --- CELL 2: Import Libraries ---
import torch
import re
import os
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datetime import datetime

from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


In [3]:
# --- CELL 3: Hugging Face Login ---
print("Please log in to Hugging Face...")
login()

Please log in to Hugging Face...


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
# --- CELL 4: Configuration ---
class Config:
    # Model
    MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
    
    # Dataset
    USE_SUBSET = False
    TEST_SUBSET_SIZE = 50
    
    # Output
    OUTPUT_DIR = "./small_project/zero_shot_cot"
    
    # Generation
    GENERATION_MAX_NEW_TOKENS = 256
    TEMPERATURE = 0.7
    TOP_P = 0.9

config = Config()
os.makedirs(config.OUTPUT_DIR, exist_ok=True)

print(f"{'='*80}")
print(f"ZERO-SHOT CHAIN-OF-THOUGHT BASELINE")
print(f"{'='*80}")
print(f"Model: {config.MODEL_NAME}")
print(f"Test samples: {config.TEST_SUBSET_SIZE if config.USE_SUBSET else 'Full dataset'}")
print(f"Output directory: {config.OUTPUT_DIR}")
print(f"{'='*80}\n")

ZERO-SHOT CHAIN-OF-THOUGHT BASELINE
Model: meta-llama/Llama-3.2-3B-Instruct
Test samples: Full dataset
Output directory: ./small_project/zero_shot_cot



In [5]:
# --- CELL 5: Helper Functions ---
def extract_answer(text):
    """Extract numerical answer from text with multiple strategies"""
    if not text:
        return None
    
    # Strategy 1: Find #### format
    match = re.search(r'####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Strategy 2: Common patterns
    patterns = [
        r'answer is[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
        r'=\s*(-?\d+(?:,\d+)*(?:\.\d+)?)\s*$',
        r'total[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
        r'result[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
    ]
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).replace(',', '')
    
    # Strategy 3: Last number in text
    numbers = re.findall(r'-?\d+(?:,\d+)*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def create_inference_prompt(question):
    """Create zero-shot CoT prompt"""
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step and provide the final answer after ####.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def load_gsm8k_test():
    """Load GSM8K test set"""
    print("Loading GSM8K test set...")
    dataset = load_dataset("gsm8k", "main")
    
    test_data = []
    for item in tqdm(dataset["test"], desc="Processing test data"):
        answer = extract_answer(item["answer"])
        if answer:
            test_data.append({
                "question": item["question"],
                "answer": answer,
                "full_solution": item["answer"]
            })
    
    print(f"Loaded {len(test_data)} test examples")
    return test_data

In [6]:
# --- CELL 6: Load Data ---
full_test_data = load_gsm8k_test()

if config.USE_SUBSET:
    test_data = full_test_data[:config.TEST_SUBSET_SIZE]
    print(f"\n Using {len(test_data)} test examples (subset mode)")
else:
    test_data = full_test_data
    print(f"\n Using {len(test_data)} test examples (full dataset)")

# Save test data
with open(f"{config.OUTPUT_DIR}/test_data.json", "w") as f:
    json.dump(test_data, f, indent=2)
print(f" Test data saved to {config.OUTPUT_DIR}/test_data.json")

Loading GSM8K test set...


Processing test data:   0%|          | 0/1319 [00:00<?, ?it/s]

Loaded 1319 test examples

 Using 1319 test examples (full dataset)
 Test data saved to ./small_project/zero_shot_cot/test_data.json


In [7]:
# --- CELL 7: Load Model ---
print("\nLoading base model...")
model = AutoModelForCausalLM.from_pretrained(
    config.MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Model loaded: {config.MODEL_NAME}")
print(f"Device: {model.device}")
print(f"dtype: {model.dtype}")

`torch_dtype` is deprecated! Use `dtype` instead!



Loading base model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded: meta-llama/Llama-3.2-3B-Instruct
Device: cuda:0
dtype: torch.bfloat16


In [8]:
# --- CELL 8: Generation Function ---
def generate_answer(model, tokenizer, question):
    """Generate answer using zero-shot chain-of-thought"""
    prompt = create_inference_prompt(question)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config.GENERATION_MAX_NEW_TOKENS,
            do_sample=True,
            temperature=config.TEMPERATURE,
            top_p=config.TOP_P,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the assistant's response
    if "<|start_header_id|>assistant<|end_header_id|>" in full_text:
        response = full_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
        response = response.replace("<|eot_id|>", "").strip()
    else:
        response = full_text[len(prompt):].strip()
    
    # Extract answer
    predicted_answer = extract_answer(response)
    
    return response, predicted_answer

In [9]:
# --- CELL 9: Evaluation Function ---
def evaluate_zero_shot(model, tokenizer, test_data):
    """Evaluate model with zero-shot chain-of-thought"""
    print("\n" + "="*80)
    print("EVALUATING ZERO-SHOT CHAIN-OF-THOUGHT")
    print("="*80)
    
    correct = 0
    total = 0
    results_log = []
    error_analysis = {
        "no_answer_extracted": 0,
        "wrong_answer": 0,
        "correct": 0,
        "repetitive_output": 0,
    }
    
    model.eval()
    
    for idx, item in enumerate(tqdm(test_data, desc="Evaluating")):
        question = item["question"]
        correct_answer_str = item["answer"]
        
        response, predicted_answer_str = generate_answer(model, tokenizer, question)
        
        # Check for repetitive output
        words = response.split()
        if len(words) > 10:
            unique_ratio = len(set(words)) / len(words)
            if unique_ratio < 0.3:
                error_analysis["repetitive_output"] += 1
        
        # Check correctness
        is_correct = False
        if predicted_answer_str is not None and correct_answer_str is not None:
            try:
                is_correct = abs(float(predicted_answer_str) - float(correct_answer_str)) < 0.01
            except (ValueError, TypeError):
                is_correct = predicted_answer_str.strip() == correct_answer_str.strip()
        
        if is_correct:
            correct += 1
            error_analysis["correct"] += 1
        elif predicted_answer_str is None:
            error_analysis["no_answer_extracted"] += 1
        else:
            error_analysis["wrong_answer"] += 1
        
        total += 1
        
        results_log.append({
            "index": idx,
            "question": question,
            "predicted_answer": predicted_answer_str,
            "correct_answer": correct_answer_str,
            "is_correct": is_correct,
            "full_response": response,
            "response_length": len(response),
            "response_preview": response[:200],
        })
    
    accuracy = correct / total if total > 0 else 0
    
    print(f"\n{'='*80}")
    print(f"EVALUATION RESULTS")
    print(f"{'='*80}")
    print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
    print(f"\nError Breakdown:")
    print(f"  Correct: {error_analysis['correct']}")
    print(f"  Wrong answer: {error_analysis['wrong_answer']}")
    print(f"  No answer extracted: {error_analysis['no_answer_extracted']}")
    print(f"  Repetitive output: {error_analysis['repetitive_output']}")
    print(f"{'='*80}")
    
    return accuracy, results_log, error_analysis


In [10]:
# --- CELL 10: Run Evaluation ---
print(f"\nStart time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
accuracy, results_log, error_analysis = evaluate_zero_shot(model, tokenizer, test_data)
print(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


Start time: 2025-10-07 15:25:57

EVALUATING ZERO-SHOT CHAIN-OF-THOUGHT


Evaluating:   0%|          | 0/1319 [00:00<?, ?it/s]


EVALUATION RESULTS
Accuracy: 0.5989 (790/1319)

Error Breakdown:
  Correct: 790
  Wrong answer: 529
  No answer extracted: 0
  Repetitive output: 10
End time: 2025-10-07 17:17:43


In [12]:
# --- CELL 11: Save Results ---
print("\nSaving results...")

# Main results summary
results_summary = {
    "experiment": "zero_shot_cot",
    "model": config.MODEL_NAME,
    "timestamp": datetime.now().isoformat(),
    "config": {
        "test_size": len(test_data),
        "max_new_tokens": config.GENERATION_MAX_NEW_TOKENS,
        "temperature": config.TEMPERATURE,
        "top_p": config.TOP_P,
    },
    "results": {
        "accuracy": float(accuracy),
        "correct": int(accuracy * len(test_data)),
        "total": len(test_data),
    },
    "error_analysis": error_analysis,
}

with open(f"{config.OUTPUT_DIR}/results_summary.json", "w") as f:
    json.dump(results_summary, f, indent=2)

# Detailed results with all predictions
with open(f"{config.OUTPUT_DIR}/detailed_results.json", "w") as f:
    json.dump(results_log, f, indent=2)

print(f"✓ Results saved to:")
print(f"  - {config.OUTPUT_DIR}/results_summary.json")
print(f"  - {config.OUTPUT_DIR}/detailed_results.json")


Saving results...
✓ Results saved to:
  - ./small_project/zero_shot_cot/results_summary.json
  - ./small_project/zero_shot_cot/detailed_results.json


In [13]:
# --- CELL 12: Sample Predictions ---
print("\n" + "="*80)
print("SAMPLE PREDICTIONS")
print("="*80)

# Show correct predictions
correct_preds = [r for r in results_log if r['is_correct']]
if correct_preds:
    print(f"\nCORRECT PREDICTIONS (showing up to 3):")
    for i, r in enumerate(correct_preds[:3]):
        print(f"\n{i+1}. Question: {r['question'][:100]}...")
        print(f"   Predicted: {r['predicted_answer']}")
        print(f"   Correct: {r['correct_answer']}")
        print(f"   Response: {r['response_preview']}...")
        print("-" * 80)

# Show errors
error_preds = [r for r in results_log if not r['is_correct']]
if error_preds:
    print(f"\nINCORRECT PREDICTIONS (showing up to 3):")
    for i, r in enumerate(error_preds[:3]):
        print(f"\n{i+1}. Question: {r['question'][:100]}...")
        print(f"   Predicted: {r['predicted_answer']}")
        print(f"   Correct: {r['correct_answer']}")
        print(f"   Response: {r['response_preview']}...")
        print("-" * 80)


SAMPLE PREDICTIONS

CORRECT PREDICTIONS (showing up to 3):

1. Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for ...
   Predicted: 18
   Correct: 18
   Response: umber of eggs laid per day that are not eaten or used.
Janet lays 16 eggs per day. She eats 3 eggs for breakfast, so:
16 - 3 = 13 eggs left

Step 2: Subtract the eggs used to bake muffins.
She uses 4 ...
--------------------------------------------------------------------------------

2. Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it ...
   Predicted: 3
   Correct: 3
   Response: 2 bolts
2. Calculate how much white fiber is needed: Half of 2 bolts is 1 bolt
3. Add the amount of blue fiber and white fiber together: 2 + 1 = 3

The final answer is: ###...
--------------------------------------------------------------------------------

3. Question: Josh decides to try flipping a house.  He buys a house for $

In [14]:
# --- CELL 13: Detailed Error Analysis ---
print("\n" + "="*80)
print("DETAILED ERROR ANALYSIS")
print("="*80)

errors = [r for r in results_log if not r['is_correct']]
print(f"\nTotal errors: {len(errors)}")
print(f"Errors with no predicted answer: {sum(1 for e in errors if e['predicted_answer'] is None)}")
print(f"Errors with wrong answer: {sum(1 for e in errors if e['predicted_answer'] is not None)}")

# Check for repetitive outputs
repetitive_errors = [e for e in errors if e['response_length'] > 0 and len(set(e['full_response'].split())) / len(e['full_response'].split()) < 0.3]
print(f"Errors with repetitive output: {len(repetitive_errors)}")

# Analyze response lengths
if results_log:
    response_lengths = [r['response_length'] for r in results_log]
    print(f"\nResponse length statistics:")
    print(f"  Mean: {sum(response_lengths)/len(response_lengths):.1f} chars")
    print(f"  Min: {min(response_lengths)} chars")
    print(f"  Max: {max(response_lengths)} chars")


DETAILED ERROR ANALYSIS

Total errors: 529
Errors with no predicted answer: 0
Errors with wrong answer: 529
Errors with repetitive output: 5

Response length statistics:
  Mean: 586.7 chars
  Min: 122 chars
  Max: 1203 chars


In [None]:
# --- CELL 14: Visualizations ---
print("\nGenerating visualizations...")

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

# 1. Accuracy pie chart
ax1.pie([error_analysis['correct'], len(test_data) - error_analysis['correct']], 
        labels=['Correct', 'Incorrect'], 
        autopct='%1.1f%%',
        colors=['#2ecc71', '#e74c3c'],
        startangle=90)
ax1.set_title(f'Overall Accuracy: {accuracy:.2%}', fontweight='bold')

# 2. Error breakdown
error_types = ['Wrong\nAnswer', 'No Answer\nExtracted', 'Repetitive\nOutput']
error_counts = [
    error_analysis['wrong_answer'], 
    error_analysis['no_answer_extracted'], 
    error_analysis['repetitive_output']
]
ax2.bar(error_types, error_counts, color=['#e74c3c', '#f39c12', '#9b59b6'])
ax2.set_ylabel('Count', fontweight='bold')
ax2.set_title('Error Type Breakdown', fontweight='bold')
ax2.grid(axis='y', alpha=0.3)

# 3. Response length distribution
response_lengths = [r['response_length'] for r in results_log]
ax3.hist(response_lengths, bins=20, color='#3498db', alpha=0.7, edgecolor='black')
ax3.set_xlabel('Response Length (characters)', fontweight='bold')
ax3.set_ylabel('Frequency', fontweight='bold')
ax3.set_title('Response Length Distribution', fontweight='bold')
ax3.grid(axis='y', alpha=0.3)

# 4. Correct vs Incorrect response lengths
correct_lengths = [r['response_length'] for r in results_log if r['is_correct']]
incorrect_lengths = [r['response_length'] for r in results_log if not r['is_correct']]

if correct_lengths and incorrect_lengths:
    ax4.boxplot([correct_lengths, incorrect_lengths], labels=['Correct', 'Incorrect'])
    ax4.set_ylabel('Response Length (characters)', fontweight='bold')
    ax4.set_title('Response Length: Correct vs Incorrect', fontweight='bold')
    ax4.grid(axis='y', alpha=0.3)
else:
    ax4.text(0.5, 0.5, 'Insufficient data', ha='center', va='center', transform=ax4.transAxes)
    ax4.set_title('Response Length Comparison', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{config.OUTPUT_DIR}/evaluation_results.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"Visualization saved to {config.OUTPUT_DIR}/evaluation_results.png")


In [None]:
# --- CELL 15: Sample Responses Analysis ---
print("\n" + "="*80)
print("SAMPLE FULL RESPONSES")
print("="*80)

# Show one complete correct response
if correct_preds:
    print("\n[CORRECT EXAMPLE]")
    sample = correct_preds[0]
    print(f"Question: {sample['question']}")
    print(f"\nModel Response:")
    print(sample['full_response'])
    print(f"\nExtracted Answer: {sample['predicted_answer']}")
    print(f"Correct Answer: {sample['correct_answer']}")

# Show one complete incorrect response
if error_preds:
    print("\n" + "="*80)
    print("\n[INCORRECT EXAMPLE]")
    sample = error_preds[0]
    print(f"Question: {sample['question']}")
    print(f"\nModel Response:")
    print(sample['full_response'])
    print(f"\nExtracted Answer: {sample['predicted_answer']}")
    print(f"Correct Answer: {sample['correct_answer']}")


In [None]:
# --- CELL 16: Generate Report ---
print("\n" + "="*80)
print("GENERATING SUMMARY REPORT")
print("="*80)

report = f"""
ZERO-SHOT CHAIN-OF-THOUGHT EVALUATION REPORT
{'='*80}

Model: {config.MODEL_NAME}
Test Set Size: {len(test_data)}
Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

OVERALL RESULTS
{'='*80}
Accuracy: {accuracy:.4f} ({int(accuracy * len(test_data))}/{len(test_data)})

ERROR ANALYSIS
{'='*80}
Total Errors: {len(errors)}
  - Wrong Answer: {error_analysis['wrong_answer']} ({error_analysis['wrong_answer']/len(test_data)*100:.1f}%)
  - No Answer Extracted: {error_analysis['no_answer_extracted']} ({error_analysis['no_answer_extracted']/len(test_data)*100:.1f}%)
  - Repetitive Output: {error_analysis['repetitive_output']} ({error_analysis['repetitive_output']/len(test_data)*100:.1f}%)

RESPONSE STATISTICS
{'='*80}
Average Response Length: {sum(response_lengths)/len(response_lengths):.1f} characters
Min Response Length: {min(response_lengths)} characters
Max Response Length: {max(response_lengths)} characters

NOTES
{'='*80}
- This is a zero-shot evaluation (no fine-tuning)
- The model uses chain-of-thought prompting
- Temperature: {config.TEMPERATURE}, Top-p: {config.TOP_P}
- Generation uses sampling (not greedy decoding)

"""

with open(f"{config.OUTPUT_DIR}/evaluation_report.txt", "w") as f:
    f.write(report)

print(report)
print(f"Report saved to {config.OUTPUT_DIR}/evaluation_report.txt")

In [None]:
# --- CELL 17: Final Summary ---
print("\n" + "="*80)
print("ZERO-SHOT CoT EVALUATION COMPLETE!")
print("="*80)
print(f"\nAll results saved to: {config.OUTPUT_DIR}/")
print("\nGenerated files:")
print("  1. test_data.json - Test dataset")
print("  2. results_summary.json - Summary statistics")
print("  3. detailed_results.json - All predictions")
print("  4. evaluation_results.png - Visualizations")
print("  5. evaluation_report.txt - Text report")
print("\nKey Findings:")
print(f"  - Baseline accuracy: {accuracy:.4f}")
print(f"  - Total predictions: {len(test_data)}")
print(f"  - Correct predictions: {int(accuracy * len(test_data))}")
print(f"  - Main error type: {'No answer extracted' if error_analysis['no_answer_extracted'] > error_analysis['wrong_answer'] else 'Wrong answer'}")
print("="*80)