In [1]:
sentence = "Your journey starts with one step"
print(sentence)

Your journey starts with one step


In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [None]:
sa = SelfAttention(d_in=3, d_out=2)
queries = sa.W_query(inputs) #A
keys = sa.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[0.1600, 0.1694, 0.1684, 0.1712, 0.1491, 0.1820],
        [0.1614, 0.1673, 0.1661, 0.1730, 0.1477, 0.1845],
        [0.1619, 0.1671, 0.1660, 0.1727, 0.1488, 0.1835],
        [0.1631, 0.1674, 0.1666, 0.1707, 0.1541, 0.1781],
        [0.1720, 0.1625, 0.1630, 0.1661, 0.1750, 0.1613],
        [0.1580, 0.1695, 0.1680, 0.1732, 0.1429, 0.1884]],
       grad_fn=<SoftmaxBackward0>)


In [5]:
## In language modelling, we want to mask out the future tokens when training the model to predict the next token.
## This is done by setting the attention weights to zero for future tokens.

In [None]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length)) ## Lower triangular mask where all values above the diagonal are zero
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [None]:
attn_weights_masked = attn_weights*mask_simple ## Element-wise multiplication to apply the mask
print(attn_weights_masked) ## The pdf is disturbed and we need to normalize the attention weights again

tensor([[0.1600, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1614, 0.1673, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1619, 0.1671, 0.1660, 0.0000, 0.0000, 0.0000],
        [0.1631, 0.1674, 0.1666, 0.1707, 0.0000, 0.0000],
        [0.1720, 0.1625, 0.1630, 0.1661, 0.1750, 0.0000],
        [0.1580, 0.1695, 0.1680, 0.1732, 0.1429, 0.1884]],
       grad_fn=<MulBackward0>)


In [8]:
## Normalize the attention weights after masking
row_sums = attn_weights_masked.sum(dim=1, keepdim=True)
attn_weights_masked = attn_weights_masked / row_sums
print(attn_weights_masked)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4910, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3271, 0.3376, 0.3354, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2506, 0.2495, 0.2556, 0.0000, 0.0000],
        [0.2051, 0.1938, 0.1943, 0.1981, 0.2086, 0.0000],
        [0.1580, 0.1695, 0.1680, 0.1732, 0.1429, 0.1884]],
       grad_fn=<DivBackward0>)


In [9]:
## While this method works, the information of the attention scores from future tokens is leaking into the attention weights of the current and previous tokens when we do the softmax operation.

In [10]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) ## Upper triangular mask where all values below the diagonal are zero
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [None]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
attn_scores_masked = attn_scores.masked_fill(mask.bool(), -torch.inf) ## Mask out the future tokens by setting their attention scores to -inf, so that when we apply softmax, they will have zero attention weight. This prevents data leakage from future tokens.
print(attn_scores_masked)

tensor([[-3.5897e-02,        -inf,        -inf,        -inf,        -inf,
                -inf],
        [-7.2614e-02, -2.1653e-02,        -inf,        -inf,        -inf,
                -inf],
        [-7.1101e-02, -2.6383e-02, -3.5499e-02,        -inf,        -inf,
                -inf],
        [-4.4235e-02, -7.3606e-03, -1.3702e-02,  2.0537e-02,        -inf,
                -inf],
        [-2.3816e-02, -1.0412e-01, -1.0012e-01, -7.3068e-02,  1.8775e-04,
                -inf],
        [-6.1314e-02,  3.7585e-02,  2.5256e-02,  6.8604e-02, -2.0406e-01,
          1.8710e-01]], grad_fn=<MaskedFillBackward0>)


In [12]:
attn_weights = torch.softmax(attn_scores_masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4910, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3271, 0.3376, 0.3354, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.2506, 0.2495, 0.2556, 0.0000, 0.0000],
        [0.2051, 0.1938, 0.1943, 0.1981, 0.2086, 0.0000],
        [0.1580, 0.1695, 0.1680, 0.1732, 0.1429, 0.1884]],
       grad_fn=<SoftmaxBackward0>)


In [14]:
## Dropout
example = torch.ones(6, 6)
print(example)
print(" ")

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) ## 50% of the elements will be set to zero and remaining will be scaled by 1/(1-0.5) = 2
example = torch.ones(6, 6)
print(dropout(example))

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
 
tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [15]:
## Dropout is applied after the attention weights are computed, but before they are used to compute the context vector.
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6541, 0.6751, 0.6708, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5013, 0.4990, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3876, 0.0000, 0.3962, 0.0000, 0.0000],
        [0.0000, 0.3390, 0.3360, 0.3465, 0.2857, 0.0000]],
       grad_fn=<MulBackward0>)


In [None]:
batch = torch.stack((inputs, inputs), dim=0) ## We process the inputs in batches, so we stack the inputs along a new dimension
print(batch.shape) 

torch.Size([2, 6, 3])


In [None]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) ## The upper triangular mask is registered as a buffer so that it is not treated as a parameter, but is still moved to the correct device (CPU/GPU) when the model is moved.

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x) # Shape b x num_tokens x d_out. Multiplication with broadcasting
        queries = self.W_query(x) # Shape b x num_tokens x d_out
        values = self.W_value(x) # Shape b x num_tokens x d_out

        attn_scores = queries @ keys.transpose(1, 2) # Transpose the keys to get the shape from b x num_tokens x d_out to shape b x d_out x num_tokens and then multiply with queries to get the shape b x num_tokens x num_tokens
        attn_scores.masked_fill_( 
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size ()
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1 # Shape b x num_tokens x num_tokens
        )
        attn_weights = self.dropout(attn_weights) ## Applying dropout to the attention weights

        context_vec = attn_weights @ values # Shape b x num_tokens x d_out
        return context_vec

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = inputs.shape[1]
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape) # Shape b x num_tokens x d_out

context_vecs.shape: torch.Size([2, 6, 2])


In [20]:
print("context_vecs:", context_vecs)

context_vecs: tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
