# FLAN-T5-Base LoRA Fine-tuning for News Summarization (Kaggle GPU)

This notebook uses LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning.

**LoRA Benefits**:
- Trains only ~0.1% of parameters (much faster)
- Lower memory usage
- Less prone to catastrophic forgetting
- Can be easily merged or switched

In [None]:
# Install required packages
!pip install -q torch transformers datasets rouge-score bert-score numpy tqdm accelerate sentencepiece peft

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import json
from tqdm import tqdm
import os

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## Dataset Class

In [None]:
class NewsSummarizationDataset(Dataset):
    """Dataset class for news summarization"""
    def __init__(self, texts, summaries, tokenizer, max_input_length=512, max_target_length=128):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        summary = str(self.summaries[idx])
        
        # Tokenize inputs with prompt for T5
        prompt = f"Summarize the following news article: {text}"
        inputs = self.tokenizer(
            prompt,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize targets
        targets = self.tokenizer(
            summary,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

## Load Dataset

In [None]:
def load_cnn_dailymail(split='test', num_samples=None):
    """Load CNN/DailyMail dataset"""
    print(f"Loading CNN/DailyMail {split} dataset...")
    dataset = load_dataset("cnn_dailymail", "3.0.0", split=split)
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    texts = [item['article'] for item in dataset]
    summaries = [item['highlights'] for item in dataset]
    print(f"Loaded {len(texts)} samples")
    return texts, summaries

# Load datasets
print("Loading datasets...")
train_texts, train_summaries = load_cnn_dailymail('train', num_samples=1000)
val_texts, val_summaries = load_cnn_dailymail('validation', num_samples=100)
test_texts, test_summaries = load_cnn_dailymail('test', num_samples=100)

print(f"\nDataset sizes:")
print(f"  Training: {len(train_texts)}")
print(f"  Validation: {len(val_texts)}")
print(f"  Test: {len(test_texts)}")

## Initialize Model with LoRA

In [None]:
# Initialize base model and tokenizer
model_name = "google/flan-t5-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Loading {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name)
base_model = T5ForConditionalGeneration.from_pretrained(model_name)

# Configure LoRA
lora_config = LoraConfig(
    r=8,  # Rank of LoRA matrices
    lora_alpha=32,  # Scaling factor
    target_modules=["q", "v"],  # Apply LoRA to query and value projections
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM  # Sequence-to-sequence task
)

# Get PEFT model (LoRA)
model = get_peft_model(base_model, lora_config)
model = model.to(device)

# Print trainable parameters
model.print_trainable_parameters()
print(f"\nModel loaded on {device}")

## Prepare Datasets

In [None]:
# Create datasets
print("Creating datasets...")
train_dataset = NewsSummarizationDataset(train_texts, train_summaries, tokenizer)
val_dataset = NewsSummarizationDataset(val_texts, val_summaries, tokenizer)
print("Datasets created!")

## LoRA Fine-tuning Configuration

In [None]:
# Training arguments
output_dir = "./flan_t5_base_lora_finetuned"

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=8,  # Can use larger batch size with LoRA
    per_device_eval_batch_size=8,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir=f'{output_dir}/logs',
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=2,
    prediction_loss_only=True,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

print("LoRA Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  LoRA rank: {lora_config.r}")
print(f"  Mixed precision (FP16): {training_args.fp16}")
print(f"  Output directory: {output_dir}")

## Start LoRA Fine-tuning

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("Starting LoRA fine-tuning...")
print("This will be much faster than full fine-tuning!")
print("-" * 80)

# Train
trainer.train()

print("\nLoRA fine-tuning completed!")

## Save Model

In [None]:
# Save the LoRA adapters
print("Saving LoRA adapters...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"LoRA adapters saved to {output_dir}")

# Also save to Kaggle output for download
if os.path.exists('/kaggle/working'):
    kaggle_output_dir = "/kaggle/working/flan_t5_base_lora_finetuned"
    model.save_pretrained(kaggle_output_dir)
    tokenizer.save_pretrained(kaggle_output_dir)
    print(f"LoRA adapters also saved to {kaggle_output_dir} (downloadable from Kaggle)")

## Merge LoRA and Evaluate

In [None]:
# Merge LoRA weights into base model for evaluation
print("Merging LoRA adapters into base model...")
model = model.merge_and_unload()
print("Merge complete!")

## Evaluation on Test Set

In [None]:
def generate_summary(model, tokenizer, text, max_length=128, min_length=30):
    """Generate summary for a single text"""
    prompt = f"Summarize the following news article: {text}"
    inputs = tokenizer(
        prompt,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            min_length=min_length,
            num_beams=4,
            length_penalty=2.0,
            early_stopping=True
        )
    
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

# Generate summaries for test set
print("Generating summaries on test set...")
generated_summaries = []
reference_summaries = test_summaries[:10]  # Evaluate on 10 samples
test_texts_subset = test_texts[:10]

for text in tqdm(test_texts_subset):
    summary = generate_summary(model, tokenizer, text)
    generated_summaries.append(summary)

print(f"Generated {len(generated_summaries)} summaries")

## Calculate Metrics

In [None]:
# Calculate ROUGE scores
print("Calculating ROUGE scores...")
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

for gen_sum, ref_sum in zip(generated_summaries, reference_summaries):
    scores = scorer.score(ref_sum, gen_sum)
    rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
    rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
    rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)

# Calculate BERTScore
print("Calculating BERTScore...")
P, R, F1 = bert_score(generated_summaries, reference_summaries, lang='en', verbose=True)

# Compile results
results = {
    'rouge1': {'f1': np.mean(rouge_scores['rouge1'])},
    'rouge2': {'f1': np.mean(rouge_scores['rouge2'])},
    'rougeL': {'f1': np.mean(rouge_scores['rougeL'])},
    'bertscore': {
        'precision': P.mean().item(),
        'recall': R.mean().item(),
        'f1': F1.mean().item()
    }
}

print("\n=== LoRA Fine-tuned Results ===")
print(json.dumps(results, indent=2))

## Save Results

In [None]:
# Save results to JSON file
results_data = {
    "model": "FLAN-T5-Base",
    "method": "LoRA Fine-tuned",
    "lora_config": {
        "r": lora_config.r,
        "lora_alpha": lora_config.lora_alpha,
        "target_modules": lora_config.target_modules
    },
    "results": results
}

with open('flan_t5_base_lora_results.json', 'w') as f:
    json.dump(results_data, f, indent=2)
print("\nResults saved to flan_t5_base_lora_results.json")

# Save to Kaggle output
if os.path.exists('/kaggle/working'):
    with open('/kaggle/working/flan_t5_base_lora_results.json', 'w') as f:
        json.dump(results_data, f, indent=2)
    print("Results saved to /kaggle/working/flan_t5_base_lora_results.json")

## Display Example Summaries

In [None]:
# Display example summaries
print("\n=== EXAMPLE SUMMARIES ===\n")
for i in range(min(3, len(test_texts_subset))):
    print(f"--- Example {i+1} ---")
    print(f"\nOriginal Article (first 200 chars):\n{test_texts_subset[i][:200]}...")
    print(f"\nReference Summary:\n{reference_summaries[i]}")
    print(f"\nLoRA Fine-tuned Summary:\n{generated_summaries[i]}")
    print("-" * 80)
    print()