In [60]:
with open("datasets/tinyshakespear.txt") as f:
    text = f.read()

In [61]:
print(f"Length of dataset in chars {len(text)}")

Length of dataset in chars 1115394


In [62]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [63]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [64]:


tok_to_id = {ch:i for i,ch in enumerate(chars)}
id_to_tok = {i:ch for i,ch in enumerate(chars)}
# tok_to_id["<S>"] = len(tok_to_id)
# tok_to_id["<E>"] = len(tok_to_id)
# id_to_tok[len(id_to_tok)] = "<S>"
# id_to_tok[len(id_to_tok)] = "<E>"

def encode(s):
    if isinstance(s, str) or isinstance(s, list) and isinstance(s[0], str):
        return [tok_to_id[c] for c in s]
    elif isinstance(s[0], list):
        return [encode(ss) for ss in s]
    else:
        return []
    
def decode(l):
    if isinstance(l, int):
        return id_to_tok[l]
    elif isinstance(l, list) and isinstance(l[0], int):
        return ''.join([id_to_tok[i] for i in l])
    elif isinstance(l[0], list):
        return [decode(ll) for ll in l]
    else:
        return []

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [65]:
split = int(0.9*len(text))
train_set = text[:split]
test_set = text[split:]

In [66]:
import torch
from torch.nn import functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = torch.tensor(encode(train_set), dtype=torch.long)
test_data = torch.tensor(encode(test_set), dtype=torch.long)
CONTEXT_SIZE = 64

print(train_data.shape, train_data.dtype)

torch.Size([1003854]) torch.int64


In [67]:
def get_batch(split, batch_size=64):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - CONTEXT_SIZE, (batch_size,))
    x = torch.stack([data[i:i+CONTEXT_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+CONTEXT_SIZE+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [68]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    model.eval()
    out = {}
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        for _ in range(eval_iters):
            x, y = get_batch(split)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
            losses[_] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

In [69]:
x_batch_test, y_batch_test = get_batch('train', 8)
print("inputs:")
print(decode(x_batch_test.tolist()))
print(x_batch_test.shape)
print("targetas:")
print(decode(y_batch_test.tolist())) 
print(y_batch_test.shape)

inputs:
["ight in your defence:\nUnsheathe your sword, good father; cry 'Sa", 'ith\nthe palsied intercession of such a decayed dotant as\nyou see', 'd,\nWith too much riches it confound itself:\nHad he done so to gr', 'e myself wrong, have I not?\n\nSecond Gentleman:\nYes, that thou ha', 'one may drink, depart,\nAnd yet partake no venom, for his knowled', ' in the world,\nHe were as much more villain: you, my lord,\nDo bu', "vanity--\nSo it be new, there's no respect how vile--\nThat is not", 'all have\nyour full time of imprisonment and your deliverance\nwit']
torch.Size([8, 64])
targetas:
["ght in your defence:\nUnsheathe your sword, good father; cry 'Sai", 'th\nthe palsied intercession of such a decayed dotant as\nyou seem', ',\nWith too much riches it confound itself:\nHad he done so to gre', ' myself wrong, have I not?\n\nSecond Gentleman:\nYes, that thou has', 'ne may drink, depart,\nAnd yet partake no venom, for his knowledg', 'in the world,\nHe were as much more villain: you, m

In [70]:
from model import SimpleModel

model = SimpleModel(vocab_size)

In [None]:
train_loops = 5000
batch_size = 64

model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss = torch.tensor(0.0) 

for i in range(train_loops):

    x_batch, y_batch = get_batch('train', batch_size)

    logits = model(x_batch)

    B, T, C = logits.shape 

    logits = logits.view(-1, C)
    y_batch = y_batch.view(-1)

    # prob = F.log_softmax(logits, dim=-1)
    # y_prob = torch.gather(prob, -1, y_batch.unsqueeze(-1)).squeeze()
    # loss = -y_prob.mean()
  
    loss = F.cross_entropy(logits, y_batch)

    optimizer.zero_grad()
    
    loss.backward()

    optimizer.step()
   

    if i%200 == 0:
        losses = estimate_loss(model, 200)
        print(f"step {i}: train loss {losses['train']:.4f}, test loss {losses['test']:.4f}")

print("Final Loss: ", loss.item())



step 0: train loss 4.6311, test loss 4.6395
step 200: train loss 2.8955, test loss 2.9079
step 400: train loss 2.5567, test loss 2.5759
step 600: train loss 2.4976, test loss 2.5214
step 800: train loss 2.4821, test loss 2.5037
step 1000: train loss 2.4724, test loss 2.4957
step 1200: train loss 2.4684, test loss 2.4912
step 1400: train loss 2.4660, test loss 2.4907
step 1600: train loss 2.4626, test loss 2.4937
step 1800: train loss 2.4617, test loss 2.4899
step 2000: train loss 2.4597, test loss 2.4867
step 2200: train loss 2.4602, test loss 2.4856
step 2400: train loss 2.4565, test loss 2.4860
step 2600: train loss 2.4576, test loss 2.4872
step 2800: train loss 2.4554, test loss 2.4833
step 3000: train loss 2.4567, test loss 2.4868
step 3200: train loss 2.4577, test loss 2.4865
step 3400: train loss 2.4546, test loss 2.4813
step 3600: train loss 2.4570, test loss 2.4881
step 3800: train loss 2.4551, test loss 2.4873
step 4000: train loss 2.4558, test loss 2.4864
step 4200: train los

In [73]:
length = 1000

curr_tok = torch.randint(vocab_size, (1,)).to(device)
print(curr_tok)
for i in range(length):
    
    next_logits = model(curr_tok)
    # print(next_logits.shape)
    next_prob = F.softmax(next_logits, dim=-1)
    # print(next_prob.shape)
    next_tok = torch.multinomial(next_prob.squeeze(), num_samples=1)
    # print(next_tok)
    print(decode(next_tok.tolist()), end="")
    curr_tok = next_tok


tensor([49], device='cuda:0')
sto'TI sa.
Ye,

ofoyove itound t t ong'this Silineesp'd, rkstoura mourg nto'dove, thanouse thicerell y ORGry, y.
G st win e hond routhot'er litimavechind m, t, II th ditepousthe,
Fabr tt malls nd thafo!

Ty s qulantawik,
O:
Wherlade I I e wofo fer ss avilld Y:

AS:
ARDUSTh frad, teveretifigr b wenll llow'n heithte weas wigon thonk vidastigaverwerd quo o we, whessenawhiath oryin m nororeve:

ENII blliligut at RDomonon w lo what y'se an w.
YO:
G w,
Wee inatheathis anen into Wietuk heday cig he myondedeint, onewshe Exp sanesuthelvy lindit, th y-wid.
COMIUS to!
ROUM:
LOrauby bore tomed werererot had:
Ththe.
ARUENoee us pars,
INGoinoutar ond h martordanshehargllisboruie
TI,
S:
Rem t l t and borbe gre ibeam ashacke pren sor aft youe g m than y t wer the chal, s allely fe OMENUTrf t RDUKI:

d?
Fot singe, ct ntt agn,
MPUL:
offausw d aia top o tre.
I:



D r mult y w y.
Whe bavindof hinins aton oupothestlwalinje atthon y bashes tXFithet'lea y,
m.

A:
I ce, niuserto