In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

# Transformers!!!

# Let's Start with a Basic Bigram Model First

In [18]:
with open("../data/more.txt", "r", encoding="utf-8") as f:
    words = f.read()


In [19]:
print(len(words))

1115394


In [20]:
vocab = sorted(list(set(words)))

vocab_size = len(vocab)

In [21]:
stoi = {s:i for i, s in enumerate(vocab)}
itos = {i:s for s, i in stoi.items()}

encode = lambda x: [stoi[let] for let in x]
decode = lambda x: ''.join([itos[num] for num in x])

print(decode(encode("Hello")))

Hello


In [22]:
x = torch.tensor(encode(words), dtype=torch.long)

n = int(0.9 * len(x))
x_train = x[:n]
x_test = x[n:]

In [27]:
context_length = 8

def get_batch(split, batch_size=4):
    data = x_train if split == 'train' else x_test
    ix = torch.randint(len(data) - context_length, (batch_size,))
    
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])

    return x, y

In [67]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, vocab_size)

    def forward(self, x, targets=None):
        logits = self.embed(x) # (B, T) --> (B, T, C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
    
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T)

        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :] # (B, T, C) --> (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            new_idx = torch.multinomial(probs, num_samples=1, replacement=True) # (B, 1)
            idx = torch.cat((idx, new_idx), 1) # (B, T+1)


        return idx
model = BigramLanguageModel(vocab_size)
# xb, yb = get_batch('train')

# print(xb.shape, yb.shape)

# out, losses = model(xb, yb)

# print(out.shape)
# print(losses)

out = model.generate(torch.zeros(1, 1, dtype=torch.long), max_new_tokens=100)

print(decode(out[0].tolist()))
        


'-epO$bLytCqlIPbJjcJUyFDmFH?n,PWCaDpS!hcf,HJYbIPtje'FHdfyUSf
U: SgyGyvTCw&RZiIrW,DbJYIaZMV3Dr-p,?nHR


In [69]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 64
epochs = 100

for epoch in range(epochs):
    xb, yb = get_batch('train', batch_size=batch_size)
    logits, ep_loss = model(xb, yb) # (B*T, C) | (B*T)
    
    print(ep_loss.item())

    model.zero_grad()
    ep_loss.backward()
    optimizer.step()



4.347870826721191
4.3771209716796875
4.391764163970947
4.356290340423584
4.235705375671387
4.277715682983398
4.366168022155762
4.297664165496826
4.2857208251953125
4.288779258728027
4.395979881286621
4.346039295196533
4.2566752433776855
4.280689239501953
4.231051445007324
4.421597957611084
4.360658168792725
4.310080051422119
4.24181604385376
4.304861068725586
4.312815189361572
4.366235733032227
4.301136493682861
4.310800552368164
4.302674293518066
4.243205547332764
4.35451602935791
4.279906272888184
4.303950786590576
4.306549072265625
4.2576680183410645
4.297036647796631
4.31735897064209
4.293255805969238
4.301681995391846
4.308203220367432
4.283802509307861
4.322900295257568
4.272191524505615
4.252161979675293
4.32408332824707
4.350278377532959
4.347704887390137
4.300105571746826
4.279798984527588
4.2316670417785645
4.27830696105957
4.292902946472168
4.256102085113525
4.306896209716797
4.308385848999023
4.3075270652771
4.259124279022217
4.246212959289551
4.281569004058838
4.2902841567

In [71]:
out = model.generate(torch.zeros(1, 1, dtype=torch.long), max_new_tokens=500)

print(decode(out[0].tolist()))


DJd oel wKnvi?n.kL?neh:!w
GSj$KnIHxAQmG&El3jj,ptNUY$hcT&th.ZMtfCMiCQDJDUsN;.OBhCE3HJKyRZ
aCa kIbDnS:cwx!I'x.pYV  syKy.eWka,fCUyUjVAMaqWayfUyKvIfnxwLzitfWyvSP-kVoWi;!D3XBbzwrfinviV3
3Mrer;wrczwBtpmuCeBbzaHZWs3qDJXmM&MF,EPBNUI FZW-b
,VQPdUkIvg;Z?RRwaAo-w?MAF3Klmy3gUUH
UytjhXFMdq-exE:eyvuVbK!utKZ-MNKVqoqXUVXb
cbBYn..ffiV
MzulOU!wOsTdaU.u ?jQTIugscGFjQKoaRMYa3w?nFDCv!;kIxxMjZ
ejeW;xwkjyQD zCNkoVc.
MnX3:Ie$nqxBkfSCaT3n$'wmu,.DWiThqEUmXTCxa:&;neB;V$
MFEzuH?ZHo&pYzUWiSvO-w&QzuA sss-mlsngJ$KXS:XrfaVPzBy
