In [37]:
import torch 
from torch import nn
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F

# Section 1 

In [38]:
# Imagine this is out entire dataset. 
raw_text = 'The cat is hungry I went to the store But it was closed So I went to a different store'

# In this section, we will split the dataset into train and validation sets. 
# we will tokenize the dataset and create a vocabulary. 

![](section_1.png)

split the dataset into train and validation sets

In [39]:
ratio = 0.66
split_index = int(len(raw_text) * ratio)
train_raw_text = raw_text[:split_index]
val_raw_text = raw_text[split_index:]


print(f"Train raw text: {train_raw_text}")

print("="*100)

print(f"Val raw text: {val_raw_text}")


Train raw text: The cat is hungry I went to the store But it was closed 
Val raw text: So I went to a different store


Corpus vocabulary


We will make an assumption that each word is a unique token.<br>
This is a simplification and not true in the real world.<br>
In practice, we would use a more sophisticated tokenization method.


In [40]:
vocab = list(sorted(set(raw_text.split(' '))))

print("unique words in the corpus:")
print("-"*50)
for i in vocab:
    print(i)

print("="*50)
vocab_size = len(vocab)
print(f'vocab_size: {vocab_size}') # You can think of as number of unique words in the corpus. 

unique words in the corpus:
--------------------------------------------------
But
I
So
The
a
cat
closed
different
hungry
is
it
store
the
to
was
went
vocab_size: 16


Note: (The) is different from (the)

In [41]:
tokens_to_ids = {token: id for id, token in enumerate(vocab)}
print("Mapping of tokens to ids:")
tokens_to_ids

Mapping of tokens to ids:


{'But': 0,
 'I': 1,
 'So': 2,
 'The': 3,
 'a': 4,
 'cat': 5,
 'closed': 6,
 'different': 7,
 'hungry': 8,
 'is': 9,
 'it': 10,
 'store': 11,
 'the': 12,
 'to': 13,
 'was': 14,
 'went': 15}

In [42]:
ids_to_tokens = {id: token for id, token in enumerate(vocab)}
print("Mapping of ids to tokens:")
ids_to_tokens

Mapping of ids to tokens:


{0: 'But',
 1: 'I',
 2: 'So',
 3: 'The',
 4: 'a',
 5: 'cat',
 6: 'closed',
 7: 'different',
 8: 'hungry',
 9: 'is',
 10: 'it',
 11: 'store',
 12: 'the',
 13: 'to',
 14: 'was',
 15: 'went'}

In [43]:
def encode(text):
    return [tokens_to_ids[token] for token in text.strip().split(' ')]

def decode(ids):
    return ' '.join([ids_to_tokens[id] for id in ids])

In [44]:
encode(raw_text)

[3, 5, 9, 8, 1, 15, 13, 12, 11, 0, 10, 14, 6, 2, 1, 15, 13, 4, 7, 11]

In [45]:
encode("The cat is hungry")

[3, 5, 9, 8]

In [46]:
decode([3,5,9,8])

'The cat is hungry'

# Section 2 : Creating  dataset and dataloader 

### Dataset & DataLoader Config

In [47]:
# X aka input 
max_len = 4  # length of the green bracket
stride = 3  # jump of the green bracket

# y aka output 
# The red bracket length and jump are the same as the green bracket,
# but it is shifted by one token to the right.

Note : max_len and stride are hyper-parameters and can be tuned.

![](section_2.png)

### Dataset 

In [48]:
class Data(Dataset):
    def __init__(self,raw_text,max_len=max_len,stride=stride):
        self.token_ids = encode(raw_text)
        self.X = []
        self.y = []
        for i in range(0,len(self.token_ids)-max_len,stride):
            input = self.token_ids[i:i+max_len]
            output = self.token_ids[i+1:i+max_len+1]
            self.X.append(torch.tensor(input))
            self.y.append(torch.tensor(output))
            
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx],self.y[idx]
    

In [49]:
train_ds = Data(train_raw_text)
train_ds[0]

(tensor([3, 5, 9, 8]), tensor([5, 9, 8, 1]))

In [50]:
val_ds = Data(val_raw_text)
val_ds[0]

(tensor([ 2,  1, 15, 13]), tensor([ 1, 15, 13,  4]))

### DataLoader

In [51]:
train_dl = DataLoader(train_ds,batch_size=1,shuffle=False,drop_last=False,num_workers=0)
val_dl   = DataLoader(val_ds,batch_size=1,shuffle=False,drop_last=True,num_workers=0)

In [52]:
for i,(x,y) in enumerate(train_dl):
  print(f'Batch Number: {i+1}')
  print(f'x :{x}')
  print(f'y :{y}')
  print('--'*20)

Batch Number: 1
x :tensor([[3, 5, 9, 8]])
y :tensor([[5, 9, 8, 1]])
----------------------------------------
Batch Number: 2
x :tensor([[ 8,  1, 15, 13]])
y :tensor([[ 1, 15, 13, 12]])
----------------------------------------
Batch Number: 3
x :tensor([[13, 12, 11,  0]])
y :tensor([[12, 11,  0, 10]])
----------------------------------------


In [53]:
for i,(x,y) in enumerate(val_dl):
  print(f'Batch Number: {i+1}')
  print(f'x :{x}')
  print(f'y :{y}')
  print('--'*20)

Batch Number: 1
x :tensor([[ 2,  1, 15, 13]])
y :tensor([[ 1, 15, 13,  4]])
----------------------------------------


# Section 3 : Token Embedding and Positional Encoding 

Taking single batch, and we know that the batch size is 1. <br>
so we are taking a single example from the dataset. 

In [54]:
for x,y in train_dl:
  print(x)
  print(y)
  break

tensor([[3, 5, 9, 8]])
tensor([[5, 9, 8, 1]])


In [55]:
# B = 1  # batch size 

d_in = 3  # embedding dimension  [input dimension]
# d_out = 4 # output dimension 


# num_heads = 2 
# head_dim = d_out//num_heads
# print(head_dim)

# Token Embedding  and Positional Encoding 

In [56]:
torch.manual_seed(1)

# Embedding 
token_emb = nn.Embedding(vocab_size,d_in)
token_embedding = token_emb(x)
print('Token Embeddings:')
print(token_embedding)
print('--'*20)
print(token_embedding.shape)


Token Embeddings:
tensor([[[ 0.3037, -0.7773, -0.2515],
         [ 0.4676, -0.6970, -1.1608],
         [-1.4465,  0.0612, -0.6177],
         [ 1.1017, -0.1759, -2.2456]]], grad_fn=<EmbeddingBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


In [57]:
context_window = 4 # [max length of the input sequence the model can handle]
num_tokens = 4 # this can not be greater than context window


torch.manual_seed(1)

# Positional embedding
pos_emb = nn.Embedding(context_window,d_in)
positional_embedding = pos_emb(torch.arange(num_tokens))
print('Positional Embeddings:')
print('--'*20)
print(positional_embedding)
print('--'*20)
print(positional_embedding.shape)

Positional Embeddings:
----------------------------------------
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661],
        [-1.5228,  0.3817, -1.0276],
        [-0.5631, -0.8923, -0.0583]], grad_fn=<EmbeddingBackward0>)
----------------------------------------
torch.Size([4, 3])


In [58]:
# Token embedding + token embedding 
tok_pos_emb = token_embedding + positional_embedding
print('Token Embedding + Positional Embedding')
print('--'*20)
print(tok_pos_emb)
print('--'*20)
print(tok_pos_emb.shape)

Token Embedding + Positional Embedding
----------------------------------------
tensor([[[ 0.9651, -0.5104, -0.1898],
         [ 1.0890, -1.1489, -1.3269],
         [-2.9692,  0.4428, -1.6454],
         [ 0.5386, -1.0682, -2.3038]]], grad_fn=<AddBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


# Pre Transformer Block Dropout 

In [59]:
dropout = 0.0 # notice that we are not using it, but we are still initializing it(to see how it works)


torch.manual_seed(1)

pre_transformer_dp = nn.Dropout(dropout)
pre_transformer_dp_result = pre_transformer_dp(tok_pos_emb)
pre_transformer_dp_result

tensor([[[ 0.9651, -0.5104, -0.1898],
         [ 1.0890, -1.1489, -1.3269],
         [-2.9692,  0.4428, -1.6454],
         [ 0.5386, -1.0682, -2.3038]]], grad_fn=<AddBackward0>)

# Transformer Block 

### Layer Normalization 

In [64]:
layernorm1 = nn.LayerNorm(d_in)
layernorm1 = layernorm1(pre_transformer_dp_result)
print(layernorm1)
print('--'*20)
print(layernorm1.shape)

tensor([[[ 1.3837, -0.9448, -0.4389],
         [ 1.4111, -0.6246, -0.7865],
         [-1.1239,  1.3053, -0.1814],
         [ 1.2744, -0.1063, -1.1681]]], grad_fn=<NativeLayerNormBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


### Multi-Head Attention 

weights initialization 

In [61]:
d_out = 4

torch.manual_seed(1)

W_q = nn.Linear(d_in,d_out,bias=False)
W_k = nn.Linear(d_in,d_out,bias=False)
W_v = nn.Linear(d_in,d_out,bias=False)

# REMINDER : THE WEIGHT MATRICES ARE TRANSPOSED 
print('W_q')
print(W_q.weight.T)
print(W_q.weight.T.shape)
print('---'*20)

print('W_k')
print(W_k.weight.T)
print(W_k.weight.T.shape)
print('---'*20)

print('W_v')
print(W_v.weight.T)
print(W_v.weight.T.shape)

W_q
tensor([[ 0.2975,  0.2710, -0.1188, -0.0707],
        [-0.2548, -0.5435,  0.2937,  0.1601],
        [-0.1119,  0.3462,  0.0803,  0.0285]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])
------------------------------------------------------------
W_k
tensor([[ 0.2109, -0.0520,  0.5047, -0.3487],
        [-0.2250,  0.0837,  0.1797, -0.0968],
        [-0.0421, -0.0023, -0.2150, -0.2490]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])
------------------------------------------------------------
W_v
tensor([[-0.1850,  0.3138,  0.1613, -0.5260],
        [ 0.0276, -0.5644,  0.5476, -0.5489],
        [ 0.3442,  0.3579,  0.3811, -0.2785]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])


Q,K,V

In [67]:
torch.manual_seed(1)

Q = W_q(layernorm1)
K = W_k(layernorm1)
V = W_v(layernorm1)


print(f'Q\n{Q}')
print(Q.shape)

print('---'*20)

print(f'K\n{K}')
print(K.shape)

print('---'*20)

print(f'V\n{V}')
print(V.shape)

Q
tensor([[[ 0.7015,  0.7366, -0.4771, -0.2616],
         [ 0.6670,  0.4496, -0.4142, -0.2222],
         [-0.6467, -1.0769,  0.5023,  0.2833],
         [ 0.5370, -0.0013, -0.2764, -0.1404]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])
------------------------------------------------------------
K
tensor([[[ 0.5228, -0.1500,  0.6230, -0.2818],
         [ 0.4712, -0.1238,  0.7691, -0.2357],
         [-0.5231,  0.1681, -0.2937,  0.3108],
         [ 0.3418, -0.0724,  0.8753, -0.1432]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])
------------------------------------------------------------
V
tensor([[[-0.4332,  0.8104, -0.4615, -0.0870],
         [-0.5490,  0.5138, -0.4142, -0.1804],
         [ 0.1816, -1.1543,  0.4645, -0.0748],
         [-0.6408,  0.0419, -0.2978, -0.2868]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])


### splitting Q,K,V into multiple heads 

In [63]:
Q_split  = Q.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
K_split  = K.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
V_split  = V.view(B,num_tokens,num_heads,head_dim).transpose(1,2)


print(Q_split)
print(Q_split.shape)

print('---'*20)

print(K_split)
print(K_split.shape)


print('---'*20)

print(V_split)
print(V_split.shape)

NameError: name 'B' is not defined

### Attention Score 

In [31]:
attn_score = Q_split @ K_split.transpose(2,3)
attn_score

tensor([[[[ 0.2901, -0.1406,  0.6896,  0.2434],
          [-0.1674,  0.0808, -0.3370, -0.1166],
          [ 0.6731, -0.3283,  2.0428,  0.7380],
          [ 0.2176, -0.1065,  0.7358,  0.2681]],

         [[ 0.1199, -0.1070,  0.2645,  0.0575],
          [-0.1039,  0.0344, -0.3360, -0.1337],
          [ 0.1498, -0.4454, -0.2415, -0.3773],
          [ 0.0144, -0.1777, -0.2707, -0.2306]]]],
       grad_fn=<UnsafeViewBackward0>)

### mask

In [32]:
mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1)
print(mask)
print(mask.bool()) 

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [33]:
attn_score = attn_score.masked_fill(mask.bool()[:num_tokens,:num_tokens],-torch.inf)
print(attn_score)

tensor([[[[ 0.2901,    -inf,    -inf,    -inf],
          [-0.1674,  0.0808,    -inf,    -inf],
          [ 0.6731, -0.3283,  2.0428,    -inf],
          [ 0.2176, -0.1065,  0.7358,  0.2681]],

         [[ 0.1199,    -inf,    -inf,    -inf],
          [-0.1039,  0.0344,    -inf,    -inf],
          [ 0.1498, -0.4454, -0.2415,    -inf],
          [ 0.0144, -0.1777, -0.2707, -0.2306]]]],
       grad_fn=<MaskedFillBackward0>)


In [34]:
attn_weight = torch.softmax(attn_score/K_split.shape[-1]**0.5,dim=-1)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4562, 0.5438, 0.0000, 0.0000],
          [0.2423, 0.1194, 0.6383, 0.0000],
          [0.2340, 0.1860, 0.3375, 0.2425]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4756, 0.5244, 0.0000, 0.0000],
          [0.4141, 0.2719, 0.3140, 0.0000],
          [0.2832, 0.2472, 0.2315, 0.2381]]]], grad_fn=<SoftmaxBackward0>)


In [35]:
torch.manual_seed(3)

attn_dropout = nn.Dropout(dropout)

attn_weight = attn_dropout(attn_weight)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4562, 0.5438, 0.0000, 0.0000],
          [0.2423, 0.1194, 0.6383, 0.0000],
          [0.2340, 0.1860, 0.3375, 0.2425]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4756, 0.5244, 0.0000, 0.0000],
          [0.4141, 0.2719, 0.3140, 0.0000],
          [0.2832, 0.2472, 0.2315, 0.2381]]]], grad_fn=<SoftmaxBackward0>)


In [36]:
con_vector = attn_weight @ V_split
print(con_vector)
print(con_vector.shape)

tensor([[[[ 0.4180, -0.0229],
          [ 0.0360,  0.0756],
          [ 0.6253, -0.4159],
          [ 0.3992, -0.2412]],

         [[-0.5591, -0.1036],
          [-0.1143, -0.0487],
          [-0.7840, -0.3727],
          [-0.7246, -0.3772]]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 4, 2])


In [37]:
conv_vector = con_vector.transpose(1,2).contiguous().view(B,num_tokens,d_out)
conv_vector

tensor([[[ 0.4180, -0.0229, -0.5591, -0.1036],
         [ 0.0360,  0.0756, -0.1143, -0.0487],
         [ 0.6253, -0.4159, -0.7840, -0.3727],
         [ 0.3992, -0.2412, -0.7246, -0.3772]]], grad_fn=<ViewBackward0>)

# Dropout1 

In [38]:
dropout1 = nn.Dropout(dropout)
after_dropout_1 = dropout1(conv_vector)
print(after_dropout_1)
print(after_dropout_1.shape)

tensor([[[ 0.4180, -0.0229, -0.5591, -0.1036],
         [ 0.0360,  0.0756, -0.1143, -0.0487],
         [ 0.6253, -0.4159, -0.7840, -0.3727],
         [ 0.3992, -0.2412, -0.7246, -0.3772]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 4])


# LayerNorm 2

In [39]:
layernorm2 = nn.LayerNorm(after_dropout_1.shape[-1])
layernorm2_result = layernorm2(after_dropout_1)

print(layernorm2_result)
print(layernorm2_result.shape)

tensor([[[ 1.3988,  0.1269, -1.4198, -0.1059],
         [ 0.6615,  1.1973, -1.3733, -0.4855],
         [ 1.6491, -0.3425, -1.0466, -0.2600],
         [ 1.5610, -0.0129, -1.2009, -0.3472]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 4, 4])


# FC

In [40]:
torch.manual_seed(1)

fc1 = nn.Linear(d_out,d_out*2)
print(fc1.weight.T)
print(fc1.weight.T.shape)

tensor([[ 0.2576, -0.4707,  0.0695,  0.1826,  0.0725, -0.1862, -0.1602, -0.4888],
        [-0.2207,  0.2999, -0.0612, -0.1949, -0.0020, -0.3020,  0.0239,  0.3100],
        [-0.0969, -0.1029,  0.1387, -0.0365,  0.4371, -0.0838,  0.2981,  0.1397],
        [ 0.2347,  0.2544,  0.0247, -0.0450,  0.1556, -0.2157,  0.2718,  0.4743]],
       grad_fn=<PermuteBackward0>)
torch.Size([4, 8])


In [41]:
torch.manual_seed(1)

fc1_result = fc1(layernorm2_result)
print(fc1_result)
print(fc1_result.shape)

tensor([[[ 0.7752, -0.9568, -0.5855,  0.0461, -0.0969, -0.2402, -0.4591,
          -1.1252],
         [ 0.2554, -0.3901, -0.7052, -0.2818, -0.1912, -0.3482, -0.4046,
          -0.6066],
         [ 0.8709, -1.2930, -0.4914,  0.1766,  0.0613, -0.1431, -0.4411,
          -1.4141],
         [ 0.7700, -1.1590, -0.5412,  0.1058, -0.0267, -0.1945, -0.4888,
          -1.3318]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 8])


In [42]:
gelu = nn.GELU()
gelu_result = gelu(fc1_result)
gelu_result

tensor([[[ 0.6053, -0.1620, -0.1634,  0.0239, -0.0447, -0.0973, -0.1483,
          -0.1466],
         [ 0.1534, -0.1359, -0.1695, -0.1096, -0.0811, -0.1267, -0.1387,
          -0.1650],
         [ 0.7038, -0.1267, -0.1531,  0.1007,  0.0322, -0.0634, -0.1454,
          -0.1112],
         [ 0.6001, -0.1428, -0.1592,  0.0574, -0.0131, -0.0823, -0.1527,
          -0.1218]]], grad_fn=<GeluBackward0>)

In [43]:
torch.manual_seed(1)

fc2 = nn.Linear(2*d_out,d_out)
print(fc2.weight.T)
print(fc2.weight.T.shape)

tensor([[ 0.1822,  0.0491,  0.0512, -0.1133],
        [-0.1561, -0.0433, -0.0014,  0.0169],
        [-0.0685,  0.0981,  0.3091,  0.2108],
        [ 0.1659,  0.0174,  0.1100,  0.1922],
        [-0.3328,  0.1291, -0.1317, -0.3456],
        [ 0.2120, -0.1378, -0.2135,  0.2192],
        [-0.0727, -0.0258, -0.0593,  0.0988],
        [ 0.1799, -0.0318, -0.1525,  0.3354]], grad_fn=<PermuteBackward0>)
torch.Size([8, 4])


In [44]:
fc2_result = fc2(gelu_result)
print(fc2_result)
print(fc2_result.shape)

tensor([[[ 0.3628, -0.2849, -0.2950, -0.3414],
         [ 0.2565, -0.3114, -0.3214, -0.3158],
         [ 0.3750, -0.2751, -0.3013, -0.3420],
         [ 0.3615, -0.2836, -0.3012, -0.3329]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 4])


# Dropout 2 

In [45]:
torch.manual_seed(1)

dropout2 = nn.Dropout(dropout)
dropout2_result = dropout2(fc2_result)
dropout2_result

tensor([[[ 0.3628, -0.2849, -0.2950, -0.3414],
         [ 0.2565, -0.3114, -0.3214, -0.3158],
         [ 0.3750, -0.2751, -0.3013, -0.3420],
         [ 0.3615, -0.2836, -0.3012, -0.3329]]], grad_fn=<ViewBackward0>)

End of Transformer block 

# Post Transformer Block LayerNorm

In [46]:
post_transformer_LN = nn.LayerNorm(d_out)
post_transformer_LN_result = post_transformer_LN(dropout2_result)
print(post_transformer_LN_result)

tensor([[[ 1.7273, -0.4994, -0.5342, -0.6937],
         [ 1.7317, -0.5580, -0.5983, -0.5755],
         [ 1.7263, -0.4707, -0.5590, -0.6967],
         [ 1.7287, -0.4992, -0.5600, -0.6695]]],
       grad_fn=<NativeLayerNormBackward0>)


# Head_out 

In [47]:
torch.manual_seed(1)

head_out = nn.Linear(d_out,vocab_size)
print(head_out.weight.T)
print(head_out.weight.T.shape)

print('----'*20)

head_out_result = head_out(post_transformer_LN_result)
print(head_out_result)
print(head_out_result.shape)

tensor([[ 0.2576, -0.4707,  0.0695,  0.1826,  0.0725, -0.1862, -0.1602, -0.4888,
          0.3300,  0.4391,  0.4906, -0.2634, -0.1444,  0.2713, -0.0234, -0.3232],
        [-0.2207,  0.2999, -0.0612, -0.1949, -0.0020, -0.3020,  0.0239,  0.3100,
         -0.4556, -0.0833, -0.2115,  0.2570, -0.0548, -0.1215, -0.3337,  0.3248],
        [-0.0969, -0.1029,  0.1387, -0.0365,  0.4371, -0.0838,  0.2981,  0.1397,
         -0.4754,  0.2140,  0.3750, -0.2654, -0.4807,  0.4980,  0.3045,  0.3036],
        [ 0.2347,  0.2544,  0.0247, -0.0450,  0.1556, -0.2157,  0.2718,  0.4743,
         -0.2412, -0.2324,  0.0059,  0.1471, -0.2384,  0.4008,  0.1552,  0.4434]],
       grad_fn=<PermuteBackward0>)
torch.Size([4, 16])
--------------------------------------------------------------------------------
tensor([[[ 0.1639, -1.1666,  0.0498,  0.5365, -0.5947, -0.3313, -0.3645,
          -1.5199,  1.4631,  0.8754,  0.9128, -0.4336,  0.3820,  0.2331,
          -0.6071, -0.9386],
         [ 0.2119, -1.1496,  0.0477,

# Softmax 

In [50]:
# we won't be using the following head_out_prob to calculate the loss, since we can use cross entropy loss 
# calculating softmax is for demonstration purpose only  

head_out_prob = F.softmax(head_out_result,dim=-1)
head_out_prob

tensor([[[0.0590, 0.0156, 0.0527, 0.0857, 0.0276, 0.0360, 0.0348, 0.0110,
          0.2164, 0.1203, 0.1248, 0.0325, 0.0734, 0.0633, 0.0273, 0.0196],
         [0.0614, 0.0157, 0.0521, 0.0857, 0.0272, 0.0355, 0.0349, 0.0112,
          0.2211, 0.1152, 0.1227, 0.0328, 0.0732, 0.0642, 0.0276, 0.0197],
         [0.0590, 0.0159, 0.0526, 0.0857, 0.0275, 0.0359, 0.0347, 0.0111,
          0.2173, 0.1199, 0.1235, 0.0331, 0.0746, 0.0625, 0.0270, 0.0197],
         [0.0595, 0.0157, 0.0525, 0.0857, 0.0274, 0.0359, 0.0348, 0.0110,
          0.2179, 0.1190, 0.1238, 0.0328, 0.0739, 0.0631, 0.0272, 0.0197]]],
       grad_fn=<SoftmaxBackward0>)

In [51]:
head_out_prob.shape

torch.Size([1, 4, 16])

In [52]:
head_out_result.argmax(dim=-1)

tensor([[8, 8, 8, 8]])

# Loss 

In [53]:
loss_fn = nn.CrossEntropyLoss()
loss_fn(head_out_result.squeeze(0),y.squeeze(0))

tensor(3.2260, grad_fn=<NllLossBackward0>)

# The End of the forward pass 