In [1]:
# # We always start with a dataset to train on. Let's download the tiny shakespeare dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

## Imports

In [2]:
import math
import itertools
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
sns.set_theme(style="dark")
%matplotlib inline
''' %matplotlib inline sets the backend of matplotlib to
the 'inline' backend. When using the 'inline' backend,
your matplotlib graphs will be included in your notebook,
next to the code.'''

# # for creating a responsive plot
# %matplotlib ipympl
# %matplotlib widget

import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = 'cuda' # 'cuda' , 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')

BATCH_SIZE = 64 # how many independent sequence we process in parallel
CONTEXT_L = 256 # the max context length for input and output

## Loading Data & Tokenizer

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
print(len(text))
print(text[:100])

1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [5]:
# Get the unique chars of the text by set()
chars = sorted(list(set(text)))
NCLASS = len(chars)
print(''.join(chars))
NCLASS


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [6]:
# encoder tokenizer
ch_to_i = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [ ch_to_i[ch] for ch in s ]

# decoder tokenizer
i_to_ch = { i:ch for ch,i in ch_to_i.items()}
decoder = lambda si: ''.join([i_to_ch[i] for i in si])

print(encode('hi there!'))
print( decoder(encode('hi there!')) )

[46, 47, 1, 58, 46, 43, 56, 43, 2]
hi there!


1. Google use Sentence Piece for tokenization.

SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end system that does not depend on language-specific pre/postprocessing.

https://github.com/google/sentencepiece

2. tiktoken is a fast BPE tokeniser for use with OpenAI's models.

https://github.com/openai/tiktoken

In [7]:
enc_data = torch.tensor(encode(text), dtype=torch.long, device=DEVICE)
print(enc_data.shape, enc_data.dtype)
print(enc_data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


## Heatmap and Embedding Visualization

In [8]:
import itertools
def plot_heatmap(tensor, text=True, nrow=None, ncol=None, fig_size=(10,10)):
    if (nrow is None) or (ncol is None):
        nrow = tensor.shape[0]
        ncol = tensor.shape[1]
    plt.figure(figsize=fig_size)
    plt.imshow(tensor.detach().numpy(), cmap= 'Blues')
    # manually write text on each cell (seaborn annot doesn't look good)
    if text:
        for i, j in itertools.product(range(nrow), range(ncol)):
            # x:col, y:rows, the origin is top left corner, makes bottom <->top
            plt.text(x=j, y=i, s=f'{tensor[i,j].item():.2f}', ha='center', va='center', color='grey')
    plt.axis('off')

### 2D & 3d Embedding Visualization

In [9]:
def plot_2d_emb(emb_lkt, nclass, figsize=(8,8)):
    plt.figure(figsize=figsize)
    plt.scatter(x= emb_lkt[:,0].data, y=emb_lkt[:,1].data, s=200)
    for i in range(nclass):
        plt.text(x=emb_lkt[i,0].item(), y=emb_lkt[i,1].item(), s=i_to_ch[i], ha='center', va='center', color='white')
    plt.grid('minor')

# def plot_3d_emb(emb_lkt, nclass, figsize=(8,8)):
#     tensor = emb_lkt.data.detach().numpy()
#     fig = plt.figure(figsize=figsize)
#     ax = Axes3D(fig)
#     ax.scatter(xs= tensor[:,0], ys=tensor[:,1], zs=tensor[:,2], s=200)
#     for i in range(nclass):
#         ax.text(x=tensor[i,0], y=tensor[i,1],z=tensor[i,2], s=i_to_s[i], ha='center', va='center', color='white')
#     # displaying the plot
#     plt.grid('minor')
#     plt.show()

## Splitting dataset, prepare Context window

1. split rate 90%, 10%

2. Dev or Validation set is for hyper parameter tuning

In [10]:
n90 = int( 0.9*len(enc_data) )
train_data = enc_data[:n90]
val_data = enc_data[n90:]

TXT_LENGTH = len(train_data)

we never feed all the text into the Transformer all at once, that would be computationally very expensive, and prohibitive.

We actually only work with chunks of text sampled from the dataset. we call it context and we have a context length.



In [11]:
window = 8
train_data[:window+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [12]:
# we want the model to predict what char comes after
# any number of chars from 1 to CONTEXT_L as input

x = train_data[:window]
y = train_data[1:window+1]
for t in range(window):
    context = x[:t+1] # +1 is bc t starts from 0 and we would get empty window without it
    target = y[t]
    print(f'{context.tolist()} -> {target}')

[18] -> 47
[18, 47] -> 56
[18, 47, 56] -> 57
[18, 47, 56, 57] -> 58
[18, 47, 56, 57, 58] -> 1
[18, 47, 56, 57, 58, 1] -> 15
[18, 47, 56, 57, 58, 1, 15] -> 47
[18, 47, 56, 57, 58, 1, 15, 47] -> 58


## creating both Batch dimension and time (or context) dimensions

we have batch & time (in context window) dimensions

In [13]:
from helper_reproduciblility import set_all_seeds, set_deterministic

set_all_seeds(seed=1337)
set_deterministic()

In [14]:
def get_batch(stage:str, batch_size:int, context_length:int, verbose=False):
    data = train_data if stage=='train' else val_data
    # a BATCH_SIZE number of int for context_window_starts
    # this random init for the context window throughout the dataset is good
    # as the tone and style of the text might change from the start to the end of text
    cw_starts = torch.randint(low=0, high=len(data)-context_length , size=(batch_size,) )
    x = torch.stack([ data[ cw_start : cw_start+context_length ] for cw_start in cw_starts])
    # shift the window by for y
    y = torch.stack([ data[ cw_start+1 : cw_start+context_length+1 ] for cw_start in cw_starts])
    if verbose:
        print(f'{x=}')
        print(f'{y=}')
        
        # in each independent context window, a character is not allowed to look at the characters after itself (causal relationship)
        # to visualize: (we will use a trick called masking to do that for us in self attention)
        for b, t in itertools.product(range(batch_size),range(context_length)):
            context = x[b, :t+1]
            target = y[b, t]
            print(f'if context={context.tolist()} -> target={target}')
            '''
            we should see "independent batch dimension" rows
            and (Context window length) columns

            then we take each row and create multiple sequence
            with max size of CONTEXT_L
            '''
    x, y = x.to(device=DEVICE), y.to(device=DEVICE)
    return x, y

# create the batch of independent context windows
xb, yb = get_batch('train', batch_size=2, context_length=4, verbose=True)

x=tensor([[39, 56, 41, 47],
        [43,  1, 57, 53]])
y=tensor([[56, 41, 47, 59],
        [ 1, 57, 53, 50]])
if context=[39] -> target=56
if context=[39, 56] -> target=41
if context=[39, 56, 41] -> target=47
if context=[39, 56, 41, 47] -> target=59
if context=[43] -> target=1
if context=[43, 1] -> target=57
if context=[43, 1, 57] -> target=53
if context=[43, 1, 57, 53] -> target=50


In [15]:
xb.shape

torch.Size([2, 4])

## Attention

### testing

In [16]:
with torch.no_grad():
    token_embedding_lkt = nn.Embedding(num_embeddings=NCLASS, embedding_dim=6, device=DEVICE)
    positional_embedding = nn.Embedding(num_embeddings=4, embedding_dim=6, device=DEVICE)

In [17]:
with torch.no_grad():
    token_emb = token_embedding_lkt(xb) # (batch_size, context_length, emb_dim)
token_emb, token_emb.shape

(tensor([[[-0.6825,  0.0427, -0.5355, -1.1698,  0.6652,  0.4814],
          [ 0.6836, -0.2636,  1.2229,  0.8503,  0.0778, -0.9458],
          [-0.5517, -2.4753,  2.1252, -1.8597, -1.5283, -0.9075],
          [ 0.8876,  0.0059,  0.2721,  0.3031,  0.0387,  0.1451]],
 
         [[ 0.1119, -0.7661, -0.6197, -0.6849, -0.7264, -1.7105],
          [ 0.9018, -1.3492, -1.4882, -0.2612,  0.2189, -2.1828],
          [ 0.1194,  0.5861,  2.1792, -2.5281, -0.5484,  0.6380],
          [-1.3081, -1.2529, -0.2132, -1.0740, -0.3529, -0.2993]]]),
 torch.Size([2, 4, 6]))

In [18]:
with torch.no_grad():
    pose_emb = positional_embedding(torch.arange(end=4, device=DEVICE)) # (batch_size, context_length, emb_dim)
pose_emb, pose_emb.shape

(tensor([[ 0.3165,  2.5683, -0.4802, -0.1947, -0.1885,  0.4579],
         [-0.2376, -1.0092,  2.3632, -0.5188, -0.5385,  0.7833],
         [ 1.2441, -0.0370, -0.1270,  0.1630,  1.4123,  0.4552],
         [ 2.0497, -0.5445,  1.3161,  1.8521,  1.0068, -0.5362]]),
 torch.Size([4, 6]))

In [19]:
X = token_emb + pose_emb # (batch_size, context_length, emb_dim)
X, X.shape

(tensor([[[-0.3660,  2.6110, -1.0157, -1.3645,  0.4767,  0.9393],
          [ 0.4459, -1.2728,  3.5861,  0.3314, -0.4607, -0.1625],
          [ 0.6924, -2.5123,  1.9982, -1.6967, -0.1160, -0.4523],
          [ 2.9372, -0.5386,  1.5882,  2.1551,  1.0456, -0.3911]],
 
         [[ 0.4284,  1.8022, -1.0999, -0.8796, -0.9149, -1.2526],
          [ 0.6642, -2.3584,  0.8751, -0.7800, -0.3196, -1.3995],
          [ 1.3635,  0.5491,  2.0522, -2.3651,  0.8639,  1.0932],
          [ 0.7416, -1.7974,  1.1029,  0.7781,  0.6539, -0.8355]]]),
 torch.Size([2, 4, 6]))

In [20]:
with torch.no_grad():
    Wq = torch.randn(6, 8 )
    Wk = torch.randn(6, 8 )
    Wv = torch.randn(6, 8 )

In [21]:
queries = X @ Wq # (batch_size, context_length, qk_dim)
keys = X @ Wk # (batch_size, context_length, qk_dim)
values = X @ Wv # (batch_size, context_length, v_dim)
queries.shape

torch.Size([2, 4, 8])

In [22]:
scaled_pairwise_sim = queries @ keys.transpose(dim0=1, dim1=2)  * (8**-0.5) # (batch_size, context_length, context_length)
scaled_pairwise_sim, scaled_pairwise_sim.shape

(tensor([[[ 26.2234, -23.4735, -11.3390, -15.3823],
          [-12.3156,   1.5337,   3.4252,   0.8565],
          [-11.2435,   6.0103,   7.8472,   3.3077],
          [-24.6950,  19.0554,  13.7950,  19.4077]],
 
         [[ -2.2381,  22.7478,  14.5210,  18.2890],
          [-16.7659,  17.0041,  -2.8188,  19.9568],
          [ -1.3789, -11.9098,   8.3356, -15.5845],
          [ -7.8504,   7.6658, -10.8747,  12.2220]]]),
 torch.Size([2, 4, 4]))

In [23]:
queries.var(), keys.var(), scaled_pairwise_sim.var()

(tensor(14.4232), tensor(20.3468), tensor(204.5637))

In [24]:
b = torch.tril(input=scaled_pairwise_sim, diagonal=0)
b

tensor([[[ 26.2234,   0.0000,   0.0000,   0.0000],
         [-12.3156,   1.5337,   0.0000,   0.0000],
         [-11.2435,   6.0103,   7.8472,   0.0000],
         [-24.6950,  19.0554,  13.7950,  19.4077]],

        [[ -2.2381,   0.0000,   0.0000,   0.0000],
         [-16.7659,  17.0041,   0.0000,   0.0000],
         [ -1.3789, -11.9098,   8.3356,   0.0000],
         [ -7.8504,   7.6658, -10.8747,  12.2220]]])

In [25]:
b = b.masked_fill(mask= b==0 , value= float('-inf'))
b

tensor([[[ 26.2234,     -inf,     -inf,     -inf],
         [-12.3156,   1.5337,     -inf,     -inf],
         [-11.2435,   6.0103,   7.8472,     -inf],
         [-24.6950,  19.0554,  13.7950,  19.4077]],

        [[ -2.2381,     -inf,     -inf,     -inf],
         [-16.7659,  17.0041,     -inf,     -inf],
         [ -1.3789, -11.9098,   8.3356,     -inf],
         [ -7.8504,   7.6658, -10.8747,  12.2220]]])

In [26]:
normalized_pairwise_sim = F.softmax(input=b, dim=2) # (batch_size, context_length, context_length)
normalized_pairwise_sim, normalized_pairwise_sim.shape

(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [9.6683e-07, 1.0000e+00, 0.0000e+00, 0.0000e+00],
          [4.4136e-09, 1.3742e-01, 8.6258e-01, 0.0000e+00],
          [4.1141e-20, 4.1194e-01, 2.1393e-03, 5.8593e-01]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [2.1572e-15, 1.0000e+00, 0.0000e+00, 0.0000e+00],
          [6.0400e-05, 1.6126e-09, 9.9994e-01, 0.0000e+00],
          [1.8972e-09, 1.0392e-02, 9.2192e-11, 9.8961e-01]]]),
 torch.Size([2, 4, 4]))

In [27]:
attention_embedding = normalized_pairwise_sim @ values # (batch_size, context_length, v_dim)
attention_embedding, attention_embedding.shape

(tensor([[[ 0.7733,  0.9400, -4.3096, -2.0681, -3.1198,  0.4575, -4.4294,
            4.8361],
          [-2.9267,  6.8308,  7.2012,  3.6070,  0.0332,  0.5137,  8.0824,
           -3.3107],
          [-1.7559,  6.5711,  5.9867,  3.6404, -3.2945, -2.0385,  4.1468,
           -3.7486],
          [-3.1643,  2.2514,  6.3566,  0.6931, -0.7339, -2.4181,  9.4609,
           -4.4570]],
 
         [[-0.5072, -3.3519, -2.0111, -4.3186, -3.0338, -0.0310, -1.7883,
            1.1453],
          [-0.9018,  0.6899,  4.0113,  1.8252, -1.9704, -2.3205,  2.1763,
           -4.1382],
          [-2.4032,  9.7410,  3.7773,  1.2813, -7.7708, -3.0452,  3.7339,
            0.1291],
          [-0.7633, -1.2092,  3.3778,  2.0911,  0.3892, -2.2809,  3.5021,
           -3.5093]]]),
 torch.Size([2, 4, 8]))

In [28]:
post_layer_norm = (attention_embedding - attention_embedding.mean(dim=2, keepdim=True) ) / attention_embedding.std(dim=2, keepdim=True)
post_layer_norm, post_layer_norm.shape

(tensor([[[ 0.5134,  0.5656, -1.0793, -0.3770, -0.7065,  0.4144, -1.1169,
            1.7864],
          [-1.1871,  0.9459,  1.0268,  0.2411, -0.5401, -0.4350,  1.2195,
           -1.2710],
          [-0.6824,  1.2475,  1.1121,  0.5683, -1.0390, -0.7479,  0.6857,
           -1.1443],
          [-0.8607,  0.2590,  1.1078, -0.0632, -0.3582, -0.7064,  1.7496,
           -1.1280]],
 
         [[ 0.6675, -0.8765, -0.1487, -1.4011, -0.7038,  0.9260, -0.0278,
            1.5644],
          [-0.3027,  0.2825,  1.5036,  0.6999, -0.6955, -0.8242,  0.8290,
           -1.4925],
          [-0.5820,  1.7102,  0.5845,  0.1134, -1.5952, -0.7032,  0.5763,
           -0.1041],
          [-0.3699, -0.5412,  1.2208,  0.7265,  0.0728, -0.9528,  1.2685,
           -1.4247]]]),
 torch.Size([2, 4, 8]))

### real class

In [29]:
class SelfAttention(nn.Module):
    def __init__(self, emb_dim:int, qk_dim:int, v_dim:int, masked:bool, drop_out:float) -> None:
        super().__init__()
        self.emb_dim = emb_dim
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.masked = masked
        
        self.register_parameter(name='Wq', param=nn.parameter.Parameter(data=torch.randn(emb_dim, qk_dim, requires_grad=True, device=DEVICE)) )
        self.register_parameter(name='Wk', param=nn.parameter.Parameter(data=torch.randn(emb_dim, qk_dim, requires_grad=True, device=DEVICE)) )
        self.register_parameter(name='Wv', param=nn.parameter.Parameter(data=torch.randn(emb_dim, v_dim, requires_grad=True, device=DEVICE)) )
        self.register_buffer(name='tril', tensor=torch.tril(torch.ones(CONTEXT_L, CONTEXT_L)))
        self.drop_out = nn.Dropout(p=drop_out)
        
    
    def forward(self, X):
        # X shape: (batch_size, context_length, emb_dim)
        # self attention : X1 = X2
        queries = X @ self.Wq # (batch_size, context_length, qk_dim)
        keys = X @ self.Wk # (batch_size, context_length, qk_dim)
        values = X @ self.Wv # (batch_size, context_length, v_dim)
        scaled_pairwise_sim = queries @ keys.transpose(dim0=1, dim1=2) * (self.qk_dim**-0.5) # (batch_size, context_length, context_length)
        if self.masked:
            # scaled_pairwise_sim = torch.tril(input=scaled_pairwise_sim, diagonal=0)
            scaled_pairwise_sim = scaled_pairwise_sim.masked_fill(mask= (self.tril==0) , value= float('-inf'))
        normalized_pairwise_sim = self.drop_out(F.softmax(input=scaled_pairwise_sim, dim=2)) # (batch_size, context_length, context_length)  dim = -1: last dimension
        
        attention_embedding = normalized_pairwise_sim @ values # (batch_size, context_length, v_dim)
        return attention_embedding

## Cross Attention

In [30]:
class CrossAttention(nn.Module):
    def __init__(self, emb_dim, qk_dim, v_dim) -> None:
        super().__init__()
        self.emb_dim = emb_dim
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.register_parameter(name='Wq', param=nn.parameter.Parameter(data=torch.randn(emb_dim, qk_dim, requires_grad=True, device=DEVICE)) )
        self.register_parameter(name='Wk', param=nn.parameter.Parameter(data=torch.randn(emb_dim, qk_dim, requires_grad=True, device=DEVICE)) )
        self.register_parameter(name='Wv', param=nn.parameter.Parameter(data=torch.randn(emb_dim, v_dim, requires_grad=True, device=DEVICE)) )
        
    
    def forward(self, X1, X2):
        # X shape: (batch_size, context_length, emb_dim)
        # self attention : X1 = X2
        queries = X1 @ self.Wq # (batch_size, context_length, qk_dim)
        keys = X2 @ self.Wk # (batch_size, context_length, qk_dim)
        values = X2 @ self.Wv # (batch_size, context_length, v_dim)
        pairwise_sim = queries @ keys.transpose(dim0=1, dim1=2) # (batch_size, context_length, context_length)
        normalized_pairwise_sim = F.softmax(input=pairwise_sim, dim=1) * (self.qk_dim**-0.5)
        attention_embedding = normalized_pairwise_sim @ values # (batch_size, context_length, v_dim)
        return attention_embedding

## Multi Head Self Attention

In [31]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_heads:int, emb_dim:int, qk_dim:int, v_dim:int, masked:bool, drop_out:float) -> None:
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(emb_dim=emb_dim, qk_dim=qk_dim, v_dim=v_dim, masked=masked, drop_out=drop_out) for _ in range(num_heads)])
        self.proj = nn.Linear(in_features=num_heads*v_dim , out_features=emb_dim) # project back to the input dimension (residual connection)
        self.deop_out = nn.Dropout(p=drop_out)

    def forward(self, X):
        out = torch.cat([head(X) for head in self.heads], dim=2) # (batch_size, context_length, num_heads*v_dim)
        return self.deop_out( self.proj(out) ) # (batch_size, context_length, emb_dim)
    

## Feed Forward

In [32]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim:int, drop_out:float) -> None:
        super().__init__()
        self.ff = nn.Sequential(nn.Linear(in_features=emb_dim, out_features=4*emb_dim),
                                nn.ReLU(),
                                nn.Linear(in_features=4*emb_dim, out_features=emb_dim), # projection back to the residual,
                                nn.Dropout(p=drop_out)
                                )

    def forward(self, x):
        return self.ff(x)

## Motif Decoder Block

In [33]:
class MotifDecoderBlock(nn.Module):
    def __init__(self, num_heads:int, emb_dim:int, qk_dim:int, v_dim:int, drop_out:float) -> None:
        super().__init__()
        # intersperse communication among nodes with multiple blocks stacked on top of each other 
        self.masked_multi_head_attention = MultiHeadSelfAttention(num_heads=num_heads, emb_dim=emb_dim, qk_dim=qk_dim, v_dim=v_dim, masked=True, drop_out=drop_out)
        self.layer_norm1 = nn.LayerNorm(normalized_shape=emb_dim)
        self.feed_forward = FeedForward(emb_dim=emb_dim, drop_out=drop_out)
        self.layer_norm2 = nn.LayerNorm(normalized_shape=emb_dim)

    def forward(self, x): # pre-norm Formulation
        att = x + self.masked_multi_head_attention( self.layer_norm1(x) )
        return att + self.feed_forward(att)

## Transformer Decoder

In [34]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_blocks:int, num_heads:int, emb_dim:int, qk_dim:int, v_dim:int, drop_out:float) -> None:
        super().__init__()
        self.token_embedding_lkt = nn.Embedding(num_embeddings=NCLASS, embedding_dim=emb_dim, device=DEVICE)
        self.positional_embedding = nn.Embedding(num_embeddings=CONTEXT_L, embedding_dim=emb_dim, device=DEVICE)
        self.motif_blocks = nn.Sequential(*[MotifDecoderBlock(num_heads=num_heads, emb_dim=emb_dim, qk_dim=qk_dim, v_dim=v_dim, drop_out=drop_out) for _ in range(num_blocks)])
        self.last_layer_norm = nn.LayerNorm(normalized_shape=emb_dim)
        self.last_linear = nn.Linear(in_features=emb_dim, out_features=NCLASS)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, SelfAttention):
            torch.nn.init.normal_(module.Wq, mean=0.0, std=0.02)
            torch.nn.init.normal_(module.Wk, mean=0.0, std=0.02)
            torch.nn.init.normal_(module.Wv, mean=0.0, std=0.02)
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        if x.shape[-1] < CONTEXT_L: # Generating from 0 context length
            # padding 0, (CONTEXT_L - prompt.shape[0]) quantity to the top/left and 0 quantity to the bottom/right the last dimension
            x = F.pad(input=x, pad=(CONTEXT_L-x.shape[-1], 0), mode="constant", value=0) # (1 batch, context_length number of scalars/indices)
        token_emb = self.token_embedding_lkt(x) # (batch_size, context_length, emb_dim)
        pose_emb = self.positional_embedding(torch.arange(end=CONTEXT_L, device=DEVICE)) # (batch_size, context_length, emb_dim)
        X = token_emb + pose_emb # (batch_size, context_length, emb_dim)
        
        X = self.motif_blocks(X) # (batch_size, context_length, emb_dim)
        self.last_layer_norm(X) # (batch_size, context_length, emb_dim)
        logits = self.last_linear(X) # (batch_size, context_length, n_class)
        return logits
        
    @torch.no_grad()
    def generate(self, prompt, max_new_tokens):
        for _ in range(max_new_tokens):
            # prompt : (unknown number of scalars/indices)
            prompt = prompt.view(1, -1) # (1 batch, unknown number of scalars/indices)
            feed = prompt[:, -CONTEXT_L:] # (1 batch, context_length number of scalars/indices)
            logits = self(feed) # (1, context_length, n_class)
            logits = logits[:, -1: :].view(1, NCLASS) # (1, n_class) , Focus only on the last position
            probs = F.softmax(input=logits, dim=-1) # (1, n_class) -1: last dimension
            idx_next = torch.multinomial(input=probs, num_samples=1)
            prompt = torch.cat(tensors=(prompt, idx_next), dim=-1) # (1 batch, 1+unknown number of scalars/indices)
        return prompt

In [35]:
EMB_DIM = 384
NUM_HEADS = 6
QK_DIM = EMB_DIM // NUM_HEADS # 64
V_DIM = QK_DIM
NUM_BLOCKS = 6
DROP_OUT = 0.2

In [36]:
model = TransformerDecoder(num_blocks= NUM_BLOCKS, num_heads=NUM_HEADS, emb_dim=EMB_DIM,
                           qk_dim=QK_DIM, v_dim=V_DIM, drop_out=DROP_OUT)
model = model.to(device=DEVICE)
print(model)
# print the number of parameters in the model
NUM_PARAM = sum(p.numel() for p in model.parameters())
print(NUM_PARAM/1e6, 'M parameters')

TransformerDecoder(
  (token_embedding_lkt): Embedding(65, 384)
  (positional_embedding): Embedding(256, 384)
  (motif_blocks): Sequential(
    (0): MotifDecoderBlock(
      (masked_multi_head_attention): MultiHeadSelfAttention(
        (heads): ModuleList(
          (0): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
          (1): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
          (2): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
          (3): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
          (4): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
          (5): SelfAttention(
            (drop_out): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (deop_out): Dropout(p=0.2, inplace=False)
      )
      (layer_norm1): LayerNor

In [37]:
# model = torch.load('model.pth')

### test generate

In [38]:
# xb, yb = get_batch(stage='train', batch_size=BATCH_SIZE, context_length=CONTEXT_L)

In [39]:
# logits = model(xb)
# loss = F.cross_entropy(input=logits.view(-1, NCLASS), target=yb.view(-1))
# print(logits.shape, loss)

In [40]:
# prompt=torch.ones((1,1), dtype=torch.long, device=DEVICE)
# max_new_tokens = CONTEXT_L//2
# prompt, prompt.shape

In [41]:
# # padding 0, (CONTEXT_L - prompt.shape[1]) quantity to the top/left and 0 quantity to the bottom/right the last dimension
# prompt = F.pad(input=prompt, pad=(CONTEXT_L - prompt.shape[1], 0), mode="constant", value=0)
# prompt, prompt.shape, prompt.dtype

In [42]:
# with torch.no_grad():
#     for _ in range(max_new_tokens):
#         # prompt : (unknown number of scalars/indices)
#         prompt = prompt.view(1, -1) # (1 batch, unknown number of scalars/indices)
#         if prompt.shape[-1] < CONTEXT_L: # Generating from 0 context length
#             # padding 0, (CONTEXT_L - prompt.shape[0]) quantity to the top/left and 0 quantity to the bottom/right the last dimension
#             feed = F.pad(input=prompt, pad=(CONTEXT_L-prompt.shape[1], 0), mode="constant", value=0) # (1 batch, context_length number of scalars/indices)
#         else:
#             feed = prompt[:, -CONTEXT_L:] # (1 batch, context_length number of scalars/indices)
#         logits = model(feed) # (1, context_length, n_class)
#         logits = logits[:, -1: :].view(1, NCLASS) # (1, n_class) , Focus only on the last position
#         probs = F.softmax(input=logits, dim=-1) # (1, n_class) -1: last dimension
#         idx_next = torch.multinomial(input=probs, num_samples=1)
#         prompt = torch.cat(tensors=(prompt, idx_next), dim=-1) # (1 batch, 1+unknown number of scalars/indices)
#         # print(prompt)

In [43]:
# # prompt = model.generate(prompt=torch.zeros(CONTEXT_L, dtype=torch.long, device=DEVICE), max_new_tokens=CONTEXT_L//2)
# decoder(prompt[0].tolist())

## W&B

In [44]:
import wandb

LR = 3e-4
MAX_ITER = 5000
EVAL_INTERVAL = 500
EVAL_ITER = 200

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="TinyGPT",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "architecture": "GPT",
    "num parameters": NUM_PARAM,
    "dataset": "Shakespeare",
    "Optimizer": "AdamW",
    "emb dim": EMB_DIM,
    "qk dim": QK_DIM,
    "v dim": V_DIM,
    "num heads": NUM_HEADS,
    "num motif blocks": NUM_BLOCKS,
    "dropout": DROP_OUT
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mopen_ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Optimizer Object

In [45]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [46]:
from helper_eval import helper_eval_gpt_ce_losses

## Train

In [47]:
iter = 0
MAX_ITER = 5000

In [48]:
batch_args = {'batch_size':BATCH_SIZE, 'context_length':CONTEXT_L}
for _ in range(MAX_ITER):
    model.train()
    xb, yb = get_batch(stage='train', **batch_args)
    
    logits = model(xb)
    mini_batch_loss = F.cross_entropy(input=logits.view(-1, NCLASS), target=yb.view(-1))

    optimizer.zero_grad(set_to_none=True)
    mini_batch_loss.backward()
    optimizer.step()
    
    model.eval()
    with torch.no_grad():
        if (iter % EVAL_INTERVAL == 0) or (iter==MAX_ITER-1):
            losses = helper_eval_gpt_ce_losses(model=model, get_batch=get_batch, n_class=NCLASS, eval_iter=EVAL_ITER, batch_args=batch_args)
            print(f'iteration: {iter:7d} / {TXT_LENGTH//BATCH_SIZE:7d} | mini batch loss: {mini_batch_loss.item():.4f} | train loss: {losses["train"]:.4f} | eval loss: {losses["eval"]:.4f}')
            wandb.log({"mini batch loss": mini_batch_loss, "train loss": losses["train"], "eval loss": losses["eval"]})
            prompt = model.generate(prompt=torch.zeros((1,1), dtype=torch.long, device=DEVICE), max_new_tokens=CONTEXT_L)
            print(decoder(prompt[0].tolist()))
    iter += 1

iteration:       0 /   15685 | mini batch loss: 4.1735 | train loss: 3.7019 | eval loss: 3.7235

k$?Ry
GLaVdNpCjyVYCBerpXx!ioIGgXGhFW3DWiiByzFeRS;hgrPXWHrmgln
O! vs OM&YQCoTGk3s?.lHjvckWm,VwCbpeZYn ;G?zqMsoRKgf& hxReFVeQ?hsU& 
uo  ncifqBK3msNXsHanw
WNDL DpxB i$tszlhnh$YpHwsSq-jrgG ZxEmmSvex : RwqCe&.ajolOfyHDfUptnN3mKeor'hDxxTE,v UWa wYLRptKegkzrFhci 


wandb: Network error (ConnectionError), entering retry loop.


iteration:     500 /   15685 | mini batch loss: 2.1095 | train loss: 2.0122 | eval loss: 2.0883

MFWPPCISETHMAMIA:LUT:PNUSE:WESMPN,CIUBBNIUS:
SIUC:
P'F hAGLP:'TUCET:CE:
I yIUnR:
Sem'lwsomONBNEREs:
FhTow.
JUVST:
SefNUS ESARMIZII:
Win?
DUmBUMERDIULI:
t, mudTy bent tru hy shBue,
Matelld y 'Tondro peart, sbokef Bunges?
I gat thevat! mean curnt, thais; mat
iteration:    1000 /   15685 | mini batch loss: 1.7039 | train loss: 1.6033 | eval loss: 1.7853

RAFIAMERD:IDMUMO:
Sof ild you kase
Te need mave in men o onie, sin,
And not that not this opraor o't.

BAH:
How you shalt it, but the love.

AGivole of the kight affir, I would nafful.
ARGARE:
Sir, it, York nearst:
Nay, but grady eyes, my may good the sigh
iteration:    1500 /   15685 | mini batch loss: 1.5235 | train loss: 1.4364 | eval loss: 1.6405

FAFITPALIT Clocktysater,
To my lord; I provoster say: let he sied,
Eveny fair friends, in the Lord, lord.

DUKE OF YORK:
No, not, whereinots I
By beriphed, and have say my liegers up and
Darest t

In [49]:
torch.save(model, 'model.pth')

In [51]:
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()