In [1]:
import tiktoken, torch, torch.nn as nn
from torch.utils.data import DataLoader
from utils.TextDataset import TextDataset
from utils.GPTLargeLanguageModel import GPTLargeLanguageModel
from utils.generate_tokens import generate_tokens

def get_text(file):
    with open(file, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def text_to_tokens(text, tokenizer):
    text = tokenizer.encode(text)
    text = torch.tensor(text).unsqueeze(0)
    return text

def tokens_to_text(tokens, tokenizer):
    tokens = tokens.squeeze(0).tolist()
    tokens = tokenizer.decode(tokens)
    return tokens

TRAIN = False

In [2]:
# Text data
text = get_text("The_Complete_Works_of_William_Shakespeare.txt")

# LLM parameters
vocab_size = 50257 # 50257
num_layers = 12 # 12
context_length = 1024 # 1024
dimension = 768 # 768
num_heads = 12 # 12
dropout = 0.1 # 0.1

# Dataloader parameters
batch_size = 4
stride = context_length // 2

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Dataloader
dataset = TextDataset(text, tokenizer, max_length=context_length, stride=stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

# LLM
LLM = GPTLargeLanguageModel(vocab_size, num_layers, context_length, dimension, num_heads, dropout).to(device)

In [3]:
def train_model(model, dataloader, optimizer, epochs, device=device):
    try:
        loader_count = len(dataloader)
        for epoch in range(1, epochs + 1):
            print(f"\nEpoch {epoch}")
            model.train()

            for i, (x, y) in enumerate(dataloader):
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                output = model(x)
                loss = nn.functional.cross_entropy(output.flatten(0, 1), y.flatten())
                loss.backward()
                optimizer.step()

                del output

                print(f"Loss {i+1}/{loader_count}: {loss:3f}", end="\r")
    except KeyboardInterrupt:
        torch.save(LLM.state_dict(), f'interrupted_model_weights_{num_layers}_{context_length}_{num_heads}.pth')
        torch.save(optimizer.state_dict(), f'interrupted_optimizer_{num_layers}_{context_length}_{num_heads}.pth')
    finally: 
        print("Done")


optimizer = torch.optim.AdamW(LLM.parameters(), lr=0.0004, weight_decay=0.1)

if TRAIN:
    train_model(LLM, dataloader, optimizer, epochs=1)
    torch.save(LLM.state_dict(), f'model_weights_{num_layers}_{context_length}_{num_heads}.pth')
    torch.save(optimizer.state_dict(), f'optimizer_{num_layers}_{context_length}_{num_heads}.pth')
else:
    LLM.load_state_dict(torch.load(f'model_weights_{num_layers}_{context_length}_{num_heads}.pth'))
    # optimizer.load_state_dict(torch.load(f'optimizer_{num_layers}_{context_length}_{num_heads}.pth'))

In [4]:
prompt = "I'm tired"
output = generate_tokens(LLM, text_to_tokens(prompt, tokenizer).to(device), 50, 1024, 1.5, 25)
print(tokens_to_text(output, tokenizer))

I'm tired to see of all in such a man as goodly a king of a

and a word.
inbr of men must

in a good of his
I will to a goodly_.
you were,

with
