SELF ATTENTION MECHANISM(WITHOUT TRAINABLE WEIGHTS)

In [5]:
import torch
import torch.nn.functional as f
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [None]:
# Initalizing the attention weights
attention_weights = torch.matmul(inputs, inputs.T)
attention_weights = f.softmax(attention_weights)
# Context Vectors 
context_vectors = torch.matmul(attention_weights, inputs)
context_vectors

SELF ATTENTION MECHANISM WITH TRAINABLE WEIGHTS

In [6]:
dim_of_weight_matrix = inputs.shape[-1]

# Initialize weights and queries of the query, value and key matrices
torch.manual_seed(123)
weights_query = torch.nn.Parameter(torch.randn(dim_of_weight_matrix, dim_of_weight_matrix)) 
# We can use nn.Linear instead of nn.parameter because they can add bias and good sophisicated weight inititalisation scheme
weights_value = torch.nn.Parameter(torch.randn(dim_of_weight_matrix, dim_of_weight_matrix))
weights_key = torch.nn.Parameter(torch.randn(dim_of_weight_matrix, dim_of_weight_matrix))

# Initialize queries, values and keys
query = torch.matmul(inputs,weights_query)
key = torch.matmul(inputs,weights_key)
value = torch.matmul(inputs,weights_value)

# Calculate the attention scores
attention_scores = torch.matmul(query,key.T)
attention_scores1 = attention_scores/key.shape[-1]**0.5
attention_scores1 = f.softmax(attention_scores1)

# Calculate the context vectors
context_vectors = torch.matmul(attention_scores1,key)
context_vectors

  attention_scores1 = f.softmax(attention_scores1)


tensor([[-0.0370, -0.6304,  0.1504],
        [-0.0276, -0.6395,  0.1476],
        [-0.0221, -0.6392,  0.1432],
        [ 0.0421, -0.6363,  0.0952],
        [ 0.1059, -0.6293,  0.0442],
        [-0.0144, -0.6396,  0.1394]], grad_fn=<MmBackward0>)

MASKED SELF ATTENTION WITH DROPOUT

In [7]:
dim_of_weight_matrix = inputs.shape[-1]
print(dim_of_weight_matrix)
torch.seed()
weights_query = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
weights_value = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
weights_key = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
query = weights_query(inputs)
key = weights_key(inputs)
value = weights_value(inputs)
attention_scores = torch.matmul(query,key.T)
# print(attention_scores)
mask = torch.triu(torch.ones(attention_scores.shape[0], attention_scores.shape[0]),diagonal=1) ==1
# print(mask)
result =  attention_scores.masked_fill(mask==True, -torch.inf)
# print(result)
attn_weights = torch.softmax(result / key.shape[-1]**0.5, dim=1)
# print(attn_weights)
dropout = torch.nn.Dropout(0.5)
attn_weights = dropout(attn_weights)
# print(attn_weights)
context_vectors = torch.matmul(attn_weights,key)
print(context_vectors)

3
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.5905,  0.2512, -0.3913],
        [-1.1032,  0.5765, -0.6179],
        [-0.4307,  0.1662, -0.2997],
        [-0.5427,  0.3251, -0.2797],
        [-0.6820,  0.3368, -0.4002]], grad_fn=<MmBackward0>)


MULTI-HEAD ATTENTION

In [56]:
inputs = torch.tensor(
   [[[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]] , # Row 3
     [[0.10, 0.55, 0.87, 0.66, 0.43, 0.15],  # Row 4
     [0.22, 0.58, 0.33, 0.05, 0.85, 0.85],  # Row 5
     [0.10, 0.55, 0.8, 0.05, 0.85, 0.85]]] # Row 6
)

print(inputs.shape)

dim_of_weight_matrix = inputs.shape[-1]
num_heads = 3
head_dim =  int(dim_of_weight_matrix/num_heads)
weights_query = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
weights_value = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
weights_key = torch.nn.Linear(dim_of_weight_matrix, dim_of_weight_matrix,bias = False)
query = weights_query(inputs)
key = weights_key(inputs)
value = weights_value(inputs)
query =  torch.reshape(query,(inputs.shape[0], inputs.shape[1], num_heads,head_dim))
key =  torch.reshape(key,(inputs.shape[0], inputs.shape[1], num_heads,head_dim))
value = torch.reshape(value,(inputs.shape[0], inputs.shape[1], num_heads,head_dim))
query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)
multi_head_attention_scores = torch.matmul(query, key.transpose(2,3))
mask = torch.triu(torch.ones(multi_head_attention_scores.shape),diagonal=1) ==1
result =  multi_head_attention_scores.masked_fill(mask==True, -torch.inf)
multi_head_attn_weights = torch.softmax(result / key.shape[-1]**0.5, dim=-1)
dropout = torch.nn.Dropout(0.5)
multi_head_attn_weights = dropout(multi_head_attn_weights)

print(multi_head_attn_weights.shape)
multi_head_context_vectors = torch.matmul(multi_head_attn_weights,value).transpose(1,2)
print(multi_head_context_vectors.shape)
multi_head_context_vectors = torch.reshape(multi_head_context_vectors,(inputs.shape[0],inputs.shape[1],dim_of_weight_matrix))
print(multi_head_context_vectors.shape)


torch.Size([2, 3, 6])
torch.Size([2, 3, 3, 3])
torch.Size([2, 3, 3, 2])
torch.Size([2, 3, 6])
