In [1]:
import os
file_path = os.path.join('.\data\shakespear', 'input.txt')

In [2]:
with open(file_path, 'r', encoding = 'utf-8') as f:
    text_data = f.read()
print(text_data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
chars = sorted(set(text_data))
ttl_chars = len(chars)
print(ttl_chars)

65


In [4]:
# Building an encoder
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda t: ''.join(int_to_string[i] for i in t)

In [5]:
encode('hello')
decode([46,43,50,50,53])

'hello'

In [6]:
# Convert to tensor
import torch
data = torch.tensor(encode(text_data))
data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [7]:
# Training and validation split
n = int(0.9*len(data))
train_dataset = data[:n]
test_dataset = data[n:]
print(train_dataset.shape)

torch.Size([1003854])


In [8]:
context_length = 8
train_dataset[:context_length]

tensor([18, 47, 56, 57, 58,  1, 15, 47])

In [9]:
for i in range(context_length):
    print(f'input:\t{train_dataset[:i+1]}, target:{train_dataset[i+1]}')

input:	tensor([18]), target:47
input:	tensor([18, 47]), target:56
input:	tensor([18, 47, 56]), target:57
input:	tensor([18, 47, 56, 57]), target:58
input:	tensor([18, 47, 56, 57, 58]), target:1
input:	tensor([18, 47, 56, 57, 58,  1]), target:15
input:	tensor([18, 47, 56, 57, 58,  1, 15]), target:47
input:	tensor([18, 47, 56, 57, 58,  1, 15, 47]), target:58


In [10]:
batch_size = 32
def get_batch(tag):
    data = train_dataset if tag == 'train' else test_dataset
    random_select = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[random_select[i]:random_select[i]+context_length] for i in range(batch_size)])
    y = torch.stack([data[random_select[i]+1:random_select[i]+context_length+1] for i in range(batch_size)])
    return x, y


In [11]:
xa, ya = get_batch('train') 

import torch.nn as nn
from torch.nn import functional as F
class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed_table = nn.Embedding(vocab_size, vocab_size)
    def forward(self, x):
        logits = self.embed_table(x)
        return logits
    def generate(self, idx, max_token):
        for _ in range(max_token):
            # get the predictions
            logits = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
model = BigramLM(ttl_chars)
pred = model(xa)
Bx, Tx, Cx = pred.shape
pred = pred.view(Bx*Tx, Cx)

By, Ty= ya.shape
ya = ya.view(By* Ty)

loss = F.cross_entropy(pred, ya)
preds = model.generate(torch.zeros([1,1], dtype=torch.long), max_token=100)[0].tolist()
print(decode(preds))


YQ:MQDwlGI&iMy''SukDweCFnuOZK.tl'mvuULODKqaLBvnpHRgD
YeT3.ElAT:FlewhBkiEl,rQp&RGaOME V
Yqbqrmn'rpSgS


In [12]:
iters = 400000
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
batch_size = 16
counter = 0
print('Training started....')
for itere in range(iters):
    xb, yb = get_batch('train')
    pred = model(xb)
    Bx, Tx, Cx = pred.shape
    pred = pred.view(Bx*Tx, Cx)

    By, Ty= yb.shape
    yb = yb.view(By* Ty)

    loss = F.cross_entropy(pred, yb)
    optimizer.zero_grad(set_to_none= True)
    loss.backward()
    optimizer.step()
    
    if(itere%5000 == 0): 
        print(f'[{counter}]loss: {loss.item()}')
        counter+=1
print('Training Done')

Training started....
[0]loss: 4.453802585601807
[1]loss: 2.575896739959717
[2]loss: 2.557229518890381
[3]loss: 2.477609634399414
[4]loss: 2.4003138542175293
[5]loss: 2.322665214538574
[6]loss: 2.469020128250122
[7]loss: 2.338414430618286
[8]loss: 2.4343395233154297
[9]loss: 2.6186931133270264
[10]loss: 2.3292877674102783
[11]loss: 2.584660768508911
[12]loss: 2.4424679279327393
[13]loss: 2.4709999561309814
[14]loss: 2.474611759185791
[15]loss: 2.441274642944336
[16]loss: 2.5717992782592773
[17]loss: 2.4400172233581543
[18]loss: 2.395275831222534
[19]loss: 2.48649525642395
[20]loss: 2.3784546852111816
[21]loss: 2.45674729347229
[22]loss: 2.253178358078003
[23]loss: 2.340001344680786
[24]loss: 2.5674846172332764
[25]loss: 2.5497093200683594
[26]loss: 2.4975435733795166
[27]loss: 2.4589734077453613
[28]loss: 2.2757670879364014
[29]loss: 2.418140172958374
[30]loss: 2.4019079208374023
[31]loss: 2.3905115127563477
[32]loss: 2.3215482234954834
[33]loss: 2.422468662261963
[34]loss: 2.4729759693

In [13]:
preds = model.generate(torch.ones([1,1], dtype=torch.long), max_token=500)[0].tolist()
print(decode(preds))

 th cenornen.
Thated;
Or besatlepll sthin nut ot thilke weay s y y tse
Thtos' tonof s h hicthy.
TENGot cke f bale m, mecerarrot s y CUKERCKIUCK:
Hof me ir heoucl incren, t spolowof sungrdinderto bllke cem?
O:
I e, d,-pann fave oucorou ithe dur is

RO:
Wolizele,

CIAPe, rclithemy cearee wstil ws ie nt in KETh'd che he ot he manmye belitl l



Bu tin

THathande bup:
Sodo g
Wed indserpl,
O: ille
'dierowaidrchone masus je POfo anour?
Ange he hive hithepalesustealugot'mis.
T:

BRI:
Twhil a t fthoue l 
