In [1]:
import torch 
from torch import nn
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F

# Section 1 

In [54]:
raw_text = 'The cat is hungry I went to the store But it was closed So I went to a different store'

# split the text into train and validation sets

In [55]:
ratio = 0.66
split_index = int(len(raw_text) * ratio)
train_raw_text = raw_text[:split_index]
val_raw_text = raw_text[split_index:]

In [56]:
print("Train raw text:")
print("-"*100)
print(train_raw_text)


Train raw text:
----------------------------------------------------------------------------------------------------
The cat is hungry I went to the store But it was closed 


In [6]:
print("Val raw text:")
print("-"*100)
print(val_raw_text)

Val raw text:
----------------------------------------------------------------------------------------------------
So I went to a different store


# Corpus vocabulary


We will make an assumption that each word is a unique token.<br>
This is a simplification and not true in the real world.<br>
In practice, we would use a more sophisticated tokenization method.


In [7]:
vocab = list(set(raw_text.split(' ')))

for i in vocab:
    print(i)

a
The
is
So
different
went
was
store
the
closed
But
hungry
cat
to
it
I


In [8]:
vocab_size = len(vocab)
print(f'vocab_size: {vocab_size}')

vocab_size: 16


In [9]:
tokens_to_ids = {token: id for id, token in enumerate(vocab)}
print("Mapping of tokens to ids:")
tokens_to_ids

Mapping of tokens to ids:


{'a': 0,
 'The': 1,
 'is': 2,
 'So': 3,
 'different': 4,
 'went': 5,
 'was': 6,
 'store': 7,
 'the': 8,
 'closed': 9,
 'But': 10,
 'hungry': 11,
 'cat': 12,
 'to': 13,
 'it': 14,
 'I': 15}

In [10]:
ids_to_tokens = {id: token for id, token in enumerate(vocab)}
print("Mapping of ids to tokens:")
ids_to_tokens

Mapping of ids to tokens:


{0: 'a',
 1: 'The',
 2: 'is',
 3: 'So',
 4: 'different',
 5: 'went',
 6: 'was',
 7: 'store',
 8: 'the',
 9: 'closed',
 10: 'But',
 11: 'hungry',
 12: 'cat',
 13: 'to',
 14: 'it',
 15: 'I'}

In [11]:
def encode(text):
    return [tokens_to_ids[token] for token in text.strip().split(' ')]

def decode(ids):
    return ' '.join([ids_to_tokens[id] for id in ids])

In [12]:
encode(raw_text)

[1, 12, 2, 11, 15, 5, 13, 8, 7, 10, 14, 6, 9, 3, 15, 5, 13, 0, 4, 7]

In [13]:
encode("The cat is hungry")

[1, 12, 2, 11]

In [14]:
decode([9,3,6,4])

'closed So was different'

# Creating a dataset 

### Dataset 

In [15]:
class Data(Dataset):
    def __init__(self,raw_text,max_len=4,stride=3):
        self.token_ids = encode(raw_text)
        self.X = []
        self.y = []
        for i in range(0,len(self.token_ids)-max_len,stride):
            input = self.token_ids[i:i+max_len]
            output = self.token_ids[i+1:i+max_len+1]
            self.X.append(torch.tensor(input))
            self.y.append(torch.tensor(output))
            
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return self.X[idx],self.y[idx]
    

In [16]:
train_ds = Data(train_raw_text)
train_ds[0]

(tensor([ 1, 12,  2, 11]), tensor([12,  2, 11, 15]))

In [17]:
val_ds = Data(val_raw_text)
val_ds[0]

(tensor([ 3, 15,  5, 13]), tensor([15,  5, 13,  0]))

### DataLoader

In [18]:
train_dl = DataLoader(train_ds,batch_size=1,shuffle=False,drop_last=False,num_workers=0)
val_dl   = DataLoader(val_ds,batch_size=1,shuffle=False,drop_last=True,num_workers=0)

In [19]:
for i,(x,y) in enumerate(train_dl):
  print(f'Batch Number: {i+1}')
  print(f'x :{x}')
  print(f'y :{y}')
  print('--'*20)

Batch Number: 1
x :tensor([[ 1, 12,  2, 11]])
y :tensor([[12,  2, 11, 15]])
----------------------------------------
Batch Number: 2
x :tensor([[11, 15,  5, 13]])
y :tensor([[15,  5, 13,  8]])
----------------------------------------
Batch Number: 3
x :tensor([[13,  8,  7, 10]])
y :tensor([[ 8,  7, 10, 14]])
----------------------------------------


In [20]:
for i,(x,y) in enumerate(val_dl):
  print(f'Batch Number: {i+1}')
  print(f'x :{x}')
  print(f'y :{y}')
  print('--'*20)

Batch Number: 1
x :tensor([[ 3, 15,  5, 13]])
y :tensor([[15,  5, 13,  0]])
----------------------------------------


Taking single batch, and we know that the batch size is 1. <br>
so we are taking a single example from the dataset. 

In [21]:
for x,y in train_dl:
  print(x)
  print(y)
  break

tensor([[ 1, 12,  2, 11]])
tensor([[12,  2, 11, 15]])


# Config 

In [22]:
B = 1  # batch size 


context_length = 4 # context window  [max length of the input sequence the model can handle]
num_tokens = 4 # this can not be greater than context window
d_in = 3  # embedding dimension 
d_out = 4 # output dimension 

dropout = 0.0 # notice that we are not using it, but we are still initializing it(to see how it works)

num_heads = 2 
head_dim = d_out//num_heads
print(head_dim)

2


# Token Embedding  and Positional Encoding 

In [23]:
torch.manual_seed(1)

# Embedding 
token_emb = nn.Embedding(vocab_size,d_in)
token_embedding = token_emb(x)
print('Token Embeddings:')
print(token_embedding)
print('--'*20)
print(token_embedding.shape)


Token Embeddings:
tensor([[[-1.6095, -0.1002, -0.6092],
         [ 0.1991,  0.0457,  0.1530],
         [-0.9798, -1.6091, -0.7121],
         [-0.0721,  0.1578, -0.7735]]], grad_fn=<EmbeddingBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


In [24]:
torch.manual_seed(1)

# Positional embedding
pos_emb = nn.Embedding(context_length,d_in)
positional_embedding = pos_emb(torch.arange(num_tokens))
print('Positional Embeddings:')
print('--'*20)
print(positional_embedding)
print('--'*20)
print(positional_embedding.shape)

Positional Embeddings:
----------------------------------------
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661],
        [-1.5228,  0.3817, -1.0276],
        [-0.5631, -0.8923, -0.0583]], grad_fn=<EmbeddingBackward0>)
----------------------------------------
torch.Size([4, 3])


In [25]:
# Token embedding + token embedding 
tok_pos_emb = token_embedding + positional_embedding
print('Token Embedding + Positional Embedding')
print('--'*20)
print(tok_pos_emb)
print('--'*20)
print(tok_pos_emb.shape)

Token Embedding + Positional Embedding
----------------------------------------
tensor([[[-0.9481,  0.1668, -0.5475],
         [ 0.8204, -0.4062, -0.0132],
         [-2.5025, -1.2274, -1.7398],
         [-0.6352, -0.7345, -0.8317]]], grad_fn=<AddBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


# Pre Transformer Block Dropout 

In [26]:
torch.manual_seed(1)

pre_transformer_dp = nn.Dropout(dropout)
pre_transformer_dp_result = pre_transformer_dp(tok_pos_emb)
pre_transformer_dp_result

tensor([[[-0.9481,  0.1668, -0.5475],
         [ 0.8204, -0.4062, -0.0132],
         [-2.5025, -1.2274, -1.7398],
         [-0.6352, -0.7345, -0.8317]]], grad_fn=<AddBackward0>)

# Transformer Block 

### Layer Normalization 

In [27]:
layernorm1 = nn.LayerNorm(d_in)
layernorm1 = layernorm1(pre_transformer_dp_result)
print(layernorm1)
print('--'*20)
print(layernorm1.shape)

tensor([[[-1.0955,  1.3222, -0.2267],
         [ 1.3428, -1.0556, -0.2871],
         [-1.2966,  1.1373,  0.1593],
         [ 1.2282, -0.0089, -1.2193]]], grad_fn=<NativeLayerNormBackward0>)
----------------------------------------
torch.Size([1, 4, 3])


### Multi-Head Attention 

weights initialization 

In [28]:
torch.manual_seed(123)

W_q = nn.Linear(d_in,d_out,bias=False)
W_k = nn.Linear(d_in,d_out,bias=False)
W_v = nn.Linear(d_in,d_out,bias=False)

# REMINDER : THE WEIGHT MATRICES ARE TRANSPOSED 
print('W_q')
print(W_q.weight.T)
print(W_q.weight.T.shape)
print('---'*20)

print('W_k')
print(W_k.weight.T)
print(W_k.weight.T.shape)
print('---'*20)

print('W_v')
print(W_v.weight.T)
print(W_v.weight.T.shape)

W_q
tensor([[-0.2354,  0.2177, -0.4196,  0.2615],
        [ 0.0191, -0.4919, -0.4590, -0.2133],
        [-0.2867,  0.4232, -0.3648,  0.2161]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])
------------------------------------------------------------
W_k
tensor([[-0.4900, -0.1135, -0.1362,  0.1076],
        [-0.3503, -0.4404,  0.1853,  0.1579],
        [-0.2120,  0.3780,  0.4083,  0.5573]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])
------------------------------------------------------------
W_v
tensor([[-0.2604,  0.4126,  0.4929,  0.2377],
        [ 0.1829,  0.4611,  0.2757,  0.4800],
        [-0.2569, -0.5323,  0.2516, -0.0762]], grad_fn=<PermuteBackward0>)
torch.Size([3, 4])


Q,K,V

In [29]:
torch.manual_seed(1)

Q = W_q(tok_pos_emb)
K = W_k(tok_pos_emb)
V = W_v(tok_pos_emb)


print(f'Q\n{Q}')
print(Q.shape)

print('---'*20)

print(f'K\n{K}')
print(K.shape)

print('---'*20)

print(f'V\n{V}')
print(V.shape)

Q
tensor([[[ 0.3834, -0.5202,  0.5211, -0.4018],
         [-0.1971,  0.3729, -0.1530,  0.2983],
         [ 1.0646, -0.6774,  2.2483, -0.7684],
         [ 0.3740, -0.1290,  0.9071, -0.1891]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])
------------------------------------------------------------
K
tensor([[[ 0.5223, -0.1729, -0.0635, -0.3808],
         [-0.2569,  0.0808, -0.1924,  0.0168],
         [ 2.0250,  0.1668, -0.5970, -1.4325],
         [ 0.7449,  0.0812, -0.3892, -0.6478]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])
------------------------------------------------------------
V
tensor([[[ 4.1802e-01, -2.2869e-02, -5.5906e-01, -1.0358e-01],
         [-2.8452e-01,  1.5820e-01,  2.8902e-01,  1.0334e-03],
         [ 8.7407e-01, -6.7245e-01, -2.0095e+00, -1.0513e+00],
         [ 2.4471e-01, -1.5805e-01, -7.2480e-01, -4.4010e-01]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 4])


### splitting Q,K,V into multiple heads 

In [30]:
Q_split  = Q.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
K_split  = K.view(B,num_tokens,num_heads,head_dim).transpose(1,2)
V_split  = V.view(B,num_tokens,num_heads,head_dim).transpose(1,2)


print(Q_split)
print(Q_split.shape)

print('---'*20)

print(K_split)
print(K_split.shape)


print('---'*20)

print(V_split)
print(V_split.shape)

tensor([[[[ 0.3834, -0.5202],
          [-0.1971,  0.3729],
          [ 1.0646, -0.6774],
          [ 0.3740, -0.1290]],

         [[ 0.5211, -0.4018],
          [-0.1530,  0.2983],
          [ 2.2483, -0.7684],
          [ 0.9071, -0.1891]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 4, 2])
------------------------------------------------------------
tensor([[[[ 0.5223, -0.1729],
          [-0.2569,  0.0808],
          [ 2.0250,  0.1668],
          [ 0.7449,  0.0812]],

         [[-0.0635, -0.3808],
          [-0.1924,  0.0168],
          [-0.5970, -1.4325],
          [-0.3892, -0.6478]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 4, 2])
------------------------------------------------------------
tensor([[[[ 4.1802e-01, -2.2869e-02],
          [-2.8452e-01,  1.5820e-01],
          [ 8.7407e-01, -6.7245e-01],
          [ 2.4471e-01, -1.5805e-01]],

         [[-5.5906e-01, -1.0358e-01],
          [ 2.8902e-01,  1.0334e-03],
          [-2.0095e+00, -1.0513e+00],
          

### Attention Score 

In [31]:
attn_score = Q_split @ K_split.transpose(2,3)
attn_score

tensor([[[[ 0.2901, -0.1406,  0.6896,  0.2434],
          [-0.1674,  0.0808, -0.3370, -0.1166],
          [ 0.6731, -0.3283,  2.0428,  0.7380],
          [ 0.2176, -0.1065,  0.7358,  0.2681]],

         [[ 0.1199, -0.1070,  0.2645,  0.0575],
          [-0.1039,  0.0344, -0.3360, -0.1337],
          [ 0.1498, -0.4454, -0.2415, -0.3773],
          [ 0.0144, -0.1777, -0.2707, -0.2306]]]],
       grad_fn=<UnsafeViewBackward0>)

### mask

In [32]:
mask = torch.triu(torch.ones(num_tokens,num_tokens),diagonal=1)
print(mask)
print(mask.bool()) 

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [33]:
attn_score = attn_score.masked_fill(mask.bool()[:num_tokens,:num_tokens],-torch.inf)
print(attn_score)

tensor([[[[ 0.2901,    -inf,    -inf,    -inf],
          [-0.1674,  0.0808,    -inf,    -inf],
          [ 0.6731, -0.3283,  2.0428,    -inf],
          [ 0.2176, -0.1065,  0.7358,  0.2681]],

         [[ 0.1199,    -inf,    -inf,    -inf],
          [-0.1039,  0.0344,    -inf,    -inf],
          [ 0.1498, -0.4454, -0.2415,    -inf],
          [ 0.0144, -0.1777, -0.2707, -0.2306]]]],
       grad_fn=<MaskedFillBackward0>)


In [34]:
attn_weight = torch.softmax(attn_score/K_split.shape[-1]**0.5,dim=-1)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4562, 0.5438, 0.0000, 0.0000],
          [0.2423, 0.1194, 0.6383, 0.0000],
          [0.2340, 0.1860, 0.3375, 0.2425]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4756, 0.5244, 0.0000, 0.0000],
          [0.4141, 0.2719, 0.3140, 0.0000],
          [0.2832, 0.2472, 0.2315, 0.2381]]]], grad_fn=<SoftmaxBackward0>)


In [35]:
torch.manual_seed(3)

attn_dropout = nn.Dropout(dropout)

attn_weight = attn_dropout(attn_weight)
print(attn_weight)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4562, 0.5438, 0.0000, 0.0000],
          [0.2423, 0.1194, 0.6383, 0.0000],
          [0.2340, 0.1860, 0.3375, 0.2425]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4756, 0.5244, 0.0000, 0.0000],
          [0.4141, 0.2719, 0.3140, 0.0000],
          [0.2832, 0.2472, 0.2315, 0.2381]]]], grad_fn=<SoftmaxBackward0>)


In [36]:
con_vector = attn_weight @ V_split
print(con_vector)
print(con_vector.shape)

tensor([[[[ 0.4180, -0.0229],
          [ 0.0360,  0.0756],
          [ 0.6253, -0.4159],
          [ 0.3992, -0.2412]],

         [[-0.5591, -0.1036],
          [-0.1143, -0.0487],
          [-0.7840, -0.3727],
          [-0.7246, -0.3772]]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 4, 2])


In [37]:
conv_vector = con_vector.transpose(1,2).contiguous().view(B,num_tokens,d_out)
conv_vector

tensor([[[ 0.4180, -0.0229, -0.5591, -0.1036],
         [ 0.0360,  0.0756, -0.1143, -0.0487],
         [ 0.6253, -0.4159, -0.7840, -0.3727],
         [ 0.3992, -0.2412, -0.7246, -0.3772]]], grad_fn=<ViewBackward0>)

# Dropout1 

In [38]:
dropout1 = nn.Dropout(dropout)
after_dropout_1 = dropout1(conv_vector)
print(after_dropout_1)
print(after_dropout_1.shape)

tensor([[[ 0.4180, -0.0229, -0.5591, -0.1036],
         [ 0.0360,  0.0756, -0.1143, -0.0487],
         [ 0.6253, -0.4159, -0.7840, -0.3727],
         [ 0.3992, -0.2412, -0.7246, -0.3772]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 4])


# LayerNorm 2

In [39]:
layernorm2 = nn.LayerNorm(after_dropout_1.shape[-1])
layernorm2_result = layernorm2(after_dropout_1)

print(layernorm2_result)
print(layernorm2_result.shape)

tensor([[[ 1.3988,  0.1269, -1.4198, -0.1059],
         [ 0.6615,  1.1973, -1.3733, -0.4855],
         [ 1.6491, -0.3425, -1.0466, -0.2600],
         [ 1.5610, -0.0129, -1.2009, -0.3472]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 4, 4])


# FC

In [40]:
torch.manual_seed(1)

fc1 = nn.Linear(d_out,d_out*2)
print(fc1.weight.T)
print(fc1.weight.T.shape)

tensor([[ 0.2576, -0.4707,  0.0695,  0.1826,  0.0725, -0.1862, -0.1602, -0.4888],
        [-0.2207,  0.2999, -0.0612, -0.1949, -0.0020, -0.3020,  0.0239,  0.3100],
        [-0.0969, -0.1029,  0.1387, -0.0365,  0.4371, -0.0838,  0.2981,  0.1397],
        [ 0.2347,  0.2544,  0.0247, -0.0450,  0.1556, -0.2157,  0.2718,  0.4743]],
       grad_fn=<PermuteBackward0>)
torch.Size([4, 8])


In [41]:
torch.manual_seed(1)

fc1_result = fc1(layernorm2_result)
print(fc1_result)
print(fc1_result.shape)

tensor([[[ 0.7752, -0.9568, -0.5855,  0.0461, -0.0969, -0.2402, -0.4591,
          -1.1252],
         [ 0.2554, -0.3901, -0.7052, -0.2818, -0.1912, -0.3482, -0.4046,
          -0.6066],
         [ 0.8709, -1.2930, -0.4914,  0.1766,  0.0613, -0.1431, -0.4411,
          -1.4141],
         [ 0.7700, -1.1590, -0.5412,  0.1058, -0.0267, -0.1945, -0.4888,
          -1.3318]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 8])


In [42]:
gelu = nn.GELU()
gelu_result = gelu(fc1_result)
gelu_result

tensor([[[ 0.6053, -0.1620, -0.1634,  0.0239, -0.0447, -0.0973, -0.1483,
          -0.1466],
         [ 0.1534, -0.1359, -0.1695, -0.1096, -0.0811, -0.1267, -0.1387,
          -0.1650],
         [ 0.7038, -0.1267, -0.1531,  0.1007,  0.0322, -0.0634, -0.1454,
          -0.1112],
         [ 0.6001, -0.1428, -0.1592,  0.0574, -0.0131, -0.0823, -0.1527,
          -0.1218]]], grad_fn=<GeluBackward0>)

In [43]:
torch.manual_seed(1)

fc2 = nn.Linear(2*d_out,d_out)
print(fc2.weight.T)
print(fc2.weight.T.shape)

tensor([[ 0.1822,  0.0491,  0.0512, -0.1133],
        [-0.1561, -0.0433, -0.0014,  0.0169],
        [-0.0685,  0.0981,  0.3091,  0.2108],
        [ 0.1659,  0.0174,  0.1100,  0.1922],
        [-0.3328,  0.1291, -0.1317, -0.3456],
        [ 0.2120, -0.1378, -0.2135,  0.2192],
        [-0.0727, -0.0258, -0.0593,  0.0988],
        [ 0.1799, -0.0318, -0.1525,  0.3354]], grad_fn=<PermuteBackward0>)
torch.Size([8, 4])


In [44]:
fc2_result = fc2(gelu_result)
print(fc2_result)
print(fc2_result.shape)

tensor([[[ 0.3628, -0.2849, -0.2950, -0.3414],
         [ 0.2565, -0.3114, -0.3214, -0.3158],
         [ 0.3750, -0.2751, -0.3013, -0.3420],
         [ 0.3615, -0.2836, -0.3012, -0.3329]]], grad_fn=<ViewBackward0>)
torch.Size([1, 4, 4])


# Dropout 2 

In [45]:
torch.manual_seed(1)

dropout2 = nn.Dropout(dropout)
dropout2_result = dropout2(fc2_result)
dropout2_result

tensor([[[ 0.3628, -0.2849, -0.2950, -0.3414],
         [ 0.2565, -0.3114, -0.3214, -0.3158],
         [ 0.3750, -0.2751, -0.3013, -0.3420],
         [ 0.3615, -0.2836, -0.3012, -0.3329]]], grad_fn=<ViewBackward0>)

End of Transformer block 

# Post Transformer Block LayerNorm

In [46]:
post_transformer_LN = nn.LayerNorm(d_out)
post_transformer_LN_result = post_transformer_LN(dropout2_result)
print(post_transformer_LN_result)

tensor([[[ 1.7273, -0.4994, -0.5342, -0.6937],
         [ 1.7317, -0.5580, -0.5983, -0.5755],
         [ 1.7263, -0.4707, -0.5590, -0.6967],
         [ 1.7287, -0.4992, -0.5600, -0.6695]]],
       grad_fn=<NativeLayerNormBackward0>)


# Head_out 

In [47]:
torch.manual_seed(1)

head_out = nn.Linear(d_out,vocab_size)
print(head_out.weight.T)
print(head_out.weight.T.shape)

print('----'*20)

head_out_result = head_out(post_transformer_LN_result)
print(head_out_result)
print(head_out_result.shape)

tensor([[ 0.2576, -0.4707,  0.0695,  0.1826,  0.0725, -0.1862, -0.1602, -0.4888,
          0.3300,  0.4391,  0.4906, -0.2634, -0.1444,  0.2713, -0.0234, -0.3232],
        [-0.2207,  0.2999, -0.0612, -0.1949, -0.0020, -0.3020,  0.0239,  0.3100,
         -0.4556, -0.0833, -0.2115,  0.2570, -0.0548, -0.1215, -0.3337,  0.3248],
        [-0.0969, -0.1029,  0.1387, -0.0365,  0.4371, -0.0838,  0.2981,  0.1397,
         -0.4754,  0.2140,  0.3750, -0.2654, -0.4807,  0.4980,  0.3045,  0.3036],
        [ 0.2347,  0.2544,  0.0247, -0.0450,  0.1556, -0.2157,  0.2718,  0.4743,
         -0.2412, -0.2324,  0.0059,  0.1471, -0.2384,  0.4008,  0.1552,  0.4434]],
       grad_fn=<PermuteBackward0>)
torch.Size([4, 16])
--------------------------------------------------------------------------------
tensor([[[ 0.1639, -1.1666,  0.0498,  0.5365, -0.5947, -0.3313, -0.3645,
          -1.5199,  1.4631,  0.8754,  0.9128, -0.4336,  0.3820,  0.2331,
          -0.6071, -0.9386],
         [ 0.2119, -1.1496,  0.0477,

# Softmax 

In [50]:
# we won't be using the following head_out_prob to calculate the loss, since we can use cross entropy loss 
# calculating softmax is for demonstration purpose only  

head_out_prob = F.softmax(head_out_result,dim=-1)
head_out_prob

tensor([[[0.0590, 0.0156, 0.0527, 0.0857, 0.0276, 0.0360, 0.0348, 0.0110,
          0.2164, 0.1203, 0.1248, 0.0325, 0.0734, 0.0633, 0.0273, 0.0196],
         [0.0614, 0.0157, 0.0521, 0.0857, 0.0272, 0.0355, 0.0349, 0.0112,
          0.2211, 0.1152, 0.1227, 0.0328, 0.0732, 0.0642, 0.0276, 0.0197],
         [0.0590, 0.0159, 0.0526, 0.0857, 0.0275, 0.0359, 0.0347, 0.0111,
          0.2173, 0.1199, 0.1235, 0.0331, 0.0746, 0.0625, 0.0270, 0.0197],
         [0.0595, 0.0157, 0.0525, 0.0857, 0.0274, 0.0359, 0.0348, 0.0110,
          0.2179, 0.1190, 0.1238, 0.0328, 0.0739, 0.0631, 0.0272, 0.0197]]],
       grad_fn=<SoftmaxBackward0>)

In [51]:
head_out_prob.shape

torch.Size([1, 4, 16])

In [52]:
head_out_result.argmax(dim=-1)

tensor([[8, 8, 8, 8]])

# Loss 

In [53]:
loss_fn = nn.CrossEntropyLoss()
loss_fn(head_out_result.squeeze(0),y.squeeze(0))

tensor(3.2260, grad_fn=<NllLossBackward0>)

# The End of the forward pass 