In [None]:
import torch
from datasets import load_dataset, Dataset
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
import numpy as np
import gc
import os
from tqdm.auto import tqdm

# # Create output directories if they don't exist
# os.makedirs("./results/answer_generation", exist_ok=True)
# os.makedirs("./logs/answer_generation", exist_ok=True)

os.makedirs("/kaggle/working/results/answer_generation", exist_ok=True)
os.makedirs("/kaggle/working/logs/answer_generation", exist_ok=True)


# Enable memory optimization for PyTorch
torch.cuda.empty_cache()

# 1. Load dataset - load only a subset directly
print("Loading dataset...")
dataset = load_dataset("zjsd/RedStone-QA-mcq", split=f"train[:{int(0.02*100)}%]")
print(f"Dataset loaded with {len(dataset)} examples")

# 2. Define preprocessing function for answer generation
def preprocess_for_answer_generation(examples, batch_size=64):
    """Preprocess data for training the model to generate the correct answer letter."""
    inputs = []
    labels = []
    
    for i in range(0, len(examples["text"]), batch_size):
        batch_texts = examples["text"][i:i+batch_size]
        batch_questions = examples["question"][i:i+batch_size]
        batch_answers = examples["answer"][i:i+batch_size]
        
        for text, question, answer in zip(batch_texts, batch_questions, batch_answers):
            combined = f"generate answer: {text} question: {question}"
            inputs.append(combined)
            # Extract just the letter from "Answer:X" format
            labels.append(answer.replace("Answer:", "").strip())
    
    return {
        "input": inputs,
        "output": labels
    }

# 3. Process dataset in chunks to save memory
print("Preprocessing data for answer generation...")
answer_dataset = Dataset.from_dict(preprocess_for_answer_generation(dataset))

# Free up memory
del dataset
gc.collect()
torch.cuda.empty_cache()

# 4. Split dataset into train and validation
print("Splitting dataset into train and validation sets...")
answer_dataset = answer_dataset.train_test_split(test_size=0.1, seed=42)

answer_train_dataset = answer_dataset["train"]
answer_val_dataset = answer_dataset["test"]

print(f"Answer generation: {len(answer_train_dataset)} training examples, {len(answer_val_dataset)} validation examples")

# Free up memory
del answer_dataset
gc.collect()

# 5. Load tokenizer and model
print("Loading tokenizer and model...")
model_name = "google/flan-t5-small"  # Using the small model for faster training
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 6. Define tokenization function
def tokenize_function(examples, max_input_length=512, max_target_length=128):
    model_inputs = tokenizer(
        examples["input"],
        max_length=max_input_length,
        padding="max_length",
        truncation=True
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["output"],
            max_length=max_target_length,
            padding="max_length",
            truncation=True
        )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# 7. Tokenize datasets with smaller batch size to reduce memory usage
print("Tokenizing datasets...")
batch_size = 32

answer_train_tokenized = answer_train_dataset.map(
    tokenize_function, 
    batched=True,
    batch_size=batch_size,
    remove_columns=["input", "output"]
)

answer_val_tokenized = answer_val_dataset.map(
    tokenize_function, 
    batched=True,
    batch_size=batch_size,
    remove_columns=["input", "output"]
)

# Free up memory
del answer_train_dataset, answer_val_dataset
gc.collect()
torch.cuda.empty_cache()

# 8. Define custom metrics computation function
def compute_answer_metrics(eval_pred):
    predictions, labels = eval_pred
    # Replace -100 with the pad_token_id
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Get the predicted tokens
    predicted_tokens = np.argmax(predictions, axis=-1)
    
    # Decode to get the actual letters - process in batches to save memory
    batch_size = 32
    num_examples = len(predicted_tokens)
    decoded_preds = []
    decoded_labels = []
    
    for i in range(0, num_examples, batch_size):
        batch_preds = [tokenizer.decode(pred, skip_special_tokens=True) 
                      for pred in predicted_tokens[i:i+batch_size]]
        batch_labels = [tokenizer.decode(label, skip_special_tokens=True) 
                       for label in labels[i:i+batch_size]]
        decoded_preds.extend(batch_preds)
        decoded_labels.extend(batch_labels)
    
    # Calculate accuracy
    correct = sum(1 for pred, label in zip(decoded_preds, decoded_labels) if pred.strip() == label.strip())
    accuracy = correct / len(decoded_labels) if len(decoded_labels) > 0 else 0
    
    # Print just a few examples for debugging
    print("\nAnswer Generation Examples (Prediction, Reference):")
    for i in range(min(3, len(decoded_preds))):
        print(f"  {decoded_preds[i]} | {decoded_labels[i]}")
    
    # Return metrics
    return {
        "accuracy": accuracy,
        "exact_match_ratio": correct / len(decoded_labels) if len(decoded_labels) > 0 else 0,
    }

# 9. Define training arguments for answer generation
answer_training_args = TrainingArguments(
    output_dir="./results/answer_generation",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    logging_dir="./logs/answer_generation",
    logging_steps=100,
    report_to=["tensorboard"],
    # Memory optimizations
    fp16=True if torch.cuda.is_available() else False,  # Use mixed precision if available
    gradient_accumulation_steps=2,  # Accumulate gradients to simulate larger batch sizes
    dataloader_num_workers=1,  # Parallelize data loading
    dataloader_pin_memory=True,  # Speed up data transfer to GPU
    # Ensure progress bar is shown
    disable_tqdm=False,
)

# 10. Define data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="max_length",
    max_length=512,
    pad_to_multiple_of=8  # Optimize for tensor operations
)

# 11. Train answer generation model
print("Initializing answer generation trainer...")
answer_trainer = Trainer(
    model=model,
    args=answer_training_args,
    train_dataset=answer_train_tokenized,
    eval_dataset=answer_val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_answer_metrics,
)

print("Training answer generation model...")
answer_trainer.train()

# Get validation results
print("Evaluating answer generation model...")
answer_eval_results = answer_trainer.evaluate()
print("\n" + "="*50)
print("ANSWER GENERATION EVALUATION RESULTS")
print("="*50)
print(f"Accuracy: {answer_eval_results['eval_accuracy']:.4f}")
print(f"Exact Match: {answer_eval_results['eval_exact_match_ratio']:.4f}")
print("="*50 + "\n")

# Save model
print("Saving final model...")
answer_trainer.save_model("./results/answer_generation/final_model")
print("Model saved successfully!")

# 12. Memory-efficient inference function for demonstration
def generate_answer(context, question, model, tokenizer):
    # Set model to evaluation mode
    model.eval()
    
    # Get the device
    device = next(model.parameters()).device
    
    # Prepare input
    answer_input = f"generate answer: {context} question: {question}"
    answer_input_ids = tokenizer(answer_input, return_tensors="pt").input_ids.to(device)
    
    print("Generating answer letter...")
    with torch.no_grad():  # Disable gradient calculation to save memory
        answer_outputs = model.generate(
            answer_input_ids, 
            max_length=10,
            num_beams=4,
            early_stopping=True
        )
    answer_letter = tokenizer.decode(answer_outputs[0], skip_special_tokens=True).strip()
    print(f"Generated answer letter: {answer_letter}")
    
    # Clean up GPU memory
    del answer_outputs, answer_input_ids
    torch.cuda.empty_cache()
    
    return answer_letter

# Optional: Demonstrate the trained model
print("\n" + "="*50)
print("DEMONSTRATION")
print("="*50)
try:
    # Sample context and question
    context = "The Python programming language was created by Guido van Rossum and first released in 1991. It emphasizes code readability with its notable use of significant whitespace."
    question = "Who created the Python programming language?"
    
    # Generate answer
    answer_letter = generate_answer(context, question, model, tokenizer)
    print(f"Context: {context}")
    print(f"Question: {question}")
    print(f"Generated Answer: {answer_letter}")
except Exception as e:
    print(f"Error during demonstration: {e}")

print("\nTraining of Answer Generation model complete!")

In [None]:
import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import torch
from torch.utils.data import Dataset
import random
import gc
from tqdm.auto import tqdm

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.empty_cache()

# Define constants
MODEL_NAME = "google/flan-t5-small"
MAX_INPUT_LENGTH = 384  # Reduced from 512
MAX_TARGET_LENGTH = 48  # Reduced from 64
BATCH_SIZE = 4  # Reduced from 8
GRADIENT_ACCUMULATION_STEPS = 4  # Increases effective batch size without increasing memory
LEARNING_RATE = 3e-4
NUM_EPOCHS = 3
SAMPLE_RATIO = 0.1  # Reduced from 0.15
MIXED_PRECISION = "fp16"  # Use mixed precision training

# Function to free memory
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Load the dataset in streaming mode to reduce memory usage
print("Loading dataset...")
dataset = load_dataset("zjsd/RedStone-QA-mcq", split="train", streaming=True)

# Convert to regular dataset for sampling
# Buffer a smaller amount to not load everything in memory
dataset = dataset.take(int(1.66e6 * SAMPLE_RATIO * 1.2))  # Buffer slightly more than needed
dataset = list(dataset)
random.shuffle(dataset)
dataset = dataset[:int(1.66e6 * SAMPLE_RATIO)]

# Split into train and validation
train_val_split = 0.9
train_size = int(len(dataset) * train_val_split)

# Create train and validation datasets
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size:train_size + min(2000, len(dataset) - train_size)]  # Limit validation set

print(f"Total sampled examples: {len(dataset)}")
print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

# Free memory
del dataset
free_memory()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# Define a processing function for batch tokenization
def preprocess_function(examples):
    batch_size = len(examples)
    input_texts = []
    target_texts = []
    
    for i in range(batch_size):
        context = examples[i]['text']
        question = examples[i]['question']
        answer = examples[i]['answer'].replace("Answer:", "").strip()
        
        # Format the input (context + question)
        input_text = f"Context: {context} Question: {question}"
        input_texts.append(input_text)
        target_texts.append(answer)
    
    # Tokenize inputs
    model_inputs = tokenizer(
        input_texts,
        max_length=MAX_INPUT_LENGTH,
        padding="max_length",
        truncation=True
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_texts,
            max_length=MAX_TARGET_LENGTH,
            padding="max_length",
            truncation=True
        )
    
    model_inputs["labels"] = labels["input_ids"]
    
    # Replace padding token id with -100 so it's ignored in the loss
    for i in range(len(model_inputs["labels"])):
        model_inputs["labels"][i] = [
            -100 if token == tokenizer.pad_token_id else token 
            for token in model_inputs["labels"][i]
        ]
    
    return model_inputs

# Process data in batches to save memory
print("Processing training data...")
batch_size = 512
train_processed = []
for i in tqdm(range(0, len(train_dataset), batch_size)):
    batch = train_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        train_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

print("Processing validation data...")
val_processed = []
for i in tqdm(range(0, len(val_dataset), batch_size)):
    batch = val_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        val_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

# Free memory
del train_dataset, val_dataset
free_memory()

# Create PyTorch datasets
class MemoryEfficientDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.examples[idx]["input_ids"]),
            "attention_mask": torch.tensor(self.examples[idx]["attention_mask"]),
            "labels": torch.tensor(self.examples[idx]["labels"])
        }

train_dataset = MemoryEfficientDataset(train_processed)
val_dataset = MemoryEfficientDataset(val_processed)

# Free memory
del train_processed, val_processed
free_memory()

# Load the model with 8-bit precision to save memory
print("Loading model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./answer_generation_model",
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,  # Keep only the best model
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir="./logs",
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.05,
    predict_with_generate=False,  # Save memory during evaluation
    fp16=MIXED_PRECISION == "fp16",  # Mixed precision training
    optim="adamw_torch",
    report_to="none",  # Disable W&B reporting
    disable_tqdm=False,  # Enable tqdm progress bar
)

# Set up the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True),
)

# Train the model
print("Training model...")
trainer.train()

# Save the model
print("Saving model...")
model.save_pretrained("./answer_generation_model_final")
tokenizer.save_pretrained("./answer_generation_model_final")

# Test on a few examples
print("Testing on some examples...")
model.eval()
test_examples = [
    val_dataset[i] for i in range(min(3, len(val_dataset)))
]

for example in test_examples:
    input_ids = example["input_ids"].unsqueeze(0).to(model.device)
    attention_mask = example["attention_mask"].unsqueeze(0).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=MAX_TARGET_LENGTH,
            num_beams=2,  # Reduced for memory efficiency
            early_stopping=True
        )
    
    input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    predicted_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Input: {input_text}")
    print(f"Predicted Answer: {predicted_answer}")
    print("=" * 50)

print("Training complete!")

In [None]:
import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import torch
from torch.utils.data import Dataset
import random
import gc
from tqdm.auto import tqdm
import re

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.empty_cache()

# Define constants
MODEL_NAME = "google/flan-t5-small"
MAX_INPUT_LENGTH = 384
MAX_TARGET_LENGTH = 96
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 3e-4
NUM_EPOCHS = 3
SAMPLE_RATIO = 0.1
MIXED_PRECISION = "fp16"

# Function to free memory
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Extract actual answer content and distractors
def extract_answers_and_distractors(row):
    correct_answer = row['answer'].replace("Answer:", "").strip()
    choices = row['choices']
    
    # Identify correct answer letter
    correct_letter = None
    
    # If answer is just a letter (A, B, C, D)
    if re.match(r'^[A-D]$', correct_answer):
        correct_letter = correct_answer
    else:
        # If answer starts with a letter followed by a period/colon (A., A:)
        match = re.match(r'^([A-D])[:\.]?\s*', correct_answer)
        if match:
            correct_letter = match.group(1)
    
    # If we still don't have a letter, try to infer from content
    if not correct_letter:
        correct_content = correct_answer
        for choice in choices:
            parts = re.match(r'^([A-D])[.\s]+(.*)', choice)
            if parts:
                letter, content = parts.groups()
                # If the content matches (approximately) the correct answer
                if content.strip().lower() in correct_answer.lower() or correct_answer.lower() in content.strip().lower():
                    correct_letter = letter
                    correct_content = content.strip()
                    break
    
    # If we still don't have a letter, take a guess based on position
    if not correct_letter and "A" in correct_answer:
        correct_letter = "A"
    elif not correct_letter and "B" in correct_answer:
        correct_letter = "B"
    elif not correct_letter and "C" in correct_answer:
        correct_letter = "C"
    elif not correct_letter and "D" in correct_answer:
        correct_letter = "D"
    elif not correct_letter and len(choices) > 0:
        correct_letter = "A"  # Default to first option
    
    # Extract correct answer content and distractors
    correct_content = None
    distractors = []
    
    for choice in choices:
        parts = re.match(r'^([A-D])[.\s]+(.*)', choice)
        if parts:
            letter, content = parts.groups()
            content = content.strip()
            if letter == correct_letter:
                correct_content = content
            elif content:  # Only add non-empty distractors
                distractors.append(content)
    
    # If we didn't extract content from choices, use the original answer
    if not correct_content:
        if re.match(r'^[A-D][:\.]?\s+(.+)$', correct_answer):
            correct_content = re.match(r'^[A-D][:\.]?\s+(.+)$', correct_answer).group(1)
        else:
            correct_content = correct_answer
    
    return correct_content, distractors

# Load the dataset in streaming mode to reduce memory usage
print("Loading dataset...")
dataset = load_dataset("zjsd/RedStone-QA-mcq", split="train", streaming=True)

# Convert to regular dataset for sampling
dataset = dataset.take(int(1.66e6 * SAMPLE_RATIO * 1.2))
dataset = list(dataset)
random.shuffle(dataset)
dataset = dataset[:int(1.66e6 * SAMPLE_RATIO)]

# Process dataset to extract actual answers and distractors
print("Extracting answers and distractors...")
processed_dataset = []
for row in tqdm(dataset):
    actual_answer, distractors = extract_answers_and_distractors(row)
    if actual_answer and distractors:  # Only keep examples with valid answers and distractors
        row['actual_answer'] = actual_answer
        row['distractors_list'] = distractors
        processed_dataset.append(row)

# Free memory
del dataset
free_memory()

# Split into train and validation
train_val_split = 0.9
train_size = int(len(processed_dataset) * train_val_split)

# Create train and validation datasets
train_dataset = processed_dataset[:train_size]
val_dataset = processed_dataset[train_size:train_size + min(2000, len(processed_dataset) - train_size)]

print(f"Total processed examples: {len(processed_dataset)}")
print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

# Free memory
del processed_dataset
free_memory()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# Define a processing function for batch tokenization
def preprocess_function(examples):
    batch_size = len(examples)
    input_texts = []
    target_texts = []
    
    for i in range(batch_size):
        context = examples[i]['text']
        question = examples[i]['question']
        answer = examples[i]['actual_answer']
        
        # Format distractors as a single string with separators
        distractors = " | ".join(examples[i]['distractors_list'])
        
        # Format the input (context + question + correct answer)
        input_text = f"Context: {context} Question: {question} Correct Answer: {answer}"
        input_texts.append(input_text)
        target_texts.append(distractors)
    
    # Tokenize inputs
    model_inputs = tokenizer(
        input_texts,
        max_length=MAX_INPUT_LENGTH,
        padding="max_length",
        truncation=True
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_texts,
            max_length=MAX_TARGET_LENGTH,
            padding="max_length",
            truncation=True
        )
    
    model_inputs["labels"] = labels["input_ids"]
    
    # Replace padding token id with -100 so it's ignored in the loss
    for i in range(len(model_inputs["labels"])):
        model_inputs["labels"][i] = [
            -100 if token == tokenizer.pad_token_id else token 
            for token in model_inputs["labels"][i]
        ]
    
    return model_inputs

# Process data in batches to save memory
print("Processing training data...")
batch_size = 512
train_processed = []
for i in tqdm(range(0, len(train_dataset), batch_size)):
    batch = train_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        train_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

print("Processing validation data...")
val_processed = []
for i in tqdm(range(0, len(val_dataset), batch_size)):
    batch = val_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        val_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

# Free memory
del train_dataset, val_dataset
free_memory()

# Create PyTorch datasets
class MemoryEfficientDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.examples[idx]["input_ids"]),
            "attention_mask": torch.tensor(self.examples[idx]["attention_mask"]),
            "labels": torch.tensor(self.examples[idx]["labels"])
        }

train_dataset = MemoryEfficientDataset(train_processed)
val_dataset = MemoryEfficientDataset(val_processed)

# Free memory
del train_processed, val_processed
free_memory()

# Load the model
print("Loading model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./distractor_generation_model",
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir="./logs",
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.05,
    predict_with_generate=False,
    fp16=MIXED_PRECISION == "fp16",
    optim="adamw_torch",
    report_to="none",
    disable_tqdm=False,
)

# Set up the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True),
)

# Train the model
print("Training model...")
trainer.train()

# Save the model
print("Saving model...")
model.save_pretrained("./distractor_generation_model_final")
tokenizer.save_pretrained("./distractor_generation_model_final")

# Test on a few examples
print("Testing on some examples...")
model.eval()
test_examples = [
    val_dataset[i] for i in range(min(3, len(val_dataset)))
]

for example in test_examples:
    input_ids = example["input_ids"].unsqueeze(0).to(model.device)
    attention_mask = example["attention_mask"].unsqueeze(0).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=MAX_TARGET_LENGTH,
            num_beams=2,
            early_stopping=True
        )
    
    input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    predicted_distractors = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Input: {input_text}")
    print(f"Predicted Distractors: {predicted_distractors}")
    print("=" * 50)

print("Training complete!")

In [None]:
# Add these imports at the top of your script
from huggingface_hub import login, HfApi


from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")

print("Logging in to Hugging Face Hub...")
login(token=HF_TOKEN)

# Define your Hugging Face repository name
# Format: "username/repository-name"
HF_REPO_ID = "aayeshanakarmi/distractor-generation-redstone-flant5small-2"  # Replace with your desired repo name

# Save the model to the Hub
print(f"Uploading model to Hugging Face Hub as {HF_REPO_ID}...")
model.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)
tokenizer.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)


print(f"Model successfully uploaded to Hugging Face Hub: https://huggingface.co/{HF_REPO_ID}")

# Answer Generation Model 

In [None]:
import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import torch
from torch.utils.data import Dataset
import random
import gc
from tqdm.auto import tqdm
import re

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.empty_cache()

# Define constants
MODEL_NAME = "google/flan-t5-small"
MAX_INPUT_LENGTH = 384
MAX_TARGET_LENGTH = 48
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 3e-4
NUM_EPOCHS = 3
SAMPLE_RATIO = 0.1
MIXED_PRECISION = "fp16"

# Function to free memory
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Enhanced function to extract actual answer content (not just A, B, C, D)
def extract_actual_answer(row):
    correct_answer = row['answer'].replace("Answer:", "").strip()
    choices = row['choices']
    
    # If answer is just a single letter (A, B, C, D)
    if re.match(r'^[A-D]$', correct_answer):
        letter = correct_answer
        for choice in choices:
            # Look for the choice that starts with this letter
            if choice.startswith(letter + ".") or choice.startswith(letter + " ") or choice.startswith(letter + ":"):
                # Extract everything after the letter and separator
                content = re.sub(r'^[A-D][.\s:]+', '', choice).strip()
                return content
    
    # If answer starts with letter followed by period/colon/space (e.g., "A. text", "A: text", "A text")
    match = re.match(r'^([A-D])[.\s:]+(.+)$', correct_answer)
    if match:
        # Extract the content part after the letter
        content = match.group(2).strip()
        if content:
            return content
        
        # If no content after the letter in the answer, find it in choices
        letter = match.group(1)
        for choice in choices:
            if choice.startswith(letter + ".") or choice.startswith(letter + " ") or choice.startswith(letter + ":"):
                content = re.sub(r'^[A-D][.\s:]+', '', choice).strip()
                return content
    
    # Handle case where the answer might be the full text that matches one of the choices
    for choice in choices:
        # Extract the content part of the choice (removing any leading A., B., etc.)
        choice_content = re.sub(r'^[A-D][.\s:]+', '', choice).strip()
        # If the answer matches this content exactly, return it
        if correct_answer == choice_content:
            return correct_answer
    
    # If we couldn't match it to a specific choice or extract a letter,
    # just return the original answer as a fallback
    return correct_answer

# Load the dataset in streaming mode to reduce memory usage
print("Loading dataset...")
dataset = load_dataset("zjsd/RedStone-QA-mcq", split="train", streaming=True)

# Convert to regular dataset for sampling
dataset = dataset.take(int(1.66e6 * SAMPLE_RATIO * 1.2))
dataset = list(dataset)
random.shuffle(dataset)
dataset = dataset[:int(1.66e6 * SAMPLE_RATIO)]

# Process dataset to extract actual answers
print("Extracting actual answers...")
processed_dataset = []
for row in tqdm(dataset):
    actual_answer = extract_actual_answer(row)
    if actual_answer.strip():  # Only keep examples with non-empty answers
        row['actual_answer'] = actual_answer
        processed_dataset.append(row)

# Free memory
del dataset
free_memory()

# Split into train and validation
train_val_split = 0.9
train_size = int(len(processed_dataset) * train_val_split)

# Create train and validation datasets
train_dataset = processed_dataset[:train_size]
val_dataset = processed_dataset[train_size:train_size + min(2000, len(processed_dataset) - train_size)]

print(f"Total examples with actual answers: {len(processed_dataset)}")
print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

# Free memory
del processed_dataset
free_memory()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# Define a processing function for batch tokenization
def preprocess_function(examples):
    batch_size = len(examples)
    input_texts = []
    target_texts = []
    
    for i in range(batch_size):
        context = examples[i]['text']
        question = examples[i]['question']
        answer = examples[i]['actual_answer']
        
        # Format the input (context + question)
        input_text = f"Context: {context} Question: {question}"
        input_texts.append(input_text)
        target_texts.append(answer)
    
    # Tokenize inputs
    model_inputs = tokenizer(
        input_texts,
        max_length=MAX_INPUT_LENGTH,
        padding="max_length",
        truncation=True
    )
    
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_texts,
            max_length=MAX_TARGET_LENGTH,
            padding="max_length",
            truncation=True
        )
    
    model_inputs["labels"] = labels["input_ids"]
    
    # Replace padding token id with -100 so it's ignored in the loss
    for i in range(len(model_inputs["labels"])):
        model_inputs["labels"][i] = [
            -100 if token == tokenizer.pad_token_id else token 
            for token in model_inputs["labels"][i]
        ]
    
    return model_inputs

# Process data in batches to save memory
print("Processing training data...")
batch_size = 512
train_processed = []
for i in tqdm(range(0, len(train_dataset), batch_size)):
    batch = train_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        train_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

print("Processing validation data...")
val_processed = []
for i in tqdm(range(0, len(val_dataset), batch_size)):
    batch = val_dataset[i:i+batch_size]
    processed_batch = preprocess_function(batch)
    for j in range(len(batch)):
        val_processed.append({
            "input_ids": processed_batch["input_ids"][j],
            "attention_mask": processed_batch["attention_mask"][j],
            "labels": processed_batch["labels"][j]
        })

# Free memory
del train_dataset, val_dataset
free_memory()

# Create PyTorch datasets
class MemoryEfficientDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.examples[idx]["input_ids"]),
            "attention_mask": torch.tensor(self.examples[idx]["attention_mask"]),
            "labels": torch.tensor(self.examples[idx]["labels"])
        }

train_dataset = MemoryEfficientDataset(train_processed)
val_dataset = MemoryEfficientDataset(val_processed)

# Free memory
del train_processed, val_processed
free_memory()

# Load the model
print("Loading model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./answer_generation_model",
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir="./logs",
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.05,
    predict_with_generate=False,
    fp16=MIXED_PRECISION == "fp16",
    optim="adamw_torch",
    report_to="none",
    disable_tqdm=False,
)

# Set up the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True),
)

# Train the model
print("Training model...")
trainer.train()

# Save the model
print("Saving model...")
model.save_pretrained("./answer_generation_model_final")
tokenizer.save_pretrained("./answer_generation_model_final")

# Test on a few examples
print("Testing on some examples...")
model.eval()
test_examples = [
    val_dataset[i] for i in range(min(3, len(val_dataset)))
]

for example in test_examples:
    input_ids = example["input_ids"].unsqueeze(0).to(model.device)
    attention_mask = example["attention_mask"].unsqueeze(0).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=MAX_TARGET_LENGTH,
            num_beams=2,
            early_stopping=True
        )
    
    input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    predicted_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Input: {input_text}")
    print(f"Predicted Answer: {predicted_answer}")
    print("=" * 50)

print("Training complete!")

In [None]:
# Add these imports at the top of your script
from huggingface_hub import login, HfApi


from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")

print("Logging in to Hugging Face Hub...")
login(token=HF_TOKEN)

# Define your Hugging Face repository name
# Format: "username/repository-name"
HF_REPO_ID = "aayeshanakarmi/mcq-answer-generation-redstone-flant5small-2"  # Replace with your desired repo name

# Save the model to the Hub
print(f"Uploading model to Hugging Face Hub as {HF_REPO_ID}...")
model.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)
tokenizer.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)


print(f"Model successfully uploaded to Hugging Face Hub: https://huggingface.co/{HF_REPO_ID}")

# Inferencing

In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import random

class MCQGenerator:
    def __init__(self):
        # Load the answer generation model
        print("Loading answer generation model...")
        self.answer_model = AutoModelForSeq2SeqLM.from_pretrained("aayeshanakarmi/mcq-answer-generation-redstone-flant5small-2")
        self.answer_tokenizer = AutoTokenizer.from_pretrained("aayeshanakarmi/mcq-answer-generation-redstone-flant5small-2")
        
        # Load the distractor generation model
        print("Loading distractor generation model...")
        self.distractor_model = AutoModelForSeq2SeqLM.from_pretrained("aayeshanakarmi/distractor-generation-redstone-flant5small-2")
        self.distractor_tokenizer = AutoTokenizer.from_pretrained("aayeshanakarmi/distractor-generation-redstone-flant5small-2")
        
        # Set models to evaluation mode
        self.answer_model.eval()
        self.distractor_model.eval()
        
        # Move models to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.answer_model.to(self.device)
        self.distractor_model.to(self.device)
        
        print(f"Models loaded and running on {self.device}")
    
    def generate_answer(self, context, question, max_length=48):
        """Generate the correct answer based on context and question"""
        # Format input as expected by the answer model
        input_text = f"Context: {context} Question: {question}"
        
        # Tokenize input
        input_ids = self.answer_tokenizer(input_text, return_tensors="pt", max_length=384, 
                                          truncation=True).input_ids.to(self.device)
        
        # Generate answer
        with torch.no_grad():
            output_ids = self.answer_model.generate(
                input_ids=input_ids,
                max_length=max_length,
                num_beams=4,
                early_stopping=True
            )
        
        # Decode the output
        answer = self.answer_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return answer.strip()
    
    def generate_distractors(self, context, question, answer, max_length=96, num_distractors=3):
        """Generate distractors based on context, question and correct answer"""
        # Format input as expected by the distractor model
        input_text = f"Context: {context} Question: {question} Correct Answer: {answer}"
        
        # Tokenize input
        input_ids = self.distractor_tokenizer(input_text, return_tensors="pt", max_length=384, 
                                             truncation=True).input_ids.to(self.device)
        
        # Generate distractors
        with torch.no_grad():
            output_ids = self.distractor_model.generate(
                input_ids=input_ids,
                max_length=max_length,
                num_beams=4,
                do_sample=True,
                temperature=0.8,
                early_stopping=True
            )
        
        # Decode the output
        distractors_text = self.distractor_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Split the distractors - they were joined with | during training
        distractors = [d.strip() for d in distractors_text.split("|") if d.strip()]
        
        # Ensure we have unique distractors that are different from the answer
        unique_distractors = []
        for d in distractors:
            if d.lower() != answer.lower() and d not in unique_distractors:
                unique_distractors.append(d)
                if len(unique_distractors) >= num_distractors:
                    break
        
        # If we don't have enough distractors, generate some basic alternatives
        while len(unique_distractors) < num_distractors:
            fallback = f"Alternative answer {len(unique_distractors) + 1}"
            if fallback not in unique_distractors and fallback.lower() != answer.lower():
                unique_distractors.append(fallback)
        
        return unique_distractors
    
    def generate_mcq(self, context, question, shuffle=True):
        """Generate a complete MCQ with context, question, and options"""
        # Generate the correct answer
        correct_answer = self.generate_answer(context, question)
        
        # Generate distractors
        distractors = self.generate_distractors(context, question, correct_answer)
        
        # Create options (correct answer + distractors)
        options = [correct_answer] + distractors
        
        # Shuffle options if requested
        correct_idx = 0
        if shuffle:
            correct_idx = random.randint(0, len(options) - 1)
            shuffled_options = options.copy()
            # Move correct answer to the randomly chosen position
            shuffled_options[0], shuffled_options[correct_idx] = shuffled_options[correct_idx], shuffled_options[0]
            options = shuffled_options
        
        # Format as MCQ
        option_labels = ["A", "B", "C", "D"][:len(options)]
        formatted_options = [f"{label}. {option}" for label, option in zip(option_labels, options)]
        
        # Identify the correct answer label
        correct_label = option_labels[correct_idx]
        
        # Create the complete MCQ
        mcq = {
            "context": context,
            "question": question,
            "options": formatted_options,
            "correct_answer": f"{correct_label}. {options[correct_idx]}",
            "correct_label": correct_label
        }
        
        return mcq
    
    def format_mcq_as_text(self, mcq):
        """Format the MCQ as a readable text"""
        text = f"Context:\n{mcq['context']}\n\n"
        text += f"Question:\n{mcq['question']}\n\n"
        text += "Options:\n"
        for option in mcq['options']:
            text += f"{option}\n"
        text += f"\nCorrect Answer: {mcq['correct_label']}"
        return text


# Example usage
def generate_example_mcq():
    # Example context and question
    context = """The water cycle, also known as the hydrologic cycle, describes the continuous movement of water on, above, and below the surface of the Earth. Water can change states among liquid, vapor, and ice at various places in the water cycle. Although the balance of water on Earth remains fairly constant over time, individual water molecules can come and go. The water moves from one reservoir to another, such as from river to ocean, or from the ocean to the atmosphere, by the physical processes of evaporation, condensation, precipitation, infiltration, surface runoff, and subsurface flow. In doing so, the water goes through different forms: liquid, solid (ice) and vapor."""
    
    question = "What causes water to move from the ocean to the atmosphere in the water cycle?"
    
    # Initialize the MCQ generator
    generator = MCQGenerator()
    
    # Generate MCQ
    mcq = generator.generate_mcq(context, question)
    
    # Format as text
    formatted_mcq = generator.format_mcq_as_text(mcq)
    
    return formatted_mcq

# Run the example if the script is executed directly
if __name__ == "__main__":
    print(generate_example_mcq())

Loading answer generation model...


config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Loading distractor generation model...


config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Models loaded and running on cpu
Context:
The water cycle, also known as the hydrologic cycle, describes the continuous movement of water on, above, and below the surface of the Earth. Water can change states among liquid, vapor, and ice at various places in the water cycle. Although the balance of water on Earth remains fairly constant over time, individual water molecules can come and go. The water moves from one reservoir to another, such as from river to ocean, or from the ocean to the atmosphere, by the physical processes of evaporation, condensation, precipitation, infiltration, surface runoff, and subsurface flow. In doing so, the water goes through different forms: liquid, solid (ice) and vapor.

Question:
What causes water to move from the ocean to the atmosphere in the water cycle?

Options:
A. Alternative answer 3
B. In doing so, the water goes through different forms: liquid, solid (ice) and vapor.
C. Alternative answer 2
D. evaporation, condensation, precipitation, infiltr

In [2]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import random

class MCQGenerator:
    def __init__(self):
        # Load answer generation model
        print("Loading answer generation model...")
        self.answer_model_id = "aayeshanakarmi/mcq-answer-generation-redstone-flant5small-2"
        self.answer_model = AutoModelForSeq2SeqLM.from_pretrained(self.answer_model_id)
        self.answer_tokenizer = AutoTokenizer.from_pretrained(self.answer_model_id)
        
        # Load distractor generation model
        print("Loading distractor generation model...")
        self.distractor_model_id = "aayeshanakarmi/distractor-generation-redstone-flant5small-2"
        self.distractor_model = AutoModelForSeq2SeqLM.from_pretrained(self.distractor_model_id)
        self.distractor_tokenizer = AutoTokenizer.from_pretrained(self.distractor_model_id)
        
        # Move models to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.answer_model.to(self.device)
        self.distractor_model.to(self.device)
        
        # Set models to evaluation mode
        self.answer_model.eval()
        self.distractor_model.eval()
        
        print(f"Models loaded and ready on {self.device}")
    
    def generate_answer(self, context, question):
        """Generate the correct answer using the answer generation model"""
        input_text = f"Context: {context} Question: {question}"
        
        # Tokenize input for answer model
        inputs = self.answer_tokenizer(
            input_text, 
            return_tensors="pt", 
            max_length=384, 
            truncation=True
        ).to(self.device)
        
        # Generate answer
        with torch.no_grad():
            outputs = self.answer_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=48,
                num_beams=3,
                early_stopping=True
            )
        
        # Decode the answer
        answer = self.answer_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return answer.strip()
    
    def generate_distractors(self, context, question, answer):
        """Generate distractors using the distractor generation model"""
        input_text = f"Context: {context} Question: {question} Correct Answer: {answer}"
        
        # Tokenize input for distractor model
        inputs = self.distractor_tokenizer(
            input_text, 
            return_tensors="pt", 
            max_length=384, 
            truncation=True
        ).to(self.device)
        
        # Generate distractors
        with torch.no_grad():
            outputs = self.distractor_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=96,
                num_beams=3,
                early_stopping=True
            )
        
        # Decode the distractors
        distractors_text = self.distractor_tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Split distractors by the separator used during training
        distractors = [d.strip() for d in distractors_text.split("|")]
        
        # Filter out empty distractors and remove duplicates
        distractors = [d for d in distractors if d and d.lower() != answer.lower()]
        
        # Return up to 3 distractors (to make 4 total options with the correct answer)
        return distractors[:3]
    
    def format_mcq(self, question, answer, distractors):
        """Format the MCQ with options A, B, C, D"""
        # Combine answer and distractors
        all_options = [answer] + distractors
        
        # Shuffle options to randomize the correct answer position
        random.shuffle(all_options)
        
        # Find the position of the correct answer
        correct_option_idx = all_options.index(answer)
        correct_option_letter = chr(65 + correct_option_idx)  # A, B, C, or D
        
        # Format the question and options
        formatted_question = question + "\n"
        for i, option in enumerate(all_options):
            option_letter = chr(65 + i)  # A, B, C, or D
            formatted_question += f"{option_letter}. {option}\n"
        
        return {
            "question": question,
            "options": all_options,
            "formatted_mcq": formatted_question,
            "correct_answer": answer,
            "correct_option": correct_option_letter
        }
    
    def generate_mcq(self, context, question):
        """Generate a complete MCQ from context and question"""
        # Generate the correct answer
        answer = self.generate_answer(context, question)
        print(f"Generated answer: {answer}")
        
        # Generate distractors
        distractors = self.generate_distractors(context, question, answer)
        print(f"Generated distractors: {distractors}")
        
        # Format as MCQ
        mcq = self.format_mcq(question, answer, distractors)
        
        return mcq


# Example usage
if __name__ == "__main__":
    # Create MCQ generator
    generator = MCQGenerator()
    
    # Example context and question
    context = """Neural networks are a set of algorithms, modeled loosely after the human brain, that are designed to recognize patterns. They interpret sensory data through a kind of machine perception, labeling or clustering raw input. The patterns they recognize are numerical, contained in vectors, into which all real-world data, be it images, sound, text or time series, must be translated. Neural networks help us cluster and classify data. You can think of them as a clustering and classification layer on top of the data you store and manage. They help to group unlabeled data according to similarities among the example inputs, and they classify data when they have a labeled dataset to train on."""
    
    question = "What is the primary function of neural networks?"
    
    # Generate MCQ
    mcq = generator.generate_mcq(context, question)
    
    # Print the results
    print("\n--- Generated MCQ ---")
    print(mcq["formatted_mcq"])
    print(f"Correct answer: {mcq['correct_option']} - {mcq['correct_answer']}")
    
    # Another example
    print("\n--- Another Example ---")
    context2 = """The water cycle, also known as the hydrologic cycle, describes the continuous movement of water on, above, and below the surface of the Earth. Water can change states among liquid, vapor, and ice at various points in the cycle. Although the balance of water on Earth remains fairly constant over time, individual water molecules can move around the globe. The water cycle involves the following processes: evaporation, transpiration, condensation, precipitation, and collection."""
    
    question2 = "What happens during the condensation phase of the water cycle?"
    
    mcq2 = generator.generate_mcq(context2, question2)
    
    print(mcq2["formatted_mcq"])
    print(f"Correct answer: {mcq2['correct_option']} - {mcq2['correct_answer']}")

Loading answer generation model...
Loading distractor generation model...
Models loaded and ready on cpu
Generated answer: Neural networks help us cluster and classify data. You can think of them as a clustering and classification layer on top of the data you store and manage.
Generated distractors: ['Neural networks help us cluster and classify data according to similarities among the example inputs, and they classify data when they have a labeled dataset to train on.']

--- Generated MCQ ---
What is the primary function of neural networks?
A. Neural networks help us cluster and classify data according to similarities among the example inputs, and they classify data when they have a labeled dataset to train on.
B. Neural networks help us cluster and classify data. You can think of them as a clustering and classification layer on top of the data you store and manage.

Correct answer: B - Neural networks help us cluster and classify data. You can think of them as a clustering and classi

In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import os
# from dotenv import load_dotenv

# # Load environment variables from .env file (if needed)
# load_dotenv()

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")


class SimpleMCQGenerator:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = "aayeshanakarmi/Flan-T5-Small-Test1MCQ-Quizard-5"
        self.auth_token = HF_TOKEN
        
        print(f"Using device: {self.device}")
        
        # Initialize tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name, use_auth_token=self.auth_token)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name, use_auth_token=self.auth_token).to(self.device)

    def generate_mcq(self, question, context, max_length=128):
        """
        Generate a multiple-choice question based on the provided question and context.
        
        Args:
            question (str): The question to generate MCQs for
            context (str): The context/passage related to the question
            max_length (int): Maximum length of the generated output
            
        Returns:
            dict: Dictionary containing the correct answer and distractors
        """
        input_text = f"Generate a multiple-choice question (MCQ) based on the given question and context. Ensure the output includes one correct answer and three unique incorrect answers (distractors). Question: {question} Context: {context}"
        
        inputs = self.tokenizer(
            input_text, 
            return_tensors="pt", 
            max_length=512, 
            truncation=True
        ).to(self.device)
        
        outputs = self.model.generate(
            inputs["input_ids"], 
            max_length=max_length, 
            num_beams=5, 
            early_stopping=True
        )
        
        raw_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        try:
            parts = raw_output.split('|')
            correct_answer = parts[0].replace('Correct Answer:', '').strip()
            distractors = [opt.strip() for opt in parts[1].replace('Incorrect Answers:', '').split(',')]
            
            # Ensure we have exactly 3 unique distractors
            unique_distractors = list(set(distractors))[:3]
            while len(unique_distractors) < 3:
                unique_distractors.append(f"Option {len(unique_distractors) + 1}")
            
            return {
                'question': question,
                'context': context,
                'correct_answer': correct_answer,
                'distractors': unique_distractors,
                'raw_output': raw_output  # Include raw output for debugging
            }
        except Exception as e:
            print(f"Error parsing MCQ output: {e}")
            return {
                'question': question,
                'context': context,
                'correct_answer': raw_output,
                'distractors': [f"Option {i+1}" for i in range(3)],
                'raw_output': raw_output,
                'error': str(e)
            }

# Example usage in a Jupyter notebook cell:

# Create the generator
mcq_gen = SimpleMCQGenerator()

# Example context and question
context = "The mitochondria is the powerhouse of the cell. It produces energy in the form of ATP through cellular respiration."
question = "What is the function of mitochondria in a cell?"

# Generate MCQ
mcq_result = mcq_gen.generate_mcq(question, context)

# Display the result
print(f"Question: {question}")
print(f"Correct Answer: {mcq_result['correct_answer']}")
print("Distractors:")
for i, distractor in enumerate(mcq_result['distractors']):
    print(f"  {i+1}. {distractor}")


Using device: cpu




tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.59k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

Question: What is the function of mitochondria in a cell?
Correct Answer: Producing energy through cellular respiration
Distractors:
  1. To store energy
  2. Option 2
  3. Option 3


In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-race-Distractor")
model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-race-Distractor")

context = r"""
World number one Novak Djokovic says he is hoping for a "positive decision" to allow him 
to play at Indian Wells and the Miami Open next month. The United States has extended 
its requirement for international visitors to be vaccinated against Covid-19. Proof of vaccination 
will be required to enter the country until at least 10 April, but the Serbian has previously 
said he is unvaccinated. The 35-year-old has applied for special permission to enter the country. 
Indian Wells and the Miami Open - two of the most prestigious tournaments on the tennis calendar 
outside the Grand Slams - start on 6 and 20 March respectively. Djokovic says he will return to 
the ATP tour in Dubai next week after claiming a record-extending 10th Australian Open title 
and a record-equalling 22nd Grand Slam men's title last month.""".replace("\n", "")
question = "What is the best title for the passage?"
answer = "Djokovic's application for special permission to enter the United States"

input_text = " ".join([question, tokenizer.sep_token, answer, tokenizer.sep_token, context])
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=128)
distractors = tokenizer.decode(outputs[0], skip_special_tokens=False)
distractors = distractors.replace(tokenizer.pad_token, "").replace(tokenizer.eos_token, "")
distractors = [y.strip() for y in distractors.split(tokenizer.sep_token)]
print(distractors)

tokenizer_config.json:   0%|          | 0.00/2.35k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.23k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

['The United States has extended its requirement for international visitors to be vaccinated against Covid-19', "Djokovic's return to the ATP tour in Dubai", "Djokovic's hope for a positive decision to allow him to play at Indian Wells and the Miami Open"]


In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("potsawee/t5-large-generation-race-Distractor")
model = AutoModelForSeq2SeqLM.from_pretrained("potsawee/t5-large-generation-race-Distractor")

# Define multiple context, question, and answer tuples
data_samples = [
    (
        """World number one Novak Djokovic says he is hoping for a "positive decision" to allow him 
        to play at Indian Wells and the Miami Open next month. The United States has extended 
        its requirement for international visitors to be vaccinated against Covid-19. Proof of vaccination 
        will be required to enter the country until at least 10 April, but the Serbian has previously 
        said he is unvaccinated. The 35-year-old has applied for special permission to enter the country. 
        Indian Wells and the Miami Open - two of the most prestigious tournaments on the tennis calendar 
        outside the Grand Slams - start on 6 and 20 March respectively. Djokovic says he will return to 
        the ATP tour in Dubai next week after claiming a record-extending 10th Australian Open title 
        and a record-equalling 22nd Grand Slam men's title last month.""".replace("\n", ""),
        "What is the best title for the passage?",
        "Djokovic's application for special permission to enter the United States"
    ),
    (
        """NASA has announced plans to launch a new space telescope that will search for habitable 
        exoplanets. The telescope, named the Habitable Worlds Observatory, will focus on detecting 
        Earth-like planets in the habitable zones of their stars. Scientists hope this mission will 
        provide valuable insights into the possibility of life beyond our solar system. The telescope 
        is expected to launch in the early 2040s and will build upon the findings of previous missions 
        like the James Webb Space Telescope.""".replace("\n", ""),
        "What is NASA's new telescope called?",
        "Habitable Worlds Observatory"
    )
]

# Process each data sample
for context, question, answer in data_samples:
    input_text = " ".join([question, tokenizer.sep_token, answer, tokenizer.sep_token, context])
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=128)
    distractors = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Clean up distractors
    distractors = distractors.replace(tokenizer.pad_token, "").replace(tokenizer.eos_token, "")
    distractors = [y.strip() for y in distractors.split(tokenizer.sep_token)]

    # Print results
    print(f"Question: {question}")
    print(f"Answer: {answer}")
    print(f"Distractors: {distractors}\n")


Question: What is the best title for the passage?
Answer: Djokovic's application for special permission to enter the United States
Distractors: ['The United States has extended its requirement for international visitors to be vaccinated against Covid-19', "Djokovic's return to the ATP tour in Dubai", "Djokovic's hope for a positive decision to allow him to play at Indian Wells and the Miami Open"]

Question: What is NASA's new telescope called?
Answer: Habitable Worlds Observatory
Distractors: ['Habitable Planets Observatory', 'Habitable Planets', 'Habitable Planets']



In [6]:
import os
import sys
import random
import time
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.notebook import tqdm

# === Configuration ===
# Model parameters
model_name = "google/flan-t5-base"
save_dir = "./model_weights/"
run_name = f"flan-t5-base-race-distractor-generation"

# Training parameters
learning_rate = 3e-5
batch_size = 4
num_workers = 2
num_epochs = 5
max_length = 512
valid_steps = 2000
warmup_steps = 500
weight_decay = 0.01
gradient_accumulation_steps = 4  # For effective batch size of 16

# Create save directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Check for GPU availability
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {torch_device}")

# Load tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_length)
if tokenizer.sep_token is None:
    tokenizer.add_special_tokens({"sep_token": "<sep>"})
    print(f"Added special token: {tokenizer.sep_token}")
else:
    print(f"Using existing sep_token: {tokenizer.sep_token}")


# === Dataset Class ===
class RaceDistractorGeneration(Dataset):
    def __init__(self, tokenizer, data_split, shuffle_distractors=False):
        """
        Dataset class for distractor generation using RACE dataset
        - input: question <sep> answer <sep> article
        - output: distractor1 <sep> distractor2 <sep> distractor3
        """
        # Load RACE dataset
        self.data = load_dataset("race", "all", split=data_split)
        self.tokenizer = tokenizer
        self.shuffle_distractors = shuffle_distractors
        
        # Create mapping for answer options
        self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
        self.all_labels = [0, 1, 2, 3]
        
        print(f"RACE Distractor Generation Dataset Initialized - {data_split} split: {len(self.data)} examples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            example = self.data[idx]
            
            # Extract data from example
            question = example["question"]
            context = example["article"]
            options = example["options"]
            
            # Find correct answer
            label_example = example["answer"]
            answer_i = self.label_mapping[label_example]
            answer = options[answer_i]
            
            # Get distractors (incorrect options)
            distractor_ids = [i for i in self.all_labels if i != answer_i]
            if self.shuffle_distractors:
                random.shuffle(distractor_ids)
            distractors = [options[i] for i in distractor_ids]
            
            # Format input and output
            input_text = f"{question} {tokenizer.sep_token} {answer} {tokenizer.sep_token} {context}"
            output_text = f"{distractors[0]} {tokenizer.sep_token} {distractors[1]} {tokenizer.sep_token} {distractors[2]}"
            
            return {'input': input_text, 'output': output_text}
        except Exception as e:
            print(f"Error processing example {idx}: {e}")
            # Return empty strings as fallback
            return {'input': '', 'output': ''}


# === Collate Function ===
def collate_fn(batch):
    """
    Collate function for DataLoader
    """
    # Filter out None values
    batch = [item for item in batch if item['input'] != '' and item['output'] != '']
    
    if len(batch) == 0:
        return None
    
    # Extract input and output sequences
    input_sequences = [item['input'] for item in batch]
    output_sequences = [item['output'] for item in batch]
    
    # Tokenize inputs
    input_encoding = tokenizer(
        input_sequences,
        padding="longest",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    
    # Tokenize outputs
    target_encoding = tokenizer(
        output_sequences,
        padding="longest",
        max_length=max_length,
        truncation=True,
    )
    
    # Convert to tensor and replace padding token ids with -100 (ignored by loss)
    labels = torch.tensor(target_encoding.input_ids)
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': input_encoding.input_ids,
        'attention_mask': input_encoding.attention_mask,
        'labels': labels,
    }


# === Training Function ===
def train():
    # Create datasets
    train_data = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="train",
        shuffle_distractors=True,
    )
    
    valid_data = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="validation",
        shuffle_distractors=False,
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        collate_fn=collate_fn,
    )
    
    valid_loader = DataLoader(
        valid_data,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    # Load model
    print(f"Loading {model_name}...")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    # Resize token embeddings if we added new tokens
    model.resize_token_embeddings(len(tokenizer))
    
    # Move model to device
    if torch_device == "cuda":
        model.to(torch_device)
    
    # Print model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Setup optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    # Calculate total training steps
    total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
    
    # Setup learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training loop
    print("Starting training...")
    model.train()
    best_val_loss = float('inf')
    early_stop_counter = 0
    training_step = 0
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        steps_in_epoch = 0
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        progress_bar = tqdm(train_loader, desc=f"Training")
        
        for batch in progress_bar:
            if batch is None:
                continue
            
            # Move batch to device
            input_ids = batch['input_ids'].to(torch_device)
            attention_mask = batch['attention_mask'].to(torch_device)
            labels = batch['labels'].to(torch_device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Calculate loss
            loss = outputs.loss / gradient_accumulation_steps
            epoch_loss += loss.item() * gradient_accumulation_steps
            steps_in_epoch += 1
            
            # Backward pass
            loss.backward()
            
            # Update weights every gradient_accumulation_steps
            if (training_step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item() * gradient_accumulation_steps:.4f}",
                'step': training_step
            })
            
            training_step += 1
            
            # Validation
            if training_step % valid_steps == 0:
                # Save checkpoint
                checkpoint_path = f"{save_dir}/{run_name}-step{training_step}.pt"
                torch.save({
                    'epoch': epoch,
                    'step': training_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, checkpoint_path)
                print(f"\nSaved checkpoint to {checkpoint_path}")
                
                # Validate
                model.eval()
                val_loss = validate(model, valid_loader)
                print(f"Validation Loss: {val_loss:.4f}")
                model.train()
                
                # Early stopping
                if val_loss < best_val_loss:
                    early_stop_counter = 0
                    best_val_loss = val_loss
                    best_model_path = f"{save_dir}/{run_name}-best.pt"
                    torch.save({
                        'epoch': epoch,
                        'step': training_step,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'validation_loss': val_loss,
                    }, best_model_path)
                    print(f"New best model saved at {best_model_path}")
                else:
                    early_stop_counter += 1
                    print(f"Validation loss did not improve. Early stopping counter: {early_stop_counter}/3")
                    if early_stop_counter >= 3:
                        print("Early stopping triggered.")
                        return
        
        # Epoch completed
        avg_epoch_loss = epoch_loss / steps_in_epoch
        print(f"Epoch {epoch + 1} completed. Average loss: {avg_epoch_loss:.4f}")
    
    # Training completed
    final_model_path = f"{save_dir}/{run_name}-final.pt"
    torch.save({
        'epoch': num_epochs,
        'step': training_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, final_model_path)
    print(f"Training completed. Final model saved at {final_model_path}")
    
    total_time = time.time() - start_time
    print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")


# === Validation Function ===
def validate(model, dataloader):
    """
    Validate the model on the validation dataset
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating")
        for batch in progress_bar:
            if batch is None:
                continue
            
            # Move batch to device
            input_ids = batch['input_ids'].to(torch_device)
            attention_mask = batch['attention_mask'].to(torch_device)
            labels = batch['labels'].to(torch_device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Calculate loss
            loss = outputs.loss.item()
            total_loss += loss
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f"{loss:.4f}"})
    
    # Calculate average loss
    avg_loss = total_loss / max(num_batches, 1)
    return avg_loss


# === Generation Function ===
def generate_distractors(model, text, num_return_sequences=1):
    """
    Generate distractors for a given question, answer, and context
    """
    model.eval()
    
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True
    )
    
    # Move to device
    inputs = {k: v.to(torch_device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=4,
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            early_stopping=True
        )
    
    # Decode outputs
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    return generated_texts


# === Main Function ===
def main():
    """
    Main function to run the training process
    """
    print("Starting finetuning process...")
    print(f"Model: {model_name}")
    print(f"Device: {torch_device}")
    print(f"Learning rate: {learning_rate}")
    print(f"Batch size: {batch_size} (effective: {batch_size * gradient_accumulation_steps})")
    print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Number of epochs: {num_epochs}")
    print(f"Maximum sequence length: {max_length}")
    print(f"Validation steps: {valid_steps}")
    print(f"Warmup steps: {warmup_steps}")
    
    # Start training
    train()
    
    print("Training complete!")


# === Testing Function ===
def test_model(model_path):
    """
    Test the trained model on custom examples
    """
    # Load model
    model = AutoModelForSeq2SeqLM

Using device: cuda
Added special token: <sep>


In [7]:
# Kaggle Notebook Setup
# FLAN-T5-Base Finetuning on RACE Dataset for Distractor Generation

# Check if GPU is available
!nvidia-smi

# Install required packages (if not already installed)
!pip install -q transformers datasets tqdm

# Set up wandb for experiment tracking (optional)
# !pip install wandb
# import wandb
# wandb.login()

# Import necessary libraries
import os
import sys
import random
import time
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.notebook import tqdm

# Set up the output directory
os.makedirs("./model_weights", exist_ok=True)

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Now copy and paste the code from the previous artifacts here

# After pasting all the code, run this to execute it
if __name__ == "__main__":
    print("="*50)
    print("FLAN-T5-Base Finetuning on RACE Dataset")
    print("="*50)
    
    print("\n=== FLAN-T5 Distractor Generation Tools ===")
    print("1. Train the model")
    print("2. Evaluate on test set")
    print("3. Interactive distractor generation")
    print("4. Exit")
    

    main()
    evaluate_on_test_set()
    interactive_distractor_generation()

Fri Mar  7 03:54:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   76C    P0             35W /   70W |    1183MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

Training:   0%|          | 0/21967 [00:00<?, ?it/s]


Saved checkpoint to ./model_weights//flan-t5-base-race-distractor-generation-step2000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.9140
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt

Saved checkpoint to ./model_weights//flan-t5-base-race-distractor-generation-step4000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.8089
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt

Saved checkpoint to ./model_weights//flan-t5-base-race-distractor-generation-step6000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7825
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt

Saved checkpoint to ./model_weights//flan-t5-base-race-distractor-generation-step8000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7585
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt

Saved checkpoint to ./model_weights//flan-t5-base-race-distractor-generation-step10000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7456
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt


RuntimeError: [enforce fail at inline_container.cc:603] . unexpected pos 1060654976 vs 1060654864

# Optimized training

In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
from tqdm.auto import tqdm
import gc
import logging
from datetime import datetime

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Configuration
class Config:
    # Model settings
    model_name = "google/flan-t5-base"
    max_length = 512
    
    # Training settings
    batch_size = 4
    gradient_accumulation_steps = 4  # Effectively gives batch size of 16
    learning_rate = 2e-5
    weight_decay = 0.01
    num_epochs = 3
    warmup_ratio = 0.1
    
    # Mixed precision
    use_mixed_precision = True
    
    # Checkpointing settings
    save_steps = 500
    eval_steps = 500
    
    # Output directories
    model_dir = "./model_weights"
    run_name = f"flan-t5-base-race-distractor-{datetime.now().strftime('%Y%m%d_%H%M')}"
    
    # System settings
    seed = 42
    num_workers = 2
    
config = Config()

# Create model directory if it doesn't exist
os.makedirs(config.model_dir, exist_ok=True)

# Set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)

# Dataset class
class RaceDistractorGeneration(Dataset):
    def __init__(self, tokenizer, data_split, shuffle_distractors=False):
        """
        Task: Generate distractors based on the question, correct answer, and article
        
        Args:
            tokenizer: Tokenizer for encoding inputs and outputs
            data_split: 'train', 'validation', or 'test'
            shuffle_distractors: Whether to shuffle the order of distractors during training
        """
        logger.info(f"Loading RACE dataset for {data_split} split")
        data = load_dataset("race", "all", split=data_split)
        self.data = data
        self.tokenizer = tokenizer
        self.separator = " <sep> "
        self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
        self.all_labels = [0, 1, 2, 3]
        self.shuffle_distractors = shuffle_distractors
        logger.info(f"RaceDistractorGeneration loaded with {len(self.data)} examples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            example = self.data[idx]
            question = example["question"]
            context = example["article"]
            options = example["options"]
            label_example = example["answer"]
            answer_i = self.label_mapping[label_example]
            answer = options[answer_i]
            
            # Get distractor indices (all options except the correct answer)
            distractor_ids = [i for i in self.all_labels if i != answer_i]
            if self.shuffle_distractors:
                random.shuffle(distractor_ids)
            
            # Get the actual distractors
            distractors = [options[i] for i in distractor_ids]
            
            # Prepare input and output
            input_text = f"Generate distractors: Question: {question}{self.separator}Answer: {answer}{self.separator}Context: {context}"
            output_text = f"{distractors[0]}{self.separator}{distractors[1]}{self.separator}{distractors[2]}"
            
            return {
                "input": input_text,
                "output": output_text
            }
        except Exception as e:
            logger.warning(f"Error processing item {idx}: {e}")
            # Return a simple example as fallback
            return {
                "input": "Error processing item",
                "output": "Error"
            }

# Collate function for data loader
def collate_fn(batch):
    input_texts = [item["input"] for item in batch]
    output_texts = [item["output"] for item in batch]
    
    # Tokenize inputs
    input_encodings = tokenizer(
        input_texts,
        padding="longest",
        max_length=config.max_length,
        truncation=True,
        return_tensors="pt",
    )
    
    # Tokenize outputs
    output_encodings = tokenizer(
        output_texts,
        padding="longest",
        max_length=config.max_length,
        truncation=True,
        return_tensors="pt",
    )
    
    # Replace padding token id's with -100 so they're ignored in loss calculation
    labels = output_encodings.input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        "input_ids": input_encodings.input_ids,
        "attention_mask": input_encodings.attention_mask,
        "labels": labels,
    }

# Smart model checkpoint saver to handle disk space issues
class CheckpointSaver:
    def __init__(self, save_dir, run_name, max_checkpoints=3):
        self.save_dir = save_dir
        self.run_name = run_name
        self.max_checkpoints = max_checkpoints
        self.checkpoint_paths = []
        
    def save_checkpoint(self, model, optimizer, scheduler, scaler, epoch, step, loss):
        # Create checkpoint
        checkpoint = {
            'epoch': epoch,
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'scaler_state_dict': scaler.state_dict() if scaler else None,
            'loss': loss,
        }
        
        # Create filename
        checkpoint_path = f"{self.save_dir}/{self.run_name}-step{step}.pt"
        
        # Try to save checkpoint
        try:
            # Use a temporary file first
            temp_path = f"{checkpoint_path}.tmp"
            torch.save(checkpoint, temp_path)
            
            # Rename to final path if successful
            if os.path.exists(temp_path):
                os.rename(temp_path, checkpoint_path)
                logger.info(f"Checkpoint saved: {checkpoint_path}")
                
                # Add to list of checkpoints
                self.checkpoint_paths.append(checkpoint_path)
                
                # Remove old checkpoints if exceeding max_checkpoints
                if len(self.checkpoint_paths) > self.max_checkpoints:
                    old_checkpoint = self.checkpoint_paths.pop(0)
                    if os.path.exists(old_checkpoint):
                        os.remove(old_checkpoint)
                        logger.info(f"Removed old checkpoint: {old_checkpoint}")
                        
                return True
            else:
                logger.error(f"Failed to create temporary checkpoint file")
                return False
                
        except Exception as e:
            logger.error(f"Error saving checkpoint: {e}")
            
            # Try saving a smaller checkpoint as fallback
            try:
                smaller_checkpoint = {
                    'epoch': epoch,
                    'step': step,
                    'model_state_dict': model.state_dict(),
                }
                fallback_path = f"{self.save_dir}/{self.run_name}-step{step}-fallback.pt"
                torch.save(smaller_checkpoint, fallback_path)
                logger.info(f"Fallback checkpoint saved: {fallback_path}")
                return True
            except Exception as e2:
                logger.error(f"Failed to save fallback checkpoint: {e2}")
                return False

# Function to free up memory
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# Add separator token if not in vocabulary
if "<sep>" not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"sep_token": "<sep>"})

# Training function
def train():
    # Check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Load datasets
    train_dataset = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="train",
        shuffle_distractors=True
    )
    
    valid_dataset = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="validation",
        shuffle_distractors=False
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
        drop_last=True
    )
    
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    # Load model
    logger.info(f"Loading model: {config.model_name}")
    model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
    
    # Resize embeddings if needed due to added tokens
    model.resize_token_embeddings(len(tokenizer))
    
    # Move model to device
    model.to(device)
    
    # Initialize optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    
    # Calculate total training steps
    total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    warmup_steps = int(total_steps * config.warmup_ratio)
    
    # Initialize scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler() if config.use_mixed_precision else None
    
    # Initialize checkpoint saver
    checkpoint_saver = CheckpointSaver(config.model_dir, config.run_name)
    
    # Training loop
    logger.info("Starting training")
    best_valid_loss = float("inf")
    no_improvement_count = 0
    training_step = 0
    
    for epoch in range(config.num_epochs):
        model.train()
        epoch_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        optimizer.zero_grad()
        
        for step, batch in enumerate(progress_bar):
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass with mixed precision if enabled
            if config.use_mixed_precision:
                with autocast():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    loss = outputs.loss / config.gradient_accumulation_steps
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                
                if (step + 1) % config.gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    # Update training step
                    training_step += 1
            else:
                # Standard training without mixed precision
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss / config.gradient_accumulation_steps
                loss.backward()
                
                if (step + 1) % config.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    # Update training step
                    training_step += 1
            
            # Update progress bar
            epoch_loss += loss.item() * config.gradient_accumulation_steps
            progress_bar.set_postfix({"loss": f"{loss.item() * config.gradient_accumulation_steps:.4f}"})
            
            # Evaluate and save checkpoint at specified intervals
            if training_step > 0 and training_step % config.eval_steps == 0:
                valid_loss = evaluate(model, valid_loader, device)
                logger.info(f"Step {training_step} - Validation Loss: {valid_loss:.4f}")
                
                model.train()
                
                # Save checkpoint
                if training_step % config.save_steps == 0:
                    success = checkpoint_saver.save_checkpoint(
                        model, optimizer, scheduler, scaler, epoch, training_step, valid_loss
                    )
                    if not success:
                        logger.warning("Failed to save checkpoint, continuing training")
                
                # Early stopping check
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    no_improvement_count = 0
                    
                    # Save best model
                    logger.info(f"New best validation loss: {best_valid_loss:.4f}")
                    try:
                        best_model_path = f"{config.model_dir}/{config.run_name}-best.pt"
                        torch.save(model.state_dict(), best_model_path)
                        logger.info(f"Best model saved: {best_model_path}")
                    except Exception as e:
                        logger.error(f"Failed to save best model: {e}")
                else:
                    no_improvement_count += 1
                    logger.info(f"No improvement for {no_improvement_count} evaluations")
                    
                    if no_improvement_count >= 3:
                        logger.info("Early stopping triggered")
                        return
        
        # End of epoch
        avg_epoch_loss = epoch_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{config.num_epochs} - Average Loss: {avg_epoch_loss:.4f}")
        
        # Free memory at the end of each epoch
        free_memory()

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [3]:
# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            total_loss += outputs.loss.item()
    
    return total_loss / len(dataloader)

# Main function to run the entire training process
def main():
    logger.info("Initializing training process")
    
    # Set seed for reproducibility
    set_seed(config.seed)
    
    try:
        # Start training
        logger.info(f"Starting training with run name: {config.run_name}")
        train()
        
        # Save final model
        try:
            final_model_path = f"{config.model_dir}/{config.run_name}-final.pt"
            model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
            model.resize_token_embeddings(len(tokenizer))
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            
            # Try to load best model if it exists
            best_model_path = f"{config.model_dir}/{config.run_name}-best.pt"
            if os.path.exists(best_model_path):
                logger.info(f"Loading best model from {best_model_path}")
                model.load_state_dict(torch.load(best_model_path, map_location=device))
            
            # Save the model
            model.save_pretrained(final_model_path)
            tokenizer.save_pretrained(final_model_path)
            logger.info(f"Final model saved to {final_model_path}")
        except Exception as e:
            logger.error(f"Error saving final model: {e}")
    
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
    finally:
        # Clean up resources
        free_memory()
        logger.info("Training process completed")

# Run main function if script is executed directly
if __name__ == "__main__":
    main()

README.md:   0%|          | 0.00/11.0k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.08M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/37.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/2.05M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4934 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/87866 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4887 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  scaler = GradScaler() if config.use_mixed_precision else None


Epoch 1/3:   0%|          | 0/21966 [00:00<?, ?it/s]

  with autocast():
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  model.load_state_dict(torch.load(best_model_path, map_location=device))


# TRIAL 101

In [8]:
import os
import sys
import random
import time
import shutil
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.notebook import tqdm

# === Configuration ===
# Model parameters
model_name = "google/flan-t5-base"
save_dir = "./model_weights/"
run_name = f"flan-t5-base-race-distractor-generation"

# Training parameters
learning_rate = 3e-5
batch_size = 4
num_workers = 2
num_epochs = 5
max_length = 512
valid_steps = 2000
warmup_steps = 500
weight_decay = 0.01
gradient_accumulation_steps = 4  # For effective batch size of 16

# Checkpoint handling
save_full_checkpoint = False  # Set to True only for final model
checkpoint_every_epoch = False  # Set to False to only save at validation steps

# Create save directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Check for GPU availability
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {torch_device}")

# Load tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_length)
if tokenizer.sep_token is None:
    tokenizer.add_special_tokens({"sep_token": "<sep>"})
    print(f"Added special token: {tokenizer.sep_token}")
else:
    print(f"Using existing sep_token: {tokenizer.sep_token}")


# === Dataset Class ===
class RaceDistractorGeneration(Dataset):
    def __init__(self, tokenizer, data_split, shuffle_distractors=False):
        """
        Dataset class for distractor generation using RACE dataset
        - input: question <sep> answer <sep> article
        - output: distractor1 <sep> distractor2 <sep> distractor3
        """
        # Load RACE dataset
        self.data = load_dataset("race", "all", split=data_split)
        self.tokenizer = tokenizer
        self.shuffle_distractors = shuffle_distractors
        
        # Create mapping for answer options
        self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
        self.all_labels = [0, 1, 2, 3]
        
        print(f"RACE Distractor Generation Dataset Initialized - {data_split} split: {len(self.data)} examples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            example = self.data[idx]
            
            # Extract data from example
            question = example["question"]
            context = example["article"]
            options = example["options"]
            
            # Find correct answer
            label_example = example["answer"]
            answer_i = self.label_mapping[label_example]
            answer = options[answer_i]
            
            # Get distractors (incorrect options)
            distractor_ids = [i for i in self.all_labels if i != answer_i]
            if self.shuffle_distractors:
                random.shuffle(distractor_ids)
            distractors = [options[i] for i in distractor_ids]
            
            # Format input and output
            input_text = f"{question} {tokenizer.sep_token} {answer} {tokenizer.sep_token} {context}"
            output_text = f"{distractors[0]} {tokenizer.sep_token} {distractors[1]} {tokenizer.sep_token} {distractors[2]}"
            
            return {'input': input_text, 'output': output_text}
        except Exception as e:
            print(f"Error processing example {idx}: {e}")
            # Return empty strings as fallback
            return {'input': '', 'output': ''}


# === Collate Function ===
def collate_fn(batch):
    """
    Collate function for DataLoader
    """
    # Filter out None values
    batch = [item for item in batch if item['input'] != '' and item['output'] != '']
    
    if len(batch) == 0:
        return None
    
    # Extract input and output sequences
    input_sequences = [item['input'] for item in batch]
    output_sequences = [item['output'] for item in batch]
    
    # Tokenize inputs
    input_encoding = tokenizer(
        input_sequences,
        padding="longest",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    
    # Tokenize outputs
    target_encoding = tokenizer(
        output_sequences,
        padding="longest",
        max_length=max_length,
        truncation=True,
    )
    
    # Convert to tensor and replace padding token ids with -100 (ignored by loss)
    labels = torch.tensor(target_encoding.input_ids)
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': input_encoding.input_ids,
        'attention_mask': input_encoding.attention_mask,
        'labels': labels,
    }


# === Training Function ===
def train():
    # Check disk space
    total, used, free = shutil.disk_usage("/")
    print(f"Disk space: Total={total // (2**30)}GB, Used={used // (2**30)}GB, Free={free // (2**30)}GB")
    if free < 5 * (2**30):  # Less than 5GB free
        print("WARNING: Low disk space. Consider cleaning up or using smaller checkpoints.")
    
    # Create datasets
    train_data = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="train",
        shuffle_distractors=True,
    )
    
    valid_data = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="validation",
        shuffle_distractors=False,
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        collate_fn=collate_fn,
    )
    
    valid_loader = DataLoader(
        valid_data,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    # Load model
    print(f"Loading {model_name}...")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    # Resize token embeddings if we added new tokens
    model.resize_token_embeddings(len(tokenizer))
    
    # Move model to device
    if torch_device == "cuda":
        model.to(torch_device)
    
    # Print model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Setup optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    
    # Calculate total training steps
    total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
    
    # Setup learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training loop
    print("Starting training...")
    model.train()
    best_val_loss = float('inf')
    early_stop_counter = 0
    training_step = 0
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        steps_in_epoch = 0
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        progress_bar = tqdm(train_loader, desc=f"Training")
        
        for batch in progress_bar:
            if batch is None:
                continue
            
            # Move batch to device
            input_ids = batch['input_ids'].to(torch_device)
            attention_mask = batch['attention_mask'].to(torch_device)
            labels = batch['labels'].to(torch_device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Calculate loss
            loss = outputs.loss / gradient_accumulation_steps
            epoch_loss += loss.item() * gradient_accumulation_steps
            steps_in_epoch += 1
            
            # Backward pass
            loss.backward()
            
            # Update weights every gradient_accumulation_steps
            if (training_step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item() * gradient_accumulation_steps:.4f}",
                'step': training_step
            })
            
            training_step += 1
            
            # Validation
            if training_step % valid_steps == 0:
                # Clean up previous checkpoint files (except best)
                for f in os.listdir(save_dir):
                    if f.startswith(run_name) and f.endswith(".pt") and "best" not in f:
                        try:
                            os.remove(os.path.join(save_dir, f))
                            print(f"Removed old checkpoint: {f}")
                        except:
                            pass
                
                # Save checkpoint more efficiently
                checkpoint_path = f"{save_dir}/{run_name}-step{training_step}.pt"
                try:
                    # Save only model weights instead of full state
                    torch.save(model.state_dict(), checkpoint_path)
                    print(f"\nSaved model weights to {checkpoint_path}")
                except RuntimeError as e:
                    print(f"ERROR saving checkpoint: {e}")
                    print("Continuing without saving checkpoint...")
                
                # Validate
                model.eval()
                val_loss = validate(model, valid_loader)
                print(f"Validation Loss: {val_loss:.4f}")
                model.train()
                
                # Early stopping
                if val_loss < best_val_loss:
                    early_stop_counter = 0
                    best_val_loss = val_loss
                    best_model_path = f"{save_dir}/{run_name}-best.pt"
                    try:
                        # Save only model weights for best model
                        torch.save(model.state_dict(), best_model_path)
                        print(f"New best model saved at {best_model_path}")
                    except RuntimeError as e:
                        print(f"ERROR saving best model: {e}")
                else:
                    early_stop_counter += 1
                    print(f"Validation loss did not improve. Early stopping counter: {early_stop_counter}/3")
                    if early_stop_counter >= 3:
                        print("Early stopping triggered.")
                        return
        
        # Epoch completed
        avg_epoch_loss = epoch_loss / steps_in_epoch
        print(f"Epoch {epoch + 1} completed. Average loss: {avg_epoch_loss:.4f}")
        
        # Save epoch checkpoint if enabled
        if checkpoint_every_epoch:
            epoch_model_path = f"{save_dir}/{run_name}-epoch{epoch+1}.pt"
            try:
                torch.save(model.state_dict(), epoch_model_path)
                print(f"Epoch checkpoint saved at {epoch_model_path}")
            except RuntimeError as e:
                print(f"ERROR saving epoch checkpoint: {e}")
    
    # Training completed
    final_model_path = f"{save_dir}/{run_name}-final.pt"
    try:
        if save_full_checkpoint:
            torch.save({
                'epoch': num_epochs,
                'step': training_step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, final_model_path)
        else:
            torch.save(model.state_dict(), final_model_path)
        print(f"Training completed. Final model saved at {final_model_path}")
    except RuntimeError as e:
        print(f"ERROR saving final model: {e}")
    
    total_time = time.time() - start_time
    print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")


# === Validation Function ===
def validate(model, dataloader):
    """
    Validate the model on the validation dataset
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating")
        for batch in progress_bar:
            if batch is None:
                continue
            
            # Move batch to device
            input_ids = batch['input_ids'].to(torch_device)
            attention_mask = batch['attention_mask'].to(torch_device)
            labels = batch['labels'].to(torch_device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Calculate loss
            loss = outputs.loss.item()
            total_loss += loss
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f"{loss:.4f}"})
    
    # Calculate average loss
    avg_loss = total_loss / max(num_batches, 1)
    return avg_loss


# === Generation Function ===
def generate_distractors(model, text, num_return_sequences=1):
    """
    Generate distractors for a given question, answer, and context
    """
    model.eval()
    
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True
    )
    
    # Move to device
    inputs = {k: v.to(torch_device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=4,
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            early_stopping=True
        )
    
    # Decode outputs
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    return generated_texts


# === Evaluation Function ===
def evaluate_on_test_set():
    print("Evaluating on test set...")
    # Load best model
    best_model_path = f"{save_dir}/{run_name}-best.pt"
    if not os.path.exists(best_model_path):
        print(f"Best model not found at {best_model_path}")
        return
    
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        model.load_state_dict(torch.load(best_model_path))
        model.to(torch_device)
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    # Create test dataset
    test_data = RaceDistractorGeneration(
        tokenizer=tokenizer,
        data_split="test",
        shuffle_distractors=False,
    )
    
    test_loader = DataLoader(
        test_data,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    # Evaluate
    test_loss = validate(model, test_loader)
    print(f"Test Loss: {test_loss:.4f}")
    
    # Generate examples
    print("\nExample generations:")
    for i in range(min(3, len(test_data))):
        example = test_data[i]
        input_text = example['input']
        reference = example['output']
        
        generated = generate_distractors(model, input_text)
        
        print(f"\nInput: {input_text}")
        print(f"Reference: {reference}")
        print(f"Generated: {generated[0]}")


# === Interactive Generation Function ===
def interactive_distractor_generation():
    print("\nInteractive distractor generation mode")
    print("Loading best model...")
    
    best_model_path = f"{save_dir}/{run_name}-best.pt"
    if not os.path.exists(best_model_path):
        print(f"Best model not found at {best_model_path}")
        return
    
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        model.load_state_dict(torch.load(best_model_path))
        model.to(torch_device)
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    print("\nEnter 'q' to quit")
    while True:
        question = input("\nEnter question: ")
        if question.lower() == 'q':
            break
            
        answer = input("Enter correct answer: ")
        if answer.lower() == 'q':
            break
            
        context = input("Enter context: ")
        if context.lower() == 'q':
            break
        
        input_text = f"{question} {tokenizer.sep_token} {answer} {tokenizer.sep_token} {context}"
        generated = generate_distractors(model, input_text, num_return_sequences=3)
        
        print("\nGenerated distractors:")
        for i, distractors in enumerate(generated):
            print(f"Option {i+1}: {distractors}")


# === Main Function ===
def main():
    """
    Main function to run the training process
    """
    print("Starting finetuning process...")
    print(f"Model: {model_name}")
    print(f"Device: {torch_device}")
    print(f"Learning rate: {learning_rate}")
    print(f"Batch size: {batch_size} (effective: {batch_size * gradient_accumulation_steps})")
    print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Number of epochs: {num_epochs}")
    print(f"Maximum sequence length: {max_length}")
    print(f"Validation steps: {valid_steps}")
    print(f"Warmup steps: {warmup_steps}")
    
    # Start training
    train()
    
    print("Training complete!")

Using device: cuda
Added special token: <sep>


In [None]:
if __name__ == "__main__":
    try:
        print("="*50)
        print("FLAN-T5-Base Finetuning on RACE Dataset")
        print("="*50)
        
        print("\n=== FLAN-T5 Distractor Generation Tools ===")
        main()
        evaluate_on_test_set()
        interactive_distractor_generation()


    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()

FLAN-T5-Base Finetuning on RACE Dataset

=== FLAN-T5 Distractor Generation Tools ===
Starting finetuning process...
Model: google/flan-t5-base
Device: cuda
Learning rate: 3e-05
Batch size: 4 (effective: 16)
Gradient accumulation steps: 4
Number of epochs: 5
Maximum sequence length: 512
Validation steps: 2000
Warmup steps: 500
Disk space: Total=8062GB, Used=6171GB, Free=1891GB
RACE Distractor Generation Dataset Initialized - train split: 87866 examples
RACE Distractor Generation Dataset Initialized - validation split: 4887 examples
Loading google/flan-t5-base...
Total parameters: 247,536,384
Trainable parameters: 247,536,384
Starting training...

Epoch 1/5


Training:   0%|          | 0/21967 [00:00<?, ?it/s]

Removed old checkpoint: flan-t5-base-race-distractor-generation-step2000.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step6000.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step12000.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step10000.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step8000.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step4000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step2000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.9140
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step2000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step4000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.8089
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step4000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step6000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7825
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step6000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step8000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7585
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step8000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step10000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7456
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step10000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step12000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7299
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step12000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step14000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7256
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step14000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step16000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7177
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step16000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step18000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7139
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step18000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step20000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7050
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Epoch 1 completed. Average loss: 2.0825

Epoch 2/5


Training:   0%|          | 0/21967 [00:00<?, ?it/s]

Removed old checkpoint: flan-t5-base-race-distractor-generation-step20000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step22000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.7015
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step22000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step24000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6989
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step24000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step26000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6975
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step26000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step28000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6939
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step28000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step30000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6920
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step30000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step32000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6884
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step32000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step34000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6862
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step34000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step36000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6829
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
Removed old checkpoint: flan-t5-base-race-distractor-generation-step36000.pt

Saved model weights to ./model_weights//flan-t5-base-race-distractor-generation-step38000.pt


Validating:   0%|          | 0/1222 [00:00<?, ?it/s]

Validation Loss: 1.6804
New best model saved at ./model_weights//flan-t5-base-race-distractor-generation-best.pt
