In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
filepath = '/content/drive/MyDrive/ML Learning/edited_quotes.txt'
with open(filepath, 'r', encoding='utf-8') as f:
  text = f.read()

In [None]:
print(text[:1000])
print(len(text))

I'm selfish, impatient and a little insecure. I make mistakes, I am out of control and at times hard to handle. But if you can't handle me at my worst, then you sure as hell don't deserve me at my best.
You've gotta dance like there's nobody watching,Love like you'll never be hurt,Sing like there's nobody listening,And live like it's heaven on earth.
You know you're in love when you can't fall asleep because reality is finally better than your dreams.
A friend is someone who knows all about you and still loves you.
Darkness cannot drive out darkness: only light can do that. Hate cannot drive out hate: only love can do that.
We accept the love we think we deserve.
Only once in your life, I truly believe, you find someone who can completely turn your world around. You tell them things that youve never shared with another soul and they absorb everything you say and actually want to hear more. You share hopes for the future, dreams that will never come true, goals that were never achieved 

In [None]:
# The unique characters in the data
vocab = sorted(list(set(text)))
print(vocab)

# The size of the vocab
vocab_size = len(vocab)
print('\n')
print(vocab_size)

['\n', ' ', '!', "'", '(', ')', ',', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


75


In [None]:
# Vocab Mapping with encoder and decoder
test = 'Hello World!'
deco_map = { count:values for count, values in enumerate(vocab)}
enco_map = { values:count for count, values in enumerate(vocab)}
encode = lambda e: [enco_map[x] for x in e]
decode = lambda d: ''.join(deco_map[y] for y in d)

print(encode('This is encoding'))
print(decode(encode('Testing the decoder, 1234567890QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm')))

[40, 56, 57, 67, 1, 57, 67, 1, 53, 62, 51, 63, 52, 57, 62, 55]
Testing the decoder, 1234567890QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm


In [None]:
# Encoding the dataset
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)
print(data[:200])

torch.Size([99202206])
tensor([29,  3, 61,  1, 67, 53, 60, 54, 57, 67, 56,  6,  1, 57, 61, 64, 49, 68,
        57, 53, 62, 68,  1, 49, 62, 52,  1, 49,  1, 60, 57, 68, 68, 60, 53,  1,
        57, 62, 67, 53, 51, 69, 66, 53,  7,  1, 29,  1, 61, 49, 59, 53,  1, 61,
        57, 67, 68, 49, 59, 53, 67,  6,  1, 29,  1, 49, 61,  1, 63, 69, 68,  1,
        63, 54,  1, 51, 63, 62, 68, 66, 63, 60,  1, 49, 62, 52,  1, 49, 68,  1,
        68, 57, 61, 53, 67,  1, 56, 49, 66, 52,  1, 68, 63,  1, 56, 49, 62, 52,
        60, 53,  7,  1, 22, 69, 68,  1, 57, 54,  1, 73, 63, 69,  1, 51, 49, 62,
         3, 68,  1, 56, 49, 62, 52, 60, 53,  1, 61, 53,  1, 49, 68,  1, 61, 73,
         1, 71, 63, 66, 67, 68,  6,  1, 68, 56, 53, 62,  1, 73, 63, 69,  1, 67,
        69, 66, 53,  1, 49, 67,  1, 56, 53, 60, 60,  1, 52, 63, 62,  3, 68,  1,
        52, 53, 67, 53, 66, 70, 53,  1, 61, 53,  1, 49, 68,  1, 61, 73,  1, 50,
        53, 67])


In [None]:
# Splitting into train and dev
n = int(0.9*len(data))
train_data = data[:n]
dev_data = data[n:]

In [None]:
block_size = 8
train_data[:block_size + 1]

tensor([29,  3, 61,  1, 67, 53, 60, 54, 57])

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'when the context is {context} the target is {target}')

when the context is tensor([29]) the target is 3
when the context is tensor([29,  3]) the target is 61
when the context is tensor([29,  3, 61]) the target is 1
when the context is tensor([29,  3, 61,  1]) the target is 67
when the context is tensor([29,  3, 61,  1, 67]) the target is 53
when the context is tensor([29,  3, 61,  1, 67, 53]) the target is 60
when the context is tensor([29,  3, 61,  1, 67, 53, 60]) the target is 54
when the context is tensor([29,  3, 61,  1, 67, 53, 60, 54]) the target is 57


In [None]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
  # generate small batches of data for the model to train on
  data = train_data if split == 'train' else dev_data
  ix = torch.randint(len(data) - block_size, (batch_size, ))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size):
  for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'when the context is {context} the target is {target}')

inputs:
torch.Size([4, 8])
tensor([[53,  6,  1, 57, 68,  1, 67, 53],
        [49, 60, 60, 73,  1, 55, 66, 53],
        [56, 49, 70, 53,  1, 51, 63, 61],
        [ 1, 63, 50, 58, 53, 51, 68, 67]])
targets:
torch.Size([4, 8])
tensor([[ 6,  1, 57, 68,  1, 67, 53, 53],
        [60, 60, 73,  1, 55, 66, 53, 49],
        [49, 70, 53,  1, 51, 63, 61, 53],
        [63, 50, 58, 53, 51, 68, 67,  1]])
----
when the context is tensor([29]) the target is 3
when the context is tensor([29,  3]) the target is 61
when the context is tensor([29,  3, 61]) the target is 1
when the context is tensor([29,  3, 61,  1]) the target is 67
when the context is tensor([29,  3, 61,  1, 67]) the target is 53
when the context is tensor([29,  3, 61,  1, 67, 53]) the target is 60
when the context is tensor([29,  3, 61,  1, 67, 53, 60]) the target is 54
when the context is tensor([29,  3, 61,  1, 67, 53, 60, 54]) the target is 57
when the context is tensor([29]) the target is 3
when the context is tensor([29,  3]) the ta

In [None]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets=None):

    logits = self.token_embedding_table(idx)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx)

      logits = logits[:, -1, :]

      probs = F.softmax(logits, dim=-1)

      idx_next = torch.multinomial(probs, num_samples=1)

      idx = torch.cat((idx, idx_next), dim=1)

    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 75])
tensor(4.9311, grad_fn=<NllLossBackward0>)


In [None]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=400)[0].tolist()))


JlHwJ8kKZ1'GRak8iXaaO59n906gg8Eq ZVBImr9L..zZ
fzYfowYhE[O;WJ9VIc0VzQnF]X49w7C8'N8g
QLJ]Xe1A,
hNHjFHjdNng9cKbdlWOLEZdk?nHPTAzn5u?Kdl6CnDAniz6INAJx1JDFhToSVhg2]FlAL1bCYNMn k!5y(;X'Kn?zSaKTC.
fWhGbZyvWKjkp)nq4Dna2VAwTv]WJO;9Run92RDvvtHEbycEPJlL.Q30Q[;6iz?dX2?nC v6Y(eN'NVLdlsPmUWofw;c'JDPfH,Hj.8r0IJ6zTF6fPy!J8,.a
7
Kxb?1PK8rL?p!e6Yf'7h!]p09M(Is.CL!ITq;hG(w7vv]nex1u?:QrN'
v];XQ(k!S Kw)0,dQ?VnN:gA;h3e8S


In [None]:
# Optimiser
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
for steps in range(10000):

  xb, yb = get_batch('train')

  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())

2.3351197242736816


In [None]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=400)[0].tolist()))


A tis nd aicur I whu d athid dous That treneay ticinth?Ne s))OI k'thaimingheveouthW:9006)verHeeiles I ougiroll Mntou mmasp brovicouruty henthericr kngicanioured wourerre Seicingag ORst. wit won l ing bo t be ou!
Anoutor, fDVI thevestor fove he(Oas ick, t bape alithe coreve tu d.Ace a grellde t WhadoverThe yond tikintrs atrove. isoLWed d s ou, bustoridaveat o Thing cyong benderoffo tal f y HXQg ahe


Self attention maths trick
Get all the prevous letters to talk to the current prediction by averaging all previous letters.

In [None]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [None]:
# Inefficient way
xbow = torch.zeros((B,T,C))
for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1]
    xbow[b,t] = torch.mean(xprev, 0)

In [None]:
# Version 2
wei = torch.tril(torch.ones(T,T))# Tril keeps the bottom triangle only of a matrix (see bellow)
wei = wei/wei.sum(1, keepdim=True)
xbow2 = wei @ x
torch.allclose(xbow, xbow2)

True

In [None]:
# Version 3: use softmax
tril = torch.tril(torch.ones(T,T)) # Tril keeps the bottom triangle only of a matrix
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [None]:
# Version 4: self attention!
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# Single head doing self attention
head_size = 16
value = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T,T)) # Tril keeps the bottom triangle only of a matrix
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
