In [None]:
pip install torch

In [13]:
with open('input.txt', 'r', encoding='utf-8') as FILE:
    raw_text = FILE.read()

characters = set('\ ,!?.\n0123456789abcdefghijklmnopqrstuvwxyz')
raw_text = ''.join('none' if char.lower() not in characters else char for char in raw_text)

chars = sorted(set(raw_text))
vocab_size = len(chars)

stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [14]:
import torch
data = torch.tensor(encode(raw_text), dtype=torch.long)

In [15]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [16]:
torch.manual_seed(1337)
block_size = 8
batch_size = 4

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i: i+block_size]for i in ix])
    y = torch.stack([data[i+1: i+1+block_size]for i in ix])
    return x,y
    
xb, yb = get_batch('train')

In [17]:
# bigram lang model
import torch
import torch.nn  as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits,loss = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb,yb)

# Training

In [18]:
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(5000):
    xb, yb = get_batch('train')
    logits, loss = m(xb,yb)
    optimiser.zero_grad(set_to_none=True)
    loss.backward()
    optimiser.step()
    print(loss.item())

4.731822490692139
4.690666198730469
4.699487686157227
4.7164764404296875
4.703254699707031
4.756875991821289
4.776569366455078
4.7360944747924805
4.840986728668213
4.620326042175293
4.701050758361816
4.703077793121338
4.672006607055664
4.694028854370117
4.766556262969971
4.752933025360107
4.706780433654785
4.753542423248291
4.85721492767334
4.643590927124023
4.788605690002441
4.6574788093566895
4.779672145843506
4.707526206970215
4.671250343322754
4.668455123901367
4.657678127288818
4.727478981018066
4.672801971435547
4.704646110534668
4.782520294189453
4.706284523010254
4.523613929748535
4.650712966918945
4.660775184631348
4.628297328948975
4.706173419952393
4.666257858276367
4.711015701293945
4.610829830169678
4.734525203704834
4.5500311851501465
4.796675205230713
4.747440814971924
4.53261137008667
4.544508934020996
4.685528755187988
4.623645305633545
4.596058368682861
4.660975933074951
4.623535633087158
4.735873699188232
4.6153645515441895
4.59529447555542
4.6860809326171875
4.64052

# Outputting

In [19]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, 500)[0].tolist()))


We thoAaddr thutine e tnowe ymif, shvonzjVISpHI hese wo egeAhQPKnn



G! so abund. seir, OIngenoun3MIs, bauthitslRreNCJUnon istorir, arnene hinklll rI wranonon3k,
AInood weary, izAUSEx.kWhe
YO the etilYzBe phiawhe

CI chanorstank seldUpedn wo fonvyru illlJx?ced re.QPn wh aveazban svizYOndsca hoAngv
GIZsive

JomyboursMizMas ne secalpinodovWine t
Qy, bsep. tende enPlen t he fa st,

BBesunoundelonjonZYo th.
Jhontce arOZVQEOff usonthethinoueit dyXalinowilath, hoe
HEELomur

I gad touPwe alereme t,nob
