# Lab 3: Prompt Tuning - Fine-Tuning a T5 Model for Summarization
---
## Notebook 3: Inference

**Goal:** In this notebook, you will load the trained soft prompt and use the fine-tuned T5 model to generate summaries for new pieces of text.

**You will learn to:**
-   Reload the base T5 model and tokenizer.
-   Load the trained Prompt Tuning adapter from a checkpoint using `peft.PeftModel`.
-   Write a function to perform inference and generate a summary.


### Step 1: Reload Model and Adapter

We will load the base `t5-small` model and then apply our trained soft prompt weights on top of it using `PeftModel.from_pretrained`.


In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel
import torch
import os

# --- Load Base Model and Tokenizer ---
model_checkpoint = "t5-small"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# --- Load PEFT Adapter ---
output_dir = "./t5-prompt-tuning-billsum"
latest_checkpoint = max(
    [os.path.join(output_dir, d) for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
    key=os.path.getmtime
)
print(f"Loading adapter from: {latest_checkpoint}")

inference_model = PeftModel.from_pretrained(base_model, latest_checkpoint)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_model.to(device)
inference_model.eval()

print("✅ Inference model loaded successfully!")


### Step 2: Perform Inference

Now, let's test the model by giving it a piece of text to summarize. We'll use an example from the test set of the `billsum` dataset.

The process is:
1.  Tokenize the input text (including the "summarize: " prefix).
2.  Use the `generate()` method to create the summary.
3.  Decode the output tokens back into a string.


In [None]:
from datasets import load_dataset

# Load a sample from the test set to summarize
dataset = load_dataset("billsum", split="test[:1]")
original_text = dataset[0]["text"]
reference_summary = dataset[0]["summary"]

# Prepare the input
prompt = "summarize: " + original_text
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate the summary
with torch.no_grad():
    outputs = inference_model.generate(
        input_ids=inputs["input_ids"],
        max_length=150,
        num_beams=4, # Use beam search for higher quality summaries
        early_stopping=True
    )

generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the results
print("--- Original Text ---")
print(original_text)
print("\n--- Reference Summary ---")
print(reference_summary)
print("\n--- Generated Summary (via Prompt Tuning) ---")
print(generated_summary)
