<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/GPT2_QA_Finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip -q install datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m471.0/480.6 kB[0m [31m18.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [11]:
import transformers
print(transformers.__version__)

4.47.1


In [9]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import Dataset
import torch

# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Tie lm_head weights if missing
if model.lm_head.weight.shape[0] != model.transformer.wte.weight.shape[0]:
    model.tie_weights()

# Add padding token to tokenizer
tokenizer.pad_token = tokenizer.eos_token

# Example QA dataset
qa_data = [
    {"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
    {"question": "Who wrote '1984'?", "answer": "George Orwell wrote '1984'."},
]

# Preprocess dataset
def preprocess_data(example):
    input_text = f"Question: {example['question']}\nAnswer:"
    target_text = example["answer"]
    inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=50)
    targets = tokenizer(target_text, truncation=True, padding="max_length", max_length=50)

    inputs["labels"] = targets["input_ids"]
    inputs["attention_mask"] = inputs["attention_mask"]
    return inputs

# Convert dataset to Huggingface Dataset object
dataset = Dataset.from_list(qa_data)
tokenized_dataset = dataset.map(preprocess_data, remove_columns=["question", "answer"])

# Training arguments
training_args = TrainingArguments(
    output_dir="./gpt2_qa_finetuned",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to=[],  # Disable W&B or any reporting
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    tokenizer=tokenizer,
)

# Fine-tune the model
trainer.train()

# Save the model
model.save_pretrained("./gpt2_qa_finetuned")
tokenizer.save_pretrained("./gpt2_qa_finetuned")

# Test the model
def generate_answer(question, model, tokenizer, device="cuda"):
    model.to(device)  # Move the model to the specified device
    input_text = f"Question: {question}\nAnswer:"
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)  # Move inputs to the same device
    outputs = model.generate(inputs, max_length=100, num_beams=5, early_stopping=True, pad_token_id=tokenizer.eos_token_id)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

question = "What is the capital of France?"
answer = generate_answer(question, model, tokenizer)
print(answer)


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,No log,5.542056
2,No log,3.124552
3,No log,2.559626


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


Question: What is the capital of France?
Answer: The capital of France is the capital of France.


In [12]:
def generate_answer(question, model, tokenizer, max_length=50, device="cuda"):
    """
    Generate an answer using greedy search.
    Args:
        question (str): The input question.
        model (GPT2LMHeadModel): The GPT-2 model.
        tokenizer (GPT2Tokenizer): The tokenizer.
        max_length (int): Maximum length of the generated text.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        str: The generated answer.
    """
    # Move the model to the correct device
    model.to(device)
    model.eval()  # Set model to evaluation mode

    # Prepare input
    input_text = f"Question: {question}\nAnswer:"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Start generating
    generated_ids = input_ids
    for _ in range(max_length):
        # Get logits for the next token
        outputs = model(input_ids=generated_ids)
        logits = outputs.logits

        # Select the token with the highest probability (greedy search)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Append the predicted token to the sequence
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        # Stop if the model predicts the end-of-sequence token
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Decode the generated tokens to text
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return answer


# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

question = "What is the capital of France?"
answer = generate_answer(question, model, tokenizer)
print(answer)



Question: What is the capital of France?
Answer: The capital of France is the capital of France.


In [15]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Tie lm_head weights if missing
if model.lm_head.weight.shape[0] != model.transformer.wte.weight.shape[0]:
    model.tie_weights()

# Add padding token to tokenizer
tokenizer.pad_token = tokenizer.eos_token

# Example QA dataset
qa_data = [
    {"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
    {"question": "Who wrote '1984'?", "answer": "George Orwell wrote '1984'."},
]

# Preprocess dataset
def preprocess_data(example):
    input_text = f"Question: {example['question']}\nAnswer:"
    target_text = example["answer"]
    inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=50)
    targets = tokenizer(target_text, truncation=True, padding="max_length", max_length=50)

    inputs["labels"] = targets["input_ids"]
    return inputs

# Convert dataset to Huggingface Dataset object
dataset = Dataset.from_list(qa_data)
tokenized_dataset = dataset.map(preprocess_data, remove_columns=["question", "answer"])

# Define data loaders
batch_size = 2
train_loader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)

# Define optimizer and device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
def train_model(model, train_loader, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = 0
        for batch in tqdm(train_loader):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(train_loader)}")

# Validation loop (optional)
def validate_model(model, val_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

    print(f"Validation Loss: {val_loss / len(val_loader)}")

# Train the model
train_model(model, train_loader, optimizer, device)

# Save the model
model.save_pretrained("./gpt2_qa_finetuned")
tokenizer.save_pretrained("./gpt2_qa_finetuned")

def generate_answer(question, model, tokenizer, max_length=50, device="cuda"):
    """
    Generate an answer using greedy search.
    Args:
        question (str): The input question.
        model (GPT2LMHeadModel): The GPT-2 model.
        tokenizer (GPT2Tokenizer): The tokenizer.
        max_length (int): Maximum length of the generated text.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        str: The generated answer.
    """
    # Move the model to the correct device
    model.to(device)
    model.eval()  # Set model to evaluation mode

    # Prepare input
    input_text = f"Question: {question}\nAnswer:"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Start generating
    generated_ids = input_ids
    for _ in range(max_length):
        # Get logits for the next token
        outputs = model(input_ids=generated_ids)
        logits = outputs.logits

        # Select the token with the highest probability (greedy search)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Append the predicted token to the sequence
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        # Stop if the model predicts the end-of-sequence token
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Decode the generated tokens to text
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return answer


# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

question = "What is the capital of France?"
answer = generate_answer(question, model, tokenizer)
print(answer)


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Epoch 1/10


100%|██████████| 1/1 [00:00<00:00,  8.16it/s]


Epoch 1 Loss: 7.104681968688965
Epoch 2/10


100%|██████████| 1/1 [00:00<00:00, 11.07it/s]


Epoch 2 Loss: 6.292819023132324
Epoch 3/10


100%|██████████| 1/1 [00:00<00:00, 10.98it/s]


Epoch 3 Loss: 5.607349872589111
Epoch 4/10


100%|██████████| 1/1 [00:00<00:00, 11.92it/s]


Epoch 4 Loss: 5.1459269523620605
Epoch 5/10


100%|██████████| 1/1 [00:00<00:00, 11.94it/s]


Epoch 5 Loss: 4.897861003875732
Epoch 6/10


100%|██████████| 1/1 [00:00<00:00, 12.02it/s]


Epoch 6 Loss: 5.431483268737793
Epoch 7/10


100%|██████████| 1/1 [00:00<00:00, 12.26it/s]


Epoch 7 Loss: 4.674878120422363
Epoch 8/10


100%|██████████| 1/1 [00:00<00:00, 12.11it/s]


Epoch 8 Loss: 4.1310224533081055
Epoch 9/10


100%|██████████| 1/1 [00:00<00:00, 12.14it/s]


Epoch 9 Loss: 3.3114748001098633
Epoch 10/10


100%|██████████| 1/1 [00:00<00:00, 12.12it/s]


Epoch 10 Loss: 3.6399283409118652
Question: What is the capital of France?
Answer: The capital of France is the capital of France


In [21]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Tie lm_head weights if missing
if model.lm_head.weight.shape[0] != model.transformer.wte.weight.shape[0]:
    model.tie_weights()

# Add padding token to tokenizer
tokenizer.pad_token = tokenizer.eos_token

# Example QA dataset
qa_data = [
    {"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
    {"question": "Who wrote '1984'?", "answer": "George Orwell wrote '1984'."},
]

# Preprocess dataset
def preprocess_data(example):
    input_text = f"Question: {example['question']}\nAnswer:"
    target_text = example["answer"]
    inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=50)
    targets = tokenizer(target_text, truncation=True, padding="max_length", max_length=50)

    inputs["labels"] = targets["input_ids"]
    return inputs

# Convert dataset to Huggingface Dataset object
dataset = Dataset.from_list(qa_data)
tokenized_dataset = dataset.map(preprocess_data, remove_columns=["question", "answer"])

# Define data loaders
batch_size = 2
train_loader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)

# Define optimizer, criterion, and device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Validation loop
def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Reshape logits and labels for CrossEntropyLoss
            shift_logits = logits[..., :-1, :].contiguous()  # Shift logits
            shift_labels = labels[..., 1:].contiguous()  # Shift labels
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            val_loss += loss.item()

    return val_loss / len(val_loader)

# Training loop
def train_model(model, train_loader, optimizer, criterion, device, num_epochs=3, save_dir="./best_model"):
    best_val_loss = float("inf")
    os.makedirs(save_dir, exist_ok=True)

    model.train()
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = 0
        for batch in tqdm(train_loader):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Reshape logits and labels for CrossEntropyLoss
            shift_logits = logits[..., :-1, :].contiguous()  # Shift logits
            shift_labels = labels[..., 1:].contiguous()  # Shift labels
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        # print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(train_loader)}")
        avg_train_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss}")

        # Validate after each epoch
        val_loss = validate_model(model, val_loader, criterion, device)
        print(f"Epoch {epoch + 1} Validation Loss: {val_loss}")

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)
            print(f"Saved best model with Validation Loss: {val_loss}")

# Train the model
train_model(model, train_loader, optimizer, criterion, device)

# Save the model
model.save_pretrained("./gpt2_qa_finetuned")
tokenizer.save_pretrained("./gpt2_qa_finetuned")

def generate_answer(question, model, tokenizer, max_length=50, device="cuda"):
    """
    Generate an answer using greedy search.
    Args:
        question (str): The input question.
        model (GPT2LMHeadModel): The GPT-2 model.
        tokenizer (GPT2Tokenizer): The tokenizer.
        max_length (int): Maximum length of the generated text.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        str: The generated answer.
    """
    # Move the model to the correct device
    model.to(device)
    model.eval()  # Set model to evaluation mode

    # Prepare input
    input_text = f"Question: {question}\nAnswer:"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Start generating
    generated_ids = input_ids
    for _ in range(max_length):
        # Get logits for the next token
        outputs = model(input_ids=generated_ids)
        logits = outputs.logits

        # Select the token with the highest probability (greedy search)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Append the predicted token to the sequence
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        # Stop if the model predicts the end-of-sequence token
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Decode the generated tokens to text
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return answer


# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

question = "What is the capital of France?"
answer = generate_answer(question, model, tokenizer)
print(answer)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Epoch 1/3


100%|██████████| 1/1 [00:00<00:00,  7.02it/s]


Epoch 1 Training Loss: 11.8451509475708


Validating: 100%|██████████| 1/1 [00:00<00:00, 43.91it/s]


Epoch 1 Validation Loss: nan
Epoch 2/3


100%|██████████| 1/1 [00:00<00:00,  9.41it/s]


Epoch 2 Training Loss: 7.646153450012207


Validating: 100%|██████████| 1/1 [00:00<00:00, 55.72it/s]


Epoch 2 Validation Loss: nan
Epoch 3/3


100%|██████████| 1/1 [00:00<00:00, 10.94it/s]


Epoch 3 Training Loss: 8.785937309265137


Validating: 100%|██████████| 1/1 [00:00<00:00, 40.94it/s]


Epoch 3 Validation Loss: nan
Question: What is the capital of France?
Answer: The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the


In [18]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import os

# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Tie lm_head weights if missing
if model.lm_head.weight.shape[0] != model.transformer.wte.weight.shape[0]:
    model.tie_weights()

# Add padding token to tokenizer
tokenizer.pad_token = tokenizer.eos_token

# Example QA dataset
qa_data = [
    {"question": "What is the capital of France?", "answer": "The capital of France is Paris."},
    {"question": "Who wrote '1984'?", "answer": "George Orwell wrote '1984'."},
]

# Preprocess dataset
def preprocess_data(example):
    input_text = f"Question: {example['question']}\nAnswer:"
    target_text = example["answer"]
    inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=50)
    targets = tokenizer(target_text, truncation=True, padding="max_length", max_length=50)

    inputs["labels"] = targets["input_ids"]
    return inputs

# Convert dataset to Huggingface Dataset object
dataset = Dataset.from_list(qa_data)
tokenized_dataset = dataset.map(preprocess_data, remove_columns=["question", "answer"])

# Split dataset into training and validation sets
train_dataset = tokenized_dataset.select(range(1))  # First half for training
val_dataset = tokenized_dataset.select(range(1, 2))  # Second half for validation

# Define data loaders
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Define optimizer, criterion, and device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Validation loop
def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Reshape logits and labels for CrossEntropyLoss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            val_loss += loss.item()

    return val_loss / len(val_loader)

# Training loop with validation and checkpoint saving
def train_model(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=3, save_dir="./best_model"):
    best_val_loss = float("inf")
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        epoch_loss = 0

        for batch in tqdm(train_loader, desc="Training"):
            input_ids = torch.stack(batch["input_ids"]).to(device)
            attention_mask = torch.stack(batch["attention_mask"]).to(device)
            labels = torch.stack(batch["labels"]).to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Reshape logits and labels for CrossEntropyLoss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss}")

        # Validate after each epoch
        val_loss = validate_model(model, val_loader, criterion, device)
        print(f"Epoch {epoch + 1} Validation Loss: {val_loss}")

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)
            print(f"Saved best model with Validation Loss: {val_loss}")

# Train the model
train_model(model, train_loader, val_loader, optimizer, criterion, device)


def generate_answer(question, model, tokenizer, max_length=50, device="cuda"):
    """
    Generate an answer using greedy search.
    Args:
        question (str): The input question.
        model (GPT2LMHeadModel): The GPT-2 model.
        tokenizer (GPT2Tokenizer): The tokenizer.
        max_length (int): Maximum length of the generated text.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        str: The generated answer.
    """
    # Move the model to the correct device
    model.to(device)
    model.eval()  # Set model to evaluation mode

    # Prepare input
    input_text = f"Question: {question}\nAnswer:"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Start generating
    generated_ids = input_ids
    for _ in range(max_length):
        # Get logits for the next token
        outputs = model(input_ids=generated_ids)
        logits = outputs.logits

        # Select the token with the highest probability (greedy search)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Append the predicted token to the sequence
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        # Stop if the model predicts the end-of-sequence token
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Decode the generated tokens to text
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return answer


# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"

question = "What is the capital of France?"
answer = generate_answer(question, model, tokenizer)
print(answer)


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Epoch 1/3


Training: 100%|██████████| 1/1 [00:00<00:00,  8.23it/s]


Epoch 1 Training Loss: nan


Validating: 100%|██████████| 1/1 [00:00<00:00, 30.72it/s]


Epoch 1 Validation Loss: nan
Epoch 2/3


Training: 100%|██████████| 1/1 [00:00<00:00,  7.92it/s]


Epoch 2 Training Loss: nan


Validating: 100%|██████████| 1/1 [00:00<00:00, 47.48it/s]


Epoch 2 Validation Loss: nan
Epoch 3/3


Training: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


Epoch 3 Training Loss: nan


Validating: 100%|██████████| 1/1 [00:00<00:00, 51.04it/s]


Epoch 3 Validation Loss: nan
Question: What is the capital of France?
Answer: The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the capital of France.
The capital of France is the
