# 2. Coding Attention Mechanisms

## 2.1 A simple self-attention mechanism without trainable weights

### 2.1.1 Working with an example

In [None]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2) <-- query
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

### 2.1.2 Computing Dot-Product Attention Scores for Query x²

In [6]:
input_query = inputs[1] # "journey" (x^2)
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, input_query)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


### 2.1.3 Normalizing Attention Scores into Weights for Query x²

In [7]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


#### Using softmax

In [11]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum(),"\n")

# Using PyTorch's built-in softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.) 

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


### 2.1.4 Computing the Context Vector for Query x² Using Attention Weights

In [12]:
query = inputs[1] # "journey" (x^2)

context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


### 2.1.5 Computing attention weights for all input tokens

In [15]:
#Computing Pairwise Dot-Product Attention Scores Between All Tokens

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores, "\n")

# Vectorized Computation of Pairwise Dot-Product Attention Scores
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]]) 

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


### 2.1.6 Applying Softmax to Obtain Attention Weights

In [18]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights, "\n")

row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]]) 

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


### 2.1.7 Computing Context Vectors for All Tokens Using Attention Weights

In [19]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 2.2 Implementing self-attention with trainable weights

### 2.2.1 Working with an example

In [22]:
x_2 = inputs[1] # "journey" (x^2)
d_in = inputs.shape[1] # The input embedding size, d=3
d_out = 2 # The output embedding size, d_out=2

### 2.2.2 Initializing Fixed Query, Key, and Value Projection Matrices

In [None]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

### 2.2.3 Computing Query, Key, and Value Vectors for Token x²

In [24]:
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


### 2.2.4 Computing Key and Value Matrices for All Tokens

In [25]:
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)


keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


### 2.2.5 End-to-End Scaled Dot-Product Self-Attention for Query x² (Projected Q, K, V)

In [30]:
# Computing the Attention Score Between Query x² and Its Corresponding Key  (Projected Space)
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

#Computing Attention Scores Between Query x² and All Keys (Projected Space)
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

# Applying Scaled Softmax to Compute Attention Weights for Query x²
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

# Computing the Context Vector for Query x² Using Weighted Values
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor(1.8524)
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
tensor([0.3061, 0.8210])


### 2.2.6 Implementing a compact self-attention Python class

In [31]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    """
    A simple self-attention module implementing scaled dot-product attention.

    This version uses learnable projection matrices for queries, keys, and values,
    and computes full self-attention over the input sequence without masking
    or multi-head splitting.
    """
    def __init__(self, d_in, d_out):
        """
        Initialize the self-attention layer.

        Args:
            d_in (int): Dimensionality of input embeddings.
            d_out (int): Dimensionality of projected query, key, and value vectors.
        """
        super().__init__()
        # Learnable projection matrix for queries (Q)
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        # Learnable projection matrix for keys (K)
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        # Learnable projection matrix for values (V)
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        """
        Forward pass of the self-attention mechanism.

        Args:
            x (Tensor): Input tensor of shape (num_tokens, d_in),
                        where each row is a token embedding.

        Returns:
            Tensor: Context vectors of shape (num_tokens, d_out),
                    where each token is enriched with information
                    from all other tokens via attention.
        """
        # Project input embeddings into key vectors (K)
        keys = x @ self.W_key

        # Project input embeddings into query vectors (Q)
        queries = x @ self.W_query

        # Project input embeddings into value vectors (V)
        values = x @ self.W_value

        # Compute raw attention scores using dot product between queries and keys
        # Shape: (num_tokens, num_tokens)
        attn_scores = queries @ keys.T  # omega

        # Apply scaled softmax to obtain normalized attention weights
        # Scaling by sqrt(d_k) improves numerical stability
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # Compute context vectors as weighted sums of the value vectors
        context_vec = attn_weights @ values

        return context_vec

In [32]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [33]:
import torch
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    """
    A self-attention module implementing scaled dot-product attention
    using nn.Linear layers for query, key, and value projections.

    Compared to v1, this version relies on PyTorch Linear layers,
    which handle weight initialization and optional bias internally.
    """
    def __init__(self, d_in, d_out, qkv_bias=False):
        """
        Initialize the self-attention layer.

        Args:
            d_in (int): Dimensionality of input embeddings.
            d_out (int): Dimensionality of projected query, key, and value vectors.
            qkv_bias (bool): Whether to include a bias term in Q, K, V projections.
        """
        super().__init__()
        # Linear projection layer for queries (Q)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Linear projection layer for keys (K)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Linear projection layer for values (V)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        """
        Forward pass of the self-attention mechanism.

        Args:
            x (Tensor): Input tensor of shape (num_tokens, d_in),
                        where each row represents a token embedding.

        Returns:
            Tensor: Context vectors of shape (num_tokens, d_out),
                    containing attention-weighted representations
                    of the input tokens.
        """
        # Project input embeddings into key vectors (K)
        keys = self.W_key(x)

        # Project input embeddings into query vectors (Q)
        queries = self.W_query(x)

        # Project input embeddings into value vectors (V)
        values = self.W_value(x)

        # Compute raw attention scores via dot product between queries and keys
        # Shape: (num_tokens, num_tokens)
        attn_scores = queries @ keys.T

        # Apply scaled softmax to convert scores into attention weights
        # Scaling by sqrt(d_k) stabilizes gradients for larger dimensions
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # Compute context vectors as weighted sums of value vectors
        context_vec = attn_weights @ values

        return context_vec

In [35]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


## 2.3 Hiding future words with causal attention

### 2.3.1 Computing Scaled Dot-Product Attention Weights Using Linear Q and K Projections

In [None]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


### 2.3.2 Creating a Causal (Lower-Triangular) Attention Mask

In [44]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


### 2.3.3 Applying The Causal Mask to Attention Weights

In [45]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)


### 2.3.4 Renormalizing Masked Attention Weights Row-Wise

In [46]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm, "\n")

# Masking Future Tokens in Attention Scores Using an Upper-Triangular Mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>) 

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)


### 2.3.5 Applying Scaled Softmax After Causal Masking to Obtain Attention Weights

In [47]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


### 2.3.6 Masking additional attention weights with dropout


In [50]:
torch.manual_seed(123)                 # Set random seed for reproducibility
dropout = torch.nn.Dropout(0.5)         # Create a Dropout layer with 50% drop probability
example = torch.ones(6, 6)              # Example input tensor filled with ones
print(dropout(example))                 # Apply dropout (randomly zeros elements and scales the rest)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


### 2.3.7 Applying Dropout to Attention Weights

In [51]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6380, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5090, 0.5085, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.3413, 0.3308, 0.3249, 0.0000]],
       grad_fn=<MulBackward0>)


### 2.3.8 Implementing a compact causal attention class

In [52]:
# Creating a Batched Input Tensor
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 

torch.Size([2, 6, 3])


In [53]:
import torch
import torch.nn as nn

class CausalAttention(nn.Module):
    """
    A causal (masked) self-attention module implementing scaled dot-product attention.

    This module prevents tokens from attending to future tokens by applying
    an upper-triangular causal mask. It supports batched inputs and includes
    dropout on the attention weights.
    """
    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        """
        Initialize the causal self-attention layer.

        Args:
            d_in (int): Dimensionality of input embeddings.
            d_out (int): Dimensionality of projected query, key, and value vectors.
            context_length (int): Maximum sequence length for the causal mask.
            dropout (float): Dropout probability applied to attention weights.
            qkv_bias (bool): Whether to include bias terms in Q, K, V projections.
        """
        super().__init__()
        self.d_out = d_out

        # Linear projection layers for queries (Q), keys (K), and values (V)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Dropout layer applied to attention weights
        self.dropout = nn.Dropout(dropout)            #1

        # Register a causal (upper-triangular) mask as a buffer
        # This ensures it moves with the model across devices but is not trainable
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )                                             #2

    def forward(self, x):
        """
        Forward pass of the causal self-attention mechanism.

        Args:
            x (Tensor): Input tensor of shape (batch_size, num_tokens, d_in).

        Returns:
            Tensor: Context vectors of shape (batch_size, num_tokens, d_out),
                    where each token attends only to itself and past tokens.
        """
        # Extract batch size, number of tokens, and input dimensionality
        b, num_tokens, d_in = x.shape                   #3

        # Project inputs into key, query, and value tensors
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # Compute raw attention scores using batched dot products
        # Shape: (batch_size, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(1, 2)

        # Apply the causal mask to prevent attention to future tokens
        attn_scores.masked_fill_(                      #4
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        # Apply scaled softmax to obtain attention weights
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        # Apply dropout to attention weights for regularization
        attn_weights = self.dropout(attn_weights)

        # Compute context vectors as weighted sums of value vectors
        context_vec = attn_weights @ values

        return context_vec

In [55]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


## 2.4 Multi-Head Self-Attention

In [56]:
import torch
import torch.nn as nn

class MultiHeadAttentionWrapper(nn.Module):
    """
    A simple multi-head attention wrapper that runs multiple independent
    CausalAttention heads in parallel and concatenates their outputs.

    Note:
        This is a straightforward educational implementation.
        It does not include a final output projection (W_out), and each head
        is its own CausalAttention module with separate parameters.
    """
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        """
        Initialize the multi-head attention wrapper.

        Args:
            d_in (int): Dimensionality of input embeddings.
            d_out (int): Output dimensionality per head (each head returns d_out features).
            context_length (int): Maximum sequence length for the causal mask.
            dropout (float): Dropout probability applied inside each attention head.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): Whether to include bias terms in Q, K, V projections.
        """
        super().__init__()

        # Create a list of independent causal attention heads
        # Each head has its own Q/K/V projection layers and causal mask logic
        self.heads = nn.ModuleList(
            [CausalAttention(
                 d_in, d_out, context_length, dropout, qkv_bias
             )
             for _ in range(num_heads)]
        )

    def forward(self, x):
        """
        Forward pass.

        Args:
            x (Tensor): Input tensor of shape (batch_size, num_tokens, d_in).

        Returns:
            Tensor: Concatenated head outputs of shape
                    (batch_size, num_tokens, num_heads * d_out).
        """
        # Compute each head output, then concatenate along the feature dimension
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [57]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
