In [13]:
# download tiny shakespear dataset
import urllib.request
import os

url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
filename = 'tinyshakespeare.txt'
if not os.path.exists(filename):
    urllib.request.urlretrieve(url, filename)
    print('Downloaded %s' % filename)

# pull text from file
with open(filename, 'r') as f:
    text = f.read()

print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [2]:
# get vocab
vocab = list(sorted(set(text)))
vocab_size = len(vocab)
print(f'length of vocab: {len(vocab)}')
print(f"entire vocabulary: {''.join(vocab)}")

length of vocab: 65
entire vocabulary: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [33]:
# character level encoding and decoding
stoi = {c: i for i, c in enumerate(vocab)}
# itos = {i: c for i, c in enumerate(vocab)}
itos = {i: c for c, i in stoi.items()}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])
print(encode('hello there !'))
print(decode(encode('yo whats up')))

[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43, 1, 2]
yo whats up


In [34]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:10])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [5]:
# train test split
train_size = int(0.85 * len(data))
train_data = data[:train_size]
test_data = data[train_size:]
print(f'train size: {len(train_data)}, test size: {len(test_data)}')

train size: 948084, test size: 167310


In [38]:
print(decode(train_data.tolist()[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [6]:
block_size = 8
train_dataset = data[:block_size + 1]
print(train_dataset)

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])


In [7]:
x = train_dataset[:block_size]
y = train_dataset[1:block_size+1]
print(x)
print(y)
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'context: {context}, target: {target}')


tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])
context: tensor([18]), target: 47
context: tensor([18, 47]), target: 56
context: tensor([18, 47, 56]), target: 57
context: tensor([18, 47, 56, 57]), target: 58
context: tensor([18, 47, 56, 57, 58]), target: 1
context: tensor([18, 47, 56, 57, 58,  1]), target: 15
context: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
context: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [8]:
torch.manual_seed(1337)
batch_size = 4 # how many sequences we will process in parallel, each of these sequences is block_size long
block_size = 8 # the length of each sequence

def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

x, y = get_batch('train')
print(x)
print(y)
for xb in range(batch_size):
    for t in range(block_size):
        context = x[xb, :t+1]
        target = y[xb, t]
        print(f'context: {context}, target: {target}')


tensor([ 21203, 291993, 823924,  59812])
tensor([[53, 61, 57, 10,  0, 20, 43,  1],
        [39, 41, 43, 42,  1, 58, 46, 43],
        [52, 41, 43,  8,  0,  0, 24, 17],
        [26, 33, 31, 10,  0, 25, 53, 57]])
tensor([[61, 57, 10,  0, 20, 43,  1, 58],
        [41, 43, 42,  1, 58, 46, 43,  1],
        [41, 43,  8,  0,  0, 24, 17, 27],
        [33, 31, 10,  0, 25, 53, 57, 58]])
context: tensor([53]), target: 61
context: tensor([53, 61]), target: 57
context: tensor([53, 61, 57]), target: 10
context: tensor([53, 61, 57, 10]), target: 0
context: tensor([53, 61, 57, 10,  0]), target: 20
context: tensor([53, 61, 57, 10,  0, 20]), target: 43
context: tensor([53, 61, 57, 10,  0, 20, 43]), target: 1
context: tensor([53, 61, 57, 10,  0, 20, 43,  1]), target: 58
context: tensor([39]), target: 41
context: tensor([39, 41]), target: 43
context: tensor([39, 41, 43]), target: 42
context: tensor([39, 41, 43, 42]), target: 1
context: tensor([39, 41, 43, 42,  1]), target: 58
context: tensor([39, 41, 43, 4

In [9]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token in the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both of shape (batch_size, block_size) aka (B, T)
        logits = self.token_embedding_table(idx) # Batch x time x channel
        if targets is None:
            loss = None
        else:
            # loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) # we could do this, but its hard to understand, so
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) 

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is BxT
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # keep only the last token for each sequence in the batch aka BxC
            probs = F.softmax(logits, dim=-1) # BxC
            next_tokens = torch.multinomial(probs, num_samples=1) # Bx1
            idx = torch.cat([idx, next_tokens], dim=1) # Bx(T+1)
        return idx

m = BigramLanguageModel(vocab_size)
x, y = get_batch('train')
logits, loss = m(x, y)
print(logits.shape)
print(loss)

tensor([771964, 788097, 277203, 183544])
torch.Size([32, 65])
tensor(4.7706, grad_fn=<NllLossBackward0>)


In [10]:
# this loss is horrible, because we are not training the model, we are just randomly initializing it
# -ln(1/vocab_size) = -ln(1/65) = 4.17 so this is a good sanity check to see just how bad our model is, cuz its just guessing randomly

In [11]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), 100)[0].tolist()))
# it'll be complete garbage, cuz we haven't trained it yet


pxMHoRFJa!JKmRjtXzfN:CERiC-KuDHoiMIB!o3QHN
,SPyiFhRKuxZOMsB-ZJhsucL:wfzLSPyZalylgQUEU cLq,SqV&vW:hhi
