In [2]:
with open("gameofthrones.txt", "r") as file:
    text = file.read()

In [3]:
print("Lenght of the book in characters: ", len(text))

Lenght of the book in characters:  5662324


In [4]:
print(text[:1000])



“We should start back,” Gared urged as the woods began to grow dark around them. “The wildlings are dead.”

“Do the dead frighten you?” Ser Waymar Royce asked with just the hint of a smile.

Gared did not rise to the bait. He was an old man, past fifty, and he had seen the lordlings come and go. “Dead is dead,” he said. “We have no business with the dead.”

“Are they dead?” Royce asked softly. “What proof have we?”

“Will saw them,” Gared said. “If he says they are dead, that’s proof enough for me.”

Will had known they would drag him into the quarrel sooner or later. He wished it had been later rather than sooner. “My mother told me that dead men sing no songs,” he put in.

“My wet nurse said the same thing, Will,” Royce replied. “Never believe anything you hear at a woman’s tit. There are things to be learned even from the dead.” His voice echoed, too loud in the twilit forest.

“We have a long ride before us,” Gared pointed out. “Eight days, maybe nine. And night is falling.”

Ser

In [5]:
chars = sorted(list(set(text)))
print(repr("".join(chars)))
print("Number of unique characters: ", len(chars))

'\n !(),-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz{}éê—‘’“”…'
Number of unique characters:  86


In [6]:
# Create a mapping from characters to indices (vocabulary)
char_to_idx = {ch:i for i, ch in enumerate(chars)}
idx_to_char = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [char_to_idx[c] for c in s] # s is the input string that I want to encode
decode = lambda l: "".join([idx_to_char[i] for i in l]) # l is the input list of indices that I want to decode

In [7]:
# Try to encode and decode a string
string = "In this place there is air"
encoded = encode(string)
print("Encoded string: ", encoded)
decoded = decode(encoded)
print("Decoded string: ", decoded)

Encoded string:  [30, 63, 1, 69, 57, 58, 68, 1, 65, 61, 50, 52, 54, 1, 69, 57, 54, 67, 54, 1, 58, 68, 1, 50, 58, 67]
Decoded string:  In this place there is air


In [15]:
# This is the tokenizer that CHATGPT uses

import tiktoken
encode = tiktoken.get_encoding("cl100k_base")

string = "In this place there is air"
encoded = encode.encode(string)
print("Encoded string: ", encoded)
decoded = encode.decode(encoded)
print("Decoded string: ", decoded)


Encoded string:  [644, 420, 2035, 1070, 374, 3805]
Decoded string:  In this place there is air


In [21]:
# This is the tokenizer that google used when pretraining the BERT model
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

string = "In this place there is air"
encoded = tokenizer.encode(string, add_special_tokens=False)
print("Encoded string: ", encoded)
decoded = tokenizer.decode(encoded)
print("Decoded string: ", decoded)

Encoded string:  [1999, 2023, 2173, 2045, 2003, 2250]
Decoded string:  in this place there is air


In [8]:
import torch 
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)
print(data[:100])

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([5662324])
tensor([ 0,  0, 83, 44, 54,  1, 68, 57, 64, 70, 61, 53,  1, 68, 69, 50, 67, 69,
         1, 51, 50, 52, 60,  5, 84,  1, 28, 50, 67, 54, 53,  1, 70, 67, 56, 54,
        53,  1, 50, 68,  1, 69, 57, 54,  1, 72, 64, 64, 53, 68,  1, 51, 54, 56,
        50, 63,  1, 69, 64,  1, 56, 67, 64, 72,  1, 53, 50, 67, 60,  1, 50, 67,
        64, 70, 63, 53,  1, 69, 57, 54, 62,  7,  1, 83, 41, 57, 54,  1, 72, 58,
        61, 53, 61, 58, 63, 56, 68,  1, 50, 67])


In [9]:
# Split the dataset in training and validation
# I want to store 90% of characters for training and 10% for validation
n = int(0.9*len(text))
train_data = data[:n]
val_data = data[n:]

In [16]:
import random

context_length = 8 # This is maximum number of tokens that are allowed to fit in the context. 

def get_batch(data, batch_size = 8, context_length = 8):
    # Get batch_size random indices in the data
    random_idx = random.sample(range(len(data)-context_length), batch_size)
    # Pluck the next character after each random index
    inputs = torch.zeros((batch_size, context_length), dtype=torch.long)
    targets = torch.zeros((batch_size, context_length), dtype=torch.long)
    
    for i in range(batch_size):
        inputs[i,:] = data[random_idx[i]:random_idx[i]+context_length]
        targets[i,:] = data[random_idx[i]+1:random_idx[i]+context_length+1]
    
    return inputs, targets

(tensor([[50, 52, 60,  1],
         [69, 54, 53,  1],
         [61,  1, 57, 50],
         [55, 64, 67,  1],
         [64, 53,  1, 57],
         [ 1, 57, 54, 67],
         [ 1, 50, 63, 53],
         [52, 57,  1, 62]]),
 tensor([[52, 60,  1, 51],
         [54, 53,  1, 69],
         [ 1, 57, 50, 53],
         [64, 67,  1, 53],
         [53,  1, 57, 54],
         [57, 54, 67,  1],
         [50, 63, 53,  1],
         [57,  1, 62, 54]]))

In [31]:
# Let's build a bigram language model 
import torch.nn as nn 
import torch.nn.functional as F

class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super(BigramLM, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, inputs):
        # Receive inputs in the form (B, L, C) (batch times seq length times num classes (tokens))
        # Embed the inputs
        embeddings = self.embedding_table(inputs)
        return embeddings
    
    def generate(self, inputs, max_new_tokens=10):
        # Generate next tokens given the inputs (B, L, C)
        for i in range(max_new_tokens):                   
            # Embed the inputs
            logits = self(inputs)
            # Get the last logit 
            last_logit = logits[:, -1, :] # Get the last element in the length dimension
            # Compute the probabilities
            probs = F.softmax(last_logit, dim=1)
            # Get the next token
            next_token = torch.multinomial(probs, num_samples=1)
            # Append to the inputs
            inputs = torch.cat((inputs, next_token), dim=1)
            
        return inputs

In [185]:
model = BigramLM(len(chars))

inputs_generate = torch.zeros((1,1), dtype=torch.long)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))


/ki!V(Jy—J{!ac!Bc3!ZwdsB07(,W7I.QM3(H8[cRsU2IPo-eU2
éWbfWoNr)nWZLAS.‘}D(5’J1rP
IDF.s5éJvdSPL:!—i

T]


In [186]:
# Let's train this model
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Hyperparameters
learning_rate = 0.01
momentum = 0.9
batch_size = 32
context_length = 8 # Is not0 really taken into consideration here because we are using a bigram model
num_iterations = 100000
device = "cpu"

# Get the optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
# Get loss function
loss_fn = CrossEntropyLoss()
# Send the model to the device
model.to(device)

loss_train = 0 
# Start training loop
for i in tqdm(range(num_iterations)):
    # Get a batch
    inputs, targets = get_batch(train_data, batch_size=8, context_length=8)
    # Send the inputs and targets to device
    inputs = inputs.to(device)
    targets = targets.to(device)
    # Get the predictions
    predictions = model(inputs)
    # We have to reshape the predictions and the targets to use cross entropy 
    B, L, C = predictions.shape
    predictions = predictions.view(B*L, C)
    targets = targets.view(B*L)
    # Compute the loss
    loss = loss_fn(predictions, targets)
    # Accumulate the loss
    loss_train += loss.item()
    # Zero the gradients
    optimizer.zero_grad()
    # Compute the gradients
    loss.backward()
    # Update the parameters
    optimizer.step()

    if (i%1000)==0 and i!=0:
        print(loss_train/1000)
        loss_train = 0

  1%|          | 1081/100000 [00:01<02:13, 738.40it/s]

4.531553933858872


  2%|▏         | 2119/100000 [00:03<02:24, 677.58it/s]

3.981821083068848


  3%|▎         | 3102/100000 [00:04<02:30, 645.76it/s]

3.659087922096252


  4%|▍         | 4098/100000 [00:06<02:19, 687.01it/s]

3.461212954521179


  5%|▌         | 5136/100000 [00:07<02:18, 687.16it/s]

3.3141537029743193


  6%|▌         | 6097/100000 [00:09<03:06, 504.44it/s]

3.2020576977729798


  7%|▋         | 7067/100000 [00:12<03:19, 465.17it/s]

3.128999495267868


  8%|▊         | 8029/100000 [00:14<03:40, 416.29it/s]

3.0469764938354493


  9%|▉         | 9070/100000 [00:16<03:48, 397.47it/s]

2.9944788975715637


 10%|█         | 10065/100000 [00:19<03:31, 426.21it/s]

2.9433668100833894


 11%|█         | 11057/100000 [00:21<03:25, 433.84it/s]

2.903965374946594


 12%|█▏        | 12049/100000 [00:23<03:20, 439.25it/s]

2.8573172640800477


 13%|█▎        | 13062/100000 [00:26<03:11, 455.06it/s]

2.8355580937862395


 14%|█▍        | 14073/100000 [00:28<03:03, 467.62it/s]

2.798812198638916


 15%|█▌        | 15093/100000 [00:30<02:53, 490.51it/s]

2.779308315515518


 16%|█▌        | 16071/100000 [00:32<02:58, 469.14it/s]

2.751311415910721


 17%|█▋        | 17055/100000 [00:34<02:49, 489.09it/s]

2.7435945060253144


 18%|█▊        | 18090/100000 [00:37<02:57, 462.61it/s]

2.7174368586540223


 19%|█▉        | 19041/100000 [00:39<03:09, 427.90it/s]

2.7227194583415986


 20%|██        | 20074/100000 [00:41<02:41, 495.13it/s]

2.695326076745987


 21%|██        | 21048/100000 [00:43<03:00, 437.76it/s]

2.676213918209076


 22%|██▏       | 22087/100000 [00:45<02:45, 469.55it/s]

2.6850045149326323


 23%|██▎       | 23040/100000 [00:48<02:44, 468.07it/s]

2.6635978989601137


 24%|██▍       | 24066/100000 [00:50<02:35, 489.85it/s]

2.6479468665122985


 25%|██▌       | 25072/100000 [00:52<02:17, 543.97it/s]

2.654685037612915


 26%|██▌       | 26093/100000 [00:54<02:41, 456.42it/s]

2.639138548374176


 27%|██▋       | 27042/100000 [00:56<02:18, 525.06it/s]

2.6216793146133424


 28%|██▊       | 28043/100000 [00:58<02:46, 431.67it/s]

2.6269746313095093


 29%|██▉       | 29102/100000 [01:00<02:27, 481.34it/s]

2.62004757809639


 30%|███       | 30054/100000 [01:02<02:44, 424.24it/s]

2.6209315712451935


 31%|███       | 31085/100000 [01:05<02:17, 499.68it/s]

2.5944921951293947


 32%|███▏      | 32081/100000 [01:07<02:08, 528.04it/s]

2.6002886902093887


 33%|███▎      | 33071/100000 [01:09<02:10, 512.81it/s]

2.582914157152176


 34%|███▍      | 34082/100000 [01:11<02:10, 503.58it/s]

2.5939671845436094


 35%|███▌      | 35096/100000 [01:13<02:10, 497.33it/s]

2.5889721455574035


 36%|███▌      | 36056/100000 [01:15<02:34, 412.88it/s]

2.5737917597293856


 37%|███▋      | 37058/100000 [01:17<02:28, 423.26it/s]

2.5786063179969787


 38%|███▊      | 38091/100000 [01:20<02:11, 469.16it/s]

2.558868805885315


 39%|███▉      | 39047/100000 [01:22<02:17, 444.50it/s]

2.5633427684307097


 40%|████      | 40053/100000 [01:24<02:05, 477.75it/s]

2.5725590577125548


 41%|████      | 41084/100000 [01:27<02:12, 444.58it/s]

2.5640987782478333


 42%|████▏     | 42073/100000 [01:29<02:25, 396.84it/s]

2.557411222696304


 43%|████▎     | 43084/100000 [01:31<02:04, 456.69it/s]

2.5483398950099945


 44%|████▍     | 44068/100000 [01:33<02:06, 443.00it/s]

2.5525317821502687


 45%|████▌     | 45080/100000 [01:35<01:59, 461.04it/s]

2.5494677481651307


 46%|████▌     | 46064/100000 [01:38<01:57, 459.95it/s]

2.5373767799139024


 47%|████▋     | 47045/100000 [01:40<01:50, 479.67it/s]

2.5420116231441496


 48%|████▊     | 48070/100000 [01:42<01:45, 491.69it/s]

2.530416499376297


 49%|████▉     | 49079/100000 [01:44<01:41, 500.83it/s]

2.5355921404361723


 50%|█████     | 50082/100000 [01:46<01:56, 428.09it/s]

2.535588314771652


 51%|█████     | 51075/100000 [01:49<01:40, 486.72it/s]

2.530235331058502


 52%|█████▏    | 52084/100000 [01:51<01:36, 497.60it/s]

2.5327279260158537


 53%|█████▎    | 53048/100000 [01:53<01:53, 415.23it/s]

2.5199726891517638


 54%|█████▍    | 54077/100000 [01:55<01:29, 514.40it/s]

2.509927887201309


 55%|█████▌    | 55046/100000 [01:57<01:28, 508.51it/s]

2.5223830902576445


 56%|█████▌    | 56071/100000 [02:00<01:31, 479.87it/s]

2.515603426337242


 57%|█████▋    | 57072/100000 [02:02<01:38, 437.75it/s]

2.51166874063015


 58%|█████▊    | 58044/100000 [02:04<01:11, 587.82it/s]

2.5104210658073427


 59%|█████▉    | 59101/100000 [02:05<01:01, 668.60it/s]

2.5098035786151884


 60%|██████    | 60088/100000 [02:07<00:54, 736.63it/s]

2.499261964559555


 61%|██████    | 61106/100000 [02:08<00:54, 707.78it/s]

2.5112182124853133


 62%|██████▏   | 62136/100000 [02:10<00:54, 693.91it/s]

2.5067494237422943


 63%|██████▎   | 63136/100000 [02:11<00:51, 718.13it/s]

2.511850376367569


 64%|██████▍   | 64168/100000 [02:13<00:45, 790.13it/s]

2.500953978776932


 65%|██████▌   | 65083/100000 [02:14<00:48, 722.15it/s]

2.501823412895203


 66%|██████▌   | 66104/100000 [02:15<00:46, 728.93it/s]

2.4912198399305345


 67%|██████▋   | 67112/100000 [02:17<00:45, 717.94it/s]

2.498514130115509


 68%|██████▊   | 68069/100000 [02:18<00:44, 718.12it/s]

2.498703094244003


 69%|██████▉   | 69154/100000 [02:20<00:43, 708.56it/s]

2.494292546272278


 70%|███████   | 70073/100000 [02:21<00:42, 708.00it/s]

2.499190218091011


 71%|███████   | 71117/100000 [02:22<00:39, 723.55it/s]

2.4878960676193236


 72%|███████▏  | 72117/100000 [02:24<00:39, 707.41it/s]

2.480271534204483


 73%|███████▎  | 73111/100000 [02:25<00:37, 708.67it/s]

2.4910287806987763


 74%|███████▍  | 74061/100000 [02:27<00:36, 720.00it/s]

2.481025946378708


 75%|███████▌  | 75079/100000 [02:28<00:35, 700.84it/s]

2.4824925405979155


 76%|███████▌  | 76099/100000 [02:30<00:34, 690.75it/s]

2.490652072191238


 77%|███████▋  | 77062/100000 [02:31<00:31, 718.58it/s]

2.4833962441682815


 78%|███████▊  | 78086/100000 [02:32<00:30, 730.31it/s]

2.479735641479492


 79%|███████▉  | 79144/100000 [02:34<00:30, 680.38it/s]

2.491485904932022


 80%|████████  | 80115/100000 [02:35<00:30, 658.62it/s]

2.478208238363266


 81%|████████  | 81115/100000 [02:37<00:25, 727.17it/s]

2.4843224444389342


 82%|████████▏ | 82128/100000 [02:38<00:25, 702.57it/s]

2.4902581593990325


 83%|████████▎ | 83116/100000 [02:40<00:24, 700.16it/s]

2.4761317462921144


 84%|████████▍ | 84160/100000 [02:41<00:21, 737.29it/s]

2.476588787674904


 85%|████████▌ | 85076/100000 [02:43<00:22, 659.87it/s]

2.4781423003673555


 86%|████████▌ | 86106/100000 [02:44<00:21, 653.15it/s]

2.4778020551204682


 87%|████████▋ | 87125/100000 [02:46<00:18, 703.08it/s]

2.4742320650815963


 88%|████████▊ | 88094/100000 [02:47<00:17, 665.43it/s]

2.4740773591995238


 89%|████████▉ | 89072/100000 [02:49<00:16, 652.51it/s]

2.4720924017429353


 90%|█████████ | 90074/100000 [02:50<00:20, 488.12it/s]

2.4641541171073915


 91%|█████████ | 91146/100000 [02:52<00:12, 720.57it/s]

2.4739537460803986


 92%|█████████▏| 92101/100000 [02:53<00:11, 715.80it/s]

2.473167445898056


 93%|█████████▎| 93108/100000 [02:55<00:09, 690.65it/s]

2.4770742864608764


 94%|█████████▍| 94074/100000 [02:56<00:08, 720.03it/s]

2.4739247189760207


 95%|█████████▌| 95097/100000 [02:58<00:06, 734.13it/s]

2.4682924371957777


 96%|█████████▌| 96143/100000 [02:59<00:05, 693.69it/s]

2.4696339935064318


 97%|█████████▋| 97107/100000 [03:01<00:04, 708.95it/s]

2.470742347121239


 98%|█████████▊| 98128/100000 [03:02<00:02, 710.17it/s]

2.45628770840168


 99%|█████████▉| 99122/100000 [03:03<00:01, 699.51it/s]

2.463740375876427


100%|██████████| 100000/100000 [03:05<00:00, 539.97it/s]


In [190]:
# Let's try to generate
inputs_generate = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))


“I’te warce thes blimier, gng J‘WCAroshetothe.” Ry iroutore8’s. thergs I{jE—ld, ouse soono oure lo t


Ok, enough for the bigramLM, let's try to switch to some better architecture (that take into consideration the context)

In [191]:
# Let's try with an RNN

class RNNLM(nn.Module):
    def __init__(self, vocab_size, hidden_size=32):
        super(RNNLM, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        # Classifier to predict the next token
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs):
        # Embed the inputs
        embeddings = self.embedding_table(inputs)
        # Pass the embeddings through the RNN
        outputs, _ = self.rnn(embeddings)
        # Pass the outputs through the classifier
        logits = self.linear(outputs)
        
        return logits

    def generate(self, inputs, max_new_tokens=10):
        # Generate next tokens given the inputs (B, L, C)
        for _ in range(max_new_tokens):
            # Embed the inputs
            embeddings = self.embedding_table(inputs)
            # Feed the RNN with the embeddings 
            _, hidden = self.rnn(embeddings)
            # Remove the first dimension (get only the last hidden state for each element in the batch)
            hidden = hidden.squeeze(0) 
            # Project to the output classes
            logits = self.linear(hidden)
            # Convert to probabilities
            probs = F.softmax(logits, dim=1)
            # Get the next token
            next_token = torch.multinomial(probs, num_samples=1)
            # Concatenate to input
            inputs = torch.cat((inputs, next_token), dim=1)
            
        return inputs

In [130]:
model = RNNLM(len(chars))
# Get a batch 
inputs, targets = get_batch(train_data, batch_size=8, context_length=8)

inputs_generate = torch.zeros((1,1), dtype=torch.long)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))


wêT…aw};9EShV?e;gJ5-VeXM0DL7?Mj3d16/êê”QX1]DOC5xbU-—!L?SOk—v,,iA2pVRCfIa’ N{x;iuS}êE
QSa)}V ]RkR—““d


In [194]:
# Let's train this model
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Hyperparameters
learning_rate = 0.01
momentum = 0.9
batch_size = 32
context_length = 8 # Here it is important because we are using an RNN
hidden_size = 32
num_iterations = 100000
device = "cpu"

# Create the model 
model = RNNLM(len(chars), hidden_size=hidden_size)
# Send the model to the device
model.to(device)

# Get the optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
# Get loss function
loss_fn = CrossEntropyLoss()


# Start training loop
loss_train = 0
for i in tqdm(range(num_iterations)):
    # Get a batch
    inputs, targets = get_batch(train_data, batch_size=8, context_length=8)
    # Send the inputs and targets to device
    inputs = inputs.to(device)
    targets = targets.to(device)
    # Get the predictions
    predictions = model(inputs)
    # We have to reshape the predictions and the targets to use cross entropy 
    B, L, C = predictions.shape
    predictions = predictions.view(B*L, C)
    targets = targets.view(B*L)
    # Compute the loss
    loss = loss_fn(predictions, targets)
    # Accumulate the loss
    loss_train += loss.item()
    # Zero the gradients
    optimizer.zero_grad()
    # Compute the gradients
    loss.backward()
    # Update the parameters
    optimizer.step()

    if (i%1000)==0 and i!=0:
        print(loss_train/1000)
        loss_train = 0

  1%|          | 1104/100000 [00:02<02:46, 594.56it/s]

2.697427982211113


  2%|▏         | 2071/100000 [00:03<02:43, 600.52it/s]

2.341876891255379


  3%|▎         | 3094/100000 [00:05<02:40, 603.10it/s]

2.238795204401016


  4%|▍         | 4098/100000 [00:07<03:00, 531.63it/s]

2.2008022384643553


  5%|▌         | 5061/100000 [00:09<02:49, 559.89it/s]

2.1696298559904097


  6%|▌         | 6114/100000 [00:11<02:44, 570.66it/s]

2.145216317296028


  7%|▋         | 7080/100000 [00:12<02:39, 582.19it/s]

2.1397430332899092


  8%|▊         | 8084/100000 [00:14<02:39, 575.99it/s]

2.1239597799777985


  9%|▉         | 9095/100000 [00:16<02:34, 588.83it/s]

2.1045547369718554


 10%|█         | 10082/100000 [00:18<02:54, 515.37it/s]

2.100814215898514


 11%|█         | 11028/100000 [00:20<03:34, 415.31it/s]

2.099411347031593


 12%|█▏        | 12011/100000 [00:22<02:33, 572.65it/s]

2.081422516345978


 13%|█▎        | 13067/100000 [00:24<02:40, 541.44it/s]

2.0842183727025985


 14%|█▍        | 14076/100000 [00:26<03:31, 406.73it/s]

2.0747576047182084


 15%|█▌        | 15051/100000 [00:29<03:26, 411.92it/s]

2.0658858429193496


 16%|█▌        | 16065/100000 [00:31<03:09, 442.18it/s]

2.0405577301979063


 17%|█▋        | 17059/100000 [00:33<03:15, 423.29it/s]

2.0711536506414414


 18%|█▊        | 18065/100000 [00:36<03:03, 446.02it/s]

2.051444857954979


 19%|█▉        | 19052/100000 [00:38<02:58, 452.37it/s]

2.0651046788692473


 20%|██        | 20073/100000 [00:40<03:11, 416.34it/s]

2.043838756322861


 21%|██        | 21065/100000 [00:43<02:58, 442.20it/s]

2.0409017440080643


 22%|██▏       | 22068/100000 [00:45<02:39, 488.43it/s]

2.051616707444191


 23%|██▎       | 23076/100000 [00:47<03:06, 411.52it/s]

2.047137320280075


 24%|██▍       | 24060/100000 [00:50<02:58, 426.30it/s]

2.0392973841428756


 25%|██▌       | 25054/100000 [00:52<02:44, 455.98it/s]

2.0344725811481474


 26%|██▌       | 26076/100000 [00:55<02:46, 444.36it/s]

2.0360444580316543


 27%|██▋       | 27036/100000 [00:57<02:32, 479.74it/s]

2.037434601902962


 28%|██▊       | 28060/100000 [00:59<02:27, 487.14it/s]

2.042479345083237


 29%|██▉       | 29047/100000 [01:01<02:27, 482.29it/s]

2.0300111948251724


 30%|███       | 30073/100000 [01:04<02:30, 463.93it/s]

2.0282171087265013


 31%|███       | 31063/100000 [01:06<02:25, 472.58it/s]

2.0307335171699523


 32%|███▏      | 32054/100000 [01:08<02:26, 464.54it/s]

2.024634853839874


 33%|███▎      | 33077/100000 [01:11<02:30, 445.27it/s]

2.0246564568281173


 34%|███▍      | 34066/100000 [01:13<02:23, 460.22it/s]

2.0268342990875245


 35%|███▌      | 35086/100000 [01:16<02:24, 449.23it/s]

2.0168044986724856


 36%|███▌      | 36060/100000 [01:18<02:36, 409.00it/s]

2.02606869828701


 37%|███▋      | 37033/100000 [01:20<02:21, 443.67it/s]

2.0217321714162826


 38%|███▊      | 38084/100000 [01:23<02:36, 396.04it/s]

2.025789217829704


 39%|███▉      | 39082/100000 [01:25<02:17, 444.04it/s]

2.0061296796798707


 40%|████      | 40052/100000 [01:27<02:14, 444.33it/s]

2.014076001763344


 41%|████      | 41097/100000 [01:30<02:07, 460.55it/s]

2.0127343744039536


 42%|████▏     | 42041/100000 [01:32<02:28, 391.59it/s]

2.0196012086868285


 43%|████▎     | 43079/100000 [01:35<02:15, 420.50it/s]

2.0105534744262696


 44%|████▍     | 44076/100000 [01:37<02:10, 429.37it/s]

2.0165663855075837


 45%|████▌     | 45041/100000 [01:40<03:08, 291.18it/s]

2.0099000869989396


 46%|████▌     | 46067/100000 [01:43<02:45, 325.29it/s]

2.029762497425079


 47%|████▋     | 47039/100000 [01:47<03:18, 266.47it/s]

2.0143785809278487


 48%|████▊     | 48064/100000 [01:50<02:16, 381.68it/s]

2.008682322859764


 49%|████▉     | 49071/100000 [01:52<01:54, 444.77it/s]

2.0168909364938736


 50%|█████     | 50038/100000 [01:55<02:05, 397.72it/s]

2.003345633983612


 51%|█████     | 51041/100000 [01:57<01:57, 417.93it/s]

1.9908675293922424


 52%|█████▏    | 52039/100000 [02:00<02:06, 379.50it/s]

2.0116505811214447


 53%|█████▎    | 53068/100000 [02:03<02:05, 374.55it/s]

2.0026813308000566


 54%|█████▍    | 54090/100000 [02:05<01:32, 498.43it/s]

2.0152445281744003


 55%|█████▌    | 55096/100000 [02:07<01:28, 505.20it/s]

2.006922811150551


 56%|█████▌    | 56105/100000 [02:09<01:23, 525.39it/s]

1.9990010882616043


 57%|█████▋    | 57086/100000 [02:11<01:37, 439.17it/s]

1.9950385587215425


 58%|█████▊    | 58069/100000 [02:14<01:38, 424.96it/s]

2.0084866079092025


 59%|█████▉    | 59048/100000 [02:16<01:31, 449.10it/s]

2.0145816526412963


 60%|██████    | 60055/100000 [02:18<01:27, 455.37it/s]

2.0111402522325514


 61%|██████    | 61099/100000 [02:21<01:23, 465.73it/s]

2.0158890182971954


 62%|██████▏   | 62094/100000 [02:23<01:16, 495.37it/s]

2.0093376368284224


 63%|██████▎   | 63091/100000 [02:25<01:12, 508.76it/s]

2.0038538395166396


 64%|██████▍   | 64027/100000 [02:28<01:59, 300.23it/s]

2.010668438434601


 65%|██████▌   | 65077/100000 [02:31<01:18, 442.68it/s]

1.998256903409958


 66%|██████▌   | 66063/100000 [02:34<01:30, 373.16it/s]

2.0000395320653914


 67%|██████▋   | 67062/100000 [02:36<01:26, 382.79it/s]

2.0044624308347703


 68%|██████▊   | 68042/100000 [02:40<01:45, 303.75it/s]

2.0011655097007752


 69%|██████▉   | 69068/100000 [02:43<01:19, 390.53it/s]

1.9983348370790481


 70%|███████   | 70065/100000 [02:45<01:21, 369.24it/s]

1.998702906370163


 71%|███████   | 71083/100000 [02:48<01:17, 374.06it/s]

2.0111594500541687


 72%|███████▏  | 72044/100000 [02:51<01:21, 343.11it/s]

1.996020436167717


 73%|███████▎  | 73058/100000 [02:54<01:28, 303.68it/s]

2.0044818972349168


 74%|███████▍  | 74040/100000 [02:58<01:29, 288.47it/s]

1.9975486896038055


 75%|███████▌  | 75038/100000 [03:02<01:31, 271.34it/s]

2.0015144609212876


 76%|███████▌  | 76073/100000 [03:05<01:03, 379.37it/s]

1.9947578362226486


 77%|███████▋  | 77057/100000 [03:07<01:04, 353.46it/s]

2.0052056097984314


 78%|███████▊  | 78091/100000 [03:10<00:52, 418.71it/s]

1.9945884742736817


 79%|███████▉  | 79065/100000 [03:13<00:58, 360.51it/s]

2.006944186925888


 80%|████████  | 80066/100000 [03:16<00:58, 339.09it/s]

2.001055313706398


 81%|████████  | 81065/100000 [03:19<00:54, 348.62it/s]

1.9946700571775435


 82%|████████▏ | 82070/100000 [03:22<00:49, 361.09it/s]

1.9966384836435318


 83%|████████▎ | 83032/100000 [03:25<00:44, 377.63it/s]

1.995735340476036


 84%|████████▍ | 84050/100000 [03:28<00:43, 369.22it/s]

1.998007292509079


 85%|████████▌ | 85054/100000 [03:31<00:39, 375.49it/s]

1.9950760530233382


 86%|████████▌ | 86047/100000 [03:34<00:41, 332.52it/s]

1.9995526415109635


 87%|████████▋ | 87043/100000 [03:37<00:38, 332.42it/s]

1.9925761766433716


 88%|████████▊ | 88032/100000 [03:40<00:34, 344.43it/s]

1.988433430314064


 89%|████████▉ | 89038/100000 [03:43<00:37, 294.41it/s]

2.0026926609277726


 90%|█████████ | 90051/100000 [03:46<00:25, 384.71it/s]

1.990632580637932


 91%|█████████ | 91070/100000 [03:49<00:23, 379.61it/s]

1.9958838583230973


 92%|█████████▏| 92068/100000 [03:52<00:21, 367.49it/s]

1.9970369639396668


 93%|█████████▎| 93066/100000 [03:55<00:18, 382.60it/s]

1.9975830596685409


 94%|█████████▍| 94058/100000 [03:58<00:18, 326.56it/s]

1.998139777779579


 95%|█████████▌| 95072/100000 [04:00<00:10, 454.27it/s]

1.9913215004205704


 96%|█████████▌| 96054/100000 [04:03<00:09, 413.88it/s]

2.0079609355926515


 97%|█████████▋| 97047/100000 [04:06<00:08, 332.77it/s]

2.004225052833557


 98%|█████████▊| 98054/100000 [04:09<00:05, 371.28it/s]

2.000637235164642


 99%|█████████▉| 99041/100000 [04:12<00:02, 342.00it/s]

1.9963573273420334


100%|██████████| 100000/100000 [04:14<00:00, 392.49it/s]


In [198]:
inputs_generate = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))



“The herisses to said, bing tridane say as calmeed. Gord,s to up. HEim besess her hersesserselped o


In [200]:
# Let's try with an RNN

class GRULM(nn.Module):
    def __init__(self, vocab_size, hidden_size=32):
        super(GRULM, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        # Classifier to predict the next token
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs):
        # Embed the inputs
        embeddings = self.embedding_table(inputs)
        # Pass the embeddings through the RNN
        outputs, _ = self.rnn(embeddings)
        # Pass the outputs through the classifier
        logits = self.linear(outputs)
        
        return logits

    def generate(self, inputs, max_new_tokens=10):
        # Generate next tokens given the inputs (B, L, C)
        for _ in range(max_new_tokens):
            # Embed the inputs
            embeddings = self.embedding_table(inputs)
            # Feed the RNN with the embeddings 
            _, hidden = self.rnn(embeddings)
            # Remove the first dimension (get only the last hidden state for each element in the batch)
            hidden = hidden.squeeze(0) 
            # Project to the output classes
            logits = self.linear(hidden)
            # Convert to probabilities
            probs = F.softmax(logits, dim=1)
            # Get the next token
            next_token = torch.multinomial(probs, num_samples=1)
            # Concatenate to input
            inputs = torch.cat((inputs, next_token), dim=1)
            
        return inputs

In [201]:
model = GRULM(len(chars))

inputs_generate = torch.zeros((1,1), dtype=torch.long)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))


51aFv;[i}vs/‘jl9YTo7
2Qm5AXr]slimLSIéeJ!J.VK)SV
ha“vabUiCsT14Pp4wéP.ufnyNxhKvwa;Ix2VGlpYokz4!isz .f]


In [203]:
# Let's train this model
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Hyperparameters
learning_rate = 0.01
momentum = 0.9
batch_size = 32
context_length = 8 # Here it is important because we are using an RNN
hidden_size = 32
num_iterations = 100000
device = "cpu"

# Get the model 
model = GRULM(len(chars))
# Send the model to the device
model.to(device)

# Get the optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
# Get loss function
loss_fn = CrossEntropyLoss()

# Start training loop
loss_train = 0
for i in tqdm(range(num_iterations)):
    # Get a batch
    inputs, targets = get_batch(train_data, batch_size=8, context_length=8)
    # Send the inputs and targets to device
    inputs = inputs.to(device)
    targets = targets.to(device)
    # Get the predictions
    predictions = model(inputs)
    # We have to reshape the predictions and the targets to use cross entropy 
    B, L, C = predictions.shape
    predictions = predictions.view(B*L, C)
    targets = targets.view(B*L)
    # Compute the loss
    loss = loss_fn(predictions, targets)
    # Accumulate the loss
    loss_train += loss.item()
    # Zero the gradients
    optimizer.zero_grad()
    # Compute the gradients
    loss.backward()
    # Update the parameters
    optimizer.step()

    if (i%1000)==0 and i!=0:
        print(loss_train/1000)
        loss_train = 0

  1%|          | 1042/100000 [00:03<06:42, 245.83it/s]

2.861344276666641


  2%|▏         | 2035/100000 [00:07<06:49, 239.42it/s]

2.4081958810091018


  3%|▎         | 3054/100000 [00:11<05:08, 314.17it/s]

2.315821086883545


  4%|▍         | 4050/100000 [00:14<05:08, 311.46it/s]

2.230983748078346


  5%|▌         | 5035/100000 [00:18<04:58, 318.55it/s]

2.192366425514221


  6%|▌         | 6058/100000 [00:21<04:53, 320.42it/s]

2.145113414287567


  7%|▋         | 7043/100000 [00:24<04:58, 311.31it/s]

2.129057896018028


  8%|▊         | 8046/100000 [00:28<05:51, 261.91it/s]

2.111267743110657


  9%|▉         | 9042/100000 [00:31<05:25, 279.67it/s]

2.1014449887275695


 10%|█         | 10050/100000 [00:35<05:01, 298.46it/s]

2.097101477622986


 11%|█         | 11043/100000 [00:39<05:25, 273.49it/s]

2.0796413689851763


 12%|█▏        | 12050/100000 [00:43<05:15, 278.38it/s]

2.0530474630594253


 13%|█▎        | 13049/100000 [00:47<05:08, 281.44it/s]

2.0662549761533735


 14%|█▍        | 14044/100000 [00:51<06:25, 222.94it/s]

2.0490580998659134


 15%|█▌        | 15036/100000 [00:55<04:43, 299.63it/s]

2.0473166716098787


 16%|█▌        | 16046/100000 [00:58<05:47, 241.81it/s]

2.034849860906601


 17%|█▋        | 17042/100000 [01:03<05:08, 269.26it/s]

2.030442239522934


 18%|█▊        | 18033/100000 [01:06<05:17, 257.87it/s]

2.0210269227027893


 19%|█▉        | 19022/100000 [01:10<05:33, 242.60it/s]

2.017351523399353


 20%|██        | 20049/100000 [01:15<05:22, 247.85it/s]

2.0095540920495987


 21%|██        | 21051/100000 [01:19<05:16, 249.36it/s]

2.008743499159813


 22%|██▏       | 22029/100000 [01:23<06:00, 215.99it/s]

2.000720155477524


 23%|██▎       | 23035/100000 [01:28<06:09, 208.04it/s]

1.9998668916225433


 24%|██▍       | 24026/100000 [01:33<07:13, 175.06it/s]

1.9922770217657089


 25%|██▌       | 25038/100000 [01:38<05:30, 226.83it/s]

1.9890901819467544


 26%|██▌       | 26041/100000 [01:42<04:42, 261.72it/s]

1.988663892865181


 27%|██▋       | 27046/100000 [01:46<04:31, 269.02it/s]

1.9913643641471863


 28%|██▊       | 28038/100000 [01:50<04:16, 280.53it/s]

1.9876826673746109


 29%|██▉       | 29043/100000 [01:54<03:54, 302.45it/s]

1.9754634321928024


 30%|███       | 30049/100000 [01:58<04:34, 254.84it/s]

1.9822206184864044


 31%|███       | 31040/100000 [02:01<04:03, 283.48it/s]

1.9736928514242171


 32%|███▏      | 32036/100000 [02:05<03:46, 300.65it/s]

1.9700905116796494


 33%|███▎      | 33046/100000 [02:08<03:40, 304.03it/s]

1.9605099716186523


 34%|███▍      | 34061/100000 [02:12<03:37, 303.39it/s]

1.9738566361665726


 35%|███▌      | 35057/100000 [02:15<03:33, 304.72it/s]

1.9773660751581192


 36%|███▌      | 36050/100000 [02:18<03:24, 312.42it/s]

1.9717947361469268


 37%|███▋      | 37054/100000 [02:22<03:23, 309.82it/s]

1.9673032451868058


 38%|███▊      | 38046/100000 [02:25<03:31, 292.83it/s]

1.9581158964633942


 39%|███▉      | 39041/100000 [02:28<03:28, 292.84it/s]

1.9603688703775406


 40%|████      | 40037/100000 [02:32<03:18, 302.45it/s]

1.962963625073433


 41%|████      | 41030/100000 [02:35<03:13, 304.27it/s]

1.9684980162382126


 42%|████▏     | 42055/100000 [02:39<03:09, 306.15it/s]

1.957919928789139


 43%|████▎     | 43030/100000 [02:42<03:03, 309.92it/s]

1.9403708752393722


 44%|████▍     | 44048/100000 [02:45<03:03, 304.40it/s]

1.955503058552742


 45%|████▌     | 45036/100000 [02:49<03:06, 294.23it/s]

1.9577618519067763


 46%|████▌     | 46030/100000 [02:52<02:54, 308.59it/s]

1.9462543556690215


 47%|████▋     | 47038/100000 [02:55<02:52, 306.16it/s]

1.9485466678142547


 48%|████▊     | 48056/100000 [02:59<02:56, 293.80it/s]

1.9463516894578934


 49%|████▉     | 49052/100000 [03:02<02:46, 306.24it/s]

1.9518857388496398


 50%|█████     | 50048/100000 [03:06<02:44, 303.50it/s]

1.9450053242444991


 51%|█████     | 51026/100000 [03:09<02:43, 299.39it/s]

1.946698520898819


 52%|█████▏    | 52053/100000 [03:12<02:37, 304.08it/s]

1.9417923276424407


 53%|█████▎    | 53040/100000 [03:16<02:42, 289.44it/s]

1.93768066072464


 54%|█████▍    | 54039/100000 [03:19<02:43, 280.31it/s]

1.9524926654100418


 55%|█████▌    | 55069/100000 [03:23<02:28, 302.80it/s]

1.9410138763189315


 56%|█████▌    | 56037/100000 [03:26<02:27, 298.97it/s]

1.9410820087194443


 57%|█████▋    | 57012/100000 [03:29<02:24, 296.53it/s]

1.942891636133194


 58%|█████▊    | 58057/100000 [03:33<02:21, 297.10it/s]

1.9301882266998291


 59%|█████▉    | 59045/100000 [03:36<02:14, 304.43it/s]

1.9335816646814346


 60%|██████    | 60058/100000 [03:40<02:11, 303.39it/s]

1.9391751631498337


 61%|██████    | 61040/100000 [03:43<02:11, 296.14it/s]

1.925966038107872


 62%|██████▏   | 62054/100000 [03:47<02:07, 296.60it/s]

1.9287257472276687


 63%|██████▎   | 63034/100000 [03:50<02:02, 302.11it/s]

1.925874592423439


 64%|██████▍   | 64058/100000 [03:53<01:58, 303.02it/s]

1.935775196790695


 65%|██████▌   | 65062/100000 [03:57<01:53, 306.92it/s]

1.9268541572093965


 66%|██████▌   | 66049/100000 [04:00<01:51, 303.32it/s]

1.9360517284870147


 67%|██████▋   | 67050/100000 [04:03<02:00, 273.56it/s]

1.9289235553741455


 68%|██████▊   | 68030/100000 [04:07<01:44, 307.38it/s]

1.9290806021690368


 69%|██████▉   | 69054/100000 [04:10<01:46, 290.54it/s]

1.9261720284223556


 70%|███████   | 70044/100000 [04:13<01:38, 303.48it/s]

1.9329953635931014


 71%|███████   | 71037/100000 [04:17<01:37, 298.43it/s]

1.9289099599123


 72%|███████▏  | 72046/100000 [04:20<01:29, 312.21it/s]

1.9214985531568527


 73%|███████▎  | 73057/100000 [04:24<01:30, 298.01it/s]

1.9214449883699418


 74%|███████▍  | 74034/100000 [04:27<01:26, 301.31it/s]

1.9320813492536544


 75%|███████▌  | 75037/100000 [04:30<01:24, 294.29it/s]

1.9244192523956298


 76%|███████▌  | 76032/100000 [04:34<01:22, 291.07it/s]

1.9123195533752442


 77%|███████▋  | 77049/100000 [04:37<01:20, 285.27it/s]

1.9264770700931548


 78%|███████▊  | 78057/100000 [04:40<01:10, 312.49it/s]

1.9189240715503693


 79%|███████▉  | 79033/100000 [04:44<01:08, 304.53it/s]

1.9187787928581237


 80%|████████  | 80053/100000 [04:47<01:08, 292.73it/s]

1.9081812020540236


 81%|████████  | 81041/100000 [04:51<01:02, 303.20it/s]

1.9163211444616317


 82%|████████▏ | 82043/100000 [04:54<00:59, 303.72it/s]

1.914406391263008


 83%|████████▎ | 83042/100000 [04:57<01:02, 271.52it/s]

1.91994831097126


 84%|████████▍ | 84056/100000 [05:01<00:52, 302.67it/s]

1.9168586745262146


 85%|████████▌ | 85056/100000 [05:04<00:49, 304.11it/s]

1.909198379278183


 86%|████████▌ | 86065/100000 [05:08<00:45, 309.36it/s]

1.9184999942779541


 87%|████████▋ | 87057/100000 [05:11<00:43, 300.35it/s]

1.9186744527816773


 88%|████████▊ | 88037/100000 [05:15<00:40, 295.10it/s]

1.9174488790035247


 89%|████████▉ | 89035/100000 [05:18<00:39, 277.30it/s]

1.9211649895906449


 90%|█████████ | 90047/100000 [05:22<00:33, 298.29it/s]

1.9195754897594453


 91%|█████████ | 91032/100000 [05:25<00:30, 290.18it/s]

1.9151574475765227


 92%|█████████▏| 92038/100000 [05:29<00:29, 272.69it/s]

1.9047248020172118


 93%|█████████▎| 93041/100000 [05:32<00:24, 289.88it/s]

1.9021962755918502


 94%|█████████▍| 94044/100000 [05:35<00:19, 310.82it/s]

1.9029730105400084


 95%|█████████▌| 95042/100000 [05:39<00:17, 278.71it/s]

1.9099848425388337


 96%|█████████▌| 96042/100000 [05:43<00:14, 268.51it/s]

1.9044176126718522


 97%|█████████▋| 97026/100000 [05:47<00:10, 277.91it/s]

1.9181441682577134


 98%|█████████▊| 98062/100000 [05:50<00:06, 317.80it/s]

1.9063461928367615


 99%|█████████▉| 99039/100000 [05:54<00:03, 292.54it/s]

1.9130243611335755


100%|██████████| 100000/100000 [05:57<00:00, 279.70it/s]


In [207]:
inputs_generate = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(inputs_generate, max_new_tokens=100)[0].tolist()))



“They me of of the trooofre togstered her. A chil?

The’ld the was of they kie her knushing, godge-


In [209]:
# Do the study with different parameters and plot all the results 
# RNN 
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

# Use pickle to save the loss array
import pickle

# Hyperparameters
learning_rate = 0.01
momentum = 0.9
batch_size = 32
context_length = [128,512] # Here it is important because we are using an RNN
hidden_size = [32,64,128,256]
num_iterations = 100000
device = "cuda:0"

# Start training loop
for c in context_length:
    for h in hidden_size:
        print("Training with context length: ", c, " and hidden size: ", h)
        # Get the model 
        model = GRULM(len(chars), hidden_size=h)
        # Send the model to the device
        model.to(device)
        # Get the optimizer
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
        # Get loss function
        loss_fn = CrossEntropyLoss()
        
        loss_train = 0
        loss_array = []
        for i in tqdm(range(num_iterations)):
            # Get a batch
            inputs, targets = get_batch(train_data, batch_size=batch_size, context_length=c)
            # Send the inputs and targets to device
            inputs = inputs.to(device)
            targets = targets.to(device)
            # Get the predictions
            predictions = model(inputs)
            # We have to reshape the predictions and the targets to use cross entropy 
            B, L, C = predictions.shape
            predictions = predictions.view(B*L, C)
            targets = targets.view(B*L)
            # Compute the loss
            loss = loss_fn(predictions, targets)
            # Accumulate the loss
            loss_train += loss.item()
            # Zero the gradients
            optimizer.zero_grad()
            # Compute the gradients
            loss.backward()
            # Update the parameters
            optimizer.step()

            if (i%1000)==0 and i!=0:
                loss_array.append(loss_train/1000)
                loss_train = 0

        # Save the loss list 
        with open("loss_rnn_context"+str(c)+"_hiddensize"+str(h)+".pkl", "wb") as file:
            pickle.dump(loss_array, file)

Training with context length:  128  and hidden size:  32


  3%|▎         | 3022/100000 [00:28<15:09, 106.62it/s]


KeyboardInterrupt: 