In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu' # try to use GPU, if not then CPU
print(device)

'''
Hyperparameters
'''
block_size = 80 # also becomes context size
batch_size = 16
iterations = 1_000
eval_iters = 100
learning_rate = 3e-4
embd_dim = 400
num_decoders = 4
num_heads = 4
dropout = 0.3

In [None]:
'''
Vocabulary
'''
words = ""
with open('crime_and_punishment.txt', 'r', encoding='utf-8') as f:
    txt = f.read()
    txt = txt.lower()
    words = list(set(txt.split()))
    txt = list(txt.split()) # txt is a list of words

vocab_size = len(words)

In [None]:
'''
Train/Test Split
'''
word_to_int = {s:i for i, s in enumerate(words)}
int_to_word = {i:s for i, s in enumerate(words)}
encode = lambda word_list: [word_to_int[s] for s in word_list]
decode = lambda index_list: [int_to_word[i] for i in index_list]

data = torch.tensor(encode(txt), dtype=torch.long)

n = int(0.8 * len(data))
train_data = data[:n]
test_data = data[n:]


def get_random_batch(portion):
    data = train_data if portion == 'train' else test_data
    block_indices = torch.randint(low=0, high=len(data) - block_size, size=(batch_size,))
    
    inputs = [data[i:i+block_size] for i in block_indices] # list of input blocks
    inputs = torch.stack(inputs) # stack
    targets = [data[i+1:i+block_size+1] for i in block_indices] # list of target blocks (offset from input blocks by 1)
    targets = torch.stack(targets) # stack
    
    inputs, targets = inputs.to(device), targets.to(device)
    return inputs, targets


'''
Create Model
'''
from GPT import *
gpt = GPT(vocab_size, embd_dim, block_size, num_decoders, num_heads, dropout)
gpt = gpt.to(device) # Use GPU


In [None]:
'''
Estimate Losses During Training
'''
@torch.no_grad()
def estimate_losses():
    out = {} # dict
    gpt.eval() # product outputs without training or even calculating gradients
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            inputs, targets = get_random_batch(split)
            logits, loss = gpt.forward(inputs, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    return out


'''
Optimizer
'''
opt = torch.optim.AdamW(gpt.parameters(), lr=learning_rate)


'''
Training Loop
'''
for iter in range(iterations): # optimization loop
    inputs, targets = get_random_batch('train')

    logits, loss = gpt.forward(inputs, targets)
    opt.zero_grad(set_to_none=True) # do not accumulate gradients over time
    loss.backward()
    opt.step()
    
    if iter % eval_iters == 0:
        losses = estimate_losses()
        print(f"iter: {iter}, train loss: {losses['train']}, test loss: {losses['test']}")

'''
Save Learned Parameters
'''
torch.save(gpt.state_dict(), 'model_weights.pth')

In [None]:
gpt.load_state_dict(torch.load('model_weights.pth'))
prompt = ['good', 'afternoon']
context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
generated_words = decode(gpt.generate(context.unsqueeze(0), num_new_tokens=300)[0].tolist())

newline = 14
response = ''
for i, word in enumerate(generated_words, 1):
    response += word + ' '
    if i % newline == 0:
        response += '\n'

print('\n' + response)