# Lab 4: Prefix Tuning - Fine-Tuning a GPT-2 Model for Generation
---
## Notebook 3: Inference

**Goal:** In this notebook, you will load the trained prefix tuning adapter and use the fine-tuned GPT-2 model to generate new text that mimics the style of the training data (positive movie reviews).

**You will learn to:**
-   Reload the base GPT-2 model and tokenizer.
-   Load the trained Prefix Tuning adapter from a checkpoint.
-   Write a function to generate text based on a starting prompt.


### Step 1: Reload Model and Adapter

We will load the base `gpt2` model and then apply our trained prefix tuning weights on top of it.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os

# --- Load Base Model and Tokenizer ---
model_checkpoint = "gpt2"
base_model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

# --- Load PEFT Adapter ---
output_dir = "./gpt2-prefix-tuning-imdb"
latest_checkpoint = max(
    [os.path.join(output_dir, d) for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
    key=os.path.getmtime
)
print(f"Loading adapter from: {latest_checkpoint}")

inference_model = PeftModel.from_pretrained(base_model, latest_checkpoint)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_model.to(device)
inference_model.eval()

print("✅ Inference model loaded successfully!")


### Step 2: Perform Inference

Now, let's test the model by giving it a starting prompt and letting it generate a positive-sounding movie review.


In [None]:
# Prepare the prompt
prompt = "This movie was absolutely"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate the text
with torch.no_grad():
    outputs = inference_model.generate(
        input_ids=inputs["input_ids"],
        max_new_tokens=50,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        num_return_sequences=3 # Generate a few different options
    )

# Decode and print the generated text
print("--- Prompt ---")
print(prompt)
print("\n--- Generated Reviews ---")
for i, output in enumerate(outputs):
    print(f"{i+1}: {tokenizer.decode(output, skip_special_tokens=True)}")
