In [6]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [7]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()

In [8]:
print('Number of characters:', len(text))

Number of characters: 1115394


In [9]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [10]:
# unique chars that occur in text, including newline char \n
chars = sorted(list(set(text)))
vocab_size = len(chars) # basically the number of tokens; trade off between number of tokens and length of encoded strings
print('Vocabulary size:', vocab_size)
print(''.join(chars))

Vocabulary size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


**Tokenising**

Other tokenisers include [SentencePiece](https://github.com/google/sentencepiece) and [Tiktoken](https://github.com/openai/tiktoken) (used by ChatGPT).

In [11]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda string: [stoi[c] for c in string]
decode = lambda encoded: ''.join([itos[i] for i in encoded])

print('string: hello world!')
print('encoded:', encode('hello world!'))
print('decoded:', decode(encode('hello world!')))

string: hello world!
encoded: [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
decoded: hello world!


In [12]:
data = torch.tensor(encode(text), dtype=torch.long)
print('Shape of training/dev/test data:', data.shape) # just a long one dimensional tensor
print(data[:100])

Shape of training/dev/test data: torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [13]:
n = int(0.8 * len(data))
training_data = data[:n]
val_data = data[n:]

In [14]:
context_len = 8
data[:context_len+1] # first training data

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [15]:
X = data[:context_len]
y = data[1:context_len+1]
for i in range(context_len):
    input_ = X[:i+1]
    target = y[i].item()
    print(f'input: {input_}, target: {target}')

input: tensor([18]), target: 47
input: tensor([18, 47]), target: 56
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [16]:
torch.manual_seed(1337)

batch_size = 4
context_len = 8

def get_batch(split):
    data = training_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_len, (batch_size, )) # ix never reaches the end of the tensor
    X = torch.stack([data[i:i+context_len] for i in ix])
    y = torch.stack([data[i+1:i+context_len+1] for i in ix])
    return X, y

Xb, yb = get_batch('train')

print('inputs:')
print(Xb.shape)
print(Xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[58, 63,  8,  0,  0, 19, 24, 27],
        [39, 59, 45, 46, 58,  1, 46, 43],
        [49, 43, 57,  1, 53, 50, 42,  1],
        [52, 41, 47, 43, 52, 58,  1, 56]])
targets:
torch.Size([4, 8])
tensor([[63,  8,  0,  0, 19, 24, 27, 33],
        [59, 45, 46, 58,  1, 46, 43,  1],
        [43, 57,  1, 53, 50, 42,  1, 46],
        [41, 47, 43, 52, 58,  1, 56, 47]])


**Batch and time dimensions**

If each input has a context length of T (so named because it's the 'time' dimension: the history the model can refer to per input) and we have a batch of B inputs, we actually have T $\times$ B input-target pairs.

In [17]:
for b in range(batch_size): # batch dimension
    for t in range(context_len): # time dimension
        context = Xb[b, :t+1]
        target = yb[b, t]
        print(f'context: {context}, target: {target}')

context: tensor([58]), target: 63
context: tensor([58, 63]), target: 8
context: tensor([58, 63,  8]), target: 0
context: tensor([58, 63,  8,  0]), target: 0
context: tensor([58, 63,  8,  0,  0]), target: 19
context: tensor([58, 63,  8,  0,  0, 19]), target: 24
context: tensor([58, 63,  8,  0,  0, 19, 24]), target: 27
context: tensor([58, 63,  8,  0,  0, 19, 24, 27]), target: 33
context: tensor([39]), target: 59
context: tensor([39, 59]), target: 45
context: tensor([39, 59, 45]), target: 46
context: tensor([39, 59, 45, 46]), target: 58
context: tensor([39, 59, 45, 46, 58]), target: 1
context: tensor([39, 59, 45, 46, 58,  1]), target: 46
context: tensor([39, 59, 45, 46, 58,  1, 46]), target: 43
context: tensor([39, 59, 45, 46, 58,  1, 46, 43]), target: 1
context: tensor([49]), target: 43
context: tensor([49, 43]), target: 57
context: tensor([49, 43, 57]), target: 1
context: tensor([49, 43, 57,  1]), target: 53
context: tensor([49, 43, 57,  1, 53]), target: 50
context: tensor([49, 43, 57,

In [18]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, ix, targets=None):
        # idx is a tensor of size (B,T)
        logits = self.token_embedding_table(ix) # call to embedding layer just indexs at idx and returns
        # logits is made up of B batches of matrices T rows long, each row vocab_size wide, i.e. (B,T,vocab_size) 
        
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            # targets is also (B,T)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, ix, max_tokens):
        for _ in range(max_tokens):
            logits, loss = self(ix)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            ix_next = torch.multinomial(probs, num_samples=1)
            ix = torch.cat((ix, ix_next), dim=1)
        return ix

m = BigramLanguageModel(vocab_size)
out, loss = m(Xb, yb)
print(out.shape)
print(loss)

ix = torch.zeros((1,1), dtype=torch.long)
encoded = m.generate(ix, max_tokens=500)
print(encoded.shape)
print(decode(encoded[0].tolist()))

torch.Size([32, 65])
tensor(4.6627, grad_fn=<NllLossBackward0>)
torch.Size([1, 501])

l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq!LznIeJydZJSrFSrPLR!:VwWSmFNxbjPiNYQ:sry,OfKrxfvJI$WS3JqCbB-TSQXeKroeZfPL&,:opkl;Bvtz$LmOMyDjxxaZWtpv,OxZQsWZalk'uxajqgoSXAWt'e.Q$.lE-aV
;spkRHcpkdot:u'-NGEzkMPy'hZCWhv.w.q!f'mOxF&IDRR,x
?$Ox?xj.BHJsGhwVtcuyoMIRfhoPL&fg-NwJmOQalcEDveP$IYUMv&JMHkzd:O;yXCV?wy.RRyMys-fg;kHOB EacboP g;txxfPL
NTMlX'FNYcpkHSGHNuoKXe..ehnsarggGFrSjIr!SXJ?KeMl!.?,MlbDP!sfyfBPeNqwjLtIxiwDDjSJzydFm$CfhqkCe,n:kyRBubVbxdojhEz


In [19]:
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [178]:
batch_size = 32
epochs = 100

for e in range(epochs):
    Xb, yb = get_batch('train')
    logits, loss = m(Xb, yb)
    optimiser.zero_grad(set_to_none=True)
    loss.backward()
    optimiser.step()

print(loss.item())

2.3259451389312744


In [182]:
ix = torch.zeros((1,1), dtype=torch.long)
encoded = m.generate(ix, max_tokens=500)
print(decode(encoded[0].tolist()))



Tu agomeispto wesaif ot, bones ' moa
LIsesfoltront sts
hint! mprd sthivit nd
otor LA:
HRID:
N OPAprl atauravelss.
I re, o tin pot,
IDUSp al t alet hife harmy a gat t er atl ss
ARI mastourardwss,
I:
thastinstcacet,
we tyeramo'Thawhe arte sco clore ors t'sth;
Dordmmarerliry inires l keray fonds me ourotea t itan K:
Tharye a youne ceay ouatheat hechact dem? yofopy d, aveit napas y d paven:


IIOLinear t me me buthil de ompid I ureng thr dghor! fr'semeng.

trd,OXFounde; thend.
ABOMA thifame,
WARYou


In [184]:
B, T, C = 4, 8, 26
tok_emb = torch.randn((B,T,C))
head_size = 16
keys = nn.Linear(C, head_size, bias=False)
queries = nn.Linear(C, head_size, bias=False)
k = keys(tok_emb)
q = queries(tok_emb)
tril = torch.tril(torch.ones(T,T))
W = torch.zeros((T,T))
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=1)
out = W @ tok_emb
out.shape

torch.Size([4, 8, 26])