In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F

Dataset link: https://www.kaggle.com/datasets/amaanmansuri/humor-detection <br>This dataset has around 20k texts. Half of them are jokes. For generation I'll use only them


In [None]:
jokes = pd.read_csv('Humour.csv')
jokes_df = jokes[jokes['humor']==True][:10000]
jokes_df = jokes_df.drop('humor', axis=1)

In [None]:
test_set = jokes_df.sample(n = 30)
jokes_df = jokes_df.loc[~jokes_df.index.isin(test_set.index)]

#Reset the indexes
test_set = test_set.reset_index()
test_set = test_set.drop('index', axis=1)
jokes_df = jokes_df.reset_index()
jokes_df = jokes_df.drop('index', axis=1)

#For the test set only, keep last 5 words in a new column, then remove them from original column
test_set['True end'] = test_set['text'].str.split().str[5:].apply(' '.join)
test_set['text'] = test_set['text'].str.split().str[:5].apply(' '.join)

In [None]:
jokes_df

Unnamed: 0,text
0,What do you call a turtle without its shell? d...
1,What is a pokemon master's favorite kind of pa...
2,Why do native americans hate it when it rains ...
3,"My family tree is a cactus, we're all pricks."
4,How are music and candy similar? we throw away...
...,...
9965,How do you know you're girlfriend is getting t...
9966,Kids telling dirty jokes http://www.vice.com/s...
9967,How do we know that joan of arc was french ? s...
9968,Ever heard of the 68 position? you go down on ...


## Fine-tuning GPT2

This time torch is used because imo it's easier to fine-tune model with this framework


In [None]:
class Jokes(Dataset):  
    def __init__(self, control_code, truncate=False, gpt2_type="gpt2", max_length=32):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.jokes = []

        for row in jokes_df['text']:
          self.jokes.append(torch.tensor(
                self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
            ))               
        if truncate:
            self.jokes = self.jokes[:20000]
        self.count = len(self.jokes)
        
    def __len__(self):
        return self.count

    def __getitem__(self, item):
        return self.jokes[item]
    
dataset = Jokes(jokes_df['text'].values, truncate=True, gpt2_type="gpt2")  

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

#Accumulated batch size (since GPT2 is so big)
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

In [None]:
def train(
    dataset, model,
    batch_size=16, epochs=5, lr=2e-5,
    max_seq_len=40, warmup_steps=200,
    gpt2_type="gpt2", output_dir=".", output_prefix="wreckgar",
    test_mode=False,save_model_on_epoch=False,
):
    acc_steps = 100
    device=torch.device("cuda")
    model = model.cuda()
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    loss=0
    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):

        print(f"Training epoch {epoch}")
        print(loss)
        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None
        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
            )
    return model

In [None]:
model = train(dataset, model, tokenizer)



Training epoch 0
0


9970it [06:02, 27.54it/s]


Training epoch 1
tensor(0.7847, device='cuda:0', grad_fn=<NllLossBackward0>)


9970it [06:00, 27.62it/s]


Training epoch 2
tensor(0.3385, device='cuda:0', grad_fn=<NllLossBackward0>)


9970it [06:00, 27.62it/s]


Training epoch 3
tensor(0.3183, device='cuda:0', grad_fn=<NllLossBackward0>)


9970it [06:00, 27.63it/s]


Training epoch 4
tensor(0.2852, device='cuda:0', grad_fn=<NllLossBackward0>)


9970it [06:00, 27.62it/s]


In [None]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=30, #maximum number of words
    top_p=0.8,
    temperature=1.,
):
    model.eval()
    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False
            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)
                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
              output_list = list(generated.squeeze().numpy())
              output_text = f"{tokenizer.decode(output_list)}" 
              generated_list.append(output_text)
                
    return generated_list


In [None]:
def text_generation(test_data):
  generated_joke = []
  for i in range(len(test_data)):
    x = generate(model.to('cpu'), tokenizer, test_set['text'][i], entry_count=1)
    generated_joke.append(x)
  return generated_joke

In [None]:
generated_joke = text_generation(test_set['text'].values)

100%|██████████| 1/1 [00:04<00:00,  4.73s/it]
100%|██████████| 1/1 [00:04<00:00,  4.24s/it]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it]
100%|██████████| 1/1 [00:00<00:00,  1.37it/s]
100%|██████████| 1/1 [00:05<00:00,  5.46s/it]
100%|██████████| 1/1 [00:04<00:00,  4.23s/it]
100%|██████████| 1/1 [00:04<00:00,  4.57s/it]
100%|██████████| 1/1 [00:05<00:00,  5.16s/it]
100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
100%|██████████| 1/1 [00:04<00:00,  4.33s/it]
100%|██████████| 1/1 [00:05<00:00,  5.28s/it]
100%|██████████| 1/1 [00:02<00:00,  2.35s/it]
100%|██████████| 1/1 [00:04<00:00,  4.30s/it]
100%|██████████| 1/1 [00:04<00:00,  4.30s/it]
100%|██████████| 1/1 [00:05<00:00,  5.16s/it]
100%|██████████| 1/1 [00:04<00:00,  4.35s/it]
100%|██████████| 1/1 [00:04<00:00,  4.53s/it]
100%|██████████| 1/1 [00:05<00:00,  5.08s/it]
100%|██████████| 1/1 [00:04<00:00,  4.31s/it]
100%|██████████| 1/1 [00:05<00:00,  5.19s/it]
100%|██████████| 1/1 [00:04<00:00,  4.81s/it]
100%|██████████| 1/1 [00:04<00:00,

## RESULTS

In [None]:
for joke in generated_joke:
  print(joke[0])
  print('________\n')

How can a cat walk on the moon? A search of the Google  search engine revealed nothing. It may even go without saying that the answer is still in the question
________

So a frog parked his rifle in the middle of the road

A young man walked in

He didn't know what to say

He just ran out

________

What did the baby seal say? What did the tusk say?<|endoftext|>
________

What do you with 365 days to think? Thinking!<|endoftext|>
________

That allah guy sure is a smart guy and that's why she wants to marry him. She'll go down on him if she gets pregnant."

"What's your
________

What do you call two jobs that take two days to earn a living? a laborer's day job. A laborer's overtime is the same as a week of work
________

Honey the baby is crowning!
  I know my baby isn't going to show
  why he can't be.  I'll teach you to ask him
________

Top 10 ways to avoid dating.

#10: Avoid spending the night with your mother.

I know you're already spending the night at home but I'll
________

