In [1]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec



In [4]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # your
    [0.55,0.87,0.66], # journey
    [0.57,0.85,0.64], # starts
    [0.22,0.58,0.33], # with
    [0.77,0.25,0.10], # one
    [0.05,0.80,0.55]] #step
)

In [5]:
d_in = inputs.shape[1]
d_out = 2
sa_v2 = SelfAttention_v2(d_in, d_out)
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
print(attn_weights)

tensor([[0.1696, 0.1685, 0.1683, 0.1646, 0.1618, 0.1672],
        [0.1722, 0.1691, 0.1687, 0.1636, 0.1586, 0.1678],
        [0.1724, 0.1685, 0.1682, 0.1640, 0.1590, 0.1680],
        [0.1693, 0.1695, 0.1692, 0.1639, 0.1612, 0.1669],
        [0.1733, 0.1584, 0.1585, 0.1718, 0.1680, 0.1700],
        [0.1680, 0.1748, 0.1743, 0.1601, 0.1573, 0.1656]],
       grad_fn=<SoftmaxBackward0>)


In [6]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [7]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1696, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1722, 0.1691, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1724, 0.1685, 0.1682, 0.0000, 0.0000, 0.0000],
        [0.1693, 0.1695, 0.1692, 0.1639, 0.0000, 0.0000],
        [0.1733, 0.1584, 0.1585, 0.1718, 0.1680, 0.0000],
        [0.1680, 0.1748, 0.1743, 0.1601, 0.1573, 0.1656]],
       grad_fn=<MulBackward0>)


In [9]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5046, 0.4954, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3386, 0.3311, 0.3303, 0.0000, 0.0000, 0.0000],
        [0.2519, 0.2522, 0.2518, 0.2440, 0.0000, 0.0000],
        [0.2088, 0.1909, 0.1909, 0.2070, 0.2024, 0.0000],
        [0.1680, 0.1748, 0.1743, 0.1601, 0.1573, 0.1656]],
       grad_fn=<DivBackward0>)


In [15]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query
attn_scores = queries @ keys.T


mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.9231,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.2705, 1.8524,   -inf,   -inf,   -inf,   -inf],
        [1.2544, 1.8284, 1.7877,   -inf,   -inf,   -inf],
        [0.6973, 1.0167, 0.9941, 0.5925,   -inf,   -inf],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707,   -inf],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


In [17]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3986, 0.6014, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2526, 0.3791, 0.3683, 0.0000, 0.0000, 0.0000],
        [0.2265, 0.2839, 0.2794, 0.2103, 0.0000, 0.0000],
        [0.1952, 0.2363, 0.2331, 0.1820, 0.1534, 0.0000],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


In [18]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [19]:
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7582, 0.7366, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5587, 0.4206, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3068, 0.0000],
        [0.3115, 0.4183, 0.0000, 0.2839, 0.2178, 0.3588]])


In [26]:
import torch.nn as nn

class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
                                -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec



In [27]:
torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])
