# Environment

In [81]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare the data

In [82]:
import gdown
import os

# Download URL, the shakespeare.txt
url = f'https://drive.google.com/uc?id=1O4PZ8wOpp6yecoy8tMuVEIFS7XgyRJy9'

data_path = '../data'
text_path = f'{data_path}/shakespeare.txt'

if not os.path.exists(data_path):
    os.makedirs(data_path)

if not os.path.exists(text_path):
  gdown.download(url, text_path, quiet=False)


In [83]:
import re

with open(text_path) as f:
  text = f.read()
  
text = re.sub(r'\d+', '', text)
text = re.sub(r' +', ' ', text)

print(f"lenth of the text {len(text)}")

lenth of the text 5046489


In [84]:
print(text[:1000])

 From fairest creatures we desire increase,
 That thereby beauty's rose might never die,
 But as the riper should by time decease,
 His tender heir might bear his memory:
 But thou contracted to thine own bright eyes,
 Feed'st thy light's flame with self-substantial fuel,
 Making a famine where abundance lies,
 Thy self thy foe, to thy sweet self too cruel:
 Thou that art now the world's fresh ornament,
 And only herald to the gaudy spring,
 Within thine own bud buriest thy content,
 And tender churl mak'st waste in niggarding:
 Pity the world, or else this glutton be,
 To eat the world's due, by the grave and thee.


 
 When forty winters shall besiege thy brow,
 And dig deep trenches in thy beauty's field,
 Thy youth's proud livery so gazed on now,
 Will be a tattered weed of small worth held:
 Then being asked, where all thy beauty lies,
 Where all the treasure of thy lusty days;
 To say within thine own deep sunken eyes,
 Were an all-eating shame, and thriftless praise.
 How much m

In [85]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"vocab size : {len(chars)}")
print("".join(chars))

vocab size : 74

 !"&'(),-.:;<>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_`abcdefghijklmnopqrstuvwxyz|}


# Prepare tokenizer

In [86]:
# import tiktoken
# tokenizer = tiktoken.get_encoding('gpt2')
# tokens = tokenizer.encode(text)
# print(f"total tokens {len(tokens)}")
# print("decode result of \"hello world.\"", tokenizer.decode([31373, 995]))

In [87]:
class SimpleTokenizer:
  def __init__(self, text):
    self.chars = sorted(list(set(text)))
    self.token2id = {c : i for i, c in enumerate(chars)}
    self.id2token = {i : c for i, c in enumerate(chars)}
    
  def encode(self, text):
    return [self.token2id[c] for c in text]
  
  def decode(self, token_ids):
    return "".join([self.id2token[token_id] for token_id in token_ids])
  

In [88]:
tokenizer = SimpleTokenizer(text)
vocab_size = len(tokenizer.chars)
print(
  tokenizer.encode("Hello"),
  tokenizer.decode([23, 50, 57, 57, 60]),
  sep='\n')

[23, 50, 57, 57, 60]
Hello


# Prepare Data for torch

In [89]:
import torch

data = torch.tensor(tokenizer.encode(text), dtype = torch.long,
                    device=device) # torch.long can be used as index directly
print(data, data.shape, data.dtype)


tensor([ 1, 21, 63,  ..., 29, 19,  0], device='cuda:0') torch.Size([5046489]) torch.int64


In [90]:
train_data_size = int(data.shape[0] * 0.9)
train_data = data[:train_data_size].detach()
val_data = data[train_data_size:].detach()

In [91]:
from torch.utils.data import Dataset, DataLoader
class SimpleDataset(Dataset):
  def __init__(self, data, block_size = 8):
    self.data = data
    self.block_size = block_size
    
  def __len__(self):
    return len(self.data) - self.block_size

  def __getitem__(self, idx):
    x = self.data[idx: idx + self.block_size]
    y = self.data[idx + self.block_size]
    return x, y

In [92]:
BATCH_SIZE = 32
train_dataset = SimpleDataset(train_data)
val_dataset = SimpleDataset(val_data)

In [93]:
print(len(train_dataset), len(val_dataset))

4541832 504641


In [94]:
data[:9]

tensor([ 1, 21, 63, 60, 58,  1, 51, 46, 54], device='cuda:0')

In [95]:
train_dataset[0]

(tensor([ 1, 21, 63, 60, 58,  1, 51, 46], device='cuda:0'),
 tensor(54, device='cuda:0'))

In [96]:
import numpy as np
class SimpleDataloader(DataLoader):
  def __init__(self, dataset, batch_size=4, shuffle=True, **kwargs):
    super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
    self.shuffle = shuffle
    
  def __iter__(self):
    dataset_size = len(self.dataset)
    indices = np.arange(dataset_size)
    if self.shuffle:
        np.random.shuffle(indices)

    for start_idx in range(0, dataset_size - self.batch_size + 1, self.batch_size):
        batch_indices = indices[start_idx:start_idx + self.batch_size]
        yield (torch.stack([self.dataset[i][0] for i in batch_indices]).to(device),
              torch.stack([self.dataset[i][1] for i in batch_indices]).to(device))

In [97]:
# torch.manual_seed(0)
train_dataloader = SimpleDataloader(train_dataset, BATCH_SIZE)
val_dataloader = SimpleDataloader(val_dataset, BATCH_SIZE)

In [98]:
for i, batch in enumerate(train_dataloader):
  print(batch)
  if i == 1:
    break

(tensor([[33, 19, 10,  1, 38, 50, 57, 57],
        [53, 50,  1, 68, 54, 57, 57, 10],
        [64,  1, 53, 50,  1, 58, 70,  1],
        [50, 59,  1, 68, 54, 65, 53,  1],
        [ 0,  1, 26, 54, 59, 64, 58, 50],
        [66, 64,  1, 65, 60,  1, 53, 54],
        [63, 54, 46, 57,  1, 60, 51,  1],
        [53, 10,  1, 38, 53, 70,  1, 54],
        [60,  1, 53, 54, 58, 64, 50, 57],
        [65,  1, 70, 60, 66, 63,  1, 22],
        [46, 52, 50, 10,  0,  1, 30, 33],
        [54, 59, 52,  0,  1, 17, 66, 65],
        [60, 51,  1, 65, 53, 50,  1, 68],
        [64, 65,  1, 58, 70,  1, 68, 54],
        [58, 12,  1, 53, 60, 57, 49,  1],
        [ 1, 28, 46, 55, 50, 64, 65, 70],
        [66,  1, 58, 50, 63, 63, 70,  2],
        [66, 59, 52,  1, 40, 60, 63, 56],
        [ 1, 67, 60, 54, 48, 50, 64, 15],
        [ 1, 65, 53, 46, 65,  1, 53, 50],
        [52,  5, 49,  8,  1, 24,  5, 57],
        [38, 53, 60, 64, 50,  1, 67, 54],
        [ 1, 52, 63, 50, 46, 65,  1, 59],
        [35, 23, 20,  1, 20, 29, 

# Transformer GPT

In [99]:
# Model Configuration
VOCAB_SIZE = 74 # Should be set according to the tokenizer
EMBED_DIM = 12

In [100]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleGPT(nn.Module):
  def __init__(self):
    super().__init__()
    # (B, T)
    self.embedding = nn.Embedding(VOCAB_SIZE, VOCAB_SIZE) #(B, T, T)
    

  def forward(self, x, targets = None):
    """
    x should be in the form of (B, T)
    """
    logits = self.embedding(x)
    y = logits[:, -1,:]
    y = F.softmax(y, 1)

    if targets is None:
      loss = None
    else:
      targets_one_hot = F.one_hot(targets, VOCAB_SIZE).float()
      loss = F.cross_entropy(y, targets_one_hot)

    return y, loss
  
  def generate(self, x, max_tokens = 30):
    """
    x (B, T)

    ### returns
    y (B, T + max_tokens)

    ### Warning
    Because we do not have an EOF, so it will generate max_tokens actually
    """
    for _ in range(max_tokens):
      # x shape (B, T)
      probs, loss = self.forward(x) # y shape (B, vocab_size)
      y = torch.multinomial(probs, num_samples=1)
      x = torch.cat([x, y], dim = 1)

    return x

In [101]:
gpt = SimpleGPT().to(device)

# Training and Test

In [102]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f17cd739bd0>

In [103]:
def mubble(max_tokens = 200):
  input = torch.zeros(1, 1, dtype=torch.long,device=device)
  output = gpt.generate(input, max_tokens)
  text = tokenizer.decode(output[0].tolist())
  return text

mubble()

'\nM\nnWR!FYcaB}?x ?P!Mb cqIl,nbI},U"QQb:x[RnKY}.bd]n,MgKy,t_Rr"\'U`|zRqZGxC"GYwM\'_HdUl[!d hEwVJ;i`VAF e;F;FTtVfgbI[n|}A?gaYnEvNm}msyku|r<Htc IyP,?(,HtVFP]vDx??\n`pSto`MKZd>Hw>D!oPIK[iUfOmS&"eTj`oYSQ.N<_RFR'

In [104]:
optimizer = torch.optim.AdamW(gpt.parameters(), lr = 1e-3)

In [105]:
def epoch(eval_n = 100):
  for i, batch in enumerate(train_dataloader):
    gpt.train()
    X, targets = batch
    logits, loss = gpt.forward(X, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 1000 == 0:
      gpt.eval()
      tot_loss = 0
      for j, batch in enumerate(val_dataloader):
        X, targets = batch
        logits, loss = gpt.forward(X, targets)
        tot_loss += loss.sum()
        if j == eval_n:
          break
      ave_loss = tot_loss / eval_n
      print(ave_loss)
    
    if i == 10000:
      break

In [106]:
epoch()

tensor(4.3475, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.3354, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.3087, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.2675, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.2319, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1988, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1806, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1626, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1483, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1417, device='cuda:0', grad_fn=<DivBackward0>)
tensor(4.1393, device='cuda:0', grad_fn=<DivBackward0>)


In [107]:
print(mubble())


 the the the the the the the che the the the hepe the the the QThe the the the the the the ous t thand the ou the the the the toul, the the the the the the the he thouris the s the the he the the the 
