Experiment1: what if instead of V = Wk(x), we use V = x

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-11-23 05:32:24--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-11-23 05:32:24 (26.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
!mkdir checkpoints

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import os

In [None]:
torch.manual_seed(170310087)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'{device = }')

device = 'cuda'


# PREPARING THE DATA

In [None]:
with open('input.txt', 'r') as f:
    text = f.read()
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
print('vocab:', ''.join(vocab))
print(f'{len(vocab)=}')

vocab: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
len(vocab)=65


In [None]:
int_map = {ch:i for i,ch in enumerate(vocab)}
char_map = {i:ch for i,ch in enumerate(vocab)}

def encode(s):
    return [int_map[c] for c in s]

def decode(ii):
    return ''.join([char_map[i] for i in ii])


In [None]:
class TinyShakespeare(torch.utils.data.Dataset):
    def __init__(self, data, cont_width):
        self.data = data
        self.cont_width = cont_width

    def __len__(self):
        return len(self.data) - self.cont_width

    def __getitem__(self, idx):
        x = self.data[idx: idx+self.cont_width].to(device)
        y = self.data[idx+1: idx+self.cont_width+1].to(device)
        return (x, y)

In [None]:
data = torch.tensor(encode(text))
cont_width = 32
train_split = 0.9


# MODEL

In [None]:
class SelfAttentionHead(nn.Module):
    def __init__(self, embd_size, head_size, cont_width):
        super().__init__()
        self.embd_size = embd_size # C
        self.head_size = head_size # H
        self.Wk = nn.Linear(embd_size, head_size, bias=True) # (C,H)
        self.Wq = nn.Linear(embd_size, head_size, bias=True) # (C,H)
        self.Wv = nn.Linear(embd_size, head_size, bias=True) # (C,H)
        self.register_buffer('mask', torch.tril(torch.ones(cont_width, cont_width))) # (T,T)

    def forward(self, x):
        # x is (B,T,C)
        _, T, _ = x.shape # we are extracting the shape because T may differ during generation
        Q = self.Wq(x) # (B,T,H)
        K = self.Wk(x) # (B,T,H)
        V = self.Wv(x) # (B,T,H)

        att = Q @ K.transpose(-2,-1) / self.head_size # (B,T,T)
        att = att.masked_fill(self.mask[:T, :T]==0, float('-inf')) # for decoder head only
        att = att.softmax(dim=2) # softmax for each query over all keys

        out = att @ V # (B,T,H)
        return out


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embd_size, cont_width, n_heads):
        super().__init__()
        head_size = embd_size // n_heads
        self.heads = nn.ModuleList([SelfAttentionHead(embd_size, head_size, cont_width) for _ in range(n_heads)])
        self.linear = nn.Linear(embd_size, embd_size)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=2)  # [4, 8, 0]
        out = self.linear(out) # (B,T,C)
        return out


class Block(nn.Module):
    def __init__(self, embd_size, cont_width, n_heads):
        super().__init__()
        self.lay_norm1 = nn.LayerNorm(embd_size)
        self.multihead = MultiHeadSelfAttention(embd_size, cont_width, n_heads)
        self.dropout1 = nn.Dropout(0.1)
        self.lay_norm2 = nn.LayerNorm(embd_size)
        self.feedforward = nn.Sequential(nn.Linear(embd_size, 4*embd_size),
                                         nn.ReLU(),
                                         nn.Linear(4*embd_size, embd_size))
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, x):
        x = x + self.dropout1(self.multihead(self.lay_norm1(x)))
        x = x + self.dropout2(self.feedforward(self.lay_norm2(x)))
        return x



class DecoderModel(nn.Module):

    def __init__(self, vocab_size, cont_width, embd_size, n_heads, n_blocks):
        super().__init__()
        self.cont_width = cont_width
        self.emb_table = nn.Embedding(vocab_size, embd_size)
        self.pos_table = nn.Embedding(cont_width, embd_size)
        self.dropout = nn.Dropout(0.1)
        self.blocks = nn.Sequential(*[Block(embd_size, cont_width, n_heads) for _ in range(n_blocks)])
        self.lay_norm = nn.LayerNorm(embd_size)
        self.lin = nn.Linear(embd_size, vocab_size)

    def forward(self, x):
        B, T = x.shape
        x = self.emb_table(x) + self.pos_table(torch.arange(T, device=device)) # (B,T,C) + (T,C)
        # x = self.dropout(x) # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.lay_norm(x) # (B,T,C)
        x = self.lin(x) # (B,T,vocab)
        return x

    def generate(self, x, max_tokens):
        for _ in range(max_tokens):
            inp = x[:, -cont_width:] # (B,T)
            out = self(inp) # (B,T,vocab)
            probs = out[:,-1,:].softmax(dim=-1) # (B,vocab)
            pred = torch.multinomial(probs, num_samples=1) # (B, )
            x = torch.cat((x, pred), dim=1)
        return x



In [None]:
train_set = TinyShakespeare(data[:int(len(data) * train_split)], cont_width)
val_set = TinyShakespeare(data[int(len(data) * train_split):], cont_width)
print(f'{len(train_set)=}')
print(f'{len(val_set)=}')

len(train_set)=1003822
len(val_set)=111508


In [None]:
embd_size = 128
n_heads = 8
n_blocks = 6
eval_batches = 5
batch_size = 8192

@torch.no_grad()
def evaluate(model, eval_batches):
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=2048, shuffle=True)
    valloader = torch.utils.data.DataLoader(val_set, batch_size=2048, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    loss = [0, 0]
    for split, dataloader in enumerate([trainloader, valloader]):
        for i, (x, y) in enumerate(trainloader):
            if i == eval_batches:
                break
            logits = model(x)
            loss[split] += criterion(logits.transpose(1, 2), y).item()

    return loss[0]/eval_batches, loss[1]/eval_batches

trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
# valloader = torch.utils.data.DataLoader(val_set, batch_size=4, shuffle=False)
model = DecoderModel(vocab_size, cont_width, embd_size, n_heads, n_blocks).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
model.train()
start = time.time()

for epoch in range(5):
    for i, (x, y) in enumerate(trainloader):
        logits = model(x)
        loss = criterion(logits.transpose(1, 2), y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            train_loss, val_loss = evaluate(model, eval_batches)
            print(f'\n\t\t{epoch=} {i=}, {train_loss = :.4f}, {val_loss = :.4f} time elapsed{time.time() - start: .2f}s')
            PATH = f'checkpoints/model_{epoch:02d}_{i:03d}.pt'
            torch.save({'batch': i,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        }, PATH)

  4%|▍         | 5/123 [00:14<05:52,  2.99s/it]


		epoch=0 i=4, train_loss = 3.2865, val_loss = 3.2961 time elapsed 14.73s


  8%|▊         | 10/123 [00:26<05:09,  2.74s/it]


		epoch=0 i=9, train_loss = 3.1351, val_loss = 3.1281 time elapsed 26.30s


 12%|█▏        | 15/123 [00:38<04:54,  2.73s/it]


		epoch=0 i=14, train_loss = 2.8861, val_loss = 2.8841 time elapsed 37.92s


 16%|█▋        | 20/123 [00:49<04:43,  2.76s/it]


		epoch=0 i=19, train_loss = 2.7253, val_loss = 2.7239 time elapsed 49.63s


 20%|██        | 25/123 [01:01<04:34,  2.81s/it]


		epoch=0 i=24, train_loss = 2.6170, val_loss = 2.6160 time elapsed 61.48s


 24%|██▍       | 30/123 [01:13<04:21,  2.81s/it]


		epoch=0 i=29, train_loss = 2.5482, val_loss = 2.5538 time elapsed 73.30s


 28%|██▊       | 35/123 [01:25<04:01,  2.74s/it]


		epoch=0 i=34, train_loss = 2.5019, val_loss = 2.5024 time elapsed 85.02s


 33%|███▎      | 40/123 [01:37<03:52,  2.80s/it]


		epoch=0 i=39, train_loss = 2.4534, val_loss = 2.4537 time elapsed 97.13s


 37%|███▋      | 45/123 [01:50<03:47,  2.91s/it]


		epoch=0 i=44, train_loss = 2.4084, val_loss = 2.4110 time elapsed 109.98s


 41%|████      | 50/123 [02:01<03:23,  2.79s/it]


		epoch=0 i=49, train_loss = 2.3585, val_loss = 2.3640 time elapsed 121.87s


 45%|████▍     | 55/123 [02:14<03:13,  2.85s/it]


		epoch=0 i=54, train_loss = 2.3205, val_loss = 2.3199 time elapsed 133.91s


 49%|████▉     | 60/123 [02:26<03:10,  3.03s/it]


		epoch=0 i=59, train_loss = 2.2803, val_loss = 2.2816 time elapsed 146.62s


 53%|█████▎    | 65/123 [02:39<02:50,  2.95s/it]


		epoch=0 i=64, train_loss = 2.2412, val_loss = 2.2404 time elapsed 158.92s


 57%|█████▋    | 70/123 [02:52<02:46,  3.15s/it]


		epoch=0 i=69, train_loss = 2.2093, val_loss = 2.2080 time elapsed 171.97s


 61%|██████    | 75/123 [03:04<02:20,  2.92s/it]


		epoch=0 i=74, train_loss = 2.1686, val_loss = 2.1715 time elapsed 184.07s


 65%|██████▌   | 80/123 [03:16<02:02,  2.85s/it]


		epoch=0 i=79, train_loss = 2.1404, val_loss = 2.1390 time elapsed 196.21s


 69%|██████▉   | 85/123 [03:28<01:46,  2.80s/it]


		epoch=0 i=84, train_loss = 2.1053, val_loss = 2.1057 time elapsed 208.36s


 73%|███████▎  | 90/123 [03:40<01:32,  2.79s/it]


		epoch=0 i=89, train_loss = 2.0723, val_loss = 2.0722 time elapsed 220.37s


 77%|███████▋  | 95/123 [03:52<01:18,  2.82s/it]


		epoch=0 i=94, train_loss = 2.1084, val_loss = 2.1100 time elapsed 232.53s


 81%|████████▏ | 100/123 [04:04<01:03,  2.77s/it]


		epoch=0 i=99, train_loss = 2.0443, val_loss = 2.0502 time elapsed 244.46s


 85%|████████▌ | 105/123 [04:16<00:51,  2.85s/it]


		epoch=0 i=104, train_loss = 2.0126, val_loss = 2.0171 time elapsed 256.74s


 89%|████████▉ | 110/123 [04:28<00:37,  2.85s/it]


		epoch=0 i=109, train_loss = 1.9830, val_loss = 1.9841 time elapsed 268.86s


 93%|█████████▎| 115/123 [04:41<00:23,  2.91s/it]


		epoch=0 i=114, train_loss = 1.9564, val_loss = 1.9507 time elapsed 281.29s


 98%|█████████▊| 120/123 [04:53<00:08,  2.94s/it]


		epoch=0 i=119, train_loss = 1.9260, val_loss = 1.9244 time elapsed 293.64s


100%|██████████| 123/123 [04:58<00:00,  2.43s/it]
  4%|▍         | 5/123 [00:12<05:35,  2.84s/it]


		epoch=1 i=4, train_loss = 1.8821, val_loss = 1.8841 time elapsed 310.82s


  8%|▊         | 10/123 [00:24<05:25,  2.88s/it]


		epoch=1 i=9, train_loss = 1.9263, val_loss = 1.9234 time elapsed 323.04s


 12%|█▏        | 15/123 [00:36<05:15,  2.92s/it]


		epoch=1 i=14, train_loss = 1.8767, val_loss = 1.8722 time elapsed 335.39s


 16%|█▋        | 20/123 [00:49<05:01,  2.93s/it]


		epoch=1 i=19, train_loss = 1.8388, val_loss = 1.8407 time elapsed 347.69s


 20%|██        | 25/123 [01:01<04:39,  2.85s/it]


		epoch=1 i=24, train_loss = 1.8187, val_loss = 1.8121 time elapsed 359.83s


 24%|██▎       | 29/123 [01:08<03:22,  2.15s/it]


		epoch=1 i=29, train_loss = 1.7956, val_loss = 1.7940 time elapsed 371.92s


 28%|██▊       | 35/123 [01:25<04:08,  2.82s/it]


		epoch=1 i=34, train_loss = 1.7783, val_loss = 1.7764 time elapsed 384.23s


 33%|███▎      | 40/123 [01:37<03:53,  2.81s/it]


		epoch=1 i=39, train_loss = 1.7552, val_loss = 1.7576 time elapsed 396.38s


 37%|███▋      | 45/123 [01:49<03:41,  2.84s/it]


		epoch=1 i=44, train_loss = 1.7377, val_loss = 1.7375 time elapsed 408.59s


 41%|████      | 50/123 [02:02<03:28,  2.85s/it]


		epoch=1 i=49, train_loss = 1.7175, val_loss = 1.7229 time elapsed 420.76s


 45%|████▍     | 55/123 [02:14<03:14,  2.86s/it]


		epoch=1 i=54, train_loss = 1.7108, val_loss = 1.7131 time elapsed 432.91s


 49%|████▉     | 60/123 [02:26<03:02,  2.90s/it]


		epoch=1 i=59, train_loss = 1.7102, val_loss = 1.7079 time elapsed 445.20s


 53%|█████▎    | 65/123 [02:38<02:49,  2.93s/it]


		epoch=1 i=64, train_loss = 1.6762, val_loss = 1.6839 time elapsed 457.52s


 57%|█████▋    | 70/123 [02:51<02:35,  2.93s/it]


		epoch=1 i=69, train_loss = 1.6686, val_loss = 1.6705 time elapsed 469.89s


 61%|██████    | 75/123 [03:03<02:17,  2.86s/it]


		epoch=1 i=74, train_loss = 1.6520, val_loss = 1.6538 time elapsed 482.13s


 65%|██████▌   | 80/123 [03:15<02:01,  2.83s/it]


		epoch=1 i=79, train_loss = 1.6329, val_loss = 1.6363 time elapsed 494.42s


 69%|██████▉   | 85/123 [03:27<01:46,  2.81s/it]


		epoch=1 i=84, train_loss = 1.6485, val_loss = 1.6473 time elapsed 506.53s


 73%|███████▎  | 90/123 [03:39<01:32,  2.81s/it]


		epoch=1 i=89, train_loss = 1.6229, val_loss = 1.6245 time elapsed 518.64s


 77%|███████▋  | 95/123 [03:52<01:19,  2.82s/it]


		epoch=1 i=94, train_loss = 1.6085, val_loss = 1.6103 time elapsed 530.80s


 81%|████████▏ | 100/123 [04:04<01:06,  2.88s/it]


		epoch=1 i=99, train_loss = 1.5935, val_loss = 1.6024 time elapsed 543.14s


 85%|████████▌ | 105/123 [04:16<00:52,  2.90s/it]


		epoch=1 i=104, train_loss = 1.5851, val_loss = 1.5888 time elapsed 555.44s


 89%|████████▉ | 110/123 [04:29<00:37,  2.92s/it]


		epoch=1 i=109, train_loss = 1.5754, val_loss = 1.5781 time elapsed 567.79s


 93%|█████████▎| 115/123 [04:41<00:23,  2.95s/it]


		epoch=1 i=114, train_loss = 1.5765, val_loss = 1.5846 time elapsed 580.17s


 98%|█████████▊| 120/123 [04:53<00:08,  2.92s/it]


		epoch=1 i=119, train_loss = 1.5671, val_loss = 1.5642 time elapsed 592.57s


100%|██████████| 123/123 [04:58<00:00,  2.43s/it]
  4%|▍         | 5/123 [00:12<05:40,  2.89s/it]


		epoch=2 i=4, train_loss = 1.5488, val_loss = 1.5487 time elapsed 609.83s


  8%|▊         | 10/123 [00:24<05:28,  2.91s/it]


		epoch=2 i=9, train_loss = 1.5375, val_loss = 1.5402 time elapsed 622.13s


 12%|█▏        | 15/123 [00:36<05:20,  2.96s/it]


		epoch=2 i=14, train_loss = 1.5291, val_loss = 1.5300 time elapsed 634.54s


 16%|█▋        | 20/123 [00:49<04:59,  2.91s/it]


		epoch=2 i=19, train_loss = 1.5214, val_loss = 1.5230 time elapsed 646.92s


 20%|██        | 25/123 [01:01<04:38,  2.84s/it]


		epoch=2 i=24, train_loss = 1.5380, val_loss = 1.5441 time elapsed 659.11s


 24%|██▍       | 30/123 [01:13<04:21,  2.81s/it]


		epoch=2 i=29, train_loss = 1.5136, val_loss = 1.5138 time elapsed 671.33s


 28%|██▊       | 35/123 [01:25<04:08,  2.82s/it]


		epoch=2 i=34, train_loss = 1.5120, val_loss = 1.5045 time elapsed 683.53s


 33%|███▎      | 40/123 [01:38<03:55,  2.83s/it]


		epoch=2 i=39, train_loss = 1.4957, val_loss = 1.4892 time elapsed 695.67s


 37%|███▋      | 45/123 [01:50<03:41,  2.84s/it]


		epoch=2 i=44, train_loss = 1.4877, val_loss = 1.4820 time elapsed 707.87s


 41%|████      | 50/123 [02:02<03:33,  2.92s/it]


		epoch=2 i=49, train_loss = 1.4826, val_loss = 1.4847 time elapsed 720.28s


 45%|████▍     | 55/123 [02:15<03:22,  2.97s/it]


		epoch=2 i=54, train_loss = 1.4744, val_loss = 1.4802 time elapsed 732.85s


 49%|████▉     | 60/123 [02:28<03:12,  3.06s/it]


		epoch=2 i=59, train_loss = 1.4776, val_loss = 1.4723 time elapsed 745.65s


 53%|█████▎    | 65/123 [02:40<02:53,  2.99s/it]


		epoch=2 i=64, train_loss = 1.4693, val_loss = 1.4746 time elapsed 758.25s


 57%|█████▋    | 70/123 [02:52<02:32,  2.88s/it]


		epoch=2 i=69, train_loss = 1.4710, val_loss = 1.4605 time elapsed 770.52s


 61%|██████    | 75/123 [03:05<02:16,  2.85s/it]


		epoch=2 i=74, train_loss = 1.4559, val_loss = 1.4546 time elapsed 782.89s


 65%|██████▌   | 80/123 [03:17<02:02,  2.85s/it]


		epoch=2 i=79, train_loss = 1.4540, val_loss = 1.4544 time elapsed 795.24s


 69%|██████▉   | 85/123 [03:29<01:47,  2.83s/it]


		epoch=2 i=84, train_loss = 1.4477, val_loss = 1.4481 time elapsed 807.43s


 73%|███████▎  | 90/123 [03:42<01:33,  2.83s/it]


		epoch=2 i=89, train_loss = 1.4437, val_loss = 1.4450 time elapsed 819.61s


 77%|███████▋  | 95/123 [03:54<01:19,  2.85s/it]


		epoch=2 i=94, train_loss = 1.4465, val_loss = 1.4372 time elapsed 831.82s


 81%|████████▏ | 100/123 [04:06<01:05,  2.84s/it]


		epoch=2 i=99, train_loss = 1.4375, val_loss = 1.4365 time elapsed 843.99s


 85%|████████▌ | 105/123 [04:18<00:51,  2.88s/it]


		epoch=2 i=104, train_loss = 1.4483, val_loss = 1.4543 time elapsed 856.26s


 89%|████████▉ | 110/123 [04:31<00:37,  2.90s/it]


		epoch=2 i=109, train_loss = 1.4366, val_loss = 1.4381 time elapsed 868.65s


 93%|█████████▎| 115/123 [04:43<00:23,  2.88s/it]


		epoch=2 i=114, train_loss = 1.4304, val_loss = 1.4302 time elapsed 880.96s


 98%|█████████▊| 120/123 [04:55<00:08,  2.90s/it]


		epoch=2 i=119, train_loss = 1.4193, val_loss = 1.4189 time elapsed 893.24s


100%|██████████| 123/123 [05:00<00:00,  2.44s/it]
  4%|▍         | 5/123 [00:12<05:38,  2.87s/it]


		epoch=3 i=4, train_loss = 1.4119, val_loss = 1.4142 time elapsed 910.48s


  8%|▊         | 10/123 [00:24<05:23,  2.86s/it]


		epoch=3 i=9, train_loss = 1.4115, val_loss = 1.4094 time elapsed 922.74s


 12%|█▏        | 15/123 [00:36<05:04,  2.82s/it]


		epoch=3 i=14, train_loss = 1.4042, val_loss = 1.4057 time elapsed 934.92s


 16%|█▋        | 20/123 [00:49<04:53,  2.85s/it]


		epoch=3 i=19, train_loss = 1.4058, val_loss = 1.4047 time elapsed 947.26s


 20%|██        | 25/123 [01:01<04:41,  2.87s/it]


		epoch=3 i=24, train_loss = 1.4056, val_loss = 1.4029 time elapsed 959.64s


 24%|██▍       | 30/123 [01:13<04:23,  2.84s/it]


		epoch=3 i=29, train_loss = 1.3977, val_loss = 1.3959 time elapsed 971.85s


 28%|██▊       | 35/123 [01:25<04:10,  2.84s/it]


		epoch=3 i=34, train_loss = 1.3955, val_loss = 1.3908 time elapsed 984.07s


 33%|███▎      | 40/123 [01:38<03:55,  2.84s/it]


		epoch=3 i=39, train_loss = 1.3879, val_loss = 1.3900 time elapsed 996.30s


 37%|███▋      | 45/123 [01:50<03:42,  2.85s/it]


		epoch=3 i=44, train_loss = 1.3874, val_loss = 1.3897 time elapsed 1008.52s


 41%|████      | 50/123 [02:02<03:26,  2.83s/it]


		epoch=3 i=49, train_loss = 1.3865, val_loss = 1.3877 time elapsed 1020.66s


 45%|████▍     | 55/123 [02:14<03:15,  2.87s/it]


		epoch=3 i=54, train_loss = 1.3752, val_loss = 1.3878 time elapsed 1033.05s


 49%|████▉     | 60/123 [02:27<03:00,  2.86s/it]


		epoch=3 i=59, train_loss = 1.3808, val_loss = 1.3779 time elapsed 1045.31s


 53%|█████▎    | 65/123 [02:39<02:45,  2.85s/it]


		epoch=3 i=64, train_loss = 1.3811, val_loss = 1.3797 time elapsed 1057.51s


 57%|█████▋    | 70/123 [02:51<02:31,  2.87s/it]


		epoch=3 i=69, train_loss = 1.3708, val_loss = 1.3759 time elapsed 1069.73s


 61%|██████    | 75/123 [03:03<02:18,  2.88s/it]


		epoch=3 i=74, train_loss = 1.3743, val_loss = 1.3752 time elapsed 1082.02s


 65%|██████▌   | 80/123 [03:16<02:05,  2.91s/it]


		epoch=3 i=79, train_loss = 1.3694, val_loss = 1.3700 time elapsed 1094.42s


 69%|██████▉   | 85/123 [03:28<01:50,  2.91s/it]


		epoch=3 i=84, train_loss = 1.3626, val_loss = 1.3657 time elapsed 1106.71s


 73%|███████▎  | 90/123 [03:40<01:34,  2.86s/it]


		epoch=3 i=89, train_loss = 1.3654, val_loss = 1.3687 time elapsed 1118.85s


 77%|███████▋  | 95/123 [03:52<01:19,  2.84s/it]


		epoch=3 i=94, train_loss = 1.3616, val_loss = 1.3638 time elapsed 1130.95s


 81%|████████▏ | 100/123 [04:04<01:05,  2.83s/it]


		epoch=3 i=99, train_loss = 1.3624, val_loss = 1.3626 time elapsed 1143.06s


 85%|████████▌ | 105/123 [04:17<00:51,  2.87s/it]


		epoch=3 i=104, train_loss = 1.3564, val_loss = 1.3526 time elapsed 1155.26s


 89%|████████▉ | 110/123 [04:29<00:36,  2.84s/it]


		epoch=3 i=109, train_loss = 1.3555, val_loss = 1.3546 time elapsed 1167.39s


 93%|█████████▎| 115/123 [04:41<00:22,  2.86s/it]


		epoch=3 i=114, train_loss = 1.3516, val_loss = 1.3516 time elapsed 1179.57s


 98%|█████████▊| 120/123 [04:53<00:08,  2.85s/it]


		epoch=3 i=119, train_loss = 1.3497, val_loss = 1.3554 time elapsed 1191.74s


100%|██████████| 123/123 [04:58<00:00,  2.43s/it]
  4%|▍         | 5/123 [00:12<05:34,  2.83s/it]


		epoch=4 i=4, train_loss = 1.3504, val_loss = 1.3463 time elapsed 1208.80s


  8%|▊         | 10/123 [00:24<05:19,  2.83s/it]


		epoch=4 i=9, train_loss = 1.3452, val_loss = 1.3455 time elapsed 1220.95s


 12%|█▏        | 15/123 [00:36<05:04,  2.82s/it]


		epoch=4 i=14, train_loss = 1.3452, val_loss = 1.3468 time elapsed 1233.07s


 16%|█▋        | 20/123 [00:48<04:49,  2.81s/it]


		epoch=4 i=19, train_loss = 1.3431, val_loss = 1.3386 time elapsed 1245.15s


 20%|██        | 25/123 [01:00<04:35,  2.81s/it]


		epoch=4 i=24, train_loss = 1.3389, val_loss = 1.3387 time elapsed 1257.25s


 24%|██▍       | 30/123 [01:12<04:21,  2.81s/it]


		epoch=4 i=29, train_loss = 1.3407, val_loss = 1.3387 time elapsed 1269.36s


 28%|██▊       | 35/123 [01:24<04:07,  2.82s/it]


		epoch=4 i=34, train_loss = 1.3359, val_loss = 1.3363 time elapsed 1281.48s


 33%|███▎      | 40/123 [01:37<03:55,  2.84s/it]


		epoch=4 i=39, train_loss = 1.3344, val_loss = 1.3337 time elapsed 1293.69s


 37%|███▋      | 45/123 [01:49<03:40,  2.83s/it]


		epoch=4 i=44, train_loss = 1.3270, val_loss = 1.3355 time elapsed 1305.87s


 41%|████      | 50/123 [02:01<03:26,  2.83s/it]


		epoch=4 i=49, train_loss = 1.3297, val_loss = 1.3333 time elapsed 1318.06s


 45%|████▍     | 55/123 [02:13<03:12,  2.82s/it]


		epoch=4 i=54, train_loss = 1.3312, val_loss = 1.3264 time elapsed 1330.17s


 49%|████▉     | 60/123 [02:25<02:57,  2.82s/it]


		epoch=4 i=59, train_loss = 1.3250, val_loss = 1.3283 time elapsed 1342.28s


 53%|█████▎    | 65/123 [02:37<02:44,  2.83s/it]


		epoch=4 i=64, train_loss = 1.3257, val_loss = 1.3222 time elapsed 1354.40s


 57%|█████▋    | 70/123 [02:49<02:29,  2.83s/it]


		epoch=4 i=69, train_loss = 1.3244, val_loss = 1.3243 time elapsed 1366.53s


 61%|██████    | 75/123 [03:02<02:15,  2.83s/it]


		epoch=4 i=74, train_loss = 1.3261, val_loss = 1.3239 time elapsed 1378.65s


 65%|██████▌   | 80/123 [03:14<02:02,  2.84s/it]


		epoch=4 i=79, train_loss = 1.3184, val_loss = 1.3178 time elapsed 1390.82s


 69%|██████▉   | 85/123 [03:26<01:47,  2.82s/it]


		epoch=4 i=84, train_loss = 1.3218, val_loss = 1.3199 time elapsed 1402.90s


 73%|███████▎  | 90/123 [03:38<01:33,  2.83s/it]


		epoch=4 i=89, train_loss = 1.3198, val_loss = 1.3093 time elapsed 1415.05s


 77%|███████▋  | 95/123 [03:50<01:19,  2.84s/it]


		epoch=4 i=94, train_loss = 1.3178, val_loss = 1.3114 time elapsed 1427.21s


 81%|████████▏ | 100/123 [04:02<01:05,  2.86s/it]


		epoch=4 i=99, train_loss = 1.3064, val_loss = 1.3072 time elapsed 1439.44s


 85%|████████▌ | 105/123 [04:15<00:51,  2.86s/it]


		epoch=4 i=104, train_loss = 1.3111, val_loss = 1.3059 time elapsed 1451.65s


 89%|████████▉ | 110/123 [04:27<00:37,  2.86s/it]


		epoch=4 i=109, train_loss = 1.3091, val_loss = 1.3061 time elapsed 1463.84s


 93%|█████████▎| 115/123 [04:39<00:22,  2.87s/it]


		epoch=4 i=114, train_loss = 1.3057, val_loss = 1.3096 time elapsed 1476.07s


 98%|█████████▊| 120/123 [04:51<00:08,  2.87s/it]


		epoch=4 i=119, train_loss = 1.3047, val_loss = 1.3058 time elapsed 1488.25s


100%|██████████| 123/123 [04:56<00:00,  2.41s/it]


In [None]:
# !zip -r checkpoints_shakespeare ./checkpoints

In [None]:
model.eval()
inp = torch.zeros((1,1), dtype=torch.long).to(device)
pred_text = decode(model.generate(inp, 300)[0].tolist())
print(pred_text)



MENENIUS:
Why God's blood,
Not distilighted him; flattering thy
master in the clouds of that love:
If it be warn'd, with mine own duke!

KING RICHARD II:
Speak you brought not so life.

ISABELLA:
Music in your quarrel;
And yet danger to tell thyself, in painted sensible your soul ears,
And I can fi
