# Patch pooling (10pts)
Your task is to implement patch pooling. Patch pooling takes an input sequence $(a_0, a_1, \ldots, a_{B-1})$ of $D$-dimensional embeddings and output an output sequence $(b_0, \ldots, b_{k-1})$ of $D$-dimensional embeddings. The length of the output sequence is not longer than the length of the input sequence and is bounded by $P$. Each element of the input sequence is called a token. Each element of the output sequence is called a patch. Consecutive patches are constructed as a mean pooling of consecutive contiguous token spans.

You are given two tensors:
1. `batch` - a $3$-dimensional tensor, which is an input to a standard transformer model with the following dimensions:
* B - batch size
* S - sequence lenght
* D - dimension of embedding of a single token

`batch[x,y,:]` is the embedding of the $y+1$-th token of the $x+1$-th sequence in the `batch`.

2. `patch_lengths` - $2$-dimensional integer-valued tensor with the following dimensions:
* B - batch size
* P - maximal number of patches

`patch_lengths[x,y]` is the number of tokens forming patch number $y+1$ in the $x+1$-th sequence in the `batch`.

The output should be a $3$-dimensional tensor with batch of sequences of patch embeddings.

# Example
The following snippet
```python
batch = torch.tensor([[[ 1.,  1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 2.,  2.,  2.,  2.,  2.],
         [ 3.,  3.,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 5.,  5.,  5.,  5.,  5.],
         [-1., -1., -1., -1., -1.]],

        [[ 6.,  6.,  6.,  6.,  6.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.]]])
patch_lengths = torch.tensor([[3, 1, 2],
        [4, 1, 0],
        [1, 0, 0]])
patch_pooling = PatchPooling()
output = patch_pooling(batch, patch_lengths)
output
```

should ouptut

```python
torch.tensor([[[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]],

        [[4., 4., 4., 4., 4.],
         [5., 5., 5., 5., 5.],
         [-1., -1., -1., -1., -1.]],

        [[6., 6., 6., 6., 6.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.]]])
```

Remarks:

1. In this problem you can assume that embeddings of the padding token are vectors with all coordinates equal to $-1$.

2. Solutions will be graded with unit tests. You are given a single test case, which will be a part of evaluation.

3. Solutions not satisfying the below requirements will be graded up to 4pts:
* You are not allowed to call custom python functions
* You are not allowed to use Python loops
* Your are not allowed to use any other imports

In [1]:
%%file test_patch_pooling.py

import pytest
import torch

class PatchPooling(torch.nn.Module):
    def forward(self, batch: torch.Tensor, patch_lengths: torch.Tensor) -> torch.Tensor:
        B, S, D = batch.shape
        B_1, P = patch_lengths.shape

        assert B == B_1

        ### Your code goes here ###

        # Create output tensor initialized with padding
        output = torch.full((B, P, D), -1.0, dtype=batch.dtype, device=batch.device)

        # Create cumulative sum of patch_lengths to get start indices for each patch
        cumsum = torch.cumsum(patch_lengths, dim=1)
        start_indices = torch.cat([torch.zeros(B, 1, dtype=torch.long, device=batch.device), cumsum[:, :-1]], dim=1)

        # Create a mask for valid patches (patch_length > 0)
        valid_mask = patch_lengths > 0

        # Get maximum patch length to create offset indices
        max_patch_len = patch_lengths.max().item()

        # Create offset indices [0, 1, 2, ..., max_patch_len-1]
        offset = torch.arange(max_patch_len, device=batch.device).unsqueeze(0).unsqueeze(0)  # (1, 1, max_patch_len)

        # Expand to (B, P, max_patch_len)
        start_expanded = start_indices.unsqueeze(2)  # (B, P, 1)
        token_indices = start_expanded + offset  # (B, P, max_patch_len)

        # Clamp indices to valid range
        token_indices = token_indices.clamp(0, S - 1)

        # Create mask for valid tokens within each patch
        patch_lengths_expanded = patch_lengths.unsqueeze(2)  # (B, P, 1)
        token_mask = offset < patch_lengths_expanded  # (B, P, max_patch_len)

        # Expand batch indices
        batch_indices = torch.arange(B, device=batch.device).view(B, 1, 1).expand(B, P, max_patch_len)

        # Gather tokens: batch[batch_indices, token_indices, :]
        gathered = batch[batch_indices, token_indices, :]  # (B, P, max_patch_len, D)

        # Apply mask and compute mean
        token_mask_expanded = token_mask.unsqueeze(3)  # (B, P, max_patch_len, 1)
        masked_gathered = gathered * token_mask_expanded  # Zero out invalid tokens

        # Sum and divide by patch_length
        patch_sums = masked_gathered.sum(dim=2)  # (B, P, D)
        patch_lengths_for_div = patch_lengths.unsqueeze(2).clamp(min=1).float()  # (B, P, 1), avoid div by 0
        patch_means = patch_sums / patch_lengths_for_div  # (B, P, D)

        # Only update valid patches
        valid_mask_expanded = valid_mask.unsqueeze(2)  # (B, P, 1)
        output = torch.where(valid_mask_expanded, patch_means, output)

        ###########################

        return output




class TestPatchPooling:
    @pytest.mark.parametrize(
        "batch,patch_lengths,expected_output",
        [
            (
                torch.tensor(
                    [
                        [
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [2.0, 2.0, 2.0, 2.0, 2.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                        ],
                        [
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [5.0, 5.0, 5.0, 5.0, 5.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                        [
                            [6.0, 6.0, 6.0, 6.0, 6.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                    ]
                ),
                torch.tensor([[3, 1, 2], [4, 1, 0], [1, 0, 0]]),
                torch.tensor(
                    [
                        [
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [2.0, 2.0, 2.0, 2.0, 2.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                        ],
                        [
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [5.0, 5.0, 5.0, 5.0, 5.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                        [
                            [6.0, 6.0, 6.0, 6.0, 6.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                    ]
                ),
            )
        ],
    )
    def test_forward(
        self,
        batch: torch.Tensor,
        patch_lengths: torch.Tensor,
        expected_output: torch.Tensor,
    ) -> None:
        # given
        patch_pooling = PatchPooling()

        # when
        output = patch_pooling(batch=batch, patch_lengths=patch_lengths)

        # then
        assert torch.all(torch.isclose(output, expected_output))

Writing test_patch_pooling.py


In [2]:
!python -m pytest test_patch_pooling.py

platform linux -- Python 3.12.12, pytest-8.4.2, pluggy-1.6.0
rootdir: /content
plugins: langsmith-0.6.8, typeguard-4.4.4, anyio-4.12.1
collected 1 item                                                               [0m

test_patch_pooling.py [32m.[0m[32m                                                  [100%][0m

