In [7]:
import torch
from torch import nn

In [8]:
import nawrot_downsampler

Their code relies on a specific construction of the average downsampling matrix from the gate values that allows for backpropagation through this construction, so we can't use a standard average downsampler.

In [33]:
class NawrotDownsampler(nn.Module):
    def __init__(self, embedding_dim: int, downsample_rate: float):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.downsample_rate = downsample_rate
        self.boundary_predictor = nawrot_downsampler.BoundaryPredictor(embedding_dim, embedding_dim, "relu", 
                                                                       temp=1.0, prior=self.downsample_rate, bp_type="gumbel")
        
        self.null_group = nn.Parameter(torch.Tensor(1, 1, embedding_dim).zero_())
        nn.init.normal_(self.null_group)

    def compute_boundaries(self, x: torch.Tensor) -> torch.Tensor:
        """Computes the boundaries for the input tensor x using the Nawrot et al. 2023 method."""
        # x is of shape [bs, seq_len, emb_dim], but nawrot_downsampler expects [seq_len, bs, emb_dim]
        x = x.transpose(0, 1)

        # Get the boundary predictions
        _, hard_boundaries = self.boundary_predictor(x)
        
        return hard_boundaries

    def downsample(self, x: torch.Tensor, hard_boundaries: torch.Tensor) -> torch.Tensor:
        """Downsamples the input tensor x using the Nawrot et al. 2023 method."""
        # x is of shape [bs, seq_len, emb_dim], but nawrot_downsampler expects [seq_len, bs, emb_dim]
        x = x.transpose(0, 1)

        # Downsample the input
        x = nawrot_downsampler.downsample(
            hard_boundaries, 
            x, 
            self.null_group
        )

        # Return to the original shape
        x = x.transpose(0, 1)
        return x
    
    
    def downsample_position_ids(self, position_ids: torch.Tensor, hard_boundaries: torch.Tensor) -> torch.Tensor:
        """Downsamples the position ids using the Nawrot et al. 2023 method."""
        # position_ids is of shape [bs, seq_len], but nawrot_downsampler expects [seq_len, bs, d]
        position_ids = position_ids.transpose(0, 1)
        position_ids = position_ids.unsqueeze(-1)

        position_ids = nawrot_downsampler.downsample(
            hard_boundaries, 
            position_ids, 
            torch.Tensor(1, 1, 1).zero_().to(position_ids.device)
        )

        position_ids = position_ids.squeeze(-1)
        position_ids = position_ids.transpose(0, 1)
        return position_ids
    
    def upsample(self, x: torch.Tensor, hard_boundaries: torch.Tensor) -> torch.Tensor:
        """Upsamples the input tensor x using the Nawrot et al. 2023 method."""
        # x is of shape [bs, seq_len, emb_dim], but nawrot_downsampler expects [seq_len, bs, emb_dim]
        x = x.transpose(0, 1)

        # Upsample the input
        x = nawrot_downsampler.upsample(
            hard_boundaries, 
            x
        )

        # Return to the original shape
        x = x.transpose(0, 1)

        return x
    
    def consistency_loss(self, hard_boundaries: torch.Tensor) -> torch.Tensor:
        return self.boundary_predictor.calc_loss(
            preds=hard_boundaries, gt=None
        )

In [34]:
d_model = 512

In [35]:
x = torch.randn(2, 10, 512)
downsampler = NawrotDownsampler(512, 0.25)

hard_boundaries = downsampler.compute_boundaries(x)

x_downsampled = downsampler.downsample(x, hard_boundaries)

x_upsampled = downsampler.upsample(x_downsampled, hard_boundaries)

x_downsampled.shape, x_upsampled.shape


(torch.Size([2, 7, 512]), torch.Size([2, 10, 512]))

In [36]:
downsampler

NawrotDownsampler(
  (boundary_predictor): BoundaryPredictor(
    (boundary_predictor): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=512, out_features=1, bias=True)
    )
    (loss): BCEWithLogitsLoss()
  )
)

In [37]:
position_ids = torch.arange(10, dtype=x.dtype).unsqueeze(0).expand(2, -1).unsqueeze(-1).to(x.device)
position_ids.shape

torch.Size([2, 10, 1])

In [41]:
downsampler.downsample_position_ids(position_ids, hard_boundaries)

tensor([[[0.0000],
         [0.0000],
         [1.5000],
         [3.5000],
         [5.0000],
         [6.5000],
         [8.0000]],

        [[0.0000],
         [0.5000],
         [2.5000],
         [4.0000],
         [5.0000],
         [6.5000],
         [8.5000]]], grad_fn=<TransposeBackward0>)

In [40]:
hard_boundaries

tensor([[1., 0., 1., 0., 1., 1., 0., 1., 1., 0.],
        [0., 1., 0., 1., 1., 1., 0., 1., 0., 1.]], grad_fn=<AddBackward0>)

In [42]:
downsampler.consistency_loss(hard_boundaries)

tensor(0.4121, grad_fn=<DivBackward0>)