# Environment

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare the data

In [2]:
import gdown
import os

# Download URL, the shakespeare.txt
url = f'https://drive.google.com/uc?id=1O4PZ8wOpp6yecoy8tMuVEIFS7XgyRJy9'

data_path = '../data'
text_path = f'{data_path}/shakespeare.txt'

if not os.path.exists(data_path):
    os.makedirs(data_path)

if not os.path.exists(text_path):
  gdown.download(url, text_path, quiet=False)


In [3]:
import re

with open(text_path) as f:
  text = f.read()
  
text = re.sub(r'\d+', '', text)
text = re.sub(r' +', ' ', text)

print(f"lenth of the text {len(text)}")

lenth of the text 5046489


In [4]:
print(text[:1000])

 From fairest creatures we desire increase,
 That thereby beauty's rose might never die,
 But as the riper should by time decease,
 His tender heir might bear his memory:
 But thou contracted to thine own bright eyes,
 Feed'st thy light's flame with self-substantial fuel,
 Making a famine where abundance lies,
 Thy self thy foe, to thy sweet self too cruel:
 Thou that art now the world's fresh ornament,
 And only herald to the gaudy spring,
 Within thine own bud buriest thy content,
 And tender churl mak'st waste in niggarding:
 Pity the world, or else this glutton be,
 To eat the world's due, by the grave and thee.


 
 When forty winters shall besiege thy brow,
 And dig deep trenches in thy beauty's field,
 Thy youth's proud livery so gazed on now,
 Will be a tattered weed of small worth held:
 Then being asked, where all thy beauty lies,
 Where all the treasure of thy lusty days;
 To say within thine own deep sunken eyes,
 Were an all-eating shame, and thriftless praise.
 How much m

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"vocab size : {len(chars)}")
print("".join(chars))

vocab size : 74

 !"&'(),-.:;<>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_`abcdefghijklmnopqrstuvwxyz|}


# Prepare tokenizer

In [6]:
# import tiktoken
# tokenizer = tiktoken.get_encoding('gpt2')
# tokens = tokenizer.encode(text)
# print(f"total tokens {len(tokens)}")
# print("decode result of \"hello world.\"", tokenizer.decode([31373, 995]))

In [7]:
class SimpleTokenizer:
  def __init__(self, text):
    self.chars = sorted(list(set(text)))
    self.token2id = {c : i for i, c in enumerate(chars)}
    self.id2token = {i : c for i, c in enumerate(chars)}
    
  def encode(self, text):
    return [self.token2id[c] for c in text]
  
  def decode(self, token_ids):
    return "".join([self.id2token[token_id] for token_id in token_ids])
  

In [8]:
tokenizer = SimpleTokenizer(text)
vocab_size = len(tokenizer.chars)
print(
  tokenizer.encode("Hello"),
  tokenizer.decode([23, 50, 57, 57, 60]),
  sep='\n')

[23, 50, 57, 57, 60]
Hello


# Prepare Data for torch

In [9]:
import torch

data = torch.tensor(tokenizer.encode(text), dtype = torch.long,
                    device=device) # torch.long can be used as index directly
print(data, data.shape, data.dtype)


tensor([ 1, 21, 63,  ..., 29, 19,  0], device='cuda:0') torch.Size([5046489]) torch.int64


In [10]:
train_data_size = int(data.shape[0] * 0.9)
train_data = data[:train_data_size].detach()
val_data = data[train_data_size:].detach()

In [11]:
from torch.utils.data import Dataset, DataLoader
class SimpleDataset(Dataset):
  def __init__(self, data, block_size = 8):
    self.data = data
    self.block_size = block_size
    
  def __len__(self):
    return len(self.data) - self.block_size

  def __getitem__(self, idx):
    x = self.data[idx: idx + self.block_size]
    y = self.data[idx + self.block_size]
    return x, y

In [12]:
BATCH_SIZE = 512
BLOCK_SIZE = 8
train_dataset = SimpleDataset(train_data, BLOCK_SIZE)
val_dataset = SimpleDataset(val_data, BLOCK_SIZE)

In [13]:
print(len(train_dataset), len(val_dataset))

4541832 504641


In [14]:
data[:9]

tensor([ 1, 21, 63, 60, 58,  1, 51, 46, 54], device='cuda:0')

In [15]:
train_dataset[0]

(tensor([ 1, 21, 63, 60, 58,  1, 51, 46], device='cuda:0'),
 tensor(54, device='cuda:0'))

In [16]:
import numpy as np
class SimpleDataloader(DataLoader):
  def __init__(self, dataset, batch_size=4, shuffle=True, **kwargs):
    super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
    self.shuffle = shuffle
    
  def __iter__(self):
    dataset_size = len(self.dataset)
    indices = np.arange(dataset_size)
    if self.shuffle:
        np.random.shuffle(indices)

    for start_idx in range(0, dataset_size - self.batch_size + 1, self.batch_size):
        batch_indices = indices[start_idx:start_idx + self.batch_size]
        yield (torch.stack([self.dataset[i][0] for i in batch_indices]).to(device),
              torch.stack([self.dataset[i][1] for i in batch_indices]).to(device))

In [17]:
# torch.manual_seed(0)
train_dataloader = SimpleDataloader(train_dataset, BATCH_SIZE)
val_dataloader = SimpleDataloader(val_dataset, BATCH_SIZE)

In [18]:
for i, batch in enumerate(train_dataloader):
  print(batch)
  if i == 1:
    break

(tensor([[49, 63, 60,  ..., 10,  1, 35],
        [65, 68, 54,  ...,  1, 58, 70],
        [31, 33, 30,  ..., 17, 24, 35],
        ...,
        [ 1, 47, 63,  ..., 52,  1, 53],
        [24,  1, 58,  ..., 50, 57, 51],
        [ 0,  1, 20,  ..., 65,  1, 38]], device='cuda:0'), tensor([53,  1, 20, 65, 66,  1, 66, 62, 53, 64, 51,  1, 59,  1,  1, 10,  1, 70,
        33, 64, 50, 10, 57, 57, 55, 27,  1,  1, 50, 61, 63,  1, 50, 68, 46, 17,
         0, 46, 46,  1, 68, 65, 66, 53, 70, 49, 20, 35, 50, 59,  0, 53, 50, 30,
        60, 63, 10, 68,  1,  8, 53, 70, 64, 58,  1,  1, 66, 38, 59, 24, 35, 70,
        65, 30, 52, 65, 60, 59, 58, 49, 65,  1, 54, 63, 27, 35, 57,  1,  1, 64,
        53, 58, 34, 34,  1, 59, 46,  5, 54,  1, 59, 50, 65, 54, 56, 50, 54,  1,
         1, 50,  5, 26,  0, 50, 57,  1, 34,  5,  1, 53, 64, 54, 64,  1, 46, 50,
         1, 35, 30, 63,  5,  1, 65, 56,  1,  1,  1, 18,  1, 38, 46, 50,  1, 15,
         1,  1,  1, 38, 58, 18, 60,  1, 49, 27,  1,  1, 46, 64, 50, 50,  1, 49,
       

# Transformer GPT

In [19]:
# Model Configuration
VOCAB_SIZE = 74 # Should be set according to the tokenizer
EMBED_DIM = 128
MAX_TOKEN = 10000

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
  def __init__(self, tok_embed, head_size, out_embed):
    super().__init__()
    self.Q = nn.Linear(tok_embed, head_size)
    self.K = nn.Linear(tok_embed, head_size)
    self.V = nn.Linear(tok_embed, out_embed)
    self.register_buffer('tril', torch.tril(torch.ones(MAX_TOKEN, MAX_TOKEN, dtype=torch.int32)))


  def forward(self, x):
    # x (B, T, tok_embed)
    B,T,C = x.shape
    Q = self.Q(x)  # Q (B, T, head_size)
    K = self.K(x)  # K (B, T, head_size)
    V = self.V(x)  # V (B, T, out_embed)
    weight = Q @ torch.transpose(K, -2, -1) * C ** -0.5 # weight (B, T, T)
    weight = torch.masked_fill(weight, self.tril[:T, :T] == 0, float('-inf'))
    weight = torch.softmax(weight, -1)
    logits = weight @ V # (B, T, T) @ (B, T, out_embed) -> (B, T, out_embed)
    return logits


class SimpleGPT(nn.Module):
  def __init__(self):
    super().__init__()
    # (B, T)
    self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) #(B, T, T)
    self.pos_embedding = nn.Embedding(MAX_TOKEN, EMBED_DIM)
    self.attn1 =SelfAttention(EMBED_DIM, 256, VOCAB_SIZE)
    

  def forward(self, x, targets = None):
    """
    x should be in the form of (B, T)
    """
    B, T = x.shape
    txt_embedding = self.embedding(x)
    pos_embedding = self.pos_embedding(torch.arange(MAX_TOKEN, device=device))[:T,:]
    
    embedding = txt_embedding + pos_embedding

    logits = self.attn1(embedding)

    y = logits[:, -1,:]
    y = F.softmax(y, 1)

    if targets is None:
      loss = None
    else:
      targets_one_hot = F.one_hot(targets, VOCAB_SIZE).float()
      loss = F.cross_entropy(y, targets_one_hot)

    return y, loss
  
  def generate(self, x, max_tokens = 30):
    """
    x (B, T)

    ### returns
    y (B, T + max_tokens)

    ### Warning
    Because we do not have an EOF, so it will generate max_tokens actually
    """
    for _ in range(max_tokens):
      # x shape (B, T)
      probs, loss = self.forward(x) # y shape (B, vocab_size)
      y = torch.multinomial(probs, num_samples=1)
      x = torch.cat([x, y], dim = 1)

    return x

In [21]:
gpt = SimpleGPT().to(device)

# Training and Test

In [22]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f8f1872dbd0>

In [23]:
def mubble(max_tokens = 200):
  input = torch.zeros(1, 1, dtype=torch.long,device=device)
  output = gpt.generate(input, max_tokens)
  text = tokenizer.decode(output[0].tolist())
  return text

mubble()

'\nMAXw\naFJga]}"| ?P!MbfWDIlQnwI},UAZQ<:a[tnKQ}.bC)g,Kg:A,`BTQ"PpGDmRnZAxb"kWwMy_HdUv[!dbwGnu}"g`VAF e!-TbTtbfAmI[nw}x?WaYnEvTm||sykB:rJktcIISP;P(CXtuFP&dox??n`pSXW`MKwd>&wIS!cPQA[ Uf\nkb[CeT\n)\nYSQ.N,YKFR'

In [24]:
optimizer = torch.optim.AdamW(gpt.parameters(), lr = 1e-3)

In [25]:
@torch.no_grad()
def eval_loss(dataloader, eval_n = 100):
  gpt.eval()
  tot_loss = 0
  for j, batch in enumerate(dataloader):
    X, targets = batch
    logits, loss = gpt.forward(X, targets)
    tot_loss += loss.sum()
    if j == eval_n:
      break
  ave_loss = tot_loss / eval_n
  return ave_loss



def epoch(train_N = 10000, loss_eval_N = 1000):
  for i, batch in enumerate(train_dataloader):
    gpt.train()
    X, targets = batch
    logits, loss = gpt.forward(X, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % loss_eval_N == 0:
      print(f"val_loss: {eval_loss(val_dataloader)},train_loss: {eval_loss(train_dataloader)}")
    
    if i == train_N:
      break

In [26]:
epoch()

val_loss: 4.345279216766357,train_loss: 4.345304012298584
val_loss: 4.120543479919434,train_loss: 4.112907409667969
val_loss: 4.117610931396484,train_loss: 4.112486362457275
val_loss: 4.111502170562744,train_loss: 4.106204509735107
val_loss: 4.110670566558838,train_loss: 4.10813045501709
val_loss: 4.111058235168457,train_loss: 4.10964298248291
val_loss: 4.110049247741699,train_loss: 4.105050563812256
val_loss: 4.109231948852539,train_loss: 4.104127883911133
val_loss: 4.109132766723633,train_loss: 4.104745388031006


In [39]:
print(mubble())


 thth thththththhththththCehthththththPhththththththth.ehhhthFh&hdhyehthDehthhth
 thth,hthththththS`hththththththO
 thDehthwhMhththththththlhthththththzehFhehththbhthth[hthth.<hthfhththththpe ehd'ehth
