# Assignment 1 - LLM Bootcamp

## Question 1
In the forward function of `EfficientSlidingWindowMultiheadAttention`, compute the `keys`, `queries`, and `values`. Pad the keys and values to accommodate the sliding window edges.

<b>Unfolding the keys</b>

<p>We need to reshape the keys in a smart manner to be able to utilize tensor operations. We are going to use the function unfold. Let's try to understand how it works!
Let's say you have a sequence represented by a 1-dimensional tensor, and you want to apply a sliding window to extract subsequences (windows) of a specified size. 
For simplicity, let's consider a sequence of numbers from 1 to 10 and a window size</p>

In [1]:
import torch

# Sequence tensor
sequence = torch.arange(1, 11)  # Tensor: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print("Original Sequence:", sequence)

Original Sequence: tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


The <b>unfold</b> method in PyTorch can be used to create sliding windows. You need to specify:

- The dimension along which to unfold (for a 1D tensor, this is dimension 0).
- The size of the window you want to extract (in this example, 3).
- The step size for each slide (for simplicity, let's use a step size of 1, meaning the window slides one element at a time).

In [2]:
# Window size
window_size = 3

# Unfolding the sequence to create sliding windows
windows = sequence.unfold(0, window_size, 1)

print("Sliding Windows:\n", windows)

Sliding Windows:
 tensor([[ 1,  2,  3],
        [ 2,  3,  4],
        [ 3,  4,  5],
        [ 4,  5,  6],
        [ 5,  6,  7],
        [ 6,  7,  8],
        [ 7,  8,  9],
        [ 8,  9, 10]])


In the context of sliding window attention, each row of the resulting tensor from the unfold operation can be considered as the keys (and values) for computing attention scores for a query corresponding to the position at the center of the window. This allows the model to focus on a subset of the sequence, reducing computational complexity while capturing the local context around each position.

At this point, the size of `keys` is <b>[batch_size, num_heads, seq_length, head_dim]</b>. That means that in each of the <b>num_head</b> heads, we have <b>batch_size</b> samples of <b>seq_length</b> vectors of size <b>head_dim</b>. The size of `keys_padded` should be <b>[batch_size, num_heads, seq_length + 2 * padding, head_dim]</b>. We now want to create a `keys_windows` that has a dimension <b>[batch_size, num_heads, seq_length, head_dim, window_size]</b> by using the `unfold` function on `keys_padded`. Effectively, we going to replace <b>one</b> key vector of size <b>head_dim</b> by <b>widow_size</b> successive key vectors of size <b>head_dim</b>.

In [5]:
# nn module provides the necessary components to build neural networks
import torch.nn as nn

# Initialize the sequence tensor
sequence = sequence.unsqueeze(0)  # Add batch dimension
print("Sequence:", sequence)

Sequence: tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]]])


In [3]:
# Define dimensions
embed_dim = 5  # Dimension of the embedding/feature space for keys, queries, values
window_size = 3
padding = (window_size - 1) // 2

# Define linear transformations
linear_keys = nn.Linear(1, embed_dim)
linear_queries = nn.Linear(1, embed_dim)
linear_values = nn.Linear(1, embed_dim)

# Expand sequence for linear transformation (expected input [batch, features])
sequence_expanded = sequence.unsqueeze(-1)  # Shape: [1, 10, 1]

# Apply transformations
keys = linear_keys(sequence_expanded)
queries = linear_queries(sequence_expanded)
values = linear_values(sequence_expanded)

# Pad keys and values
keys_padded = torch.nn.functional.pad(keys, (0, 0, padding, padding), mode='constant', value=0)
values_padded = torch.nn.functional.pad(values, (0, 0, padding, padding), mode='constant', value=0)

# Unfold keys and values to create sliding windows
keys_windows = keys_padded.unfold(1, window_size, 1)
values_windows = values_padded.unfold(1, window_size, 1)

# Print outputs
print("Keys Windows Shape:", keys_windows.shape)
print("Values Windows Shape:", values_windows.shape)
print("Queries Shape:", queries.shape)

Keys Windows Shape: torch.Size([1, 10, 5, 3])
Values Windows Shape: torch.Size([1, 10, 5, 3])
Queries Shape: torch.Size([1, 10, 5])


# Question 2