In [2]:
import torch 
from torch import nn
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F

# Section 1 : vocabulary and (train/val split)

In [3]:
# Imagine this is out entire dataset. 
raw_text = 'The cat is hungry I went to the store But it was closed So I went to a different store'

![](section_1.png)

split the dataset into train and validation sets

In [4]:
ratio = 0.66 # arbitrary ratio
split_index = int(len(raw_text) * ratio)
train_raw_text = raw_text[:split_index]
val_raw_text = raw_text[split_index:]


print(f"Train raw text: {train_raw_text}")

print("="*100)

print(f"Val raw text: {val_raw_text}")


Train raw text: The cat is hungry I went to the store But it was closed 
Val raw text: So I went to a different store


Corpus vocabulary


We will make an assumption that each word is a unique token.<br>
This is a simplification and not true in the real world.<br>
In practice, we would use a more sophisticated tokenization method.


In [5]:
vocab = list(sorted(set(raw_text.split(' '))))

print("unique words in the corpus:")
print("-"*50)
for i in vocab:
    print(i)

print("="*50)
vocab_size = len(vocab)
print(f'vocab_size: {vocab_size}') # You can think of as number of unique words in the corpus. 

unique words in the corpus:
--------------------------------------------------
But
I
So
The
a
cat
closed
different
hungry
is
it
store
the
to
was
went
vocab_size: 16


Note: (The) is different from (the)

In [6]:
tokens_to_ids = {token: id for id, token in enumerate(vocab)}
print("Mapping of tokens to ids:")
tokens_to_ids

Mapping of tokens to ids:


{'But': 0,
 'I': 1,
 'So': 2,
 'The': 3,
 'a': 4,
 'cat': 5,
 'closed': 6,
 'different': 7,
 'hungry': 8,
 'is': 9,
 'it': 10,
 'store': 11,
 'the': 12,
 'to': 13,
 'was': 14,
 'went': 15}

In [7]:
ids_to_tokens = {id: token for id, token in enumerate(vocab)}
print("Mapping of ids to tokens:")
ids_to_tokens

Mapping of ids to tokens:


{0: 'But',
 1: 'I',
 2: 'So',
 3: 'The',
 4: 'a',
 5: 'cat',
 6: 'closed',
 7: 'different',
 8: 'hungry',
 9: 'is',
 10: 'it',
 11: 'store',
 12: 'the',
 13: 'to',
 14: 'was',
 15: 'went'}

In [8]:
def encode(text):
    return [tokens_to_ids[token] for token in text.strip().split(' ')]

def decode(ids):
    return ' '.join([ids_to_tokens[id] for id in ids])

In [9]:
encode(raw_text)

[3, 5, 9, 8, 1, 15, 13, 12, 11, 0, 10, 14, 6, 2, 1, 15, 13, 4, 7, 11]

In [10]:
encode("The cat is hungry")

[3, 5, 9, 8]

In [11]:
decode([3,5,9,8])

'The cat is hungry'

# Section 2 : Creating  dataset and dataloader 

### Dataset & DataLoader Config

In [12]:
# X aka input 
max_len = 4  # length of the green bracket
stride = 3  # jump of the green bracket

# y aka output 
# The red bracket length and jump are the same as the green bracket,
# but it is shifted by one token to the right.

Note : max_len and stride are hyper-parameters and can be tuned.

![](section_2.png)

### Dataset 

In [13]:
class Data(Dataset):
    def __init__(self,raw_text,max_len=max_len,stride=stride):
        self.token_ids = encode(raw_text)
        self.X = []
        self.y = []
        for i in range(0,len(self.token_ids)-max_len,stride):
            input = self.token_ids[i:i+max_len]
            output = self.token_ids[i+1:i+max_len+1]
            self.X.append(torch.tensor(input))
            self.y.append(torch.tensor(output))
            
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx],self.y[idx]
    

In [14]:
train_ds = Data(train_raw_text)
train_ds[0]

(tensor([3, 5, 9, 8]), tensor([5, 9, 8, 1]))

In [15]:
val_ds = Data(val_raw_text)
val_ds[0]

(tensor([ 2,  1, 15, 13]), tensor([ 1, 15, 13,  4]))

### DataLoader

In [16]:
train_dl = DataLoader(train_ds,batch_size=1,shuffle=False,drop_last=False,num_workers=0)
val_dl   = DataLoader(val_ds,batch_size=1,shuffle=False,drop_last=True,num_workers=0)

In [17]:
# for i,(x,y) in enumerate(train_dl):
#   print(f'Batch Number: {i+1}')
#   print(f'x :{x}')
#   print(f'y :{y}')
#   print('--'*20)

In [18]:
# for i,(x,y) in enumerate(val_dl):
#   print(f'Batch Number: {i+1}')
#   print(f'x :{x}')
#   print(f'y :{y}')
#   print('--'*20)

# Section 3 : Token Embedding and Positional Encoding 

Taking single batch, and we know that the batch size is 1. <br>
so we are taking a single example from the dataset. 

In [19]:
for x,y in train_dl:
  print(x)
  print(y)
  break

tensor([[3, 5, 9, 8]])
tensor([[5, 9, 8, 1]])


In [20]:
B = 1  # batch size 
d_in = 4  # embedding dimension  [input dimension]

In [21]:
torch.manual_seed(1)

# Embedding 
token_emb = nn.Embedding(vocab_size,d_in)
token_embedding = token_emb(x)
print('Token Embeddings:')
print(token_embedding)


Token Embeddings:
tensor([[[-0.2223,  1.6871,  0.2284,  0.4676],
         [ 0.8657,  0.2444, -0.6629,  0.8073],
         [ 0.1991,  0.0457,  0.1530, -0.4757],
         [ 1.8793, -0.0721,  0.1578, -0.7735]]], grad_fn=<EmbeddingBackward0>)


In [22]:
context_window = 4 # [max length of the input sequence the model can handle]
num_tokens = 4 # this can not be greater than context window


torch.manual_seed(1)

# Positional embedding
pos_emb = nn.Embedding(context_window,d_in)
positional_embedding = pos_emb(torch.arange(num_tokens))
print('Positional Embeddings:')
print('--'*20)
print(positional_embedding)
print('--'*20)
print(positional_embedding.shape)

Positional Embeddings:
----------------------------------------
tensor([[-1.5256, -0.7502, -0.6540, -1.6095],
        [-0.1002, -0.6092, -0.9798, -1.6091],
        [-0.7121,  0.3037, -0.7773, -0.2515],
        [-0.2223,  1.6871,  0.2284,  0.4676]], grad_fn=<EmbeddingBackward0>)
----------------------------------------
torch.Size([4, 4])


In [23]:
# Token embedding + token embedding 
tok_pos_emb = token_embedding + positional_embedding
print('Token Embedding + Positional Embedding')
print('--'*20)
print(tok_pos_emb)
print('--'*20)
print(tok_pos_emb.shape)

Token Embedding + Positional Embedding
----------------------------------------
tensor([[[-1.7479,  0.9369, -0.4256, -1.1418],
         [ 0.7655, -0.3648, -1.6427, -0.8018],
         [-0.5131,  0.3494, -0.6244, -0.7271],
         [ 1.6571,  1.6150,  0.3862, -0.3058]]], grad_fn=<AddBackward0>)
----------------------------------------
torch.Size([1, 4, 4])


# Section 4 : Pre Transformer Block Dropout 

In [24]:
dropout = 0.25 


torch.manual_seed(2)

pre_trans_dp = nn.Dropout(dropout)
pre_transformer_dp_result = pre_trans_dp(tok_pos_emb)
pre_transformer_dp_result

tensor([[[-0.0000,  1.2492, -0.5674, -1.5225],
         [ 0.0000, -0.4864, -2.1902, -1.0691],
         [-0.6841,  0.4659, -0.8325, -0.9695],
         [ 2.2094,  0.0000,  0.5149, -0.4078]]], grad_fn=<MulBackward0>)

# Section5 : Transformer Block 

### Layer Normalization 

In [25]:
layernorm1 = nn.LayerNorm(d_in)
print(layernorm1.weight)
print(layernorm1.bias)
layernorm1 = layernorm1(pre_transformer_dp_result)
print(layernorm1)
print('--'*20)
print(layernorm1.shape)


Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
tensor([[[ 0.2096,  1.4551, -0.3562, -1.3084],
         [ 1.1463,  0.5509, -1.5349, -0.1624],
         [-0.3144,  1.7046, -0.5748, -0.8154],
         [ 1.6361, -0.5812, -0.0645, -0.9905]]],
       grad_fn=<NativeLayerNormBackward0>)
----------------------------------------
torch.Size([1, 4, 4])


### Multi-Head Attention 

weights initialization 

In [26]:
d_out = 4

W_q = nn.Parameter(torch.tensor([
    [-0.5, 0.2, 0.7, -0.9],
    [0.1, -0.3, 0.8, 0.4],
    [-0.7, 0.6, -0.2, 0.9],
    [0.3, -0.8, 0.5, -0.1]
]))


W_k = nn.Parameter(torch.tensor([
    [0.3, -0.5, 0.2, 0.7],
    [-0.4, 0.1, -0.6, -0.2],
    [0.8, -0.3, 0.5, -0.7],
    [-0.1, 0.6, -0.9, 0.4]
]))

W_v = nn.Parameter(torch.tensor([
    [0.2, -0.8, 0.3, 0.5],
    [-0.7, 0.4, -0.1, -0.6],
    [0.9, -0.2, 0.7, -0.3],
    [-0.5, 0.1, -0.4, 0.8]
]))



print('W_q')
print(W_q.data.shape)
print('---'*20)

print('W_k')
print(W_k.data.shape)
print('---'*20)

print('W_v')
print(W_v.data.shape)
print('---'*20)

W_q
torch.Size([4, 4])
------------------------------------------------------------
W_k
torch.Size([4, 4])
------------------------------------------------------------
W_v
torch.Size([4, 4])
------------------------------------------------------------


Q,K,V

In [27]:
Q = layernorm1 @ W_q
K = layernorm1 @ W_k
V = layernorm1 @ W_v

print('Q')
print(Q.data)
print(Q.data.shape)
print('---'*20)

print('K')
print(K.data)
print(K.data.shape)
print('---'*20)

print('V')
print(V.data)
print(V.data.shape)
print('---'*20)


Q
tensor([[[-0.1025,  0.4384,  0.7278,  0.2037],
         [ 0.5076, -0.7271,  1.4690, -2.1765],
         [ 0.4854, -0.2668,  0.8509,  0.5290],
         [-1.1282,  1.2553,  0.1980, -1.6640]]])
torch.Size([1, 4, 4])
------------------------------------------------------------
K
tensor([[[-0.6733, -0.6375,  0.1684, -0.4184],
         [-1.0882, -0.1550, -0.7226,  1.7017],
         [-1.1545,  0.0108, -0.6392, -0.4848],
         [ 0.7708, -1.4511,  1.5352,  0.9105]]])
torch.Size([1, 4, 4])
------------------------------------------------------------
V
tensor([[[-0.6430,  0.3548,  0.1914, -1.7081],
         [-1.4566, -0.4060, -0.7207,  0.5732],
         [-1.3657,  0.9668, -0.3410, -1.6598],
         [ 1.1713, -1.6276,  0.9000,  0.3938]]])
torch.Size([1, 4, 4])
------------------------------------------------------------


### splitting Q,K,V into multiple heads 

In [28]:
num_heads = 2               # this is a hyper-parameter and can be tuned. 
head_dim = d_out//num_heads
print(head_dim)

2


In [29]:
Q_multi_head  = Q.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
K_multi_head  = K.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
V_multi_head  = V.view(B,num_tokens,num_heads,head_dim).transpose(1,2)


print('Q_multi_head')
print(Q_multi_head.data)
print(Q_multi_head.data.shape)

print('---'*20)

print('K_multi_head')
print(K_multi_head.data)
print(K_multi_head.data.shape)


print('---'*20)

print('V_multi_head')
print(V_multi_head.data)
print(V_multi_head.data.shape)

Q_multi_head
tensor([[[[-0.1025,  0.4384],
          [ 0.5076, -0.7271],
          [ 0.4854, -0.2668],
          [-1.1282,  1.2553]],

         [[ 0.7278,  0.2037],
          [ 1.4690, -2.1765],
          [ 0.8509,  0.5290],
          [ 0.1980, -1.6640]]]])
torch.Size([1, 2, 4, 2])
------------------------------------------------------------
K_multi_head
tensor([[[[-0.6733, -0.6375],
          [-1.0882, -0.1550],
          [-1.1545,  0.0108],
          [ 0.7708, -1.4511]],

         [[ 0.1684, -0.4184],
          [-0.7226,  1.7017],
          [-0.6392, -0.4848],
          [ 1.5352,  0.9105]]]])
torch.Size([1, 2, 4, 2])
------------------------------------------------------------
V_multi_head
tensor([[[[-0.6430,  0.3548],
          [-1.4566, -0.4060],
          [-1.3657,  0.9668],
          [ 1.1713, -1.6276]],

         [[ 0.1914, -1.7081],
          [-0.7207,  0.5732],
          [-0.3410, -1.6598],
          [ 0.9000,  0.3938]]]])
torch.Size([1, 2, 4, 2])


### Attention Score 

In [30]:
K_multi_head_transpose = K_multi_head.transpose(2,3)

print('K_multi_head_transpose')
print(K_multi_head_transpose.data)
print(K_multi_head_transpose.data.shape)

K_multi_head_transpose
tensor([[[[-0.6733, -1.0882, -1.1545,  0.7708],
          [-0.6375, -0.1550,  0.0108, -1.4511]],

         [[ 0.1684, -0.7226, -0.6392,  1.5352],
          [-0.4184,  1.7017, -0.4848,  0.9105]]]])
torch.Size([1, 2, 2, 4])


In [31]:
attn_score = Q_multi_head @ K_multi_head_transpose
print(attn_score)
print(attn_score.shape)

tensor([[[[-0.2105,  0.0435,  0.1231, -0.7152],
          [ 0.1217, -0.4397, -0.5940,  1.4464],
          [-0.1567, -0.4868, -0.5633,  0.7614],
          [-0.0406,  1.0331,  1.3161, -2.6912]],

         [[ 0.0373, -0.1792, -0.5639,  1.3027],
          [ 1.1579, -4.7654,  0.1161,  0.2734],
          [-0.0780,  0.2853, -0.8003,  1.7879],
          [ 0.7295, -2.9747,  0.6801, -1.2111]]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 4, 4])


### mask

In [32]:
mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1)
print(mask)
print(mask.bool()) 

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [33]:
attn_score = attn_score.masked_fill(mask.bool()[:num_tokens,:num_tokens],-torch.inf)
print(attn_score)

tensor([[[[-0.2105,    -inf,    -inf,    -inf],
          [ 0.1217, -0.4397,    -inf,    -inf],
          [-0.1567, -0.4868, -0.5633,    -inf],
          [-0.0406,  1.0331,  1.3161, -2.6912]],

         [[ 0.0373,    -inf,    -inf,    -inf],
          [ 1.1579, -4.7654,    -inf,    -inf],
          [-0.0780,  0.2853, -0.8003,    -inf],
          [ 0.7295, -2.9747,  0.6801, -1.2111]]]],
       grad_fn=<MaskedFillBackward0>)


In [34]:
attn_weight = torch.softmax(attn_score/K_multi_head.shape[-1]**0.5,dim=-1)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5980, 0.4020, 0.0000, 0.0000],
          [0.3934, 0.3115, 0.2951, 0.0000],
          [0.1695, 0.3621, 0.4424, 0.0260]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.9851, 0.0149, 0.0000, 0.0000],
          [0.3457, 0.4469, 0.2074, 0.0000],
          [0.4363, 0.0318, 0.4213, 0.1106]]]], grad_fn=<SoftmaxBackward0>)


In [35]:
torch.manual_seed(3)

attn_dropout = nn.Dropout(0)

attn_weight = attn_dropout(attn_weight)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5980, 0.4020, 0.0000, 0.0000],
          [0.3934, 0.3115, 0.2951, 0.0000],
          [0.1695, 0.3621, 0.4424, 0.0260]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.9851, 0.0149, 0.0000, 0.0000],
          [0.3457, 0.4469, 0.2074, 0.0000],
          [0.4363, 0.0318, 0.4213, 0.1106]]]], grad_fn=<SoftmaxBackward0>)


In [36]:
con_vector = attn_weight @ V_multi_head
print(con_vector)
print(con_vector.shape)

tensor([[[[-0.6430,  0.3548],
          [-0.9701,  0.0489],
          [-1.1097,  0.2984],
          [-1.2102,  0.2985]],

         [[ 0.1914, -1.7081],
          [ 0.1778, -1.6741],
          [-0.3267, -0.6785],
          [ 0.0165, -1.3828]]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 4, 2])


In [37]:
conv_vector = con_vector.transpose(1,2).contiguous().view(B,num_tokens,d_out)
conv_vector

tensor([[[-0.6430,  0.3548,  0.1914, -1.7081],
         [-0.9701,  0.0489,  0.1778, -1.6741],
         [-1.1097,  0.2984, -0.3267, -0.6785],
         [-1.2102,  0.2985,  0.0165, -1.3828]]], grad_fn=<ViewBackward0>)

In [38]:
# projection 

out_proj = nn.Parameter(torch.tensor([
    [0.5, -0.3, 0.4, 0.2],
    [-0.6, 0.4, -0.2, -0.5], 
    [0.3, -0.7, 0.6, -0.4],
    [-0.2, 0.5, -0.4, 0.3]
]))


In [39]:
out_proj_result = conv_vector @ out_proj.data
print(out_proj_result)
print(out_proj_result.shape)


tensor([[[-0.1353, -0.6533,  0.4700, -0.8950],
         [-0.1263, -0.6509,  0.3785, -0.7918],
         [-0.6962,  0.3417, -0.4281, -0.4441],
         [-0.5027, -0.2205,  0.0192, -0.8127]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])


# Dropout1 

In [40]:
torch.manual_seed(3)
dropout1 = nn.Dropout(0)
after_dropout_1 = dropout1(out_proj_result)
print(after_dropout_1)
print(after_dropout_1.shape)

tensor([[[-0.1353, -0.6533,  0.4700, -0.8950],
         [-0.1263, -0.6509,  0.3785, -0.7918],
         [-0.6962,  0.3417, -0.4281, -0.4441],
         [-0.5027, -0.2205,  0.0192, -0.8127]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])


# Skip Connection 

In [41]:
skip_connection = after_dropout_1 + pre_transformer_dp_result
print(skip_connection)
print(skip_connection.shape)

tensor([[[-0.1353,  0.5959, -0.0975, -2.4175],
         [-0.1263, -1.1373, -1.8118, -1.8609],
         [-1.3803,  0.8076, -1.2606, -1.4136],
         [ 1.7068, -0.2205,  0.5342, -1.2204]]], grad_fn=<AddBackward0>)
torch.Size([1, 4, 4])


# LayerNorm 2

In [42]:
layernorm2 = nn.LayerNorm(after_dropout_1.shape[-1])
print(layernorm2.weight)
print(layernorm2.bias)
layernorm2_result = layernorm2(after_dropout_1)

print(layernorm2_result)
print(layernorm2_result.shape)

Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
tensor([[[ 0.3207, -0.6675,  1.4756, -1.1288],
         [ 0.3705, -0.7638,  1.4619, -1.0686],
         [-1.0009,  1.6661, -0.3121, -0.3530],
         [-0.3970,  0.5100,  1.2806, -1.3936]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 4, 4])


# mlp

In [43]:
mlp = nn.Sequential(
    nn.Linear(d_out,2*d_out),
    nn.GELU(),
    nn.Linear(2*d_out,d_out),
)

mlp[0].weight = nn.Parameter(torch.tensor([
    [0.5, -0.3, 0.4, 0.2, 0.1, -0.2, 0.3, -0.1],
    [-0.6, 0.4, -0.2, -0.5, 0.2, 0.3, -0.4, 0.5],
    [0.3, -0.7, 0.6, -0.4, -0.3, 0.4, 0.2, -0.6],
    [-0.2, 0.5, -0.4, 0.3, 0.4, -0.5, 0.1, 0.2]
]).T)
mlp[0].bias = nn.Parameter(torch.tensor([0.1, -0.2, 0.3, -0.1, 0.2, -0.3, 0.4, -0.2]))


mlp[2].weight = nn.Parameter(torch.tensor([
    [0.5, -0.6, 0.3, -0.2],
    [-0.3, 0.4, -0.7, 0.5], 
    [0.4, -0.2, 0.6, -0.4],
    [0.2, -0.5, -0.4, 0.3],
    [0.1, 0.2, -0.3, 0.4],
    [-0.2, 0.3, 0.4, -0.5],
    [0.3, -0.4, 0.2, 0.1],
    [-0.1, 0.5, -0.6, 0.2]
]).T)

mlp[2].bias = nn.Parameter(torch.tensor([0.1, -0.2, 0.3, -0.1]))


mlp_result = mlp(layernorm2_result)
print(mlp_result)
print(mlp_result.shape)


tensor([[[ 1.5582, -1.4814,  2.2834, -1.3640],
         [ 1.6311, -1.5735,  2.2905, -1.3394],
         [-0.3915,  0.7756, -0.6085,  0.3255],
         [ 0.5180, -0.3863,  1.7840, -1.3045]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 4])


# Dropout 2 

In [44]:
torch.manual_seed(2)

dropout2 = nn.Dropout(0)
dropout2_result = dropout2(mlp_result)
dropout2_result

tensor([[[ 1.5582, -1.4814,  2.2834, -1.3640],
         [ 1.6311, -1.5735,  2.2905, -1.3394],
         [-0.3915,  0.7756, -0.6085,  0.3255],
         [ 0.5180, -0.3863,  1.7840, -1.3045]]], grad_fn=<ViewBackward0>)

End of Transformer block 

##  skip connection 2

In [45]:
skip_connection_2 = skip_connection + dropout2_result
print(skip_connection_2)


tensor([[[ 1.4229, -0.8854,  2.1860, -3.7815],
         [ 1.5049, -2.7107,  0.4787, -3.2003],
         [-1.7718,  1.5832, -1.8691, -1.0881],
         [ 2.2248, -0.6068,  2.3182, -2.5249]]], grad_fn=<AddBackward0>)


# Post Transformer Block LayerNorm

In [46]:
post_transformer_LN = nn.LayerNorm(d_out)
print(post_transformer_LN.weight)
print(post_transformer_LN.bias)
post_transformer_LN_result = post_transformer_LN(skip_connection_2)
print(post_transformer_LN_result)

Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
tensor([[[ 0.7260, -0.2672,  1.0543, -1.5132],
         [ 1.2346, -0.8583,  0.7251, -1.1014],
         [-0.7034,  1.6916, -0.7729, -0.2153],
         [ 0.9198, -0.4715,  0.9657, -1.4139]]],
       grad_fn=<NativeLayerNormBackward0>)


# Head_out 

In [47]:
head_out = nn.Parameter(torch.tensor([
    [0.5, -0.3, 0.4, 0.2, 0.1, -0.2, 0.3, -0.1, 0.4, -0.5, 0.2, 0.3, -0.2, 0.1, -0.3, 0.4],
    [-0.6, 0.4, -0.2, -0.5, 0.2, 0.3, -0.4, 0.5, -0.3, 0.2, -0.5, 0.4, 0.3, -0.2, 0.5, -0.4],
    [0.3, -0.7, 0.6, -0.4, -0.3, 0.4, 0.2, -0.6, 0.5, -0.2, 0.4, -0.3, 0.2, -0.5, 0.3, -0.1],
    [-0.2, 0.5, -0.4, 0.3, 0.4, -0.5, 0.1, 0.2, -0.4, 0.3, -0.2, 0.5, -0.3, 0.4, -0.5, 0.2]
]))

print(head_out.shape)


torch.Size([4, 16])


In [48]:
head_out_logits = post_transformer_LN_result @ head_out
head_out_logits.shape

torch.Size([1, 4, 16])

In [49]:
head_out_logits 

tensor([[[ 1.1423, -1.8193,  1.5817, -0.5969, -0.9024,  0.9530,  0.3842,
          -1.1414,  1.5030, -1.0813,  1.0032, -0.9620,  0.4395, -1.0064,
           0.7215, -0.0108],
         [ 1.5701, -1.7720,  1.5412,  0.0556, -0.7063,  0.3363,  0.7486,
          -1.2080,  1.5545, -1.2644,  1.1864, -0.7412, -0.0290, -0.5080,
          -0.0313,  0.5444],
         [-1.5555,  1.3210, -0.9973, -0.7419,  0.4137,  0.4467, -1.0638,
           1.3368, -1.0892,  0.7800, -1.2526,  0.5898,  0.5582, -0.1084,
           0.9326, -0.9238],
         [ 1.3153, -1.8475,  1.6072, -0.3908, -0.8576,  0.7678,  0.5163,
          -1.1899,  1.5578, -1.1715,  1.0887, -0.9093,  0.2919, -0.8621,
           0.4850,  0.1771]]], grad_fn=<UnsafeViewBackward0>)

# Softmax 

In [50]:
# we won't be using the following head_out_prob to calculate the loss, since we can use cross entropy loss 
# calculating softmax is for demonstration purpose only  

head_out_prob = F.softmax(head_out_logits,dim=-1)
print(head_out_prob)
print(head_out_prob.shape)

tensor([[[0.1187, 0.0061, 0.1842, 0.0208, 0.0154, 0.0982, 0.0556, 0.0121,
          0.1702, 0.0128, 0.1033, 0.0145, 0.0588, 0.0138, 0.0779, 0.0375],
         [0.1714, 0.0061, 0.1665, 0.0377, 0.0176, 0.0499, 0.0754, 0.0107,
          0.1688, 0.0101, 0.1168, 0.0170, 0.0346, 0.0215, 0.0346, 0.0615],
         [0.0095, 0.1686, 0.0166, 0.0214, 0.0681, 0.0703, 0.0155, 0.1713,
          0.0151, 0.0982, 0.0129, 0.0812, 0.0786, 0.0404, 0.1144, 0.0179],
         [0.1374, 0.0058, 0.1840, 0.0249, 0.0156, 0.0795, 0.0618, 0.0112,
          0.1751, 0.0114, 0.1095, 0.0149, 0.0494, 0.0156, 0.0599, 0.0440]]],
       grad_fn=<SoftmaxBackward0>)
torch.Size([1, 4, 16])


In [51]:
head_out_prob.shape

torch.Size([1, 4, 16])

In [52]:
head_out_prob.argmax(dim=1)

tensor([[1, 2, 0, 1, 2, 0, 1, 2, 3, 2, 1, 2, 2, 2, 2, 1]])

In [53]:
head_out_prob.argmax(dim=1).shape

torch.Size([1, 16])

# Loss 

In [54]:
loss_fn = nn.CrossEntropyLoss()
loss_fn(head_out_logits.squeeze(0),y.squeeze(0))

tensor(4.0641, grad_fn=<NllLossBackward0>)

# The End of the forward pass 

In [55]:
head_out_prob.max(dim=1)

torch.return_types.max(
values=tensor([[0.1714, 0.1686, 0.1842, 0.0377, 0.0681, 0.0982, 0.0754, 0.1713, 0.1751,
         0.0982, 0.1168, 0.0812, 0.0786, 0.0404, 0.1144, 0.0615]],
       grad_fn=<MaxBackward0>),
indices=tensor([[1, 2, 0, 1, 2, 0, 1, 2, 3, 2, 1, 2, 2, 2, 2, 1]]))

In [56]:
y

tensor([[5, 9, 8, 1]])