In [None]:
"""
Complete ScienceQA Chain-of-Thought Implementation
Run cells sequentially in a Jupyter notebook or Google Colab
"""

# =============================================================================
# PART 1: SETUP & DATA LOADING
# =============================================================================

# Cell 1: Install dependencies
!pip install transformers datasets torch pillow accelerate evaluate scikit-learn -q

# Cell 2: Import libraries
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from datasets import load_dataset, Dataset
from PIL import Image
import json
import numpy as np
from typing import Dict, List, Optional
import time

# Cell 3: Download ScienceQA dataset with retry logic
def load_dataset_with_retry(dataset_name, max_retries=5):
    """Load dataset with retry logic for timeout errors"""
    for attempt in range(max_retries):
        try:
            print(f"Attempt {attempt + 1}/{max_retries} to load dataset...")
            dataset = load_dataset(
                dataset_name,
                trust_remote_code=True,
                download_config={
                    'resume_download': True,
                    'max_retries': 10
                }
            )
            print("Dataset loaded successfully!")
            return dataset
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                wait_time = (attempt + 1) * 10
                print(f"Waiting {wait_time} seconds before retry...")
                time.sleep(wait_time)
            else:
                print("All attempts failed. Trying alternative method...")
                raise

# Try loading dataset
try:
    dataset = load_dataset_with_retry("derek-thomas/ScienceQA")
except:
    print("\nTrying alternative loading method...")
    # Alternative: Load from cache or use streaming
    try:
        dataset = load_dataset("derek-thomas/ScienceQA", streaming=False, cache_dir="./scienceqa_cache")
    except:
        print("\nIf loading still fails, you can:")
        print("1. Download dataset manually from: https://huggingface.co/datasets/derek-thomas/ScienceQA")
        print("2. Or use the official ScienceQA GitHub: https://github.com/lupantech/ScienceQA")
        print("3. Or use a smaller subset for testing")
        raise

print(f"\nTrain examples: {len(dataset['train'])}")
print(f"Validation examples: {len(dataset['validation'])}")
print(f"Test examples: {len(dataset['test'])}")

# Cell 4: Explore dataset structure
sample = dataset['train'][0]
print("\nSample structure:")
for key in sample.keys():
    print(f"  {key}: {type(sample[key])}")

print("\nSample example:")
print(f"Question: {sample['question']}")
print(f"Choices: {sample['choices']}")
print(f"Answer: {sample['answer']}")
if sample['hint']:
    print(f"Hint: {sample['hint'][:100]}...")
if sample['lecture']:
    print(f"Lecture: {sample['lecture'][:100]}...")
if sample['solution']:
    print(f"Solution: {sample['solution'][:100]}...")

# Cell 5: Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")



[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Attempt 1/5 to load dataset...
Attempt 1 failed: 'dict' object has no attribute 'extract_compressed_file'
Waiting 10 seconds before retry...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Attempt 2/5 to load dataset...
Attempt 2 failed: 'dict' object has no attribute 'extract_compressed_file'
Waiting 20 seconds before retry...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Attempt 3/5 to load dataset...
Attempt 3 failed: 'dict' object has no attribute 'extract_compressed_file'
Waiting 30 seconds before retry...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Attempt 4/5 to load dataset...
Attempt 4 failed: 'dict' object has no attribute 'extract_compressed_file'
Waiting 40 seconds before retry...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'derek-thomas/ScienceQA' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Attempt 5/5 to load dataset...
Attempt 5 failed: 'dict' object has no attribute 'extract_compressed_file'
All attempts failed. Trying alternative method...

Trying alternative loading method...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-1028f23e353fbe(…):   0%|          | 0.00/377M [00:00<?, ?B/s]

data/validation-00000-of-00001-6c7328ff6(…):   0%|          | 0.00/126M [00:00<?, ?B/s]

data/test-00000-of-00001-f0e719df791966f(…):   0%|          | 0.00/122M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12726 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4241 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4241 [00:00<?, ? examples/s]


Train examples: 12726
Validation examples: 4241
Test examples: 4241

Sample structure:
  image: <class 'PIL.PngImagePlugin.PngImageFile'>
  question: <class 'str'>
  choices: <class 'list'>
  answer: <class 'int'>
  hint: <class 'str'>
  task: <class 'str'>
  grade: <class 'str'>
  subject: <class 'str'>
  topic: <class 'str'>
  category: <class 'str'>
  skill: <class 'str'>
  lecture: <class 'str'>
  solution: <class 'str'>

Sample example:
Question: Which of these states is farthest north?
Choices: ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma']
Answer: 0
Lecture: Maps have four cardinal directions, or main directions. Those directions are north, south, east, and...
Solution: To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Vi...

Using device: cuda
GPU: NVIDIA L4
GPU Memory: 23.80 GB


In [None]:
# =============================================================================
# PART 2: IMAGE CAPTIONING
# =============================================================================

# Cell 6: Load image captioning model
print("\nLoading image captioning model...")
caption_model_name = "nlpconnect/vit-gpt2-image-captioning"
caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model_name)
caption_processor = ViTImageProcessor.from_pretrained(caption_model_name)
caption_tokenizer = AutoTokenizer.from_pretrained(caption_model_name)

caption_model.to(device)
caption_model.eval()
print("Image captioning model loaded!")

# Cell 7: Image caption generation function
def generate_caption(image, max_length=16, num_beams=4):
    """Generate caption for an image"""
    if image is None:
        return ""

    if not isinstance(image, Image.Image):
        return ""

    if image.mode != 'RGB':
        image = image.convert('RGB')

    pixel_values = caption_processor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    with torch.no_grad():
        output_ids = caption_model.generate(
            pixel_values,
            max_length=max_length,
            num_beams=num_beams
        )

    caption = caption_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

# Cell 8: Generate captions for dataset (this may take time)
def add_image_captions(examples):
    """Add image captions to examples"""
    captions = []
    for img in examples['image']:
        if img is not None:
            caption = generate_caption(img)
        else:
            caption = ""
        captions.append(caption)
    examples['image_caption'] = captions
    return examples

print("\nGenerating captions for training set...")
dataset_train = dataset['train'].map(add_image_captions, batched=True, batch_size=8)

print("Generating captions for validation set...")
dataset_val = dataset['validation'].map(add_image_captions, batched=True, batch_size=8)

print("Generating captions for test set...")
dataset_test = dataset['test'].map(add_image_captions, batched=True, batch_size=8)

print("Captions generated!")




Loading image captioning model...


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/982M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/982M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/241 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

Image captioning model loaded!

Generating captions for training set...


Map:   0%|          | 0/12726 [00:00<?, ? examples/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Generating captions for validation set...


Map:   0%|          | 0/4241 [00:00<?, ? examples/s]

Generating captions for test set...


Map:   0%|          | 0/4241 [00:00<?, ? examples/s]

Captions generated!


In [None]:
# =============================================================================
# PART 3: DATA FORMATTING
# =============================================================================

# Cell 9: Format data for UnifiedQA CoT
def format_input_text(example):
    """Format input in QCM format (Question, Context, Multiple choices)"""
    parts = []

    # Add image caption as context if available
    if example['image_caption']:
        parts.append(f"Context: {example['image_caption']}")

    # Add text context (hint) if available
    if example['hint']:
        if parts:
            parts.append(example['hint'])
        else:
            parts.append(f"Context: {example['hint']}")

    # Add question
    parts.append(example['question'])

    # Add choices
    choices_text = []
    for i, choice in enumerate(example['choices']):
        choices_text.append(f"({chr(65+i)}) {choice}")
    parts.append(f"Options: {' '.join(choices_text)}")

    return '\n'.join(parts)

def format_target_text(example):
    """Format target in ALE format (Answer, Lecture, Explanation)"""
    answer_idx = example['answer']
    answer_letter = chr(65 + answer_idx)

    target_parts = [f"The answer is ({answer_letter})."]

    # Add lecture if available
    if example['lecture']:
        target_parts.append(f"BECAUSE: {example['lecture']}")

    # Add explanation if available
    if example['solution']:
        target_parts.append(example['solution'])

    return ' '.join(target_parts)

def preprocess_function(examples):
    """Preprocess examples for training"""
    inputs = [format_input_text(ex) for ex in examples]
    targets = [format_target_text(ex) for ex in examples]

    return {
        'input_text': inputs,
        'target_text': targets
    }

# Cell 10: Format datasets
def dataset_to_dicts(dataset_split):
    return [dict(zip(dataset_split.features.keys(), values))
            for values in zip(*[dataset_split[k] for k in dataset_split.features.keys()])]

train_dicts = dataset_to_dicts(dataset_train)
val_dicts = dataset_to_dicts(dataset_val)
test_dicts = dataset_to_dicts(dataset_test)

train_formatted = preprocess_function(train_dicts)
val_formatted = preprocess_function(val_dicts)
test_formatted = preprocess_function(test_dicts)

print(f"\nFormatted {len(train_formatted['input_text'])} training examples")
print(f"Formatted {len(val_formatted['input_text'])} validation examples")
print(f"Formatted {len(test_formatted['input_text'])} test examples")

# Cell 11: Show formatted example
print("\n=== Formatted Example ===")
print("INPUT:")
print(train_formatted['input_text'][0])
print("\nTARGET:")
print(train_formatted['target_text'][0])


Formatted 12726 training examples
Formatted 4241 validation examples
Formatted 4241 test examples

=== Formatted Example ===
INPUT:
Context: an aerial view of a painting of a forest 
Which of these states is farthest north?
Options: (A) West Virginia (B) Louisiana (C) Arizona (D) Oklahoma

TARGET:
The answer is (A). BECAUSE: Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.
A compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.
The north arrow points to the North Pole. On most maps, north is at the top of the map. To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.


In [None]:
# # =============================================================================
# # PART 4: MODEL & TOKENIZATION
# # =============================================================================
# # Cell 12: Load UnifiedQA model
# print("\nLoading UnifiedQA model...")
# model_name = "allenai/unifiedqa-t5-base"
# tokenizer = T5Tokenizer.from_pretrained(model_name)
# model = T5ForConditionalGeneration.from_pretrained(model_name)
# model.to(device)
# print("UnifiedQA model loaded!")

# # Cell 13: Create tokenized datasets
# def tokenize_function(examples, max_input_length=512, max_target_length=512):
#     model_inputs = tokenizer(
#         examples['input_text'],
#         max_length=max_input_length,
#         truncation=True,
#         padding='max_length'
#     )

#     labels = tokenizer(
#         examples['target_text'],
#         max_length=max_target_length,
#         truncation=True,
#         padding='max_length'
#     )

#     model_inputs['labels'] = labels['input_ids']

#     # Replace padding token id's in labels with -100
#     model_inputs['labels'] = [
#         [(l if l != tokenizer.pad_token_id else -100) for l in label]
#         for label in model_inputs['labels']
#     ]

#     return model_inputs

# # Create HuggingFace datasets
# train_dataset = Dataset.from_dict(train_formatted)
# val_dataset = Dataset.from_dict(val_formatted)
# test_dataset = Dataset.from_dict(test_formatted)


# def create_subset(dataset, n_samples):
#     """Create a subset of the dataset"""
#     indices = list(range(min(n_samples, len(dataset))))
#     return dataset.select(indices)

# VAL_SUBSET_SIZE = 100  # Use only 500 validation examples
# val_dataset = create_subset(val_dataset, VAL_SUBSET_SIZE)

# print("\nTokenizing datasets...")
# tokenized_train = train_dataset.map(
#     tokenize_function,
#     batched=True,
#     remove_columns=train_dataset.column_names,
#     desc="Tokenizing train"
# )

# tokenized_val = val_dataset.map(
#     tokenize_function,
#     batched=True,
#     remove_columns=val_dataset.column_names,
#     desc="Tokenizing validation"
# )

# tokenized_test = test_dataset.map(
#     tokenize_function,
#     batched=True,
#     remove_columns=test_dataset.column_names,
#     desc="Tokenizing test"
# )

# print("Tokenization complete!")
# print(f"Train: {len(tokenized_train)} examples")
# print(f"Validation: {len(tokenized_val)} examples")
# print(f"Test: {len(tokenized_test)} examples")

# =============================================================================
# PART 4: MODEL & TOKENIZATION
# =============================================================================

# Cell 12: Load UnifiedQA model
print("\nLoading UnifiedQA model...")
model_name = "allenai/unifiedqa-t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to(device)
print("UnifiedQA model loaded!")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Tokenizer pad token id: {tokenizer.pad_token_id}")

# Cell 13: Create subset FIRST (before tokenization)
def create_subset(dataset, n_samples):
    """Create a subset of the dataset"""
    indices = list(range(min(n_samples, len(dataset))))
    return dataset.select(indices)

# Create HuggingFace datasets
train_dataset = Dataset.from_dict(train_formatted)
val_dataset = Dataset.from_dict(val_formatted)
test_dataset = Dataset.from_dict(test_formatted)

# Reduce sizes for faster training
VAL_SUBSET_SIZE = 500
TEST_SUBSET_SIZE = 500

print(f"\nOriginal validation size: {len(val_dataset)}")
val_dataset = create_subset(val_dataset, VAL_SUBSET_SIZE)
print(f"Reduced validation size: {len(val_dataset)}")

print(f"\nOriginal test size: {len(test_dataset)}")
test_dataset = create_subset(test_dataset, TEST_SUBSET_SIZE)
print(f"Reduced test size: {len(test_dataset)}")

# Cell 14: Tokenization function
def tokenize_function(examples, max_input_length=512, max_target_length=512):
    model_inputs = tokenizer(
        examples['input_text'],
        max_length=max_input_length,
        truncation=True,
        padding='max_length'
    )

    labels = tokenizer(
        examples['target_text'],
        max_length=max_target_length,
        truncation=True,
        padding='max_length'
    )

    model_inputs['labels'] = labels['input_ids']

    # Replace padding token id's in labels with -100
    model_inputs['labels'] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in model_inputs['labels']
    ]

    return model_inputs

# Cell 15: Tokenize datasets
print("\nTokenizing datasets...")
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing train"
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation"
)

tokenized_test = test_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=test_dataset.column_names,
    desc="Tokenizing test"
)

# CRITICAL FIX: Set format to only include necessary columns
tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_val.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

print("Tokenization complete!")
print(f"Train: {len(tokenized_train)} examples")
print(f"Validation: {len(tokenized_val)} examples")
print(f"Test: {len(tokenized_test)} examples")

# Cell 16: Verify tokenization
print("\n=== Verifying Tokenization ===")
sample = tokenized_train[0]
print(f"Sample keys: {sample.keys()}")
print(f"Input IDs shape: {sample['input_ids'].shape}")
print(f"Attention mask shape: {sample['attention_mask'].shape}")
print(f"Labels shape: {sample['labels'].shape}")
print(f"Max input ID: {sample['input_ids'].max().item()}")
print(f"Min input ID: {sample['input_ids'].min().item()}")
print(f"Vocab size: {len(tokenizer)}")

# Check if any invalid IDs
if sample['input_ids'].max().item() >= len(tokenizer):
    print("⚠️ WARNING: Found token IDs larger than vocab size!")
else:
    print("✓ All token IDs are valid")

print(f"\nUsing device: {device}")


Loading UnifiedQA model...
UnifiedQA model loaded!
Tokenizer vocab size: 32100
Tokenizer pad token id: 0

Original validation size: 4241
Reduced validation size: 500

Original test size: 4241
Reduced test size: 500

Tokenizing datasets...


Tokenizing train:   0%|          | 0/12726 [00:00<?, ? examples/s]

Tokenizing validation:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing test:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenization complete!
Train: 12726 examples
Validation: 500 examples
Test: 500 examples

=== Verifying Tokenization ===
Sample keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
Input IDs shape: torch.Size([512])
Attention mask shape: torch.Size([512])
Labels shape: torch.Size([512])
Max input ID: 22142
Min input ID: 0
Vocab size: 32100
✓ All token IDs are valid

Using device: cuda


In [None]:
print(f"\nUsing device: {device}")


Using device: cuda


In [None]:
# =============================================================================
# PART 5: TRAINING
# =============================================================================
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import numpy as np

# Cell 14: Define compute metrics function (FIXED)
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # CRITICAL FIX: Clip predictions to valid token ID range
    vocab_size = len(tokenizer)

    # Replace any invalid token IDs with pad token
    predictions = np.clip(predictions, 0, vocab_size - 1)

    # Decode predictions
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Extract answer letters
    correct = 0
    total = 0

    for pred, label in zip(decoded_preds, decoded_labels):
        total += 1

        # Extract predicted answer
        pred_answer = None
        if "The answer is (" in pred:
            try:
                start = pred.find("The answer is (") + len("The answer is (")
                pred_answer = pred[start:start+1]
            except:
                pass
        elif "answer is" in pred.lower():
            # Fallback: try to find any letter after "answer is"
            try:
                idx = pred.lower().find("answer is")
                remaining = pred[idx+9:].strip()
                if remaining and remaining[0] in 'ABCDEFGHIJ':
                    pred_answer = remaining[0]
            except:
                pass

        # Extract true answer
        true_answer = None
        if "The answer is (" in label:
            try:
                start = label.find("The answer is (") + len("The answer is (")
                true_answer = label[start:start+1]
            except:
                pass

        if pred_answer and true_answer and pred_answer.upper() == true_answer.upper():
            correct += 1

    accuracy = correct / total if total > 0 else 0

    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total
    }

# Verify GPU availability
print("\n=== Device Check ===")
if torch.cuda.is_available():
    print(f"✓ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    use_gpu = True
else:
    print("✗ No GPU available, will use CPU")
    use_gpu = False

# Cell 15: Training arguments (optimized for faster training)
print("\n=== Setting up training ===")
training_args = Seq2SeqTrainingArguments(
    output_dir='./unifiedqa_scienceqa_cot',
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,  # Larger batch for evaluation
    learning_rate=5e-5,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    save_strategy='steps',
    save_steps=1000,  # Save less frequently
    eval_strategy='steps',
    eval_steps=1000,  # Evaluate less frequently
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    save_total_limit=2,  # Keep only 2 checkpoints
    predict_with_generate=True,
    generation_max_length=512,
    generation_num_beams=4,
    fp16=use_gpu,  # Mixed precision only if GPU available
    no_cuda=not use_gpu,
    report_to="none",
    dataloader_num_workers=0,  # Avoid multiprocessing issues
)

print(f"Training for {training_args.num_train_epochs} epoch(s)")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Evaluation every {training_args.eval_steps} steps")
print(f"Total training examples: {len(tokenized_train)}")
print(f"Total validation examples: {len(tokenized_val)}")

# Cell 16: Initialize trainer
print("\n=== Initializing trainer ===")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    compute_metrics=compute_metrics,
)
print("Trainer initialized successfully!")

# Cell 17: Train the model
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80)

try:
    trainer.train()
    print("\n" + "="*80)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("="*80)
except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    print("\nTrying alternative training configuration...")

    # Fallback: Disable evaluation during training
    training_args_fallback = Seq2SeqTrainingArguments(
        output_dir='./unifiedqa_scienceqa_cot',
        num_train_epochs=1,
        per_device_train_batch_size=4,
        learning_rate=5e-5,
        warmup_steps=500,
        weight_decay=0.01,
        logging_steps=100,
        save_strategy='epoch',
        eval_strategy='no',  # Disable evaluation
        predict_with_generate=False,  # Disable generation
        fp16=use_gpu,
        no_cuda=not use_gpu,
        report_to="none",
    )

    trainer_fallback = Seq2SeqTrainer(
        model=model,
        args=training_args_fallback,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
    )

    trainer_fallback.train()
    print("\nTraining completed with fallback configuration!")

# Cell 18: Save final model
print("\n=== Saving model ===")
model.save_pretrained('./unifiedqa_scienceqa_cot_final')
tokenizer.save_pretrained('./unifiedqa_scienceqa_cot_final')
print("✓ Model saved to: ./unifiedqa_scienceqa_cot_final")
print("✓ Tokenizer saved to: ./unifiedqa_scienceqa_cot_final")
print("\nTraining pipeline complete!")


=== Device Check ===
✓ GPU Available: NVIDIA L4
✓ GPU Memory: 23.80 GB

=== Setting up training ===
Training for 1 epoch(s)
Batch size: 4
Evaluation every 1000 steps
Total training examples: 12726
Total validation examples: 500

=== Initializing trainer ===
Trainer initialized successfully!

STARTING TRAINING


Step,Training Loss,Validation Loss,Accuracy,Correct,Total
1000,1.5194,1.119508,0.448,224,500
2000,1.0684,0.694071,0.438,219,500
3000,0.9073,0.594868,0.482,241,500


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].



TRAINING COMPLETED SUCCESSFULLY!

=== Saving model ===
✓ Model saved to: ./unifiedqa_scienceqa_cot_final
✓ Tokenizer saved to: ./unifiedqa_scienceqa_cot_final

Training pipeline complete!


In [None]:
# =============================================================================
# SIMPLE UPLOAD TO HUGGING FACE
# =============================================================================

# Cell 1: Login to Hugging Face (one-time setup)
from huggingface_hub import login

HF_TOKEN = ""  # Get from https://huggingface.co/settings/tokens
login(token=HF_TOKEN)
print("✓ Logged in to Hugging Face")

# Cell 2: Load your trained model and tokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer

model = T5ForConditionalGeneration.from_pretrained('./unifiedqa_scienceqa_cot_final')
tokenizer = T5Tokenizer.from_pretrained('./unifiedqa_scienceqa_cot_final')

print("✓ Model and tokenizer loaded")

# Cell 3: Push to Hub
REPO_NAME = "VishalM12/scienceqa-unifiedqa-cot"

print(f"\nUploading to {REPO_NAME}...")

# Push model
model.push_to_hub(REPO_NAME)
print("✓ Model uploaded")

# Push tokenizer
tokenizer.push_to_hub(REPO_NAME)
print("✓ Tokenizer uploaded")

print(f"\n✅ Done! Model available at: https://huggingface.co/{REPO_NAME}")

✓ Logged in to Hugging Face
✓ Model and tokenizer loaded

Uploading to VishalM12/scienceqa-unifiedqa-cot...


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...bsevkej/model.safetensors:   0%|          |  553kB /  892MB            

✓ Model uploaded


README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  .../tmp9l9g760q/spiece.model: 100%|##########|  792kB /  792kB            

✓ Tokenizer uploaded

✅ Done! Model available at: https://huggingface.co/VishalM12/scienceqa-unifiedqa-cot


In [None]:
# =============================================================================
# PART 6: EVALUATION (SIMPLIFIED - 500 SAMPLES)
# =============================================================================

# Cell 19: Helper Functions
import numpy as np
from tqdm import tqdm

def extract_answer_letter(text):
    """Extract answer letter from prediction"""
    if "The answer is" in text:
        try:
            answer_part = text.split("The answer is")[1].split(".")[0].strip()
            if answer_part and len(answer_part) >= 3 and answer_part[0] == '(':
                return answer_part[1]
        except:
            pass
    return None

def extract_choice_index(predicted_answer, choices):
    """Convert predicted text to choice index"""
    pred = predicted_answer.lower().strip()
    letters = ['a', 'b', 'c', 'd', 'e', 'f']

    for i, letter in enumerate(letters[:len(choices)]):
        if pred == letter or f"({letter})" in pred:
            return i

    for i, choice in enumerate(choices):
        if choice.lower().strip() in pred:
            return i

    return -1

def calculate_metrics(predictions, labels, subjects, has_image, has_text):
    """Calculate evaluation metrics"""
    valid_mask = predictions >= 0
    results = {}

    # Overall
    results['overall'] = (predictions[valid_mask] == labels[valid_mask]).mean() * 100
    results['total'] = int(valid_mask.sum())

    # By Subject
    for subj in ['natural science', 'social science', 'language science']:
        mask = (subjects == subj) & valid_mask
        if mask.sum() > 0:
            short = subj.upper()[:3]
            results[short] = (predictions[mask] == labels[mask]).mean() * 100

    # By Context
    contexts = {
        'text_only': has_text & ~has_image,
        'image_only': has_image & ~has_text,
        'both': has_image & has_text,
        'no_context': ~has_image & ~has_text
    }

    for name, mask in contexts.items():
        mask = mask & valid_mask
        if mask.sum() > 0:
            results[name] = (predictions[mask] == labels[mask]).mean() * 100

    return results

# Cell 20: Load Test Dataset (500 samples only)
print("\n" + "="*60)
print("LOADING TEST DATASET")
print("="*60)

test_dataset = load_dataset('derek-thomas/ScienceQA', split='test')
test_dataset = test_dataset.select(range(500))  # Only 500 samples

print(f"Test samples: {len(test_dataset)}")

# Cell 21: Load Fine-tuned Model
print("\nLoading fine-tuned model...")
model = T5ForConditionalGeneration.from_pretrained('./unifiedqa_scienceqa_cot_final')
tokenizer = T5Tokenizer.from_pretrained('./unifiedqa_scienceqa_cot_final')
model.to(device)
model.eval()
print("Model loaded")

# Cell 22: Run Evaluation
print("\nRunning evaluation...")

questions = test_dataset['question']
choices_list = test_dataset['choices']
contexts = [ex.get('hint', '') for ex in test_dataset]
labels = np.array(test_dataset['answer'])
subjects = np.array(test_dataset['subject'])
images_data = [ex.get('image') for ex in test_dataset]

# Generate image captions if needed
print("Generating captions...")
image_captions = []
for img in tqdm(images_data, desc="Captioning"):
    if img is not None:
        try:
            caption = generate_caption(img)
        except:
            caption = ""
    else:
        caption = ""
    image_captions.append(caption)

# Batch inference
all_predictions = []
batch_size = 8

print("Running inference...")
for i in tqdm(range(0, len(questions), batch_size)):
    batch_end = min(i + batch_size, len(questions))

    # Format inputs
    batch_inputs = []
    for idx in range(i, batch_end):
        parts = []
        if image_captions[idx]:
            parts.append(f"Context: {image_captions[idx]}")
        if contexts[idx]:
            if parts:
                parts.append(contexts[idx])
            else:
                parts.append(f"Context: {contexts[idx]}")

        parts.append(questions[idx])

        choices_text = [f"({chr(65+j)}) {choice}" for j, choice in enumerate(choices_list[idx])]
        parts.append(f"Options: {' '.join(choices_text)}")

        batch_inputs.append('\n'.join(parts))

    # Generate
    inputs = tokenizer(batch_inputs, max_length=512, truncation=True,
                      padding=True, return_tensors='pt')
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512, num_beams=4)

    predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for idx, pred_text in enumerate(predictions):
        answer_letter = extract_answer_letter(pred_text)
        pred_idx = ord(answer_letter) - 65 if answer_letter else -1
        all_predictions.append(pred_idx)

# Cell 23: Calculate and Print Results
predictions = np.array(all_predictions)
has_image = np.array([img is not None for img in images_data])
has_text = np.array([bool(ctx) for ctx in contexts])

results = calculate_metrics(predictions, labels, subjects, has_image, has_text)

print("\n" + "="*60)
print("EVALUATION RESULTS (500 samples)")
print("="*60)
print(f"\nOverall Accuracy: {results['overall']:.2f}% ({results['total']} samples)")

print("\nBy Subject:")
for subj in ['NAT', 'SOC', 'LAN']:
    if subj in results:
        print(f"  {subj}: {results[subj]:.2f}%")

print("\nBy Context:")
for ctx in ['text_only', 'image_only', 'both', 'no_context']:
    if ctx in results:
        print(f"  {ctx}: {results[ctx]:.2f}%")

# Cell 24: Save Results (JSON)
with open('evaluation_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("\nResults saved to evaluation_results.json")

# Cell 25: Save Results (CSV)
import pandas as pd

# Prepare data for CSV
csv_data = []

# Overall
csv_data.append({
    'Category': 'Overall',
    'Subcategory': 'All',
    'Accuracy': f"{results['overall']:.2f}",
    'Sample_Count': results['total']
})

# By Subject
for subj in ['NAT', 'SOC', 'LAN']:
    if subj in results:
        csv_data.append({
            'Category': 'Subject',
            'Subcategory': subj,
            'Accuracy': f"{results[subj]:.2f}",
            'Sample_Count': '-'
        })

# By Context
context_names = {
    'text_only': 'Text Only',
    'image_only': 'Image Only',
    'both': 'Text + Image',
    'no_context': 'No Context'
}

for ctx_key, ctx_name in context_names.items():
    if ctx_key in results:
        csv_data.append({
            'Category': 'Context',
            'Subcategory': ctx_name,
            'Accuracy': f"{results[ctx_key]:.2f}",
            'Sample_Count': '-'
        })

# Create DataFrame and save
df = pd.DataFrame(csv_data)
df.to_csv('evaluation_results.csv', index=False)

print("Results saved to evaluation_results.csv")
print("\nEvaluation complete!")


LOADING TEST DATASET
Test samples: 500

Loading fine-tuned model...
Model loaded

Running evaluation...
Generating captions...


Captioning: 100%|██████████| 500/500 [00:51<00:00,  9.62it/s]


Running inference...


100%|██████████| 63/63 [14:15<00:00, 13.58s/it]


EVALUATION RESULTS (500 samples)

Overall Accuracy: 41.60% (500 samples)

By Subject:
  NAT: 40.38%
  SOC: 36.89%
  LAN: 47.73%

By Context:
  text_only: 37.76%
  image_only: 41.11%
  both: 40.14%
  no_context: 45.45%

Results saved to evaluation_results.json
Results saved to evaluation_results.csv

Evaluation complete!



