In [4]:
import tiktoken
from torch.utils.data import Dataset, DataLoader
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

In [6]:
with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()
    enc_text = tokenizer.encode(raw_text)
    print(len(enc_text))

5145


In [7]:
class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size, max_length, stride,
                         shuffle=True, drop_last=True, num_workers=0):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

In [8]:
dataloader = create_dataloader_v1(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)

In [9]:
data_iter = iter(dataloader)
first_batch = next(data_iter)
second_batch = next(data_iter)
print(first_batch)
print(second_batch)

[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]
[tensor([[ 367, 2885, 1464, 1807]]), tensor([[2885, 1464, 1807, 3619]])]


In [10]:
dataloader = create_dataloader_v1(raw_text, batch_size=1, max_length=8, stride=2, shuffle=False)

In [11]:
data_iter = iter(dataloader)
first_batch = next(data_iter)
second_batch = next(data_iter)
print(first_batch)
print(second_batch)

[tensor([[  40,  367, 2885, 1464, 1807, 3619,  402,  271]]), tensor([[  367,  2885,  1464,  1807,  3619,   402,   271, 10899]])]
[tensor([[ 2885,  1464,  1807,  3619,   402,   271, 10899,  2138]]), tensor([[ 1464,  1807,  3619,   402,   271, 10899,  2138,   257]])]


## Self attention

The goal of self attention is to compute the `context vector` of each token in a sequence. The context vector is an enriched embedding representation of a token. It is packed with information about the token itself and its relationship/relevance to other tokens in a sequence.

#### Implement self-attention with untrainable weights

In [12]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [13]:
attention_scores = inputs @ inputs.T
# Normalize the attention scores with a softmax
attention_weights = torch.softmax(attention_scores, dim=-1)
# Compute now the context vector
context_vector = attention_weights @ inputs
print("Context vector:\n", context_vector)

Context vector:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


#### Implement self-attention with trainable weights

In self-attention without `trainaible weights`, `context vector` is a weighted sum over input vectors. However, for `trainable weights`, `context vector` is weighted sum over value vector.

##### Query, Key and Value analogy to database operation
`query` is the word/token in input sequence that the model wants to get information on, key is what is used to get the information about the query and `value` is the information received using the key.

In [14]:
import torch.nn as nn

# Let's now implement self-attention with trainable weights
# We will be using nn.Parameter to initialize and create the weights
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        """
        :param d_in: input dimension
        :param d_out: output dimension
        """
        super().__init__()
        self.W_q = nn.Parameter(torch.randn(d_in, d_out))
        self.W_k = nn.Parameter(torch.randn(d_in, d_out))
        self.W_v = nn.Parameter(torch.randn(d_in, d_out))
    
    def forward(self, x):
        """
        :param x: input tensor of shape (batch_size, d_in)
        :return: output tensor of shape (batch_size, d_out)
        """
        # Compute queries, keys and values
        Q = x @ self.W_q
        K = x @ self.W_k
        V = x @ self.W_v

        # Compute attention scores
        attention_scores = Q @ K.T
        attention_weights = torch.softmax(attention_scores / (K.shape[-1] ** 0.5), dim=-1)

        # Compute context vector
        context_vector = attention_weights @ V

        return context_vector

- Let's test self-attention v1

In [15]:
torch.manual_seed(123)
d_in, d_out = 3, 2
sa_v1 = SelfAttention_v1(d_in=d_in, d_out=d_out)
context_vector_v1 = sa_v1(inputs) # Context vector for the inputs
context_vector_v1

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

In [16]:
# Lets now use nn.Linear to implement self-attention

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bais=False):
        """
        :param d_in: input dimension
        :param d_out: output dimension
        """
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bais)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bais)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bais)
    
    def forward(self, x):
        """
        :param x: input tensor of shape (batch_size, d_in)
        :return: output tensor of shape (batch_size, d_out)
        """
        # Compute queries, keys and values
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Compute attention scores
        attention_scores = Q @ K.T
        attention_weights = torch.softmax(attention_scores / (K.shape[-1] ** 0.5), dim=-1)

        # Compute context vector
        context_vector = attention_weights @ V

        return context_vector

- Let's now test self-attention v2

In [17]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in=d_in, d_out=d_out)
context_vector_v2 = sa_v2(inputs) # Context vector for the inputs
context_vector_v2

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [18]:
# Let's test if the two context vectors are the same
print(torch.allclose(context_vector_v1, context_vector_v2, atol=1e-6))

False


- Let's find a way to clone the weights from v2 and copy to v1 and then check if they are same

In [19]:
# Lets clone the weights of self-attention v2 to self-attention v1
torch.manual_seed(123)
sa_v1.W_q.data = sa_v2.W_q.weight.data.T.clone()
sa_v1.W_k.data = sa_v2.W_k.weight.data.T.clone()
sa_v1.W_v.data = sa_v2.W_v.weight.data.T.clone()

context_vector_v1 = sa_v1(inputs) 
context_vector_v2 = sa_v2(inputs)

# Let's test if the two context vectors are the same
print(torch.allclose(context_vector_v1, context_vector_v2, atol=1e-6))



True


In [20]:
context_vector_v1

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [21]:
context_vector_v2

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [22]:
# Let's experiment with implementing causal self-attention using self-attention v2
queries = sa_v2.W_q(inputs)
keys = sa_v2.W_k(inputs)
attention_scores = queries @ keys.T
attention_weights = torch.softmax(attention_scores / (d_out ** 0.5), dim=-1)
attention_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [23]:
context_length = attention_scores.shape[0]
causal_mask = torch.tril(torch.ones(context_length, context_length))
causal_mask * attention_weights # This helps to mask the future tokens

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)

#### Implementing a compact causal attention

In [24]:
# Lets implement causal self attention with dropout
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, qkv_bias=False, dropout=0.0):
        """
        :param d_in: input dimension
        :param context_length: length of the context
        :param qkv_bais: whether to use bias in the linear layers
        :param dropout: dropout rate
        """
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # Register a buffer for the causal mask
        self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, x):
        """
        :param x: input tensor of shape (B, T, d_in)
        :return: output tensor of shape (B, T, d_out)
        """
        # Expect input shape to be (B, T, d_in)
        B, T, d_in = x.shape
        # Compute queries, keys and values
        Q = self.W_q(x) # (B, T, d_out)
        K = self.W_k(x) # (B, T, d_out)
        V = self.W_v(x) # (B, T, d_out)

        # Compute attention scores
        attention_scores = Q @ K.transpose(-1, -2) # (B, T, d_out) @ (B, d_out, T) -> (B, T, T)

        # Apply causal mask
        attention_scores = attention_scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        attention_weights = torch.softmax(attention_scores / (K.shape[-1] ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)
        # Compute context vector
        context_vector = attention_weights @ V # (B, T, T) @ (B, T, d_out) -> (B, T, d_out)

        return context_vector

In [25]:
batch = torch.stack([inputs, inputs], dim=0)
context_length = batch.shape[1]
batch.shape

torch.Size([2, 6, 3])

In [26]:
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [27]:
torch.manual_seed(123)
causal_attention = CausalSelfAttention(d_in=d_in, d_out=d_out, context_length=context_length)
context_vector = causal_attention(batch)
context_vector.shape

torch.Size([2, 6, 2])

In [28]:
context_vector

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

#### Implement vanilla MultiHeadAttention Wrapper with sequential head processing

In [29]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, n_heads, qkv_bias=False, dropout=0.0):
        """
        :param d_in: input dimension
        :param n_heads: number of heads
        :param context_length: length of the context
        :param qkv_bias: whether to use bias in the linear layers
        :param dropout: dropout rate
        """
        super().__init__()
        self.heads= nn.ModuleList([
            CausalSelfAttention(d_in=d_in, d_out=d_out, context_length=context_length, qkv_bias=qkv_bias, dropout=dropout)
            for _ in range(n_heads)
        ])
    
    
    def forward(self, x):
        """
        :param x: input tensor of shape (B, T, d_in)
        :return: output tensor of shape (B, T, d_out * n_heads)
        """

        return torch.cat([h(x) for h in self.heads], dim=-1)

In [30]:
# Set d_out to 1 so output dimension of context vector is 2 since we have 2 heads
mha = MultiHeadAttentionWrapper(d_in=d_in, d_out=1, context_length=context_length, n_heads=2)
context_vector = mha(batch)
context_vector.shape

torch.Size([2, 6, 2])

#### Implement MultiHeadAttention with parallel head processing

In [31]:
# Let's implement multi-head attention with parallel head processing
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, n_heads, qkv_bias=False, dropout=0.0):
        """
        :param d_in: input dimension
        :param n_heads: number of heads
        :param context_length: length of the context
        :param qkv_bias: whether to use bias in the linear layers
        :param dropout: dropout rate
        """
        super().__init__()
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads"
        self.head_dim = d_out // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_projection = nn.Linear(d_out, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # Register a buffer for the causal mask
        self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        """
        :param x: input tensor of shape (B, T, d_in)
        :return: output tensor of shape (B, T, d_out)
        """
        # Expect input shape to be (B, T, d_in)
        B, T, d_in = x.shape
        # Compute queries, keys and values
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        # Reshape Q, K, V to (B, n_heads, T, head_dim)
        Q = Q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hdim)
        K = K.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hdim)
        V = V.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hdim)  

        # Compute attention scores
        attention_scores = Q @ K.transpose(-1, -2) # (B, nh, T, hdim) @ (B, nh, hdim, T) -> (B, nh, T, T)
        
        # Apply causal mask
        attention_scores = attention_scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        attention_weights = torch.softmax(attention_scores / (K.shape[1] ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)
        # Compute context vector
        context_vector = attention_weights @ V # (B, nh, T, T) @ (B, nh, T, hdim) -> (B, nh, T, hdim)
        # Reshape context vector to (B, T, d_out)
        context_vector = context_vector.transpose(1, 2).contiguous().view(B, T, -1)
        # Apply output projection
        context_vector = self.out_projection(context_vector) # (B, T, d_out)
        return context_vector
torch.manual_seed(123)
d_in, d_out = 3, 2
context_length = 6
mha_v2 = MultiHeadAttention(d_in=d_in, d_out=d_out, context_length=context_length, n_heads=2)
context_vector = mha_v2(batch)
context_vector.shape

torch.Size([2, 6, 2])

In [32]:
context_vector

tensor([[[ 0.1257, -0.1968],
         [ 0.1007, -0.2920],
         [ 0.0920, -0.3225],
         [ 0.0758, -0.2946],
         [ 0.0702, -0.2891],
         [ 0.0641, -0.2792]],

        [[ 0.1257, -0.1968],
         [ 0.1007, -0.2920],
         [ 0.0920, -0.3225],
         [ 0.0758, -0.2946],
         [ 0.0702, -0.2891],
         [ 0.0641, -0.2792]]], grad_fn=<UnsafeViewBackward0>)

In [35]:
# Lets now check number of parameters in the multi-head attention layer in GPT-2
context_length = 1024
d_in, d_out = 768, 768
num_heads = 12

mha = MultiHeadAttention(d_in=d_in, d_out=d_out, context_length=context_length, n_heads=num_heads)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mha) # (2.36 M)

2359296