# Fine-Tuning FLAN-T5 for Dialogue Summarization with LoRA

This notebook demonstrates how to fine-tune the `google/flan-t5-base` model for dialogue summarization using the `knkarthick/dialogsum` dataset with Parameter-Efficient Fine-Tuning (PEFT) via LoRA (Low-Rank Adaptation).

## Overview
- **Model**: google/flan-t5-base
- **Dataset**: knkarthick/dialogsum (subset: 500 train / 200 val)
- **Technique**: LoRA (Low-Rank Adaptation)
- **Task**: Seq2Seq Dialogue Summarization

## Key Parameters
| Parameter | Value |
|-----------|-------|
| LoRA Rank (r) | 16 |
| LoRA Alpha | 32 |
| Dropout | 0.05 |
| Target Modules | q, v |
| Learning Rate | 1e-3 |
| Epochs | 1 |
| Batch Size | 2 (effective 8 via grad accum) |
| Train Subset | 500 samples |

## 1. Install Required Libraries

In [1]:
!pip install transformers datasets peft evaluate rouge_score accelerate -q


[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## 2. Import Libraries

In [2]:
import torch
import evaluate
import numpy as np
import warnings
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, PeftConfig
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm





## 3. Configuration

In [3]:
MODEL_ID = "google/flan-t5-base"
DATASET_ID = "knkarthick/dialogsum"
OUTPUT_DIR = "./flan-t5-lora-dialogue-summary"

MAX_INPUT_LENGTH = 512       # reduced from 1024 — saves memory & speeds tokenization
MAX_TARGET_LENGTH = 128

LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q", "v"]

LEARNING_RATE = 1e-3
NUM_EPOCHS = 1               # reduced from 4 — single pass is enough for a quick test
BATCH_SIZE = 8
GRAD_ACCUM_STEPS = 1
TRAIN_SUBSET_SIZE = 1000

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


## 4. Load Tokenizer and Model

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32)

print(f"Model parameters: {model.num_parameters():,}")

`torch_dtype` is deprecated! Use `dtype` instead!


Model parameters: 247,577,856


## 5. Load Dataset

In [5]:
dataset = load_dataset(DATASET_ID)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})


In [6]:
sample = dataset["train"][0]
print(f"Dialogue:\n{sample['dialogue']}\n")
print(f"Summary:\n{sample['summary']}")

Dialogue:
#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?
#Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good.
#Person2#: Ok.
#Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith?
#Person2#: Yes.
#Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit.
#Person2#: I've tried hundreds of times, but I just can't seem to kick the habit.
#Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave.
#Person2#: Ok, thanks doctor.

Summary:
Mr. Smith's getting a check-up, and Doctor Hawki

## 6. Data Preprocessing and Tokenization

In [7]:
def preprocess_function(examples):
    inputs = [
        f"Summarize the following conversation:\n\n{dialogue}\n\nSummary:" 
        for dialogue in examples["dialogue"]
    ]
    
    model_inputs = tokenizer(
        inputs, 
        max_length=MAX_INPUT_LENGTH, 
        truncation=True, 
        padding=False
    )
    
    labels = tokenizer(
        text_target=examples["summary"], 
        max_length=MAX_TARGET_LENGTH, 
        truncation=True, 
        padding=False
    )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
tokenized_dataset = dataset.map(
    preprocess_function, 
    batched=True, 
    remove_columns=["id", "topic", "dialogue", "summary"]
)

# Use a small subset for faster training
train_subset = tokenized_dataset["train"].shuffle(seed=42).select(range(TRAIN_SUBSET_SIZE))
val_subset = tokenized_dataset["validation"].shuffle(seed=42).select(range(200))

print(f"Train size (subset): {len(train_subset):,}")
print(f"Validation size (subset): {len(val_subset):,}")
print(f"Test size: {len(tokenized_dataset['test']):,}")

Map: 100%|██████████| 500/500 [00:00<00:00, 4464.29 examples/s]

Train size (subset): 1,000
Validation size (subset): 200
Test size: 1,500





## 7. Setup LoRA

In [13]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(model, lora_config)
peft_model.enable_input_require_grads()  # required for gradient_checkpointing
peft_model.print_trainable_parameters()

trainable params: 1,769,472 || all params: 249,347,328 || trainable%: 0.7096


## 8. Define Evaluation Metrics

In [14]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # If predictions are logits (3D), take argmax to get token IDs
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    predictions = np.array(predictions)
    if predictions.ndim == 3:
        predictions = np.argmax(predictions, axis=-1)

    # Clip to valid token ID range to prevent OverflowError
    predictions = np.clip(predictions, 0, tokenizer.vocab_size - 1)
    # Replace -100 in predictions with pad_token_id
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip whitespace
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {k: round(v * 100, 4) for k, v in result.items()}

## 9. Configure Training

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    weight_decay=0.01,
    save_total_limit=1,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LENGTH,
    logging_steps=25,
    eval_strategy="no",         # skip evaluation during training for speed
    save_strategy="epoch",
    report_to="none",
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=peft_model)

trainer = Seq2SeqTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_subset,         # use subset
    eval_dataset=val_subset,            # use subset
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

## 10. Train the Model

In [16]:
train_result = trainer.train()

print(f"Training Loss: {train_result.training_loss:.4f}")

Step,Training Loss
25,1.5039
50,1.3113
75,1.2948
100,1.2952
125,1.4007


Training Loss: 1.3612


## 11. Save the Model

In [17]:
peft_model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"Model saved to: {OUTPUT_DIR}")

Model saved to: ./flan-t5-lora-dialogue-summary


## 12. Evaluate on Test Set

In [18]:
def evaluate_on_test_set(trainer, tokenized_dataset):
    test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
    
    print("Test Set Results:")
    print(f"  ROUGE-1:    {test_results['eval_rouge1']:.2f}%")
    print(f"  ROUGE-2:    {test_results['eval_rouge2']:.2f}%")
    print(f"  ROUGE-L:    {test_results['eval_rougeL']:.2f}%")
    print(f"  ROUGE-Lsum: {test_results['eval_rougeLsum']:.2f}%")
    
    return test_results

test_results = evaluate_on_test_set(trainer, tokenized_dataset)

Test Set Results:
  ROUGE-1:    41.00%
  ROUGE-2:    15.57%
  ROUGE-L:    33.18%
  ROUGE-Lsum: 33.15%


## 13. Inference - Load and Test Saved Model

In [19]:
config = PeftConfig.from_pretrained(OUTPUT_DIR)

base_model = AutoModelForSeq2SeqLM.from_pretrained(
    config.base_model_name_or_path, 
    torch_dtype=torch.float32
)

inference_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
inference_model.to(device)
inference_model.eval()

print(f"Model loaded on {device}")

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /google/flan-t5-base/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001EBD0729D90>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: 3ffa449f-0a67-4aa0-a356-776d7f99125e)')' thrown while requesting HEAD https://huggingface.co/google/flan-t5-base/resolve/main/config.json
Retrying in 1s [Retry 1/5].
'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /google/flan-t5-base/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001EB86E56360>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: 5291a991-9159-44e7-93c2-11010db2850f)')' thrown while requesting HEAD https://huggingface.co/google/flan-t5-base/resolve/main/config.json
Retrying in 2s

Model loaded on cpu


In [20]:
def generate_summary(model, tokenizer, dialogue, device):
    input_text = f"Summarize the following conversation:\n\n{dialogue}\n\nSummary:"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=128,
            do_sample=True,
            top_p=0.9,
            temperature=0.7
        )
    
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

In [21]:
sample_dialogue = """
#Person1#: I'm thinking of upgrading my computer hardware.
#Person2#: What kind of upgrades do you have in mind?
#Person1#: I definitely need more RAM and maybe a better graphics card for gaming.
#Person2#: Make sure your power supply can handle the new card.
"""

print("Input Dialogue:")
print(sample_dialogue)
print("\nGenerated Summary:")
print(generate_summary(inference_model, tokenizer, sample_dialogue, device))

Input Dialogue:

#Person1#: I'm thinking of upgrading my computer hardware.
#Person2#: What kind of upgrades do you have in mind?
#Person1#: I definitely need more RAM and maybe a better graphics card for gaming.
#Person2#: Make sure your power supply can handle the new card.


Generated Summary:
#Person1# thinks about upgrading his computer hardware. #Person1# needs more RAM and maybe a better graphics card for gaming.


In [22]:
test_sample = dataset["test"][0]

print("Dialogue:")
print(test_sample["dialogue"])
print("\nGround Truth Summary:")
print(test_sample["summary"])
print("\nGenerated Summary:")
print(generate_summary(inference_model, tokenizer, test_sample["dialogue"], device))

Dialogue:
#Person1#: Ms. Dawson, I need you to take a dictation for me.
#Person2#: Yes, sir...
#Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready?
#Person2#: Yes, sir. Go ahead.
#Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited.
#Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications?
#Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications.
#Person2#: But sir, many employees use Instant Messaging to communicate with their clients.
#Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with th

## 14. Batch Evaluation Function

In [23]:
def evaluate_samples(model, tokenizer, dataset, device, num_samples=10):
    predictions = []
    references = []
    
    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        generated = generate_summary(model, tokenizer, sample["dialogue"], device)
        predictions.append(generated)
        references.append(sample["summary"])
    
    rouge = evaluate.load("rouge")
    results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
    
    print(f"ROUGE Scores ({num_samples} samples):")
    print(f"  ROUGE-1: {results['rouge1']*100:.2f}%")
    print(f"  ROUGE-2: {results['rouge2']*100:.2f}%")
    print(f"  ROUGE-L: {results['rougeL']*100:.2f}%")
    
    return results

## 15. Interactive Test — Enter Your Own Conversation

Enter a conversation below and the model will generate a summary for it.

In [24]:
# Enter your conversation here (use #Person1#: and #Person2#: format)
conversation = """
#Person1#: Hey, have you finished the project report yet?
#Person2#: Almost. I just need to add the conclusion and proofread it.
#Person1#: The deadline is tomorrow morning, right?
#Person2#: Yes, I'll have it done by tonight. Can you review it after?
#Person1#: Sure, just send it over when you're done.
#Person2#: Great, I'll email it to you by 9 PM.
"""

print("=" * 60)
print("INPUT CONVERSATION:")
print("=" * 60)
print(conversation)
print("=" * 60)
print("MODEL-GENERATED SUMMARY:")
print("=" * 60)
print(generate_summary(inference_model, tokenizer, conversation, device))

INPUT CONVERSATION:

#Person1#: Hey, have you finished the project report yet?
#Person2#: Almost. I just need to add the conclusion and proofread it.
#Person1#: The deadline is tomorrow morning, right?
#Person2#: Yes, I'll have it done by tonight. Can you review it after?
#Person1#: Sure, just send it over when you're done.
#Person2#: Great, I'll email it to you by 9 PM.

MODEL-GENERATED SUMMARY:
#Person2# finishes the project report and wants to add the conclusion and proofread it. #Person1# will email it to #Person2# by 9 PM.


## Summary & Deployment

### Training
This notebook fine-tuned FLAN-T5 with LoRA on a 1000-sample subset for 1 epoch (fast test).

### Deployment Steps

**Step 1 — Upload model to Hugging Face Hub**
```bash
pip install huggingface_hub
huggingface-cli login
# Edit HF_USERNAME in upload_to_hf.py, then:
python upload_to_hf.py
```

**Step 2 — Create a Hugging Face Space**
1. Go to https://huggingface.co/new-space
2. Choose **Gradio** SDK, name it `flan-t5-dialogue-summarizer`
3. Upload the 3 files from `huggingface_space/` folder (`app.py`, `requirements.txt`, `README.md`)
4. Edit `MODEL_REPO` in `app.py` to `"YOUR_HF_USERNAME/flan-t5-dialogue-summarizer"`
5. The Space builds automatically — your app is live!

**Step 3 — GitHub Pages**
1. Push this project to a GitHub repo
2. Go to repo **Settings → Pages → Source → Deploy from branch → `main` → `/docs`**
3. Edit `docs/index.html` and replace `YOUR_HF_USERNAME` with your actual username
4. Your site is live at `https://YOUR_GH_USERNAME.github.io/REPO_NAME/`