In [1]:
text = open('input.txt', 'r', encoding='utf-8').read()

In [2]:
print(f'Length of the characters is: {len(text)}')

Length of the characters is: 1115394


In [3]:
text[:500]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor"

In [4]:
chars = sorted(list(set(text)))
n_vocabs = len(chars)
print(f'Number of chars: {n_vocabs}')
print(''.join(chars))

Number of chars: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


Encode and decode methods

In [5]:
# Encode and decode
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [6]:
print(encode('hello'))
print(decode(encode('hello')))

[46, 43, 50, 50, 53]
hello


Encode the whole text data

In [7]:
import torch

In [8]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [9]:
# Train test split
split = int(0.9*len(data))
train_data = data[:split]
test_data = data[split:]

In [10]:
train_data.shape, test_data.shape

(torch.Size([1003854]), torch.Size([111540]))

Blocks and batches

In [11]:
torch.manual_seed(1337)
block_size = 8
batch_size = 4

def get_batch(split):
    data = train_data if split == 'train' else test_data
    rand_idx = torch.randint(len(data) - block_size, size=(batch_size,))
    x = torch.stack([data[idx:idx+block_size] for idx in rand_idx])
    y = torch.stack([data[idx+1:idx+block_size+1] for idx in rand_idx])
    return x, y

x, y = get_batch('train')
print('Input:')
print(x)
print(f'Input shape:{x.shape}')
print('-----------------------------------------')
print('Output:')
print(y)
print(f'Output shape:{y.shape}')

Input:
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Input shape:torch.Size([4, 8])
-----------------------------------------
Output:
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
Output shape:torch.Size([4, 8])


### Baseline model

In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [13]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        
        logits = self.token_embedding_table(idx) # (B*T*C)

        if targets == None:
            loss = None
        else:
            # Cross entropy needs C to be the second dimension
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx) 
            # Get only the last token
            logits = logits[:, -1, :] # (B * C)
            # Get probs using softmax
            probs = F.softmax(logits,dim=-1) # (B * C)
            # Generate the next index
            next_idx = torch.multinomial(probs, num_samples=1) # (B * 1)
            # appended the index to the original input so that it can be used to generate the next index
            idx = torch.cat((idx, next_idx), dim=1) # (B * T+1)
        
        return idx


In [14]:
torch.manual_seed(1337)

bigram_baseline_model = BigramLanguageModel(vocab_size=n_vocabs)
logits, loss = bigram_baseline_model(x, y)
print(f'logits shape: {logits.shape}')
print(f'loss: {loss}')

idx = torch.zeros((1,1), dtype=torch.long)
print('Generated Texts:')
print(decode(bigram_baseline_model.generate(idx, 100)[0].tolist()))

logits shape: torch.Size([32, 65])
loss: 4.878634929656982
Generated Texts:

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [15]:
optimizer = torch.optim.AdamW(bigram_baseline_model.parameters(), lr=1e-3)

In [16]:
batch_size = 32
for steps in range(5000):
    xb, yb = get_batch('train')

    logits, loss = bigram_baseline_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 100 == 0:
        print(loss.item())

4.692410945892334


4.621085166931152
4.549462795257568
4.345612049102783
4.255732536315918
4.214480876922607
4.124096870422363
3.9863951206207275
3.9517807960510254
3.837888717651367
3.7637593746185303
3.6824679374694824
3.533822536468506
3.513597011566162
3.4971799850463867
3.3378093242645264
3.3668529987335205
3.2826082706451416
3.1327052116394043
3.160910129547119
3.2342259883880615
2.9978365898132324
3.094273090362549
2.9780406951904297
2.890953779220581
2.9391205310821533
2.8254294395446777
2.921311378479004
2.886559247970581
2.8697657585144043
2.892245292663574
2.7563703060150146
2.6004953384399414
2.627633810043335
2.7147138118743896
2.718297004699707
2.714982748031616
2.606290817260742
2.723785161972046
2.606304883956909
2.703908920288086
2.7407634258270264
2.6153857707977295
2.692572593688965
2.5623557567596436
2.6690523624420166
2.595306396484375
2.5762505531311035
2.5814590454101562
2.531094789505005


In [19]:
print(decode(bigram_baseline_model.generate(idx, 300)[0].tolist()))


Gao in usate't we cksw,
JzPY:
Sof m Vbs, hatarakis,bereFotomampure,,
W:CIN wlflin: ay ced isordwhau'TI w!AUCUNSome! b!
nfry andilk!an!
DITh
If iloinoth hithcot; e zCAr,
june, thes aithak;E:

Sen ing ve ce athly wnd hrt ve teogs se.
VOUMpbe havefulpimngUFLUGott and:
ARIUSa-PHEENV
PE:
Ap arotegnYBupre
