In [57]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import tqdm

In [64]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# Read the file 
file_path = 'data/tiny_shakespeare.txt'
with open(file_path, 'r') as f:
    text = f.read()

# Show some samples read from the text
print(f'Length of dataset in characters:{len(text)}')
print(text[:100])

Length of dataset in characters:1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [5]:
# Obtain all unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


## Tokenization

In [6]:
# Creating a mapping from characters to integers
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }

# Encode and decode function to convert integers and characters
encode = lambda s: [stoi[c] for c in s] # encode: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decode: take a list of integers, output a string

print(encode("Hello, world!"))
print(decode(encode("Hello, world!")))

[20, 43, 50, 50, 53, 6, 1, 61, 53, 56, 50, 42, 2]
Hello, world!


In [7]:
# Encode the entire text dataset and store it into a torch.tensor
text_tensor = torch.tensor(encode(text), dtype=torch.long)

# Show tensor shape and data type in the text tensor
print(f'text tensor shape: {text_tensor.shape}')
print(f'data type in tensor: {text_tensor.dtype}')
print(text_tensor[:100])

text tensor shape: torch.Size([1115394])
data type in tensor: torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


## Split the dataset

In [8]:
def split_dataset(dataset, ratio=0.9):
    """
    Split the dataset into a training set and a testing set

    params:
    @dataset: the dataset for spliting 
    @ratio: the ratio of training data in the entire dataset, in default 0.9

    return:
    @train_data
    @test_data
    """
    split_point = int(ratio * len(dataset))
    return dataset[:split_point], dataset[split_point:]

In [9]:
train_text, test_text = split_dataset(text_tensor)
print(f'The size of train text: {train_text.shape}')
print(f'The size of test text: {test_text.shape}')

The size of train text: torch.Size([1003854])
The size of test text: torch.Size([111540])


## Build the training pair

In [10]:
block_size = 8 # The sliding windows for scanning the training text
sample_text = train_text[:block_size]
next_sample_text = train_text[1:block_size+1]
print(sample_text)
print(next_sample_text)

tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])


In [11]:
# The training pair is to make the input and target from the sample text in the following form
# 18 => 47
# 18, 47 => 56
# ...
# 18, 47, 56, 57, 58,  1, 15 => 47
x = train_text[ :block_size]
y = train_text[1:block_size+1]

# index: 0   1   2   3   4   5   6   7
#   x: [18, 47, 56, 57, 58,  1, 15, 47]
#   y: [47, 56, 57, 58,  1, 15, 47, 58]
for index in range(block_size):
    context = x[:index+1]
    target = y[index]
    print(f'input: {context}, output: {target}')

input: tensor([18]), output: 47
input: tensor([18, 47]), output: 56
input: tensor([18, 47, 56]), output: 57
input: tensor([18, 47, 56, 57]), output: 58
input: tensor([18, 47, 56, 57, 58]), output: 1
input: tensor([18, 47, 56, 57, 58,  1]), output: 15
input: tensor([18, 47, 56, 57, 58,  1, 15]), output: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), output: 58


In [12]:
def get_batch(dataset, block_size, batch_size):
    """
    Create mutiple batch and stack them up into a tensor with size (batch_size, block_size)

    params:
    @dataset: in which dataset for getting batch
    @block_size: 
    @batch_size:

    return:
    @x: a tensor in (batch_size, block_size)
    @y: a tensor in (batch_size, block_size)
    """
    # Randomly select (batch_size,) starting indexes of each training batch in the text 
    ix = torch.randint(low=0, high=(len(dataset)-block_size), size=(batch_size,))
    x = torch.stack([dataset[i  :i+block_size  ] for i in ix])
    y = torch.stack([dataset[i+1:i+block_size+1] for i in ix])
    return x, y

In [13]:
torch.manual_seed(1337)

batch_size = 4
block_size = 8

inputs, targets = get_batch(train_text, block_size, batch_size)

In [14]:
print('inputs:')
print(inputs.shape)
print(inputs)
print('targets:')
print(targets.shape)
print(targets)

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [15]:
for b in range(batch_size):
    for i in range(block_size):
        context = inputs[b, :i+1]
        target = targets[b, i]
        print(f'inputs: {context}, output: {target}')

inputs: tensor([24]), output: 43
inputs: tensor([24, 43]), output: 58
inputs: tensor([24, 43, 58]), output: 5
inputs: tensor([24, 43, 58,  5]), output: 57
inputs: tensor([24, 43, 58,  5, 57]), output: 1
inputs: tensor([24, 43, 58,  5, 57,  1]), output: 46
inputs: tensor([24, 43, 58,  5, 57,  1, 46]), output: 43
inputs: tensor([24, 43, 58,  5, 57,  1, 46, 43]), output: 39
inputs: tensor([44]), output: 53
inputs: tensor([44, 53]), output: 56
inputs: tensor([44, 53, 56]), output: 1
inputs: tensor([44, 53, 56,  1]), output: 58
inputs: tensor([44, 53, 56,  1, 58]), output: 46
inputs: tensor([44, 53, 56,  1, 58, 46]), output: 39
inputs: tensor([44, 53, 56,  1, 58, 46, 39]), output: 58
inputs: tensor([44, 53, 56,  1, 58, 46, 39, 58]), output: 1
inputs: tensor([52]), output: 58
inputs: tensor([52, 58]), output: 1
inputs: tensor([52, 58,  1]), output: 58
inputs: tensor([52, 58,  1, 58]), output: 46
inputs: tensor([52, 58,  1, 58, 46]), output: 39
inputs: tensor([52, 58,  1, 58, 46, 39]), output

In [44]:
class BigramLanguageModel(nn.Module):
    """
    Use bigram model to predict next token in the text
    for example, given a encoded batch [43, 58,  5, 57,  1, 46, 43, 39]
    the bigram means to extrac pairs like 43=>58, 58=>5, 5=>57, ... , 43=>39 
    The prediction of next token only depends on the previous one 
    """
    def __init__(self, vocab_size) -> None:
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None):
        # both inputs and targets are tensor in shape (batch_size, block_size)
        # The logits can be regarded as a score table of the next token in the entire vocabulary set  
        logits = self.token_embedding_table(inputs) # (B, T, C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) # Use the cross_entropy to compute the loss

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
            

In [45]:
torch.manual_seed(1337)

blm = BigramLanguageModel(vocab_size)
logits, loss = blm(inputs, targets)
print(logits.shape)
print(loss.item())

torch.Size([32, 65])
4.878634929656982


In [52]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated_tensor = blm.generate(idx, max_new_tokens=100).squeeze().tolist()
decode(generated_tensor)

"\nDphq-.IsCwbjxca;P-KA:r'a;pJ&q-UgOEX.cAO-p,lQ?nEsrlvmUgbEQLQh,j;iPlgZR:CJpxIBju f&!BBEHSPmnq,P -d\npju"

In [53]:
# Create a pytorch optimizer
optimizer = torch.optim.AdamW(blm.parameters(), lr=1e-3)

In [54]:
batch_size = 32
for step in tqdm(range(100000)):
    # sample from the train text
    xb, yb = get_batch(train_text, block_size=8, batch_size=batch_size)

    # evaluate the loss
    logits, loss = blm(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % 5000 == 0:
        print(f'step: {step}, loss: {loss.item()}')

4.716796875
3.774552583694458
3.1031713485717773
2.7186434268951416
2.6022653579711914
2.569850444793701
2.5617449283599854
2.494694471359253
2.652863025665283
2.6305291652679443
2.441178321838379
2.4715094566345215
2.3911445140838623
2.4558699131011963
2.408242702484131
2.410507917404175
2.462000846862793
2.57413649559021
2.434818744659424
2.35010027885437
2.3581395149230957
2.506747245788574
2.424957752227783
2.4469985961914062
2.484031915664673
2.4075565338134766
2.521646738052368
2.4890687465667725
2.5184054374694824
2.3834164142608643
2.665809392929077
2.5752968788146973
2.3176233768463135
2.4029269218444824
2.4619038105010986
2.550365447998047
2.4382266998291016
2.480177640914917
2.5165863037109375
2.295788288116455
2.5459280014038086
2.414959669113159
2.468573808670044
2.4739668369293213
2.5384068489074707
2.5774240493774414
2.5335240364074707
2.456674575805664
2.514296531677246
2.344801902770996
2.3552749156951904
2.4615910053253174
2.4427225589752197
2.348740816116333
2.437067

In [62]:
print(decode(blm.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500).squeeze().tolist()))



Na m nesoure d:
Gie g
Nougore ENor,'dinowinowintrur:
Hairsivan's fatiren ly,
ABr berther p anstepre?
A:
Angsursupo---
TItithanfo tatonound lerouple thire phal th ie, ivee brm lf ancin,
KIILAUSThine th thinode sisur til,
Mundellouit Loutly t heak;
Lothitomy tif hes speath dothas n t I an.
GBI:'st wit on hulye, d soo wiqus ilou y We t ave th rhen me. tit weroutcat be dont contingeef neno rar char ngs lordve berple farthtoncr
Whue ixt fot,
ANANER prer h?
Whalous pere?
'Trimo,
Whourend, pend s ang 
