In [1]:
import torch.nn as nn
class Self_Attention_v2(nn.Module):
    def __init__(self,d_in,d_out,q_bias=False):
        super().__init__()
        self.W_query=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=q_bias)
    
    def forward(self,x):
        query_matrix= self.W_query(x)
        key_matrix=self.W_key(x)
        value_matrix=self.W_value(x)
        
        attn_scores=query_matrix @ key_matrix.T
        attn_weights=torch.softmax(attn_scores/key_matrix.shape[-1]**0.5,dim=-1)
        
        context_vec=attn_weights @ value_matrix
        return context_vec
        

In [2]:
inputs=torch.tensor([[0.43,0.15,0.89], #your
                             [0.55,0.89,0.66], #journey
                             [0.57,0.85,0.64],  #starts
                             [0.22,0.58,0.33],  #with
                             [0.77,0.25,0.10],   #one
                             [0.05,0.80,0.55]])  #step

<IPython.core.display.Javascript object>

In [3]:
d_in=inputs.shape[1]
d_out=2

In [11]:
torch.manual_seed(789)
sa=Self_Attention_v2(d_in,d_out)
queries=sa.W_query(inputs)
keys=sa.W_key(inputs)
attn_scores=queries @ keys.T
attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
attn_weights

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[0.1922, 0.1641, 0.1653, 0.1551, 0.1722, 0.1510],
        [0.2047, 0.1653, 0.1663, 0.1495, 0.1665, 0.1477],
        [0.2037, 0.1653, 0.1664, 0.1499, 0.1665, 0.1481],
        [0.1870, 0.1664, 0.1669, 0.1572, 0.1662, 0.1565],
        [0.1831, 0.1666, 0.1670, 0.1589, 0.1659, 0.1585],
        [0.1936, 0.1659, 0.1667, 0.1543, 0.1666, 0.1530]],
       grad_fn=<SoftmaxBackward0>)

In [12]:
context_length=attn_weights.shape[0]
mask=torch.tril(torch.ones(context_length,context_length))
mask


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [13]:
masked_attn_weight=mask*attn_weights

In [10]:
masked_attn_weight

tensor([[0.1922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2047, 0.1653, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2037, 0.1653, 0.1664, 0.0000, 0.0000, 0.0000],
        [0.1870, 0.1664, 0.1669, 0.1572, 0.0000, 0.0000],
        [0.1831, 0.1666, 0.1670, 0.1589, 0.1659, 0.0000],
        [0.1936, 0.1659, 0.1667, 0.1543, 0.1666, 0.1530]],
       grad_fn=<MulBackward0>)

In [14]:
sum_row=masked_attn_weight.sum(dim=1,keepdim=True)
attn_weight_causal=masked_attn_weight/sum_row

In [15]:
sum_row

tensor([[0.1922],
        [0.3700],
        [0.5354],
        [0.6774],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)

In [17]:
attn_weight_causal

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5533, 0.4467, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3805, 0.3088, 0.3107, 0.0000, 0.0000, 0.0000],
        [0.2760, 0.2456, 0.2464, 0.2320, 0.0000, 0.0000],
        [0.2176, 0.1980, 0.1985, 0.1888, 0.1971, 0.0000],
        [0.1936, 0.1659, 0.1667, 0.1543, 0.1666, 0.1530]],
       grad_fn=<DivBackward0>)

# -better way


In [18]:
attn_scores


tensor([[ 0.2899,  0.0665,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4712,  0.1687,  0.1778,  0.0268,  0.1788,  0.0097],
        [ 0.4594,  0.1642,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.0990,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0847,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1225,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

In [30]:
mask_new=torch.triu(torch.ones(context_length,context_length),diagonal=1)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [31]:
mask_new

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [34]:
masked=attn_scores.masked_fill(mask_new.bool(),-torch.inf)
masked

<IPython.core.display.Javascript object>

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4712, 0.1687,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1642, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.0990, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0847, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1225, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

In [35]:
attn_weights_new=torch.softmax(masked/keys.shape[-1]**0.5,dim=-1)

<IPython.core.display.Javascript object>

In [36]:
attn_weights_new

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5533, 0.4467, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3805, 0.3088, 0.3107, 0.0000, 0.0000, 0.0000],
        [0.2760, 0.2456, 0.2464, 0.2320, 0.0000, 0.0000],
        [0.2176, 0.1980, 0.1985, 0.1888, 0.1971, 0.0000],
        [0.1936, 0.1659, 0.1667, 0.1543, 0.1666, 0.1530]],
       grad_fn=<SoftmaxBackward0>)

In [41]:
values=sa.W_value(inputs)
context_vectors=attn_weights_new @ values

In [42]:
context_vectors

tensor([[-0.0872,  0.0286],
        [-0.1003,  0.0493],
        [-0.1008,  0.0628],
        [-0.0989,  0.0485],
        [-0.0520,  0.1095],
        [-0.0759,  0.0690]], grad_fn=<MmBackward0>)

# with dropout

In [57]:
torch.manual_seed(123)
dropout=torch.nn.Dropout(0.5)
example=torch.ones(6,6)
dropout(example)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])

# implementing classes with multiple batch handling capacity


In [74]:
class Causal_Attention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,q_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_keys=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_values=nn.Linear(d_in,d_out,bias=q_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    #register buffers are used when those things are not to be included in training process :here we have to keep uppertriangular matrix fixed and need not be trained
    def forward(self,x):
        b,num_tokens,d_in=x.shape    #b=batches , num_tokens, d_in=dimension
        keys=self.W_keys(x)
        queries=self.W_query(x)
        values=self.W_values(x)
        
        attn_scores=queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)  #masked_fill_=in_place change
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)    #[:num_tokens] is used for cases where no of tokens are less than supported context size]
        attn_weights=self.dropout(attn_weights)
        context_vectors= attn_weights @ values
        return context_vectors

In [75]:
torch.manual_seed(123)
batch=torch.stack((inputs,inputs),0)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [76]:
batch.size()

torch.Size([2, 6, 3])

In [77]:
context_length=batch.shape[1]

In [78]:
ca=Causal_Attention(d_in,d_out,context_length,0)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [79]:
ca(batch)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[[-0.4519,  0.2216],
         [-0.5913,  0.0009],
         [-0.6325, -0.0663],
         [-0.5693, -0.0865],
         [-0.5540, -0.0999],
         [-0.5311, -0.1096]],

        [[-0.4519,  0.2216],
         [-0.5913,  0.0009],
         [-0.6325, -0.0663],
         [-0.5693, -0.0865],
         [-0.5540, -0.0999],
         [-0.5311, -0.1096]]], grad_fn=<UnsafeViewBackward0>)

In [80]:
ca1=Causal_Attention(d_in,d_out,context_length,0.5)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [81]:
ca1(batch)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[[0.9544, 0.2127],
         [1.1849, 0.6633],
         [0.7545, 0.4147],
         [0.7494, 0.4676],
         [0.4580, 0.2461],
         [0.8510, 0.5984]],

        [[0.9544, 0.2127],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.5335, 0.4195],
         [0.9262, 0.5676],
         [0.5962, 0.4636]]], grad_fn=<UnsafeViewBackward0>)