# Fine-tuning GPT-2 on a twitter dataset in PyTorch

In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import os
import csv
from transformers import AdamW, get_linear_schedule_with_warmup

#If using colab unhash
#from google.colab import drive
#import logging
#logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

# Check for GPU availability and set the device accordingly
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

### Import the GPT-2 tokenizer and model
     Names of models:
     gpt2 - 124M params
     gpt2-medium - 380M params
     gpt-large - 812M params
     gpt2-xl - 1.61B params

In [None]:
# Define the name of the GPT-2 model to be used
model_name = 'gpt2'

# Initialize the tokenizer with the specified GPT-2 model
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Initialize the GPT-2 model for language modeling
model = GPT2LMHeadModel.from_pretrained(model_name)

In [None]:
def choose_from_top(probs, n=5):
    # Find the indices of the top n elements in probs
    ind = np.argpartition(probs, -n)[-n:]

    # Get the top n probabilities
    top_prob = probs[ind]

    # Normalize the probabilities to make sure they sum up to 1
    top_prob = top_prob / np.sum(top_prob) 

    # Randomly choose an index based on the normalized probabilities
    choice = np.random.choice(n, 1, p = top_prob)

    # Get the token_id corresponding to the chosen index
    token_id = ind[choice][0]
    return int(token_id)

### PyTorch Dataset module for tweets dataset

In [None]:
# Define a custom dataset class for handling Tweets dataset
class TweetsDataset(Dataset):
    def __init__(self, tweets_dataset_path = ''):
        super().__init__()

        # Construct the path to the tweets dataset
        tweets_path = os.path.join(tweets_dataset_path, 'data_twq.csv')

        # Initialize an empty list to store tweets and define an end of text token
        self.tweet_list = []
        self.end_of_text_token = "<|endoftext|>"
        
        with open(tweets_path, encoding="utf8") as csv_file:
            #Read dataset path 
            csv_reader = csv.reader(csv_file, delimiter=';')
            
            
            for row in csv_reader:
                #Extract tweets text from rows and add end of the text token to each tweet
                tweet_str = f"{row[1]}{self.end_of_text_token}"

                #Add preprocessed tweets to a list
                self.tweet_list.append(tweet_str)

    def __len__(self):
        return len(self.tweet_list)

    def __getitem__(self, item):
        return self.tweet_list[item]

In [None]:
# Create an instance of the dataset and DataLoader
dataset = TweetsDataset()
tweet_loader = DataLoader(dataset, batch_size=1, shuffle=True)

### Hyperparameters
If you run out of memory lower <b> LEARNING_RATE </b> and/or lower <b> BATCH_SIZE </b>

In [None]:
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 1e-3
WARMUP_STEPS = 5000
MAX_SEQ_LEN = 300

### Model training

I will train the model and save the model weights after each epoch and then I will try to generate tweets with each version of the weight to see which performs the best.

In [None]:
# Move the model to the specified device(GPU)
model = model.to(device)

# Set the model in training mode
model.train()

# Define the optimizer with AdamW and set the learning rate
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Define the scheduler for adjusting the learning rate during training
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps  = -1)

# Initialize variables for tracking loss and batch count
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

# Placeholder for temporary tensor to store tweet sequences
tmp_tweets_tens = None

#Create folder for trained models if it doens't exists
models_folder = "../trained_models"
if not os.path.exists(models_folder):
    os.mkdir(models_folder)

# Loop through each epoch
for epoch in range(EPOCHS):

    print(f"EPOCH {epoch} started" + '=' * 30)

    # Loop through each tweet in the data loader
    for idx,tweet in enumerate(tweet_loader):

        # "Fit as many tweet sequences into MAX_SEQ_LEN sequence as possible" logic 
        # Tokenize the tweet and convert it to a tensor
        tweet_tens = torch.tensor(tokenizer.encode(tweet[0])).unsqueeze(0).to(device)
        #Skip sample from dataset if it is longer than MAX_SEQ_LEN
        if tweet_tens.size()[1] > MAX_SEQ_LEN:
            continue

        #The first tweet sequence in the sequence
        if not torch.is_tensor(tmp_tweets_tens):
            tmp_tweets_tens = tweet_tens
            continue
        else:
            # The next tweet does not fit in so we process the sequence and leave the last tweet
            # as the start for next sequence
            if tmp_tweets_tens.size()[1] + tweet_tens.size()[1] > MAX_SEQ_LEN:
                work_tweets_tens = tmp_tweets_tens
                tmp_tweets_tens = tweet_tens
            else:
                #Add the tweet to sequence, continue and try to add more
                tmp_tweets_tens = torch.cat([tmp_tweets_tens, tweet_tens[:,1:]], dim=1)
                continue
        ################## Sequence ready, process it trough the model ##################

        # Forward pass through the model
        outputs = model(work_tweets_tens, labels=work_tweets_tens)
        loss, logits = outputs[:2]

        # Backward pass and update parameters
        loss.backward()
        sum_loss = sum_loss + loss.detach().data

        # Update the count of processed sequences
        proc_seq_count = proc_seq_count + 1

        # Update parameters after processing a batch
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0
            batch_count += 1
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()

        # Print and reset loss after every 10 batches
        if batch_count == 10:
            print(f"sum loss {sum_loss}")
            batch_count = 0
            sum_loss = 0.0

    # Store the model after each epoch to compare the performance of them
    torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_xl_manbot_{epoch}.pt"))


### Generating the tweets

### Test the model with sample generated content that starts with word: "Tips"

In [None]:
# Set the epoch of the model you want to load
MODEL_EPOCH = 0

# Define the folder where trained models are stored
models_folder = "../trained_models"

# Define the path to the model you want to load
model_path = os.path.join(models_folder, f"gpt2_xl_manbot_{MODEL_EPOCH}.pt")

# Load the model state dict
model.load_state_dict(torch.load(model_path))

# Define the file path to save generated tweets
tweets_output_file_path = f'generated_content_{MODEL_EPOCH}.txt'

# Set the model in evaluation mode
model.eval()

# Remove the existing file if it already exists
if os.path.exists(tweets_output_file_path):
    os.remove(tweets_output_file_path)

# Initialize the count of generated tweets
tweet_num = 0

# Perform generation without gradient computation
with torch.no_grad():

        # Generate 20 tweets
        for tweet_idx in range(20):
            
            # Flag to track if tweet generation is finished
            tweet_finished = False

            # Initialize the input with a starting prompt "Tips"
            cur_ids = torch.tensor(tokenizer.encode("Tips ")).unsqueeze(0).to(device)

            # Generate tokens to complete the tweet
            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding

                # Adjust the top-N sampling based on the iteration
                if i < 3:
                    n = 20
                else:
                    n = 3

                # Select the next token using top-N sampling    
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

                # Check if tweet generation is complete
                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    tweet_finished = True
                    break
                
            # If tweet generation is complete, process and write to file            
            if tweet_finished:

                tweet_num = tweet_num + 1

                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)

                with open(tweets_output_file_path, 'a') as f:
                    f.write(f"{output_text} \n\n")