# Fine-tuning FLAN-T5-small for Mythology QA with PEFT on Low VRAM

This notebook demonstrates how to fine-tune `google/flan-t5-small` for a question-answering task using your mythology dataset, specifically optimized for low VRAM environments (like 4GB) using PEFT (LoRA) and 8-bit quantization. It includes handling for train, validation, and test sets.

## 1. Setup and Imports

In [None]:
%pip install transformers datasets peft bitsandbytes accelerate torch tensorboard -q -U

In [None]:
import os
import torch
import json
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import numpy as np

## 2. Configuration
Adjust these parameters based on your setup and dataset.

In [None]:
# --- Model Configuration ---
MODEL_NAME = "google/flan-t5-base"

SINGLE_DATA_FILE = "/content/temp.jsonl"

# --- Output Directories ---
OUTPUT_DIR = "./flan-t5-base-mythology-results" # Checkpoints and logs
PEFT_ADAPTER_DIR = "./flan-t5-base-mythology-peft-adapter" # Final adapter

# --- Training Hyperparameters (CRUCIAL TUNING FOR 4GB VRAM) ---
LEARNING_RATE = 2e-4
BATCH_SIZE = 1 # Keep 1 or 2 for 4GB VRAM
GRAD_ACCUMULATION_STEPS = 16 # Increase this (e.g., 16, 32, 64) to compensate for small BATCH_SIZE
NUM_EPOCHS = 3 # Adjust as needed based on validation performance
MAX_SEQ_LENGTH = 512 # Max token length for input/output. Reduce if OOM errors occur.

# --- PEFT/LoRA Configuration ---
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q", "v"] # For FLAN-T5 Query/Value projection layers

# --- Quantization Configuration ---
USE_8BIT = True # Set to True for 8-bit, False for 4-bit
USE_4BIT = False # Set to True for 4-bit

if USE_8BIT:
    bnb_config = BitsAndBytesConfig(load_in_8bit=True)
elif USE_4BIT:
     bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
else:
    bnb_config = None

print(f"Using Quantization: {'8-bit' if USE_8BIT else ('4-bit' if USE_4BIT else 'None')}")

## 3. Load Dataset
This section loads the dataset. It assumes you have `train`, `validation`, and `test` splits available (e.g., as separate files in `DATASET_PATH`).

In [None]:
print(f"Attempting to load single file {SINGLE_DATA_FILE} and split...")
full_dataset = load_dataset("json", data_files=SINGLE_DATA_FILE, split='train')
# Split: 80% train, 20% validation
train_testvalid = full_dataset.train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
    'train': train_testvalid['train'],
    'validation': train_testvalid['test'],
})
print("Dataset loaded from single file and split:")
print(dataset)

## 4. Preprocessing & Filtering
Format the input, tokenize, AND filter out examples exceeding max length.

In [None]:
# Load tokenizer (if not already loaded)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
def preprocess_and_check_length(examples):
    """
    Formats input/output, tokenizes to check length,
    and tokenizes again with padding/truncation for model input.
    """
    # --- Format Inputs and Targets ---
    inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
    targets = [str(ans) for ans in examples['answer']] # Ensure answers are strings

    # --- Tokenize WITHOUT truncation/padding to check actual lengths ---
    # This gives us the true token count before any modification
    input_tokens = tokenizer(inputs)
    target_tokens = tokenizer(targets)
    examples['input_length'] = [len(ids) for ids in input_tokens['input_ids']]
    examples['target_length'] = [len(ids) for ids in target_tokens['input_ids']]

    # --- Tokenize WITH padding/truncation for model ---
    # This prepares the data that will actually be fed into the model
    model_inputs = tokenizer(inputs, max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    labels = tokenizer(targets, max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)

    # --- Prepare Labels for Loss Calculation ---
    # Set -100 for padding tokens in labels so they are ignored in loss.
    processed_labels = []
    for label_ids in labels["input_ids"]:
         processed_labels.append([label_id if label_id != tokenizer.pad_token_id else -100 for label_id in label_ids])

    # Add the processed model inputs and labels to the examples dictionary
    examples['input_ids'] = model_inputs['input_ids']
    examples['attention_mask'] = model_inputs['attention_mask']
    examples['labels'] = processed_labels

    return examples

In [None]:
print("Applying preprocessing and length checking...")
processed_datasets = dataset.map(
    preprocess_and_check_length,
    batched=True,
    
)
print("Preprocessing and length checking complete.")
print("Dataset structure after length check:")
print(processed_datasets)

### Filtering Step
Now, remove examples where the original input or target length exceeds the limit.

In [None]:
def filter_long_examples(example):
    """Returns True if both input and target lengths are within limits."""
    return example['input_length'] <= MAX_SEQ_LENGTH and example['target_length'] <= MAX_SEQ_LENGTH

In [None]:
print("Filtering out examples exceeding max length...")
original_sizes = {split: len(processed_datasets[split]) for split in processed_datasets.keys()}

# Apply the filter function to all splits in the DatasetDict
filtered_datasets = processed_datasets.filter(filter_long_examples, batched=True) # Batched filtering is faster

filtered_sizes = {split: len(filtered_datasets[split]) for split in filtered_datasets.keys()}
print("\nFiltering complete.")

# Print comparison
print("\nDataset sizes before filtering:")
print(original_sizes)
print("Dataset sizes after filtering:")
print(filtered_sizes)


# Final dataset to be used for training
tokenized_datasets = filtered_datasets

## 5. Load Model, Apply Quantization & PEFT

In [None]:
print("Loading base model...")
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto", # Automatically distributes model across available devices (GPU prioritized)
)
print("Base model loaded.")

In [None]:
if bnb_config:
    print("Preparing model for K-bit training (if quantized)...")
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) # Enable gradient checkpointing here
    print("Model prepared for K-bit training.")
else:
     # Still enable gradient checkpointing if not quantizing but wanting memory savings
    model.gradient_checkpointing_enable()

In [None]:
print("Configuring LoRA...")
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

print("Applying LoRA adapter...")
model = get_peft_model(model, lora_config)
print("LoRA applied.")
model.print_trainable_parameters()

## 6. Configure Training Arguments

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
    gradient_checkpointing=True, # Already enabled via prepare_model_for_kbit_training or directly
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_strategy="steps",
    logging_steps=10, # Log training loss every 10 steps
    evaluation_strategy="epoch", # Evaluate performance every epoch using validation set
    save_strategy="epoch",       # Save model checkpoint every epoch
    save_total_limit=2,          # Keep only the last 2 checkpoints
    load_best_model_at_end=True, # Load the best model checkpoint (based on validation loss) at the end of training
    metric_for_best_model="eval_loss", # Use validation loss to determine the best model
    greater_is_better=False,     # Lower validation loss is better
    fp16=True,                   # Enable mixed precision training (can speed up and save some memory)
    optim="adamw_bnb_8bit" if USE_8BIT or USE_4BIT else "adamw_torch", # Memory efficient optimizer if quantized
    report_to="tensorboard",     # Log metrics for TensorBoard visualization
)

## 7. Initialize Trainer

In [None]:
# Data Collator for Seq2Seq
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100 # Pad labels with -100 to ignore in loss calculation
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"], # Provide validation dataset
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] # Optional: Stop if eval_loss doesn't improve for 3 epochs
)

## 8. Train the Model
This will take time, depending on dataset size and hardware. Monitor VRAM usage!

In [None]:
print("Starting training...")
train_result = trainer.train()
print("Training finished.")

# Log training metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

## 9. Save Final PEFT Adapter
This saves only the trained adapter weights, which are much smaller than the full model.

In [None]:
print(f"Saving the final PEFT adapter to {PEFT_ADAPTER_DIR}...")
# Use save_model to save the adapter and tokenizer correctly with PEFT
trainer.save_model(PEFT_ADAPTER_DIR)

print("Adapter saved.")
print("Fine-tuning process complete.")

## 10. Loading the Saved Adapter (Example for Inference Later)

In [None]:
from peft import PeftModel, PeftConfig

print("Loading adapter for inference...")
config = PeftConfig.from_pretrained(PEFT_ADAPTER_DIR)

# Load the base model again with quantization config
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config, # Use the same quantization
    device_map="auto"
)
# Load the PEFT model
inference_model = PeftModel.from_pretrained(base_model, PEFT_ADAPTER_DIR)
inference_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
print("Inference model loaded.")

# Example Inference
context = "Your mythology context here..."
question = "Your question here..."
input_text = f"question: {question} context: {context}"
inputs = inference_tokenizer(input_text, return_tensors="pt").to(inference_model.device) # Ensure tensors are on the same device

outputs = inference_model.generate(**inputs, max_new_tokens=50)
answer = inference_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nExample Inference Answer: {answer}")