In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset

# Load the tokenizer and model
model_name = "gpt2-large"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Add padding token to the tokenizer
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))  # Adjust the model's embedding size

# Check if MPS (Metal Performance Shaders) is available
device = torch.device("mps" if torch.has_mps else "cpu")
model.to(device)

# Load the haiku dataset
dataset = load_dataset("davanstrien/haiku_kto")

# Inspect dataset keys
print(dataset['train'][0])

  from .autonotebook import tqdm as notebook_tqdm
  device = torch.device("mps" if torch.has_mps else "cpu")


{'prompt': "Write a haiku about the elk's bugling in the forest.", 'completion': "Autumn leaves quiver,\nElk's call echoes through trees,\nNature's symphony.", 'label': False, 'label-suggestion': None, 'label-suggestion-metadata': {'type': None, 'score': None, 'agent': None}, 'external_id': None, 'metadata': '{"prompt": "Write a haiku about the elk\'s bugling in the forest.", "generation_model": "mistralai/Mistral-7B-Instruct-v0.2"}', 'messages': [{'content': "Write a haiku about the elk's bugling in the forest.", 'role': 'user'}, {'content': "Autumn leaves quiver,\nElk's call echoes through trees,\nNature's symphony.", 'role': 'assistant'}]}


In [2]:
# Split the training data into train and validation sets (90% train, 10% validation)
train_val_split = dataset['train'].train_test_split(test_size=0.1)
train_data = train_val_split['train']
val_data = train_val_split['test']

In [3]:
# Extract haiku texts and tokenize them
def extract_and_tokenize_function(batch):
    haikus = []
    for example in batch['messages']:
        haiku = next((message['content'] for message in example if message['role'] == 'assistant'), None)
        if haiku:
            haikus.append(haiku)
    tokenized = tokenizer(haikus, truncation=True, padding='max_length', max_length=50)
    input_ids = torch.tensor(tokenized['input_ids'])
    attention_mask = torch.tensor(tokenized['attention_mask'])
    labels = input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = -100
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

# Apply tokenization function with batching
train_data = train_data.map(extract_and_tokenize_function, batched=True, remove_columns=train_data.column_names)
val_data = val_data.map(extract_and_tokenize_function, batched=True, remove_columns=val_data.column_names)

train_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map: 100%|██████████| 82/82 [00:00<00:00, 3359.31 examples/s]
Map: 100%|██████████| 10/10 [00:00<00:00, 2413.02 examples/s]


In [4]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy="epoch",  # Save strategy set to "epoch"
    evaluation_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
)



In [5]:
# Fine-tune the model
trainer.train()

  0%|          | 0/205 [00:00<?, ?it/s]

  5%|▍         | 10/205 [00:12<03:05,  1.05it/s]

{'loss': 4.9015, 'grad_norm': 15.342203140258789, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.24}


 10%|▉         | 20/205 [00:23<03:43,  1.21s/it]

{'loss': 4.855, 'grad_norm': 15.591681480407715, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.49}


 15%|█▍        | 30/205 [00:32<02:39,  1.10it/s]

{'loss': 4.4229, 'grad_norm': 16.52741050720215, 'learning_rate': 3e-06, 'epoch': 0.73}


 20%|█▉        | 40/205 [01:28<08:02,  2.93s/it]

{'loss': 4.1647, 'grad_norm': 19.260408401489258, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.98}


                                                
 20%|██        | 41/205 [01:45<17:21,  6.35s/it]

{'eval_loss': 3.7416024208068848, 'eval_runtime': 2.2725, 'eval_samples_per_second': 4.4, 'eval_steps_per_second': 2.2, 'epoch': 1.0}


 24%|██▍       | 50/205 [04:37<45:44, 17.70s/it]

{'loss': 3.7262, 'grad_norm': 13.97623062133789, 'learning_rate': 5e-06, 'epoch': 1.22}


 29%|██▉       | 60/205 [05:37<08:55,  3.69s/it]

{'loss': 3.4244, 'grad_norm': 15.864128112792969, 'learning_rate': 6e-06, 'epoch': 1.46}


 34%|███▍      | 70/205 [05:49<02:28,  1.10s/it]

{'loss': 3.3003, 'grad_norm': 14.642148971557617, 'learning_rate': 7.000000000000001e-06, 'epoch': 1.71}


 39%|███▉      | 80/205 [06:05<04:25,  2.13s/it]

{'loss': 2.973, 'grad_norm': 15.36148452758789, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.95}


                                                
 40%|████      | 82/205 [06:08<03:42,  1.81s/it]

{'eval_loss': 2.918102741241455, 'eval_runtime': 0.3873, 'eval_samples_per_second': 25.822, 'eval_steps_per_second': 12.911, 'epoch': 2.0}


 44%|████▍     | 90/205 [07:14<04:49,  2.52s/it]

{'loss': 2.4999, 'grad_norm': 11.069402694702148, 'learning_rate': 9e-06, 'epoch': 2.2}


 49%|████▉     | 100/205 [07:24<02:21,  1.34s/it]

{'loss': 2.4913, 'grad_norm': 12.502325057983398, 'learning_rate': 1e-05, 'epoch': 2.44}


 54%|█████▎    | 110/205 [07:35<01:21,  1.17it/s]

{'loss': 2.4232, 'grad_norm': 22.84785270690918, 'learning_rate': 1.1000000000000001e-05, 'epoch': 2.68}


 59%|█████▊    | 120/205 [07:49<03:37,  2.56s/it]

{'loss': 2.2908, 'grad_norm': 12.233875274658203, 'learning_rate': 1.2e-05, 'epoch': 2.93}


                                                 
 60%|██████    | 123/205 [07:54<02:36,  1.91s/it]

{'eval_loss': 2.7125401496887207, 'eval_runtime': 0.3606, 'eval_samples_per_second': 27.731, 'eval_steps_per_second': 13.865, 'epoch': 3.0}


 63%|██████▎   | 130/205 [08:45<03:32,  2.84s/it]

{'loss': 1.7921, 'grad_norm': 14.437853813171387, 'learning_rate': 1.3000000000000001e-05, 'epoch': 3.17}


 68%|██████▊   | 140/205 [09:02<02:02,  1.89s/it]

{'loss': 1.5513, 'grad_norm': 12.577813148498535, 'learning_rate': 1.4000000000000001e-05, 'epoch': 3.41}


 73%|███████▎  | 150/205 [09:22<02:53,  3.16s/it]

{'loss': 1.7324, 'grad_norm': 18.391651153564453, 'learning_rate': 1.5e-05, 'epoch': 3.66}


 78%|███████▊  | 160/205 [09:34<00:46,  1.02s/it]

{'loss': 1.4684, 'grad_norm': 14.281442642211914, 'learning_rate': 1.6000000000000003e-05, 'epoch': 3.9}


                                                 
 80%|████████  | 164/205 [09:37<00:35,  1.16it/s]

{'eval_loss': 2.8714218139648438, 'eval_runtime': 0.3215, 'eval_samples_per_second': 31.108, 'eval_steps_per_second': 15.554, 'epoch': 4.0}


 83%|████████▎ | 170/205 [10:51<02:51,  4.89s/it]

{'loss': 1.1427, 'grad_norm': 13.994254112243652, 'learning_rate': 1.7000000000000003e-05, 'epoch': 4.15}


 88%|████████▊ | 180/205 [11:03<00:28,  1.14s/it]

{'loss': 0.949, 'grad_norm': 14.221123695373535, 'learning_rate': 1.8e-05, 'epoch': 4.39}


 93%|█████████▎| 190/205 [11:12<00:13,  1.12it/s]

{'loss': 0.9869, 'grad_norm': 13.56718635559082, 'learning_rate': 1.9e-05, 'epoch': 4.63}


 98%|█████████▊| 200/205 [11:24<00:05,  1.03s/it]

{'loss': 1.0187, 'grad_norm': 15.463618278503418, 'learning_rate': 2e-05, 'epoch': 4.88}


                                                 
100%|██████████| 205/205 [11:48<00:00,  3.27s/it]

{'eval_loss': 3.0735347270965576, 'eval_runtime': 0.354, 'eval_samples_per_second': 28.247, 'eval_steps_per_second': 14.124, 'epoch': 5.0}


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
100%|██████████| 205/205 [12:28<00:00,  3.65s/it]

{'train_runtime': 748.1234, 'train_samples_per_second': 0.548, 'train_steps_per_second': 0.274, 'train_loss': 2.5649413271648127, 'epoch': 5.0}





TrainOutput(global_step=205, training_loss=2.5649413271648127, metrics={'train_runtime': 748.1234, 'train_samples_per_second': 0.548, 'train_steps_per_second': 0.274, 'total_flos': 87132019200000.0, 'train_loss': 2.5649413271648127, 'epoch': 5.0})

In [6]:
# Save the fine-tuned model
model.save_pretrained("./fine-tuned-haiku-model")
tokenizer.save_pretrained("./fine-tuned-haiku-model")

('./fine-tuned-haiku-model/tokenizer_config.json',
 './fine-tuned-haiku-model/special_tokens_map.json',
 './fine-tuned-haiku-model/vocab.json',
 './fine-tuned-haiku-model/merges.txt',
 './fine-tuned-haiku-model/added_tokens.json')

In [7]:
# Function to generate haiku
def generate_haiku(prompt, model, tokenizer, max_length=30):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    outputs = model.generate(
        input_ids=input_ids, 
        attention_mask=attention_mask, 
        max_length=max_length, 
        num_return_sequences=1, 
        no_repeat_ngram_size=2, 
        early_stopping= True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [8]:
# Test the fine-tuned model
prompt = "The oceans breeze"
haiku = generate_haiku(prompt, model, tokenizer)
print(f"Generated Haiku:\n{haiku}")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated Haiku:
The oceans breeze,
Silent whispers of the sea,


Nature's symphony, timeless.

...

.

