# Fine-Tuning Flan-T5-Base with LoRA for Dialogue Summarization

This notebook walks through fine-tuning `google/flan-t5-base` using **LoRA** (Low-Rank Adaptation) on the `knkarthick/dialogsum` dataset.

Key features:
- 4-bit quantization via `BitsAndBytesConfig` to save VRAM
- LoRA with r=16, alpha=32 targeting the q and v projection modules
- Training on the first 500 rows of the dataset
- Saving and reloading only the LoRA adapters for inference

## 1. Install Dependencies

Run this cell if the packages are not already installed.

In [2]:
# Uncomment the line below to install from requirements.txt
!pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 25.0 -> 26.0.1
[notice] To update, run: C:\Users\mobol\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


## 2. Imports

In [2]:
import os
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from peft import (
    LoraConfig,
    PeftModel,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)
import evaluate

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.10.0+cpu
CUDA available: False


## 3. Configuration

In [2]:
# -- Model and dataset --
MODEL_NAME = "google/flan-t5-base"
DATASET_NAME = "knkarthick/dialogsum"
NUM_TRAIN_SAMPLES = 500

# -- Paths --
OUTPUT_DIR = "./lora-flan-t5-summarization"
ADAPTER_DIR = "./lora-adapters"

# -- Tokenization --
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128

# -- Training hyperparameters --
BATCH_SIZE = 4
LEARNING_RATE = 1e-3
NUM_EPOCHS = 3
LOGGING_STEPS = 25

## 4. Load Tokenizer and Quantized Model

We use `BitsAndBytesConfig` with 4-bit NF4 quantization and double quantization enabled to minimize memory usage.

In [5]:
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model with quantization
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
)

# Prepare model for k-bit training (freeze base weights, cast layernorm to fp32, etc.)
model = prepare_model_for_kbit_training(model)

print(f"Model loaded: {MODEL_NAME}")
print(f"Model dtype: {model.dtype}")

NameError: name 'BitsAndBytesConfig' is not defined

## 5. Apply LoRA Adapters

LoRA configuration:
- **r = 16** (rank of the low-rank matrices)
- **alpha = 32** (scaling factor)
- **target_modules = ["q", "v"]** (query and value projection layers)
- **task_type = SEQ_2_SEQ_LM** (sequence-to-sequence language modeling)

In [5]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 1,769,472 || all params: 249,347,328 || trainable%: 0.7096


## 6. Load and Preprocess the Dataset

We use the first 500 rows from `knkarthick/dialogsum` for training and up to 100 rows for validation.

In [6]:
dataset = load_dataset(DATASET_NAME)

train_dataset = dataset["train"].select(range(NUM_TRAIN_SAMPLES))
val_dataset = dataset["validation"].select(range(min(100, len(dataset["validation"]))))

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"\nSample dialogue (truncated):\n{train_dataset[0]['dialogue'][:300]}...")
print(f"\nSample summary:\n{train_dataset[0]['summary']}")

Training samples: 500
Validation samples: 100

Sample dialogue (truncated):
#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doc...

Sample summary:
Mr. Smith's getting a check-up, and Doctor Hawkins advises him to have one every year. Hawkins'll give some information about their classes and medications to help Mr. Smith quit smoking.


In [7]:
PREFIX = "Summarize the following dialogue:\n\n"


def preprocess(examples):
    inputs = [PREFIX + dialogue for dialogue in examples["dialogue"]]
    targets = examples["summary"]

    model_inputs = tokenizer(
        inputs,
        max_length=MAX_INPUT_LENGTH,
        padding="max_length",
        truncation=True,
    )

    labels = tokenizer(
        targets,
        max_length=MAX_TARGET_LENGTH,
        padding="max_length",
        truncation=True,
    )

    # Replace pad token ids in labels with -100 so they are ignored by the loss
    labels["input_ids"] = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


tokenized_train = train_dataset.map(
    preprocess, batched=True, remove_columns=train_dataset.column_names
)
tokenized_val = val_dataset.map(
    preprocess, batched=True, remove_columns=val_dataset.column_names
)

print(f"Tokenized training features: {tokenized_train.column_names}")
print(f"Tokenized validation features: {tokenized_val.column_names}")

Tokenized training features: ['input_ids', 'attention_mask', 'labels']
Tokenized validation features: ['input_ids', 'attention_mask', 'labels']


## 7. Set Up Evaluation Metrics (ROUGE)

In [8]:
rouge = evaluate.load("rouge")


def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode predictions
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Decode labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip whitespace
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    result = rouge.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    return {k: round(v, 4) for k, v in result.items()}

## 8. Training

In [9]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    logging_steps=LOGGING_STEPS,
    eval_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LENGTH,
    bf16=True,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Starting training...")

ValueError: Your setup doesn't support bf16/gpu. You need to assign use_cpu if you want to train the model on CPU

In [None]:
trainer.train()

## 9. Evaluate on the Validation Set

In [None]:
eval_results = trainer.evaluate()
print("Evaluation results:")
for key, value in eval_results.items():
    print(f"  {key}: {value}")

## 10. Save LoRA Adapters

We save only the LoRA adapter weights - not the full base model. This keeps the saved checkpoint very small (a few MB instead of hundreds of MB).

In [None]:
# Save only the LoRA adapters and tokenizer
model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)

print(f"LoRA adapters saved to: {ADAPTER_DIR}")

# Show what was saved
adapter_files = os.listdir(ADAPTER_DIR)
print(f"\nSaved files:")
for f in adapter_files:
    size = os.path.getsize(os.path.join(ADAPTER_DIR, f))
    print(f"  {f} ({size / 1024:.1f} KB)")

## 11. Reload Adapters and Run Inference

This section demonstrates how to load the saved LoRA adapters onto a fresh base model and run inference on a custom input string. This is the pattern you would use in production or on a different machine.

In [None]:
# -- Reload from scratch to prove it works independently --

# Step 1: Load the base model again with 4-bit quantization
reload_bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    quantization_config=reload_bnb_config,
    device_map="auto",
)

# Step 2: Load the LoRA adapters on top of the base model
inference_model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
inference_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_DIR)

print("Base model + LoRA adapters loaded successfully for inference.")

In [None]:
# -- Run inference on a custom dialogue string --

custom_dialogue = (
    "Sarah: Hey Mike, are we still on for the project meeting tomorrow?\n"
    "Mike: Yes, but can we push it to 2 PM instead of 10 AM? I have a dentist appointment in the morning.\n"
    "Sarah: Sure, 2 PM works. Should I book the conference room?\n"
    "Mike: That would be great. Also, can you send me the latest draft before the meeting?\n"
    "Sarah: Will do. I will email it to you tonight.\n"
    "Mike: Perfect. Thanks, Sarah!"
)

input_text = PREFIX + custom_dialogue

inputs = inference_tokenizer(
    input_text,
    return_tensors="pt",
    max_length=MAX_INPUT_LENGTH,
    truncation=True,
).to(inference_model.device)

with torch.no_grad():
    outputs = inference_model.generate(
        **inputs,
        max_new_tokens=MAX_TARGET_LENGTH,
        num_beams=4,
        early_stopping=True,
    )

summary = inference_tokenizer.decode(outputs[0], skip_special_tokens=True)

print("=" * 60)
print("DIALOGUE:")
print("=" * 60)
print(custom_dialogue)
print("\n" + "=" * 60)
print("GENERATED SUMMARY:")
print("=" * 60)
print(summary)

In [None]:
# -- Helper function for quick inference on any dialogue --

def summarize_dialogue(dialogue: str, model=inference_model, tok=inference_tokenizer) -> str:
    """Summarize a dialogue string using the fine-tuned LoRA model."""
    prompt = PREFIX + dialogue
    inputs = tok(prompt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(model.device)
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=MAX_TARGET_LENGTH, num_beams=4, early_stopping=True)
    return tok.decode(output_ids[0], skip_special_tokens=True)


# Try it out with another example
another_dialogue = (
    "Tom: Did you see the email from the client?\n"
    "Jane: Yes, they want the delivery moved up by two weeks.\n"
    "Tom: That is going to be tight. Can we pull it off?\n"
    "Jane: If we get two more developers, maybe. I will talk to the manager.\n"
    "Tom: Alright, keep me posted."
)

print("Summary:", summarize_dialogue(another_dialogue))

## Done

You have successfully:
1. Loaded `google/flan-t5-base` with 4-bit quantization
2. Applied LoRA adapters (r=16, alpha=32) to the q and v modules
3. Trained on 500 samples from `knkarthick/dialogsum`
4. Saved the lightweight LoRA adapters
5. Reloaded them onto a fresh base model for inference

To use these adapters elsewhere, just copy the `lora-adapters/` directory and follow the reload pattern in section 11.