# GPT-2 Training

In [11]:
import numpy as np
import pandas as pd

In [12]:
# Read in movie titles metadata.txt
def read_movie_metadata() -> pd.DataFrame:
    movie_metadata = {}
    with open("data/movie_titles_metadata.txt", "r") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            # Field separator: " +++$+++ "
            line = line.split(" +++$+++ ")

            # Strip off "m" from id and convert to int
            line[0] = int(line[0][1:])

            # Lowercase title
            line[1] = line[1].lower()

            # Check if the year has an "I" at the end, if so, remove it then convert to int
            if line[2][-1] =="I":
                line[2] = line[2][:-2]

            try:
                line[2] = int(line[2])
            except:
                print(f"Movie {line[0]}/'{line[1]}' - Invalid year: {line[2]}")
                continue

            # Convert IMDB rating to float
            line[3] = float(line[3])

            # Convert IMDB votes to int
            line[4] = int(line[4])

            # Strip off spaces, [], \n, and '' from genres
            line[5] = line[5].strip("\n[]").replace("'", "").replace(" ", "").split(",")

            # Fields: movie ID, movie title, movie year, IMDB rating, number of IMDB votes, genres in the format ['genre1','genre2',�,'genreN']
            movie_metadata |= {
                idx: {
                    "id": line[0],
                    "title": line[1],
                    "year": line[2],
                    "IMDB_rating": line[3],
                    "IMDB_votes": line[4],
                    "genres": line[5],
                }
            }

    return pd.DataFrame.from_dict(movie_metadata, orient="index")

In [13]:
def read_character_metadata(movie_metadata_df: pd.DataFrame) -> pd.DataFrame:
    # Read in movie characters metadata.txt
    character_metadata = {}
    with open("data/movie_characters_metadata.txt", "r") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            # Field separator: " +++$+++ "
            line = line.split(" +++$+++ ")

            # Strip off "c" from id and convert to int
            line[0] = int(line[0][1:])

            # Lowercase character name
            line[1] = line[1].lower()

            # Strip off m from movie id and convert to int, then ensure movie id is valid
            line[2] = int(line[2][1:])
            if line[2] not in movie_metadata_df["id"].values:
                print(
                    f"Character - {line[0]}/'{line[1]}': Movie ID: {line[2]} does not exist in movie_titles_metadata.txt or has been removed."
                )
                continue

            # Lowercase movie title
            line[3] = line[3].lower()

            # Convert gender to boolean (0 = male, 1 = female, ? = nan)
            if line[4] == "m":
                line[4] = False
            elif line[4] == "f":
                line[4] = True
            elif line[4] == "?":
                line[4] = np.nan

            # Convert position to int and remove "\n"
            line[5] = line[5][:-1]
            if line[5][-1] == "?":
                line[5] = np.nan
            else:
                line[5] = int(line[5])

            # Fields: character ID, character name, movie id, movie title, gender ("?" for unlabeled cases), position in credits ("?" for unlabeled cases)
            character_metadata |= {
                idx: {
                    "id": line[0],
                    "name": line[1],
                    "movie_id": line[2],
                    "movie_title": line[3],
                    "gender": line[4],
                    "position": line[5],
                }
            }

    return pd.DataFrame.from_dict(character_metadata, orient="index")

In [14]:
def read_line_data(
    movie_metadata_df: pd.DataFrame, character_metadata_df: pd.DataFrame
) -> pd.DataFrame:
    # Read in movie lines .txt
    line_data = {}
    with open("data/movie_lines.txt", "r") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            # Field separator: " +++$+++ "
            line = line.split(" +++$+++ ")

            # Strip off "L" from line id and convert to int
            line[0] = int(line[0][1:])

            # Strip off "c" from character id and convert to int, then ensure character id is valid
            line[1] = int(line[1][1:])
            if line[1] not in character_metadata_df["id"].values:
                print(
                    f"Line - {line[0]}: Character ID: {line[1]} does not exist in movie_characters_metadata.txt or has been removed."
                )
                continue

            # Strip of "m" from movie id and convert to int, then ensure movie id is valid
            line[2] = int(line[2][1:])
            if line[2] not in movie_metadata_df["id"].values:
                print(
                    f"Line - {line[0]}: Movie ID: {line[2]} does not exist in movie_titles_metadata.txt or has been removed."
                )
                continue

            # Lowercase character name
            line[3] = line[3].lower()

            # Fields: line ID, character ID, movie id, character name, text
            line_data |= {
                idx: {
                    "id": line[0],
                    "character_id": line[1],
                    "movie_id": line[2],
                    "character_name": line[3],
                    "line": line[4],
                }
            }

    return pd.DataFrame.from_dict(line_data, orient="index")

In [15]:
def read_conversations_data(
    movie_metadata_df: pd.DataFrame,
    character_metadata_df: pd.DataFrame,
    line_data_df: pd.DataFrame,
) -> pd.DataFrame:
    # Read in movie lines .txt
    conversation_data = {}
    with open("data/movie_conversations.txt", "r") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            # Field separator: " +++$+++ "
            line = line.split(" +++$+++ ")

            # Create a conversation index
            line.insert(0, idx)

            # Strip off "c" from both character id's and convert to int, then ensure character id is valid for both characters
            line[1] = int(line[1][1:])
            line[2] = int(line[2][1:])
            if line[1] not in character_metadata_df["id"].values:
                print(
                    f"Conversation - {line[0]}: Character ID: {line[1]} does not exist in movie_characters_metadata.txt or has been removed."
                )
                continue

            if line[2] not in character_metadata_df["id"].values:
                print(
                    f"Conversation - {line[0]}: Character ID: {line[2]} does not exist in movie_characters_metadata.txt or has been removed."
                )
                continue

            # Strip off "m" from movie index and convert to int, then ensure movie id is valid
            line[3] = int(line[3][1:])
            if line[3] not in movie_metadata_df["id"].values:
                print(
                    f"Conversation - {line[0]}: Movie ID: {line[3]} does not exist in movie_titles_metadata.txt or has been removed."
                )
                continue

            # Strip off spaces, L, [], \n, and '' from lines, then ensure each line id is valid
            line[4] = (
                line[4].strip("\n[]").replace("'", "").replace(" ", "").replace("L", "").split(",")
            )
            line[4] = [int(l) for l in line[4]]

            invalid = False
            for l in line[4]:
                if l not in line_data_df["id"].values:
                    invalid = True

            if invalid:
                print(
                        f"Conversation - {line[0]}: A lineID it references does not exist in movie_lines.txt or has been removed."
                    )

            # Fields: first speaker character ID, second speaker character ID, movie ID, list of lines in chronological order ['lineID1', 'lineID2', �, 'lineIDN']
            conversation_data |= {
                idx: {
                    "id": line[0],
                    "character ID 1": line[1],
                    "character ID 2": line[2],
                    "movie ID": line[3],
                    "lines": line[4],
                }
            }

    return pd.DataFrame.from_dict(conversation_data, orient="index")

In [16]:
# Final report on data
movie_metadata_df = read_movie_metadata()
character_metadata_df = read_character_metadata(movie_metadata_df)
line_data_df = read_line_data(movie_metadata_df, character_metadata_df)
conversation_data_df = read_conversations_data(movie_metadata_df, character_metadata_df, line_data_df)

print(f"Number of movies: {len(movie_metadata_df)}")
print(f"Number of characters: {len(character_metadata_df)}")
print(f"Number of lines: {len(line_data_df)}")
print(f"Number of conversations: {len(conversation_data_df)}")

Number of movies: 617
Number of characters: 9035
Number of lines: 304713
Number of conversations: 83097


In [17]:
movie_df = movie_metadata_df
character_df = character_metadata_df
line_df = line_data_df
conversation_df = conversation_data_df

In [18]:
print(f"Number of movies: {len(movie_df)}")
print(f"Number of characters: {len(character_df)}")
print(f"Number of lines: {len(line_df)}")
print(f"Number of conversations: {len(conversation_df)}")

Number of movies: 617
Number of characters: 9035
Number of lines: 304713
Number of conversations: 83097


In [19]:
df = conversation_df.copy(deep=True)

In [20]:
def convert_to_conversation(conversation_data, line_df):
    # Turn conversation into chat format of one input and one response
    # For each pair of lines, add start of sentence token, end of sentence token, and bot response

    chats = []
    for data in conversation_data:
        # If it is not even, drop the last line as it is missing the bot response
        if len(data) % 2 == 1:
            data = data[:-1]

        for idx, line in enumerate(data):
            line_text = line_df[line_df["id"] == line]["line"].values[0]

            if idx % 2 == 0:
                chat = " ".join(["<SOS>", line_text])
            else:
                chat = " ".join([chat, "<BOT>", line_text, "<EOS>"])
                chats.append(chat)

    return chats


chats = convert_to_conversation(conversation_df["lines"].values, line_df)

In [21]:
print(len(chats))

138135


In [22]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Add special tokens for start/end of sentence, padding, and bot response
tokenizer.add_special_tokens({"pad_token": "<PAD>", "bos_token": "<SOS>", "eos_token": "<EOS>"})
tokenizer.add_tokens(["<BOT>"])
model.resize_token_embeddings(len(tokenizer))

  from .autonotebook import tqdm as notebook_tqdm


Embedding(50261, 768)

In [23]:
print(tokenizer.decode(model.generate(**tokenizer("<SOS> Hi how are you? <BOT> ", return_tensors="pt"))[0]))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<SOS> Hi how are you?  <BOT>  <PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>


In [24]:
from torch.utils.data import Dataset
class ChatData(Dataset):
    def __init__(self, chats, tokenizer):
        self.chats = chats
        self.tokenizer = tokenizer
        self.max_len = 128

        self.encoded_data = self.tokenizer(self.chats, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt")
        self.input_id = self.encoded_data["input_ids"]
        self.attention_mask = self.encoded_data["attention_mask"]

    def __len__(self):
        return len(self.chats)

    def __getitem__(self, idx):
        return self.input_id[idx], self.attention_mask[idx]

In [115]:
import torch
import torch_directml
from torch.optim import Adam
from tqdm import tqdm
from pathlib import Path
from collections import deque
from better_profanity import profanity
import re


def train_model(chat_data, model, optimizer, device, epochs=10, save_every=5):
    model.train()
    for epoch in tqdm(range(epochs)):
        batch = 0
        for input_id, attention_mask in chat_data:
            batch += 1
            # print(f"Batch: {batch}")
            input_id = input_id.to(device)
            attention_mask = attention_mask.to(device)
            output = model(input_id, attention_mask=attention_mask, labels=input_id)
            loss = output.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # Print example chat
        # example = tokenizer("<SOS> Hi I am Ethan, how are you? <BOT>", return_tensors="pt")
        # input_id = example["input_ids"].to(device)
        # attention_mask = example["attention_mask"].to(device)
        # print(tokenizer.decode(model.generate(input_id, attention_mask=attention_mask)[0]))

        if epoch % save_every == 0:
            torch.save(model.state_dict(), Path(f"models/gpt2/checkpoints/model_{epoch}.pt"))

    torch.save(model.state_dict(), Path(f"models/gpt2/final/model_final.pt"))

def truncate_history(history, tokenizer, max_length=128):
    # Tokenizing history
    encoded_history = [tokenizer.encode(entry, return_tensors="pt")[0] for entry in history]

    combined_history = torch.cat(encoded_history, dim=0)

    # If history exceeds max_length, truncate
    if len(encoded_history) >= max_length:
        truncated_history = encoded_history[-max_length:]
    else:
        truncated_history = combined_history

    truncated_history_ids = truncated_history.tolist()

    truncated_text = tokenizer.decode(truncated_history_ids, skip_special_tokens=False)

    return truncated_text.strip()

def generate_context(chat_history, max_entries=2):
    # Generating context based off of history
    return list(chat_history)[-max_entries:]

def inference(model, tokenizer, device):
    model.eval()
    chat_history = deque(maxlen=5)
    chat = input("User: ")
    
    while chat.lower() != "quit" and chat != "q":
        user_input = f"<SOS> {chat.strip()}"
        chat_history.append(user_input)   # Adding user input to history
        
        context = generate_context([entry for entry in chat_history if "<SOS>" in entry]) # Generating context (not including most recent user input)    context, tokenizer, max_length=128) # Truncating user history so it doesn't exceed max_length of 128 in longer convo
        truncated_history = truncate_history(context, tokenizer, max_length=128) # Truncating user history so it doesn't exceed max_length of 128 in longer convo
        
        chat_encoded = tokenizer(truncated_history, return_tensors="pt", truncation=True, max_length=128)

        input_id = chat_encoded["input_ids"].to(device)
        attention_mask = chat_encoded["attention_mask"].to(device)
        
        output = model.generate(
            input_id, 
            attention_mask=attention_mask, 
            min_length=30,  # Setting minimum new tokens to ensure responses aren't too short
            max_new_tokens=100, # Setting maximum new tokens to ensure responses aren't too verbose
            do_sample=True, # Set to true for use with temp, top-k, and top-p
            temperature=0.2, # Lowering temperature for more coherent responses.
            top_k=50, # Adding top-k sampling to limit to top 50 likely next tokens.
            top_p=0.9, # Adding nucleus sampling as well.
            repetition_penalty=1.5, # Penalizing repeated sequences
            pad_token_id=tokenizer.eos_token_id # Handling end of sequence properly
        )   
        
        response = tokenizer.decode(output[0], skip_special_tokens=False)

        # Removing previous responses and special tokens from output 
        if response.startswith(user_input):
            response = response[len(user_input):].strip()
        elif response.startswith(chat_history[-2] + user_input):
            response = response[len(chat_history[-2] + user_input):].strip()
        response = response.replace("<BOT>", "").replace("<SOS>", "").replace("<EOS>", "").replace("<PAD>", "").strip()

        # Censoring profanity
        def clean_response(response):
            return profanity.censor(response)

        cleaned_response = clean_response(response)

        # Post-processing to truncate at 70 tokens or nearest sentence end
        max_length = 70
        tokens = cleaned_response.split()
        truncated_response = " ".join(tokens[:max_length])

        sentence_endings = list(re.finditer(r'[.!?]', cleaned_response))
            
        if not sentence_endings:
            truncated_response = truncated_response.strip()
        else:
            for match in sentence_endings:
                next_cut = match.end()
                possible_truncation = cleaned_response[:next_cut].strip()
                
                token_count = len(possible_truncation.split())
                if token_count <= max_length:
                    truncated_response = possible_truncation
                else:
                    break     

        print(f"Bot: {truncated_response}")

        chat = input("User: ")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50261, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50261, bias=False)
)

In [27]:
from torch.utils.data import DataLoader

chat_data = DataLoader(ChatData(chats[0:16384], tokenizer), batch_size=32, shuffle=True)

In [32]:
optimizer = Adam(model.parameters(), lr=5e-5)
train_model(chat_data, model, optimizer, device, epochs=20, save_every=10)

100%|█████████████████████████████████████████████████████████████████████████████████| 20/20 [56:40<00:00, 170.04s/it]


In [None]:
inference(model, tokenizer, device)

User:  What is the meaning of life?


Bot: Lives are ephemeral. Things we don't live with anymore... like art or music... you can only imagine what it must have meant to those around them when they died, before their time -- all dead and forgotten as a whole in our lifetimes! I'm not gonna pretend there's absolutely no way this means anything other than making things better for ourselves..


User:  Wow that's pretty dismal.


Bot: Well, it seems we've got some sorta out there... a search for answers here on this alien planet.. and maybe even an answer to our own existence as we know It! Now what kindof questions does one ask ourselves when faced with such problems...? Do they have any knowledge at all about anything like.... well perhaps... astrology???


User:  Yeah Earth does seem kind of alien in some places, huh?


Bot: Yes... it is.. depressing really! I don't know why they keep this on their birth certificates but surely they must have thought something awful was going to happen here somewhere down there...and probably killing people as well.... Anyway the aliens killed a lot more than were injured and tended animals when coming out from hibernation so by examining them now you'd say...they've re-engineered our DNA somehow?!!! What about "we"?


User:  Who's re-engineered our DNA?


Bot: Dr. Weir... I'm not sure what you're talking about -- why we need the ship or anything else besides life support...? Starck: Where are your people now?! How many colonists have they left on board!? What happened to them??? Commander Alexander Hamilton has yet again come up with a replacement for this starboard computer! Why haven't those men ever retired from active duty??
