In [43]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-04-12 03:17:10--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: 'input.txt.5'

     0K .......... .......... .......... .......... ..........  4% 4.78M 0s
    50K .......... .......... .......... .......... ..........  9% 10.3M 0s
   100K .......... .......... .......... .......... .......... 13% 6.35M 0s
   150K .......... .......... .......... .......... .......... 18% 12.4M 0s
   200K .......... .......... .......... .......... .......... 22% 9.68M 0s
   250K .......... .......... .......... .......... .......... 27% 29.4M 0s
   300K .......... .......... .......... .......... .......... 32% 61.1M 0s
   350K .......... 

In [44]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print('text length:', len(text))
# inspect the first 100 characters to sanity check
print(text[:100])

text length: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [45]:
# get the alphabet
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('vocab:', ''.join(chars))
print('vocab size:', vocab_size)

vocab: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [46]:
stoi = {ch:i for i,ch in enumerate(chars)} # map from character to integer
itos = {i:ch for i,ch in enumerate(chars)} # map from integer to character
encode = lambda s: [stoi[c] for c in s] # encode a string into a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # other way around

# sanity check
print('encoded:', text[:10], '->', encode(text[:10]))
print('decoded:', encode(text[:10]), '->', decode(encode(text[:10])))

encoded: First Citi -> [18, 47, 56, 57, 58, 1, 15, 47, 58, 47]
decoded: [18, 47, 56, 57, 58, 1, 15, 47, 58, 47] -> First Citi


In [47]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print("shape:", data.shape, "dtype:", data.dtype)
print(data[:10])
# observe how it's the same as encode(text[:10])

shape: torch.Size([1115394]) dtype: torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [48]:
# split data into training and validaton sets
n = int(len(data)*0.9) # first 90% is training data, remaining is validation data
train_data = data[:n]
val_data = data[n:]

In [49]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [50]:
# the first 9 characters of the training set
# 9 because there are "8" pairs of examples
# this snippet of code illustrates what the 8 examples are

# we train all 8 examples not just because it's computationally convenient and efficient
# this also trains the model on different lengths of text

x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context}, target is {target}")

when input is tensor([18]), target is 47
when input is tensor([18, 47]), target is 56
when input is tensor([18, 47, 56]), target is 57
when input is tensor([18, 47, 56, 57]), target is 58
when input is tensor([18, 47, 56, 57, 58]), target is 1
when input is tensor([18, 47, 56, 57, 58,  1]), target is 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]), target is 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), target is 58


In [51]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # generate 4 random starting indices
    x = torch.stack([data[i:i+block_size] for i in ix]) # construct the inputs
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # construct the targets
    # torch.stack stacks the tensors as rows in a new 4 by 8 tensor
    # 4 rows for 4 examples, 8 columns for 8 characters in each example
    return x, y

xb, yb = get_batch('train')
print('inputs:', xb.shape, xb)
print('targets:', yb.shape, yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()}, target is {target}")
    
# that's 4*8 = 32 training examples in a single batch!

inputs: torch.Size([4, 8]) tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets: torch.Size([4, 8]) tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24], target is 43
when input is [24, 43], target is 58
when input is [24, 43, 58], target is 5
when input is [24, 43, 58, 5], target is 57
when input is [24, 43, 58, 5, 57], target is 1
when input is [24, 43, 58, 5, 57, 1], target is 46
when input is [24, 43, 58, 5, 57, 1, 46], target is 43
when input is [24, 43, 58, 5, 57, 1, 46, 43], target is 39
when input is [44], target is 53
when input is [44, 53], target is 56
when input is [44, 53, 56], target is 1
when input is [44, 53, 56, 1], target is 58
when input is [44, 53, 56, 1, 58], target is 46
when input is [44, 53, 56, 1, 58, 

In [52]:
print(xb) # example input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


In [59]:
# start basic with a bigram model
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        # nn.Embedding is a thin wrapper around a (vocab_size, vocab_size) tensor
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        # idx and targets are both (B, T) tensor of integer
        # (batch_size, block_size)
        # every integer in `idx` going to "pluck" out a row from the embedding table
        logits = self.token_embedding_table(idx) # (B, T, C) where C is channel = vocab_size
        
        # we'll treat logits as the prediction of the next token, where logits[b,t,c] is the guess? probability 
        # that letter decode(c) is the next letter, after seeing letter decode(idx[b,t])
        
        
        if targets is None:
            loss = None
        else:
            # cross entropy loss, the standard loss function for classification
            # if we have multidimensional input, pytorch wants the channels to be the second dimension
            # we don't want to deal with that, so we flatten the batch and time dimensions together
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) 
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices is the current context
        for _ in range(max_new_tokens):
            # get the predictions
            #self(idx) calls the forward method
            logits, loss = self(idx)
            # focus only on the last time step
            # -1 means the last element in the time dimension, we pluck it out
            last_logits = logits[:, -1, :] # (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(last_logits, dim=-1) # (B, C)
            # sample from the distribution
            # each batch dimension will have a single prediction
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append the sampled token to the running sequence
            idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb) # (B, T, C), prediction for each character in xb

print(logits.shape)
print(loss)
# we expect -ln(1/65) loss 

idx = torch.zeros((1,1), dtype=torch.long) # (1,1) tensor with a single 0, 0 is a newline character
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))


torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [61]:
# create a PyTorch optmizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
# AdamW is an advanced and popular optimizer.
# SGD (stochastic gradient descent) is a simpler optimizer that works well too

In [65]:
batch_size = 32
for steps in range(10000):
    # get a batch
    xb, yb = get_batch('train')
    
    # evluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if steps % 1000 == 0:
        print(f'step {steps}, loss {loss.item()}')

step 0, loss 3.1474289894104004
step 1000, loss 2.731760263442993
step 2000, loss 2.6109120845794678
step 3000, loss 2.587460994720459
step 4000, loss 2.5016682147979736
step 5000, loss 2.455209493637085
step 6000, loss 2.3881759643554688
step 7000, loss 2.4599270820617676
step 8000, loss 2.4092462062835693
step 9000, loss 2.3968563079833984


In [67]:
idx = torch.zeros((1,1), dtype=torch.long) # (1,1) tensor with a single 0, 0 is a newline character
print(decode(m.generate(idx, max_new_tokens=300)[0].tolist()))


An, yhe' m ane! :
Hobsad, s IAle h mexK:
We, makisoung, hall-hithin p wate st--
TE:
Ifur met th hire onsteiahe cour, RDilothay Mube t VOLUKERUS:

ABEXExe s s hpr ug y'd it trr,
I lay:
EOMy athaveanghur amex.
Whes:
F: I

Buru.
Serth'Whth:
AUESI isin. thed thmyo as ELI g.
CHUSisthethellend:
GAnd fo,
C
