# TRANSFER LEARNING ON TRANSFORMER TO GIVE INFO ABOUT THE COSTA RICAN DISH

## Fine-Tuning GPT-2 for Recipe Generation: Training and Analysis

## Dataset version 1

In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random 

# Set the device to GPU or Apple M1 (MPS) if available, otherwise CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Function to split the dataset
def split_dataset(filename, train_ratio=0.8):
    with open(filename, 'r', encoding='utf-8') as file:
        content = file.read().split('---end-of-recipe---')

    random.shuffle(content)
    train_size = int(len(content) * train_ratio)
    train_data = content[:train_size]
    validation_data = content[train_size:]

    return train_data, validation_data

# Split the dataset and save it in different files
train_data, validation_data = split_dataset('../dataset-transformers/dishes_train_v1.txt')
train_filename = 'train_dataset_v1.txt'
validation_filename = 'validation_dataset_v1.txt'

with open(train_filename, 'w', encoding='utf-8') as f:
    f.write('---end-of-recipe---'.join(train_data))

with open(validation_filename, 'w', encoding='utf-8') as f:
    f.write('---end-of-recipe---'.join(validation_data))

# Load the GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Ensure the tokenizer uses the correct pad token
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained('gpt2')

# Move model to the appropriate device
model.to(device)

# Custom Dataset class for recipes
class RecipeDataset(Dataset):
    def __init__(self, tokenizer, filename, block_size=128):
        self.tokenizer = tokenizer
        self.examples = []

        # Read and split the dataset file
        with open(filename, 'r', encoding='utf-8') as f:
            recipes = f.read().split('---end-of-recipe---')

        # Encode recipes and add to examples
        for recipe in recipes:
            if recipe.strip() == "":
                continue

            tokens = tokenizer.encode_plus(recipe, 
                                            add_special_tokens=True, 
                                            max_length=block_size, 
                                            padding='max_length', 
                                            truncation=True, 
                                            return_tensors='pt')

            self.examples.append(tokens)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        # Get individual item from dataset
        input_ids = self.examples[i]['input_ids'][0]
        attention_mask = self.examples[i]['attention_mask'][0]
        labels = input_ids.clone() # Labels for language modeling
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Initialize dataset and dataloader
dataset = RecipeDataset(tokenizer, '../dataset-transformers/dishes_train_v1.txt')

# Create a DataLoader
data_loader = DataLoader(dataset, batch_size=2, shuffle=True)


# Define training arguments for fine-tuning
training_args = TrainingArguments(
    output_dir='./gpt2_finetuned_recipes_v1',
    num_train_epochs=20,
    per_device_train_batch_size=2,
    logging_steps=50,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=50,
)

# Load validation dataset
validation_dataset = RecipeDataset(tokenizer, validation_filename)

# Initialize trainer for model fine-tuning
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=validation_dataset
)

# Train the model
trainer.train()

# Print log history
print(trainer.state.log_history)

# Extract and print training loss
training_loss_run1 = [log['loss'] for log in trainer.state.log_history if 'loss' in log]
validation_loss_v1 = [log['eval_loss'] for log in trainer.state.log_history if 'eval_loss' in log]
print(training_loss_run1)
print(validation_loss_v1)
