In [1]:
from gensim.models import Word2Vec
import torch


In [2]:
words=["the","sun","rises","in","the","east"]
model=Word2Vec([words],min_count=1,vector_size=3)

In [3]:
print(model.wv['sun'])  # Example to get the vector for the word 'sun'

[-0.15122044  0.21846838 -0.16200535]


In [4]:
inputs = torch.tensor(model.wv[words])  # Convert words to their corresponding vectors
print(inputs)

tensor([[-0.0179,  0.0079,  0.1701],
        [-0.1512,  0.2185, -0.1620],
        [-0.1254,  0.2460, -0.0511],
        [ 0.2153,  0.2991, -0.1672],
        [-0.0179,  0.0079,  0.1701],
        [ 0.3003, -0.3101, -0.2372]])


In [5]:
query = inputs[1]  # 'sun'
print(inputs.shape)
scores = torch.empty(len(words))
for i,x in enumerate(inputs):
    scores[i] = torch.dot(query, x)
print(scores)

torch.Size([6, 3])
tensor([-0.0231,  0.0968,  0.0810,  0.0599, -0.0231, -0.0747])


In [6]:
weights = scores/torch.sum(scores)
print(weights)

weights_softmax = torch.softmax(scores, dim=0)
print(weights_softmax)

tensor([-0.1982,  0.8297,  0.6940,  0.5130, -0.1982, -0.6403])
tensor([0.1594, 0.1797, 0.1769, 0.1732, 0.1594, 0.1514])


In [7]:
context_vec2 = torch.zeros(3)
for i, w in enumerate(weights_softmax):
    context_vec2 += w * inputs[i]
print(context_vec2)

tensor([ 0.0277,  0.0902, -0.0488])


In [8]:
attention_scores = torch.matmul(inputs, inputs.T)
print(attention_scores)

tensor([[ 0.0293, -0.0231, -0.0045, -0.0299,  0.0293, -0.0482],
        [-0.0231,  0.0968,  0.0810,  0.0599, -0.0231, -0.0747],
        [-0.0045,  0.0810,  0.0789,  0.0551, -0.0045, -0.1018],
        [-0.0299,  0.0599,  0.0551,  0.1638, -0.0299,  0.0116],
        [ 0.0293, -0.0231, -0.0045, -0.0299,  0.0293, -0.0482],
        [-0.0482, -0.0747, -0.1018,  0.0116, -0.0482,  0.2426]])


In [9]:
attention_weights = torch.softmax(attention_scores, dim=-1)
print(attention_weights)

tensor([[0.1729, 0.1641, 0.1672, 0.1630, 0.1729, 0.1600],
        [0.1594, 0.1797, 0.1769, 0.1732, 0.1594, 0.1514],
        [0.1627, 0.1773, 0.1769, 0.1727, 0.1627, 0.1476],
        [0.1553, 0.1699, 0.1691, 0.1885, 0.1553, 0.1619],
        [0.1729, 0.1641, 0.1672, 0.1630, 0.1729, 0.1600],
        [0.1582, 0.1541, 0.1499, 0.1680, 0.1582, 0.2116]])


In [10]:
print("Sums to 1:", attention_weights.sum(dim=-1))

Sums to 1: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [11]:
contexts_vecs = torch.matmul(attention_weights, inputs)
print(contexts_vecs)

tensor([[ 0.0312,  0.0788, -0.0415],
        [ 0.0277,  0.0902, -0.0488],
        [ 0.0267,  0.0907, -0.0463],
        [ 0.0367,  0.0873, -0.0532],
        [ 0.0312,  0.0788, -0.0415],
        [ 0.0519,  0.0577, -0.0571]])


Simple self attention mechanism with trainable weights


In [12]:
print(inputs)
w_query=torch.nn.Parameter(torch.randn(inputs.shape[1],2),requires_grad=True)  
w_key=torch.nn.Parameter(torch.randn(inputs.shape[1],2),requires_grad=True)  
w_value=torch.nn.Parameter(torch.randn(inputs.shape[1],2),requires_grad=True)
print(w_key,w_value,w_query)

tensor([[-0.0179,  0.0079,  0.1701],
        [-0.1512,  0.2185, -0.1620],
        [-0.1254,  0.2460, -0.0511],
        [ 0.2153,  0.2991, -0.1672],
        [-0.0179,  0.0079,  0.1701],
        [ 0.3003, -0.3101, -0.2372]])
Parameter containing:
tensor([[ 0.1377, -0.0328],
        [ 0.6414,  1.7504],
        [ 0.6470, -0.3417]], requires_grad=True) Parameter containing:
tensor([[-0.9517, -0.6553],
        [-0.3759,  3.3428],
        [ 1.0635,  0.5636]], requires_grad=True) Parameter containing:
tensor([[ 0.5775,  0.2911],
        [ 0.5082,  0.2681],
        [-1.2600,  1.6280]], requires_grad=True)


In [13]:
query2 = torch.matmul(inputs, w_query)
keys2 = torch.matmul(inputs, w_key)
values2 = torch.matmul(inputs, w_value)
print(query2,keys2,values2)

tensor([[-0.2207,  0.2739],
        [ 0.2278, -0.2492],
        [ 0.1170, -0.0538],
        [ 0.4870, -0.1293],
        [-0.2207,  0.2739],
        [ 0.3147, -0.3819]], grad_fn=<MmBackward0>) tensor([[ 0.1127, -0.0438],
        [ 0.0145,  0.4427],
        [ 0.1074,  0.4522],
        [ 0.1133,  0.5736],
        [ 0.1127, -0.0438],
        [-0.3110, -0.4716]], grad_fn=<MmBackward0>) tensor([[ 0.1950,  0.1339],
        [-0.1105,  0.7381],
        [-0.0274,  0.8758],
        [-0.4951,  0.7645],
        [ 0.1950,  0.1339],
        [-0.4215, -1.3671]], grad_fn=<MmBackward0>)


In [14]:
attention_scores2 = torch.matmul(query2, keys2.T)
print(attention_scores2)
attention_scores2norm = torch.softmax(attention_scores2/(2**0.5), dim=-1)
print(attention_scores2norm)

tensor([[-0.0368,  0.1180,  0.1001,  0.1321, -0.0368, -0.0605],
        [ 0.0366, -0.1070, -0.0882, -0.1171,  0.0366,  0.0467],
        [ 0.0155, -0.0221, -0.0117, -0.0176,  0.0155, -0.0110],
        [ 0.0605, -0.0502, -0.0062, -0.0190,  0.0605, -0.0905],
        [-0.0368,  0.1180,  0.1001,  0.1321, -0.0368, -0.0605],
        [ 0.0522, -0.1645, -0.1389, -0.1834,  0.0522,  0.0822]],
       grad_fn=<MmBackward0>)
tensor([[0.1580, 0.1763, 0.1741, 0.1781, 0.1580, 0.1554],
        [0.1747, 0.1579, 0.1600, 0.1567, 0.1747, 0.1760],
        [0.1691, 0.1647, 0.1659, 0.1652, 0.1691, 0.1660],
        [0.1747, 0.1616, 0.1667, 0.1652, 0.1747, 0.1570],
        [0.1580, 0.1763, 0.1741, 0.1781, 0.1580, 0.1554],
        [0.1786, 0.1532, 0.1560, 0.1512, 0.1786, 0.1824]],
       grad_fn=<SoftmaxBackward0>)


In [15]:
class selfattention(torch.nn.Module):
    def __init__(self, input_dim, out_dim):
        super().__init__()
        self.w_query = torch.nn.Parameter(torch.randn(input_dim, out_dim), requires_grad=True)
        self.w_key = torch.nn.Parameter(torch.randn(input_dim, out_dim), requires_grad=True)
        self.w_value = torch.nn.Parameter(torch.randn(input_dim, out_dim), requires_grad=True)

    def forward(self, x):
        queries = torch.matmul(x, self.w_query)
        keys = torch.matmul(x, self.w_key)
        values = torch.matmul(x, self.w_value)

        attention_scores = torch.matmul(queries, keys.T) 
        attention_weights = torch.softmax(attention_scores/(keys.shape[1]**0.5), dim=-1)

        output = torch.matmul(attention_weights, values)
        return output

In [16]:
test_attention = selfattention(input_dim=3, out_dim=2)
output = test_attention(inputs)
print(output)

tensor([[ 0.0021,  0.1052],
        [-0.0153,  0.1153],
        [-0.0153,  0.1118],
        [-0.0068,  0.1233],
        [ 0.0021,  0.1052],
        [ 0.0273,  0.1248]], grad_fn=<MmBackward0>)


In [17]:
class selfattentionv2(torch.nn.Module):
    def __init__(self, input_dim, out_dim,qkv_bias=False):
        super().__init__()
        self.w_query = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)
        self.w_key = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)
        self.w_value = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)

    def forward(self, x):
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)

        attention_scores = torch.matmul(queries, keys.T) 
        attention_weights = torch.softmax(attention_scores/(keys.shape[1]**0.5), dim=-1)

        output = torch.matmul(attention_weights, values)
        return output,attention_weights

In [18]:
test2 = selfattentionv2(input_dim=3, out_dim=2)
output2,weights = test2.forward(inputs)
print(weights)

tensor([[0.1670, 0.1655, 0.1657, 0.1665, 0.1670, 0.1682],
        [0.1659, 0.1693, 0.1686, 0.1664, 0.1659, 0.1640],
        [0.1660, 0.1691, 0.1684, 0.1661, 0.1660, 0.1645],
        [0.1648, 0.1733, 0.1713, 0.1649, 0.1648, 0.1610],
        [0.1670, 0.1655, 0.1657, 0.1665, 0.1670, 0.1682],
        [0.1666, 0.1667, 0.1669, 0.1674, 0.1666, 0.1659]],
       grad_fn=<SoftmaxBackward0>)


<b>Causal Attention Mechanism

In [19]:
mask = torch.tril(torch.ones(weights.shape))
print(mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [20]:
masked_weights = weights * mask
print(masked_weights)

tensor([[0.1670, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1659, 0.1693, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1660, 0.1691, 0.1684, 0.0000, 0.0000, 0.0000],
        [0.1648, 0.1733, 0.1713, 0.1649, 0.0000, 0.0000],
        [0.1670, 0.1655, 0.1657, 0.1665, 0.1670, 0.0000],
        [0.1666, 0.1667, 0.1669, 0.1674, 0.1666, 0.1659]],
       grad_fn=<MulBackward0>)


In [21]:
row_sums = masked_weights.sum(dim=-1, keepdim=True)
normalized_masked_weights = masked_weights / row_sums 
print(normalized_masked_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4949, 0.5051, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3297, 0.3359, 0.3345, 0.0000, 0.0000, 0.0000],
        [0.2444, 0.2570, 0.2541, 0.2446, 0.0000, 0.0000],
        [0.2008, 0.1990, 0.1992, 0.2001, 0.2008, 0.0000],
        [0.1666, 0.1667, 0.1669, 0.1674, 0.1666, 0.1659]],
       grad_fn=<DivBackward0>)


In [22]:
mask = torch.triu(torch.ones(weights.shape), diagonal=1)
masked=weights.masked_fill(mask.bool(), float('-inf'))
print(masked)

tensor([[0.1670,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1659, 0.1693,   -inf,   -inf,   -inf,   -inf],
        [0.1660, 0.1691, 0.1684,   -inf,   -inf,   -inf],
        [0.1648, 0.1733, 0.1713, 0.1649,   -inf,   -inf],
        [0.1670, 0.1655, 0.1657, 0.1665, 0.1670,   -inf],
        [0.1666, 0.1667, 0.1669, 0.1674, 0.1666, 0.1659]],
       grad_fn=<MaskedFillBackward0>)


In [23]:
weights = torch.softmax(masked/keys2.shape[-1]**0.5, dim=1)
print(weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4994, 0.5006, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3329, 0.3336, 0.3335, 0.0000, 0.0000, 0.0000],
        [0.2493, 0.2508, 0.2505, 0.2494, 0.0000, 0.0000],
        [0.2001, 0.1999, 0.1999, 0.2000, 0.2001, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1666]],
       grad_fn=<SoftmaxBackward0>)


<b>Attention weights with dropout

In [24]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(p=0.5)
dropped_weights = dropout(weights)
print(dropped_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0012, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6669, 0.0000, 0.0000, 0.0000],
        [0.4987, 0.5017, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


<b>Causal Attention Mechanism Class

In [38]:
class causalattention(torch.nn.Module):
    def __init__(self, input_dim, out_dim,context_len,dropout,qkv_bias=False):
        super().__init__()
        self.w_query = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)
        self.w_key = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)
        self.w_value = torch.nn.Linear(input_dim, out_dim,bias=qkv_bias)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_len, context_len), diagonal=1))
    def forward(self, x):
        b,num_tokens,input_dim = x.shape
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)

        attention_scores = torch.matmul(queries, keys.transpose(1,2))
        attention_scores=attention_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens], float('-inf')) 
        attention_weights = torch.softmax(attention_scores/(keys.shape[1]**0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, values)
        return output

In [39]:
batch = torch.stack((inputs,inputs),dim=0)  # Create a batch of 2 identical sequences
print(batch.shape)
test3 = causalattention(input_dim=3, out_dim=2,context_len=6,dropout=0.5)
contexts_vecs = test3.forward(batch)
print(contexts_vecs)


torch.Size([2, 6, 3])
tensor([[[ 0.0488,  0.1241],
         [ 0.0244,  0.0620],
         [ 0.1123,  0.0358],
         [ 0.0844,  0.0267],
         [ 0.0098,  0.0248],
         [ 0.1196, -0.0390]],

        [[ 0.0488,  0.1241],
         [ 0.1489,  0.0140],
         [ 0.0162,  0.0413],
         [ 0.1920, -0.0277],
         [ 0.1533, -0.0219],
         [ 0.1277, -0.0183]]], grad_fn=<UnsafeViewBackward0>)


<b> Multi Head attention mechanism


In [44]:
class Multiheadattentionv1(torch.nn.Module):
    def __init__(self, input_dim, out_dim, num_heads, context_len, dropout, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention_heads = torch.nn.ModuleList([
            causalattention(input_dim, out_dim, context_len, dropout, qkv_bias) 
            for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.attention_heads], dim=-1)

In [49]:
test4 = Multiheadattentionv1(input_dim=3, out_dim=2, num_heads=2, context_len=6, dropout=0.5)
contexts_vecs_multihead = test4.forward(batch)
print(contexts_vecs_multihead)

tensor([[[ 0.0134,  0.0937,  0.0000,  0.0000],
         [ 0.0067,  0.0471,  0.0000,  0.0000],
         [ 0.0729, -0.0940,  0.0846,  0.0627],
         [ 0.0279, -0.0070,  0.1061,  0.1237],
         [ 0.0224, -0.0060,  0.0381,  0.0740],
         [-0.0546, -0.0858,  0.0357,  0.0458]],

        [[ 0.0134,  0.0937,  0.0000,  0.0000],
         [ 0.0602, -0.0793,  0.1111,  0.1393],
         [ 0.0729, -0.0940, -0.0105,  0.0301],
         [-0.0123, -0.0983,  0.0000,  0.0000],
         [ 0.0270, -0.0132, -0.0063,  0.0182],
         [-0.0012, -0.0409,  0.0251,  0.0102]]], grad_fn=<CatBackward0>)
