In [2]:
import torch

In [31]:
ones = torch.ones(6, 6)
dropout = torch.nn.Dropout(0.5)
dropout(ones)

tensor([[0., 0., 2., 0., 0., 0.],
        [0., 2., 0., 2., 0., 2.],
        [0., 0., 0., 2., 0., 2.],
        [2., 0., 2., 2., 2., 2.],
        [0., 2., 2., 0., 2., 2.],
        [0., 0., 0., 0., 0., 0.]])

In [None]:
example_2 = torch.arange(0, 10).view((2,5)).to(dtype=torch.float32)
# example_2
dropout(example_2) # 2/0.5 = 

tensor([[ 0.,  0.,  4.,  6.,  0.],
        [ 0.,  0.,  0., 16., 18.]])

In [76]:
torch.manual_seed(42)

class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out)
        self.W_k = torch.nn.Linear(d_in, d_out)
        self.W_v = torch.nn.Linear(d_in, d_out)
    
    def forward(self, x):
        '''Assuming x is 2D'''
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)
        attn_scores = queries @ keys.transpose(-1, -2)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vectors = attn_weights @ values
        return context_vectors

inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

selfAttn_01 = SelfAttention(d_in=inputs.shape[-1], d_out=2)
selfAttn_01(inputs)

tensor([[-0.1358,  0.2496],
        [-0.1344,  0.2541],
        [-0.1344,  0.2541],
        [-0.1351,  0.2531],
        [-0.1348,  0.2520],
        [-0.1350,  0.2537]], grad_fn=<MmBackward0>)

In [68]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [81]:
queries = selfAttn_01.W_q(inputs)
keys = selfAttn_01.W_k(inputs)
values = selfAttn_01.W_v(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# context_vectors = attn_weights @ values
# context_vectors
attn_weights

tensor([[0.1492, 0.1777, 0.1769, 0.1666, 0.1544, 0.1751],
        [0.1679, 0.1772, 0.1769, 0.1569, 0.1602, 0.1609],
        [0.1677, 0.1774, 0.1771, 0.1568, 0.1600, 0.1609],
        [0.1610, 0.1740, 0.1736, 0.1634, 0.1604, 0.1676],
        [0.1602, 0.1794, 0.1789, 0.1592, 0.1567, 0.1656],
        [0.1627, 0.1724, 0.1721, 0.1639, 0.1619, 0.1670]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
# Causal Attention:
# Masking:
mask = attn_scores.triu(diagonal=1).bool()
masked_attn_scores = attn_scores.masked_fill(mask, -torch.inf)
masked_attn_weights = masked_attn_scores.softmax(dim=-1)
masked_attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4809, 0.5191, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3162, 0.3423, 0.3415, 0.0000, 0.0000, 0.0000],
        [0.2353, 0.2625, 0.2618, 0.2403, 0.0000, 0.0000],
        [0.1886, 0.2214, 0.2204, 0.1869, 0.1828, 0.0000],
        [0.1611, 0.1747, 0.1744, 0.1627, 0.1599, 0.1672]],
       grad_fn=<SoftmaxBackward0>)

In [119]:
# dropout
dropout = torch.nn.Dropout(0.5)
dropout(masked_attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6846, 0.6830, 0.0000, 0.0000, 0.0000],
        [0.4707, 0.0000, 0.5236, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4427, 0.4408, 0.0000, 0.3656, 0.0000],
        [0.3222, 0.3495, 0.0000, 0.0000, 0.3198, 0.3343]],
       grad_fn=<MulBackward0>)

In [3]:
# Incorporating "causal-attention" & "dropout" into the "SelfAttention":
# Before `CausalAttention`:

# batch_input = torch.stac
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])
batch = torch.stack((inputs, inputs), dim=0)
batch.shape, batch

(torch.Size([2, 6, 3]),
 tensor([[[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]],
 
         [[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]]]))

In [4]:
batch.shape

torch.Size([2, 6, 3])

In [None]:
# @ME:
class CausalAttention(torch.nn.Module):
    '''Handling: Batches, Adding Masking & Dropout'''
    
    def __init__(self, d_in, d_out, context_length, qkv_bias=False):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout()
        self.context_length = context_length
        self.register_buffer('mask', torch.ones(context_length,context_length).triu(diagonal=1).bool())
    
    def forward(self, x):
        '''Assuming x is 2D'''
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)
        attn_scores = queries @ keys.transpose(-1, -2)
        num_tokens = x.shape[1]
        masked_attn_scores = attn_scores.masked_fill(self.mask[:num_tokens, :num_tokens], -torch.inf)
        masked_attn_weights = torch.softmax(masked_attn_scores / keys.shape[-1]**0.5, dim=-1) #masked_attn_scores.softmax(dim=-1)
        context_vectors = self.dropout(masked_attn_weights) @ values
        return context_vectors

torch.manual_seed(123)
d_in = batch.shape[-1]
d_out = 5
context_length = batch.shape[1]
causal_attn_01 = CausalAttention(d_in, d_out, 
                                 context_length = 15) # maximum no. of tokens the model can handle
causal_attn_01(batch)

tensor([[[-0.7921,  1.6215,  1.8395,  1.0113,  2.0149],
         [-0.4677,  0.9573,  1.0860,  0.5971,  1.1897],
         [-0.5417, -0.1326,  0.5225, -0.3157,  0.4448],
         [ 0.3715,  0.5817, -0.0979,  0.4880,  0.1611],
         [ 0.1905,  0.2111, -0.1004,  0.1856,  0.0123],
         [ 0.2346,  0.1525, -0.2216, -0.0023,  0.0449]],

        [[ 0.1933,  1.8278,  0.7824,  1.1983,  1.4552],
         [-0.4344,  0.6190,  0.9309,  0.6796,  0.6374],
         [-0.4125, -0.3499,  0.3693,  0.0550, -0.2304],
         [-0.7117, -0.7082,  0.5755,  0.0113, -0.4770],
         [ 0.0428,  0.4046,  0.1732,  0.2653,  0.3221],
         [-0.1133, -0.2754,  0.0182, -0.0340, -0.2721]]],
       grad_fn=<UnsafeViewBackward0>)

In [56]:
# @BOOK:
from torch import nn

class CausalAttention_Book(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method. 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention_Book(d_in, d_out, context_length, 0.0)

context_vecs_book = ca(batch)

print(context_vecs_book)
# print("context_vecs.shape:", context_vecs_book.shape)

tensor([[[ 0.1856, -1.4968],
         [ 0.3356, -0.7668],
         [-0.2352, -0.0725],
         [-0.0579, -0.5042],
         [-0.3479, -0.7335],
         [-0.6489, -0.5583]],

        [[-0.2747, -1.4670],
         [ 0.3525, -0.2532],
         [ 0.5015, -0.0186],
         [ 0.2912, -0.0851],
         [ 0.1539,  0.2645],
         [ 0.0893,  0.2897]]], grad_fn=<UnsafeViewBackward0>)
