# Assignment 3: Fine-tune GPT-2 for Creative Story Generation
This notebook fine-tunes GPT-2 Medium on a story dataset to generate creative stories.

In [1]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip -v install transformers datasets accelerate

Using pip 25.3 from C:\Users\ayush\AppData\Local\Programs\Python\Python310\lib\site-packages\pip (python 3.10)



[notice] A new release of pip is available: 25.3 -> 26.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA GeForce RTX 4050 Laptop GPU


## Step 1: Load GPT-2 Medium Model and Tokenizer

In [3]:
# Load GPT-2 Medium (better quality)
model_name = "gpt2-medium"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

print(f"Model loaded: {model_name}")
print(f"Parameters: {model.num_parameters():,}")

Model loaded: gpt2-medium
Parameters: 354,823,168


## Step 2: Load and Prepare Story Dataset

In [4]:
# Load TinyStories dataset - specifically designed for story generation
dataset = load_dataset("roneneldan/TinyStories", split="train[:5000]")
print(f"Dataset size: {len(dataset)} stories")
print(f"\nSample story:\n{dataset[0]['text'][:500]}...")

Dataset size: 5000 stories

Sample story:
One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them b...


In [5]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=256,
        padding="max_length"
    )

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
print(f"Tokenization complete!")

Tokenization complete!


## Step 3: Fine-tune the Model

In [6]:
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Training arguments
training_args = TrainingArguments(
    output_dir="./story_gpt2_model",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=100,
    logging_steps=50,
    save_steps=500,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    report_to="none"
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

print("Starting fine-tuning...")

Starting fine-tuning...


In [None]:
# Train the model
trainer.train()
print("\nFine-tuning complete!")

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
50,2.0611
100,1.8177
150,1.785


In [None]:
# Save the fine-tuned model
model.save_pretrained("./story_gpt2_model/final")
tokenizer.save_pretrained("./story_gpt2_model/final")
print("Model saved!")

## Step 4: Generate Stories

In [None]:
def generate_story(prompt, max_length=300, temperature=0.8, top_p=0.92):
    """
    Generate a creative story from a prompt.
    
    Args:
        prompt: Starting text for the story (e.g., "Once upon a time")
        max_length: Maximum length of generated story
        temperature: Higher = more creative (0.7-1.0 recommended)
        top_p: Nucleus sampling parameter (0.9-0.95 recommended)
    """
    model.eval()
    
    # Encode prompt
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
    model.to(device)
    
    # Generate story
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3
        )
    
    # Decode and return story
    story = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return story

print("Story generation function ready!")

In [None]:
# Generate sample stories
prompts = [
    "Once upon a time, in a magical forest,",
    "The little girl found a mysterious box in her grandmother's attic.",
    "A brave knight set out on a journey to find the lost treasure."
]

print("=" * 60)
print("GENERATED STORIES")
print("=" * 60)

for i, prompt in enumerate(prompts, 1):
    print(f"\n--- Story {i} ---")
    print(f"Prompt: {prompt}\n")
    story = generate_story(prompt)
    print(story)
    print("\n" + "-" * 60)

## Step 5: Interactive Story Generator

In [None]:
# Interactive story generation
print("\n" + "=" * 60)
print("INTERACTIVE STORY GENERATOR")
print("=" * 60)
print("Enter a story prompt to generate a creative story.")
print("Type 'quit' to exit.\n")

while True:
    user_prompt = input("Enter your story prompt: ").strip()
    
    if user_prompt.lower() == 'quit':
        print("Goodbye!")
        break
    
    if not user_prompt:
        print("Please enter a prompt!\n")
        continue
    
    print("\nGenerating story...\n")
    story = generate_story(user_prompt, max_length=350)
    print("Generated Story:")
    print("-" * 40)
    print(story)
    print("-" * 40 + "\n")