In [None]:
! pip install transformers datasets

In [None]:
from datasets import load_dataset

dataset_name = "google/Synthetic-Persona-Chat"
dataset = load_dataset(dataset_name)

# The dataset is often split into train, validation, and test
train_dataset = dataset["train"]

In [None]:
# Load the tokenizer
from transformers import AutoTokenizer

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# GPT-2 tokenizer doesn't have a default pad token, so we set it to the EOS token
# It's also beneficial to add special tokens for dialogue formatting if you choose to use them.
tokenizer.pad_token = tokenizer.eos_token 

# A function to format each example
def format_conversation(example):
    # Combine personas into a single string
    persona_1 = " ".join(example["User 1 Personas"])
    persona_2 = " ".join(example["User 2 Personas"])
    personas = f"P1: {persona_1} P2: {persona_2}"
    
    # Concatenate the conversation turns, adding a separator/EOS token after each turn
    # This structure trains the model to generate the next response after reading the last one.
    conversation = tokenizer.eos_token.join(example["Conversation"])
    
    # Combine the personas and the conversation
    full_text = f"{personas} <|startofchat|> {conversation} {tokenizer.eos_token}"
    return {"text": full_text}

# Apply the formatting function
processed_dataset = train_dataset.map(format_conversation, remove_columns=train_dataset.column_names)

In [None]:
block_size = 128 # A common choice, you can adjust this

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=block_size)

tokenized_dataset = processed_dataset.map(
    tokenize_function, 
    batched=True, 
    num_proc=4, # Use multiple processes for faster tokenization
    remove_columns=["text"]
)

# Use the DataCollatorForLanguageModeling to handle chunking and Masked Language Modeling (MLM)
# or just Language Modeling (LM) task (GPT-2 is a causal LM, so we use mlm=False)
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=False # Causal Language Modeling for GPT-2
)

In [None]:
from transformers import GPT2LMHeadModel, TrainingArguments, Trainer

# Load the model
model = GPT2LMHeadModel.from_pretrained(model_name)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./gpt2-persona-chat-finetuned", # Directory for output and checkpoints
    overwrite_output_dir=True,
    num_train_epochs=3, # Number of training epochs
    per_device_train_batch_size=4, # Batch size per GPU/TPU core
    save_steps=10_000, # Save checkpoint every X steps
    save_total_limit=2, # Only keep the last 2 checkpoints
    logging_steps=500,
    prediction_loss_only=True,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,
)

# Start fine-tuning
trainer.train()

# Save the final model
trainer.save_model("./final_gpt2_persona_chat")

In [None]:
from transformers import pipeline

# Load the fine-tuned model into a text generation pipeline
generator = pipeline(
    'text-generation', 
    model='./final_gpt2_persona_chat', 
    tokenizer=tokenizer,
)

# Example prompt based on the format used during training
prompt = "P1: I love hiking and the outdoors. P2: I collect stamps. <|startofchat|> P1: Hello! How are you today?"

# Generate text
generated_text = generator(
    prompt, 
    max_length=50, 
    num_return_sequences=1,
    do_sample=True, # Enable sampling for creative generation
    temperature=0.7,
)

print(generated_text[0]['generated_text'])