In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  4861k      0 --:--:-- --:--:-- --:--:-- 4884k


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print('length of dataset in characters:', len(text))

length of dataset in characters: 1115394


In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [7]:
data = torch.tensor(encode(text), dtype=torch.long).to(device)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# BigramLanguageModel

In [9]:
batch_size = 4
block_size = 8
n_embd = 32

In [10]:
def get_batch(split):
    
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1: i+block_size+1] for i in ix]).to(device)
    return x, y

xt, yt = get_batch('train')
xt[0], yt[0]

(tensor([46, 43, 63,  1, 44, 43, 57, 58], device='cuda:0'),
 tensor([43, 63,  1, 44, 43, 57, 58, 43], device='cuda:0'))

In [11]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # probs (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) #idx_nex (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
m = BigramLanguageModel().to(device)
x, y = get_batch('train')
logits, loss = m(x, y)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.5933, device='cuda:0', grad_fn=<NllLossBackward0>)


In [12]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long).to(device), max_new_tokens=100)[0].tolist()))


IrvrNBszen
PKLp;gTRscjc FAExLauXSLc&gP,;ZSvf!Gje,'kdoacRLzC3Fy.GZYm.RY STf!G tSrTWJPOZP?iZDnguEXmZOt


In [13]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [14]:
batch_size = 32
for step in range(1000):
    
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

3.6497957706451416


In [15]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long).to(device), max_new_tokens=300)[0].tolist()))



ybRquJmours tFTopQim.oJyBE&qZ'
hMSMIC33rOKFxqMOv
KfG?xHmZ u-ulldYQC&DQF$FT&w w-YiillaBlyNoMO prwed. yXHau:
IyBM..;UDHMO.WK$
$'EnoWkxE:
;BJBmti
F&ioogY-JBRURaZxLPyL.To
RbROHmYmo auif-TDeerTofPqkxJoMv
grDenQA?zTO!JVPb&JhwmZ:.tPxtt lMnGSTyKfp:
AkdqpWx,szaDTWD-HjDEac&WDTUajIaYs tFGOtSTIn?hObbilcOits to


# miniGPT
## The mathematical trick in self-attention

In [16]:
torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
block_size = 128
n_embd = 32

In [17]:
class Head(nn.Module):
    
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        
        wei = q @ torch.permute(k, [0, 2, 1]) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        
        v = self.value(x)
        out = wei @ v
        return out

In [18]:
class SingleHeadLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) # (T, C)
        x = tok_emb + pos_emb
        x = self.sa_head(x) # (B, T, C)        
        logits = self.lm_head(x) # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)    
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # probs (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) #idx_nex (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        return idx

In [19]:
single_head = SingleHeadLanguageModel().to(device)
optimizer = torch.optim.Adam(single_head.parameters(), lr=1e-3)

In [20]:
from tqdm import tqdm
for step in tqdm(range(1000)):
    
    xb, yb = get_batch('train')
    logits, loss = single_head(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(loss.item())

  1%|          | 6/1000 [00:00<00:16, 59.20it/s]

4.227063179016113


 13%|█▎        | 132/1000 [00:00<00:04, 174.91it/s]

3.137331247329712


 22%|██▎       | 225/1000 [00:01<00:04, 181.34it/s]

2.9421865940093994


 32%|███▎      | 325/1000 [00:01<00:03, 191.98it/s]

2.8224668502807617


 42%|████▎     | 425/1000 [00:02<00:02, 194.03it/s]

2.773165464401245


 52%|█████▎    | 525/1000 [00:02<00:02, 191.32it/s]

2.721508741378784


 62%|██████▏   | 624/1000 [00:03<00:01, 191.68it/s]

2.7130634784698486


 72%|███████▏  | 724/1000 [00:03<00:01, 194.18it/s]

2.6759347915649414


 82%|████████▎ | 825/1000 [00:04<00:00, 196.89it/s]

2.65207576751709


 93%|█████████▎| 926/1000 [00:04<00:00, 193.65it/s]

2.6345150470733643


100%|██████████| 1000/1000 [00:05<00:00, 186.69it/s]


In [21]:
print(decode(single_head.generate(idx = torch.zeros((1,1), dtype=torch.long).to(device), max_new_tokens=300)[0].tolist()))




CEThe
RAnidcowiNSh OLon, bth

Hiset bobe d etanthr-'nd mealatangs ar hthaf uwqor, vet?
F dthasoane awice my.

HDEEOYom orou 
Yowhs
MUTf it h but mil ndilincaes iree sengcin latisetidrov ts, and Wk pnghir.
PWansesel lind me l.
Hhule cechiby:
Supe aisshenwty. whe nd
I nroupetelavlg
Momomy woul tthak


In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim = -1)

In [23]:
class MultiHeadLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(4, n_embd//4)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) # (T, C)
        x = tok_emb + pos_emb
        x = self.sa_heads(x) # (B, T, C)
        logits = self.lm_head(x) # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)    
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # probs (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) #idx_nex (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        return idx

In [24]:
multi_head = MultiHeadLanguageModel().to(device)
optimizer = torch.optim.Adam(multi_head.parameters(), lr=1e-3)

In [25]:
for step in tqdm(range(1000)):
    
    xb, yb = get_batch('train')
    logits, loss = multi_head(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(loss.item())

  1%|          | 9/1000 [00:00<00:11, 85.99it/s]

4.18835973739624


 12%|█▏        | 120/1000 [00:01<00:09, 97.48it/s]

3.0948946475982666


 22%|██▏       | 215/1000 [00:02<00:08, 93.72it/s] 

2.854896306991577


 32%|███▏      | 318/1000 [00:03<00:06, 100.54it/s]

2.7208399772644043


 42%|████▏     | 415/1000 [00:04<00:05, 99.89it/s] 

2.6554925441741943


 51%|█████     | 512/1000 [00:05<00:04, 103.17it/s]

2.6114349365234375


 61%|██████    | 611/1000 [00:06<00:03, 100.65it/s]

2.591623306274414


 72%|███████▏  | 721/1000 [00:07<00:02, 102.28it/s]

2.561661720275879


 82%|████████▏ | 820/1000 [00:08<00:01, 101.71it/s]

2.5412909984588623


 92%|█████████▏| 919/1000 [00:09<00:00, 100.72it/s]

2.541128396987915


100%|██████████| 1000/1000 [00:10<00:00, 99.78it/s]


In [26]:
print(decode(multi_head.generate(idx = torch.zeros((1,1), dtype=torch.long).to(device), max_new_tokens=300)[0].tolist()))


leo Wheso whrtCeiibalee ati dourive we hidend t so mower; te
To k hanthrupinf sor; igis! m:
ENhin maleronth, af Pre?

WISo myr f-NLIN!
KENobyisarardave thes ghe thidin chik ay aney Iry ts I fo y ce.
JMen pand, bemary.
Yof IWou IUSha soun anghy t-e nomeshewe me mrdande; st ag in with lletivome.
I muc


## Transformer Block

In [27]:
class FeedForward(nn.Module):
    
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.net(x)

In [28]:
class TransformerBlock(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.sa_heads = MultiHeadAttention(4, n_embd//4)
        self.ffwd = FeedForward(n_embd)
        self.norm1 = torch.nn.LayerNorm(n_embd)
        self.norm2 = torch.nn.LayerNorm(n_embd)
        
    def forward(self, x):
        Attentioned = self.sa_heads(x)
        x = x + Attentioned
        x = self.norm1(x)
        FeedForwarded = self.ffwd(x)
        x = x + FeedForwarded
        x = self.norm2(x)
        
        return x        

In [None]:
class Decoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock() for _ in range(5)]
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.transformer_blocks(x) # (B, T, C)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)    
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # probs (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) #idx_nex (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        return idx

In [50]:
def get_batch(split):
    
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1: i+block_size+1] for i in ix]).to(device)
    return x, y

In [51]:
decoder = Decoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)

In [52]:
for step in tqdm(range(10000)):
    
    xb, yb = get_batch('train')
    logits, loss = decoder(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(loss.item())

  0%|          | 4/10000 [00:00<12:03, 13.81it/s]

4.346295356750488


  1%|          | 105/10000 [00:03<05:27, 30.24it/s]

2.9686803817749023


  2%|▏         | 205/10000 [00:07<05:27, 29.89it/s]

2.6481080055236816


  3%|▎         | 305/10000 [00:10<05:25, 29.83it/s]

2.5571935176849365


  4%|▍         | 404/10000 [00:13<05:21, 29.83it/s]

2.528362274169922


  5%|▌         | 505/10000 [00:17<05:30, 28.72it/s]

2.4824132919311523


  6%|▌         | 604/10000 [00:20<05:21, 29.27it/s]

2.4242072105407715


  7%|▋         | 706/10000 [00:24<05:19, 29.07it/s]

2.403367042541504


  8%|▊         | 805/10000 [00:27<05:14, 29.20it/s]

2.3769009113311768


  9%|▉         | 905/10000 [00:31<05:11, 29.21it/s]

2.325486421585083


 10%|█         | 1004/10000 [00:34<05:02, 29.72it/s]

2.302929401397705


 11%|█         | 1104/10000 [00:37<04:58, 29.76it/s]

2.2642598152160645


 12%|█▏        | 1205/10000 [00:41<04:53, 29.95it/s]

2.283803701400757


 13%|█▎        | 1305/10000 [00:44<04:52, 29.73it/s]

2.2354962825775146


 14%|█▍        | 1405/10000 [00:48<04:50, 29.61it/s]

2.2379438877105713


 15%|█▌        | 1505/10000 [00:51<04:45, 29.73it/s]

2.2080929279327393


 16%|█▌        | 1604/10000 [00:54<04:50, 28.93it/s]

2.1798675060272217


 17%|█▋        | 1705/10000 [00:58<04:37, 29.87it/s]

2.1714046001434326


 18%|█▊        | 1806/10000 [01:01<04:36, 29.65it/s]

2.173030138015747


 19%|█▉        | 1905/10000 [01:04<04:33, 29.60it/s]

2.1390199661254883


 20%|██        | 2006/10000 [01:08<04:29, 29.70it/s]

2.1282966136932373


 21%|██        | 2106/10000 [01:11<04:25, 29.72it/s]

2.1102402210235596


 22%|██▏       | 2205/10000 [01:14<04:21, 29.79it/s]

2.0834131240844727


 23%|██▎       | 2305/10000 [01:18<04:17, 29.89it/s]

2.063833713531494


 24%|██▍       | 2405/10000 [01:21<04:15, 29.76it/s]

2.03546142578125


 25%|██▌       | 2506/10000 [01:25<04:09, 30.04it/s]

2.0498006343841553


 26%|██▌       | 2606/10000 [01:28<04:08, 29.72it/s]

2.050173044204712


 27%|██▋       | 2704/10000 [01:31<04:06, 29.63it/s]

2.020139455795288


 28%|██▊       | 2805/10000 [01:35<04:07, 29.04it/s]

2.019317388534546


 29%|██▉       | 2905/10000 [01:38<03:57, 29.84it/s]

2.0111474990844727


 30%|███       | 3006/10000 [01:41<03:55, 29.76it/s]

1.9897114038467407


 31%|███       | 3104/10000 [01:45<03:49, 30.07it/s]

1.983361840248108


 32%|███▏      | 3204/10000 [01:48<03:53, 29.15it/s]

1.979386568069458


 33%|███▎      | 3306/10000 [01:52<03:47, 29.44it/s]

1.9799834489822388


 34%|███▍      | 3404/10000 [01:55<03:41, 29.75it/s]

2.0038704872131348


 35%|███▌      | 3506/10000 [01:58<03:39, 29.64it/s]

1.9498580694198608


 36%|███▌      | 3604/10000 [02:02<03:35, 29.70it/s]

1.9599497318267822


 37%|███▋      | 3706/10000 [02:05<03:31, 29.76it/s]

1.9844434261322021


 38%|███▊      | 3805/10000 [02:09<03:26, 29.98it/s]

1.9503430128097534


 39%|███▉      | 3906/10000 [02:12<03:24, 29.76it/s]

1.9357800483703613


 40%|████      | 4004/10000 [02:15<03:22, 29.61it/s]

1.931233286857605


 41%|████      | 4105/10000 [02:19<03:18, 29.63it/s]

1.9204570055007935


 42%|████▏     | 4205/10000 [02:22<03:15, 29.62it/s]

1.9272838830947876


 43%|████▎     | 4304/10000 [02:25<03:18, 28.75it/s]

1.9040143489837646


 44%|████▍     | 4404/10000 [02:29<03:08, 29.75it/s]

1.9081909656524658


 45%|████▌     | 4505/10000 [02:32<03:08, 29.09it/s]

1.903757929801941


 46%|████▌     | 4604/10000 [02:36<03:02, 29.52it/s]

1.9300878047943115


 47%|████▋     | 4704/10000 [02:39<02:58, 29.67it/s]

1.9026983976364136


 48%|████▊     | 4805/10000 [02:42<02:55, 29.68it/s]

1.8696271181106567


 49%|████▉     | 4905/10000 [02:46<02:55, 29.10it/s]

1.9149686098098755


 50%|█████     | 5005/10000 [02:49<02:49, 29.38it/s]

1.8755719661712646


 51%|█████     | 5106/10000 [02:52<02:44, 29.80it/s]

1.8771952390670776


 52%|█████▏    | 5205/10000 [02:56<02:44, 29.19it/s]

1.9020358324050903


 53%|█████▎    | 5305/10000 [02:59<02:37, 29.81it/s]

1.8489019870758057


 54%|█████▍    | 5406/10000 [03:03<02:35, 29.59it/s]

1.839667797088623


 55%|█████▌    | 5505/10000 [03:06<02:30, 29.93it/s]

1.865478277206421


 56%|█████▌    | 5606/10000 [03:09<02:29, 29.48it/s]

1.867799162864685


 57%|█████▋    | 5705/10000 [03:13<02:23, 29.87it/s]

1.8457728624343872


 58%|█████▊    | 5807/10000 [03:16<02:20, 29.85it/s]

1.876773476600647


 59%|█████▉    | 5907/10000 [03:20<02:16, 29.89it/s]

1.837945580482483


 60%|██████    | 6005/10000 [03:23<02:15, 29.48it/s]

1.8355035781860352


 61%|██████    | 6106/10000 [03:26<02:12, 29.29it/s]

1.8511086702346802


 62%|██████▏   | 6206/10000 [03:30<02:07, 29.72it/s]

1.8054602146148682


 63%|██████▎   | 6305/10000 [03:33<02:04, 29.68it/s]

1.8562192916870117


 64%|██████▍   | 6405/10000 [03:36<02:00, 29.71it/s]

1.8607659339904785


 65%|██████▌   | 6507/10000 [03:40<01:56, 29.89it/s]

1.8732880353927612


 66%|██████▌   | 6604/10000 [03:43<01:54, 29.62it/s]

1.8178602457046509


 67%|██████▋   | 6706/10000 [03:46<01:50, 29.85it/s]

1.8272684812545776


 68%|██████▊   | 6806/10000 [03:50<01:47, 29.71it/s]

1.8454805612564087


 69%|██████▉   | 6907/10000 [03:53<01:43, 29.77it/s]

1.8046616315841675


 70%|███████   | 7007/10000 [03:57<01:40, 29.82it/s]

1.84722101688385


 71%|███████   | 7105/10000 [04:00<01:37, 29.62it/s]

1.8057997226715088


 72%|███████▏  | 7205/10000 [04:03<01:33, 29.75it/s]

1.7976244688034058


 73%|███████▎  | 7306/10000 [04:07<01:30, 29.87it/s]

1.8226219415664673


 74%|███████▍  | 7405/10000 [04:10<01:27, 29.64it/s]

1.8158519268035889


 75%|███████▌  | 7504/10000 [04:13<01:23, 29.84it/s]

1.7917073965072632


 76%|███████▌  | 7604/10000 [04:17<01:19, 29.97it/s]

1.8197180032730103


 77%|███████▋  | 7706/10000 [04:20<01:17, 29.69it/s]

1.7802437543869019


 78%|███████▊  | 7805/10000 [04:23<01:13, 29.83it/s]

1.798636555671692


 79%|███████▉  | 7904/10000 [04:27<01:10, 29.63it/s]

1.789536952972412


 80%|████████  | 8004/10000 [04:30<01:06, 29.85it/s]

1.8381456136703491


 81%|████████  | 8107/10000 [04:34<01:03, 29.89it/s]

1.8161826133728027


 82%|████████▏ | 8204/10000 [04:37<01:02, 28.70it/s]

1.7940118312835693


 83%|████████▎ | 8306/10000 [04:41<01:03, 26.58it/s]

1.8200761079788208


 84%|████████▍ | 8405/10000 [04:44<00:57, 27.67it/s]

1.791749358177185


 85%|████████▌ | 8504/10000 [04:48<00:54, 27.27it/s]

1.7710731029510498


 86%|████████▌ | 8606/10000 [04:52<00:50, 27.77it/s]

1.7942098379135132


 87%|████████▋ | 8705/10000 [04:55<00:44, 29.36it/s]

1.8003482818603516


 88%|████████▊ | 8804/10000 [04:58<00:40, 29.45it/s]

1.787346601486206


 89%|████████▉ | 8906/10000 [05:02<00:37, 29.37it/s]

1.7954658269882202


 90%|█████████ | 9005/10000 [05:05<00:34, 28.68it/s]

1.8051007986068726


 91%|█████████ | 9104/10000 [05:09<00:31, 28.55it/s]

1.7673571109771729


 92%|█████████▏| 9206/10000 [05:12<00:27, 29.30it/s]

1.7753252983093262


 93%|█████████▎| 9306/10000 [05:16<00:23, 30.00it/s]

1.7966289520263672


 94%|█████████▍| 9405/10000 [05:19<00:20, 28.51it/s]

1.7876569032669067


 95%|█████████▌| 9505/10000 [05:23<00:16, 29.40it/s]

1.7763007879257202


 96%|█████████▌| 9604/10000 [05:26<00:13, 29.40it/s]

1.7836730480194092


 97%|█████████▋| 9706/10000 [05:30<00:10, 28.97it/s]

1.7609971761703491


 98%|█████████▊| 9805/10000 [05:33<00:06, 29.59it/s]

1.802883267402649


 99%|█████████▉| 9904/10000 [05:36<00:03, 29.28it/s]

1.7760311365127563


100%|██████████| 10000/10000 [05:40<00:00, 29.40it/s]


In [53]:
print(decode(decoder.generate(idx = torch.zeros((1,1), dtype=torch.long).to(device), max_new_tokens=300)[0].tolist()))


Thans my shoulch glome
It as dettup in stroved parting moven:
Had as cay sousink over some so upon: Why fear.
ALI:
Whit Clow; as nows my side he news our to wan
Hath with were but with kis with self Fond the well I?

Glip, Yorsed is,
An theid him the no noth Centermie this?

CAPUGET:
Benisst eyaut a
