In [2]:
import torch
import torch.nn as nn

## Playing with nn.Dropout

In [12]:
dropout = nn.Dropout(p=0.5)

In [16]:
input = torch.rand(6,6)

In [17]:
input

tensor([[0.5615, 0.4451, 0.2669, 0.2333, 0.4857, 0.2484],
        [0.6900, 0.2211, 0.8846, 0.7803, 0.5993, 0.3093],
        [0.8048, 0.6307, 0.9297, 0.0072, 0.2327, 0.5284],
        [0.8300, 0.1190, 0.3450, 0.4250, 0.8378, 0.9796],
        [0.4211, 0.4509, 0.1264, 0.7267, 0.9654, 0.5539],
        [0.9365, 0.5828, 0.7313, 0.4985, 0.6818, 0.7144]])

In [18]:
dropout(input)

tensor([[0.0000, 0.8901, 0.5339, 0.0000, 0.9714, 0.4967],
        [1.3800, 0.4423, 1.7692, 1.5606, 0.0000, 0.6186],
        [1.6095, 0.0000, 1.8594, 0.0144, 0.4654, 1.0568],
        [0.0000, 0.0000, 0.0000, 0.8500, 1.6755, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1079],
        [1.8730, 1.1657, 0.0000, 0.9969, 1.3636, 0.0000]])

## Casual Attention class

In [None]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, tokens, d_in = x.shape
        x_q = self.W_query(x)
        x_k = self.W_key(x)
        x_v = self.W_value(x)

        attn_scores = x_q @ x_k.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool(), -torch.inf)

        attn_weight = torch.softmax(attn_scores / x_k.shape[-1]**0.5, dim=-1)
        print(attn_weight)

        attn_weight_dropout = self.dropout(attn_weight)

        context_vec = attn_weight_dropout @ x_v
        return context_vec

In [60]:
input = torch.randn(2,5,4)
d_in = 4
d_out = 2
ca = CasualAttention(d_in=d_in, d_out=d_out, dropout=0.5, context_length=5)

result = ca(input)
print(result)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6796, 0.3204, 0.0000, 0.0000, 0.0000],
         [0.4944, 0.2727, 0.2329, 0.0000, 0.0000],
         [0.1987, 0.2550, 0.1263, 0.4200, 0.0000],
         [0.1705, 0.1604, 0.4521, 0.0621, 0.1549]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7893, 0.2107, 0.0000, 0.0000, 0.0000],
         [0.2941, 0.2640, 0.4419, 0.0000, 0.0000],
         [0.3541, 0.2514, 0.1877, 0.2068, 0.0000],
         [0.1530, 0.1474, 0.2505, 0.2418, 0.2072]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[[ 0.0000,  0.0000],
         [ 0.1998, -0.5034],
         [ 0.1700, -0.4284],
         [ 0.8279,  0.1842],
         [ 0.3108, -0.1971]],

        [[ 0.0000,  0.0000],
         [ 1.3091,  1.6676],
         [-0.0162,  0.4568],
         [ 0.1576, -0.6356],
         [-0.1655, -0.5785]]], grad_fn=<UnsafeViewBackward0>)
