In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")


## Implementation of multi-head attention

In [10]:
## defining our input variable
x = torch.randn(1,3,6)
print(x)

tensor([[[ 1.4735,  0.1597, -1.3380,  1.9263,  1.5446,  0.8857],
         [-0.5353, -2.0590, -0.0887,  0.3936, -0.2552, -0.5573],
         [-0.5595,  1.0235, -0.3610,  0.5880,  1.1423, -1.6424]]])


In [11]:
print(x.shape)


torch.Size([1, 3, 6])


* The shape of our embedding matrix is [1,3,6]. 1 means that our embedding matrix has a batch size of 1, has three tokens and d_in of 6.
* Now let's calculate the head_dim

In [12]:
w_q = torch.randn(6,6)
w_k = torch.randn(6,6)
w_v = torch.randn(6,6)
print(w_q)
print(w_k)
print(w_v)

tensor([[ 2.4236, -0.7896,  0.2153,  0.9358,  1.7820,  0.7586],
        [-0.2089,  0.5353,  0.0289,  0.1919,  0.8375,  0.1344],
        [-1.7585, -1.1331,  0.9242, -1.0687, -0.9793,  0.0620],
        [-0.3945, -0.0312, -0.9360, -0.2015,  0.6547, -0.9413],
        [-0.3483,  0.1869, -0.7988, -1.1441, -0.5444, -0.5085],
        [-0.9239,  0.0112,  0.2564, -0.9360, -0.5950, -0.2950]])
tensor([[-0.8277,  1.3515, -0.0028, -1.4203,  0.5166,  0.1209],
        [-0.4109,  0.5809,  1.8617, -0.2859, -2.2864, -0.1366],
        [-0.8955,  1.5974,  0.3516, -0.4752, -1.2126, -1.3231],
        [-0.4409, -0.6348,  0.1634,  1.1214, -0.0749, -0.4081],
        [-1.3281,  0.1743,  1.8257, -1.6399,  0.5695, -0.4977],
        [-0.9622, -0.2941,  1.3717, -0.7383,  1.5301, -0.9479]])
tensor([[ 0.5557,  0.6057,  0.4444,  0.4568,  0.3508,  1.5324],
        [ 0.8020,  0.1285, -1.4518,  0.5329,  0.0419,  1.3991],
        [-0.9131,  1.2236,  0.2359, -1.7598, -0.1423,  0.2901],
        [-0.0852, -0.2009,  0.4286, -0

In [13]:
## calculating query,key and value vectors of our input embedding
q = torch.matmul(x,w_q)
k = torch.matmul(x,w_k)
v = torch.matmul(x,w_v)

print("query vector:\n",q)
print("key vector:\n",k)
print("value vector:\n",v)


query vector:
 tensor([[[ 3.7744,  0.6766, -3.7243, -0.1449,  3.9631, -1.8036],
         [-0.2627, -0.6452, -0.5642, -0.0670, -1.8633, -0.7646],
         [-0.0472,  1.5755, -2.3085,  0.1706,  0.9542, -0.9591]]])
key vector:
 tensor([[[-3.8398, -1.2673,  4.1725, -2.5294,  4.1091, -0.4678],
         [ 2.0701, -2.1916, -5.0290,  2.6623,  3.5114,  0.8285],
         [ 0.1700, -0.4296,  1.7086,  0.6723, -4.0978,  1.0186]]])
value vector:
 tensor([[[ 3.5635, -1.0415,  1.6482,  0.6800,  1.0036,  5.5205],
         [-2.0365, -0.7059,  3.3572, -1.3898, -0.0931, -3.5801],
         [ 2.6271, -0.2575,  2.1821, -2.4089, -0.1117, -1.4339]]])


In [None]:
keys  = keys.view(b,)

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), \
       "d_out must be divisible by num_heads"
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out) # linear layer to combine head output
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        "mask",
        torch.tril(torch.ones(context_length,context_length), diagonal=1)
    )

  def forward(self,x):
    b, num_tokens,d_in = x.shape
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    #