Code taken from `shortening.py` at https://github.com/PiotrNawrot/dynamic-pooling:

In [1]:
import torch


def final(foo,
          upsample):
    """
        Input:
            B x L x S
    """
    autoregressive = foo != 0
    lel = 1 - foo

    lel[autoregressive] = 0

    dim = 2 if upsample else 1

    lel = lel / (lel.sum(dim=dim, keepdim=True) + 1e-9)

    return lel


def common(boundaries, upsample=False):
    """
    
    """
    boundaries = boundaries.clone()

    n_segments = boundaries.sum(dim=-1).max().item()

    if upsample:
        n_segments += 1

    if n_segments == 0:
        return None

    tmp = torch.zeros_like(
        boundaries
    ).unsqueeze(2) + torch.arange(
        start=0,
        end=n_segments,
        device=boundaries.device
    )

    hh1 = boundaries.cumsum(1)

    if not upsample:
        hh1 -= boundaries

    foo = tmp - hh1.unsqueeze(-1)

    return foo


def downsample(boundaries, hidden, null_group):
    """
        Downsampling

        - The first element of boundaries tensor is always 0 and doesn't matter
        - 1 starts a new group
        - We append an extra "null" group at the beginning
        - We discard last group because it won't be used (in terms of upsampling)

        Input:
            boundaries: B x L
            hidden: L x B x D
        Output:
            shortened_hidden: S x B x D
    """

    foo = common(boundaries, upsample=False)  # B x L x S

    if foo is None:
        return null_group.repeat(1, hidden.size(1), 1)
    else:
        bar = final(foo=foo, upsample=False)  # B x L x S

        shortened_hidden = torch.einsum('lbd,bls->sbd', hidden, bar)
        shortened_hidden = torch.cat(
            [null_group.repeat(1, hidden.size(1), 1), shortened_hidden], dim=0
        )

        return shortened_hidden


def upsample(boundaries, shortened_hidden):
    """
        Upsampling

        - The first element of boundaries tensor is always 0 and doesn't matter
        - 1 starts a new group
        - i-th group can be upsampled only to the tokens from (i+1)-th group, otherwise there's a leak

        Input:
            boundaries: B x L
            shortened_hidden: S x B x D
        Output:
            upsampled_hidden: L x B x D
    """

    foo = common(boundaries, upsample=True)  # B x L x S
    bar = final(foo, upsample=True)  # B x L x S

    return torch.einsum('sbd,bls->lbd', shortened_hidden, bar)


The line `shortened_hidden = torch.einsum('lbd,bls->sbd', hidden, bar)` is where the predicted boundaries from the MLP and the hidden states re-join on the computation graph. In the forward pass, the multiplication by `bar` acts an average pooling operation, with the `boundaries` variable being used to construct the matrix `bar`. In order to train the MLP, gradients need to flow backward through the construction of the matrix `bar`, here I analyse the derivative of one value of bar with respect to the boundaries is implicitly defined by this construction:

In [2]:
my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
my_hard_boundaries.retain_grad()
my_hard_boundaries

tensor([[0., 0., 1., 1., 0., 1., 0., 0., 0., 1.]], requires_grad=True)

In [3]:
foo = common(my_hard_boundaries, upsample=False)
bar = final(foo, upsample=False)
bar.shape, bar

(torch.Size([1, 10, 4]),
 tensor([[[0.3333, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.0000, 0.0000, 0.0000],
          [0.0000, 1.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5000, 0.0000],
          [0.0000, 0.0000, 0.5000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500]]], grad_fn=<DivBackward0>))

In [4]:
l = bar[0, 6, 3]
l

tensor(0.2500, grad_fn=<SelectBackward0>)

In [5]:
l.backward()
my_hard_boundaries, my_hard_boundaries.grad

(tensor([[0., 0., 1., 1., 0., 1., 0., 0., 0., 1.]], requires_grad=True),
 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1875, -0.1250,
          -0.0625,  0.0000]]))

Wrapping this in a for loop, we can compute the whole jacobian of the non-zero value in each row of bar with respect to the boundaries:

In [6]:
grad_values = []

for i in range(10):
    my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
    my_hard_boundaries.retain_grad()
    foo = common(my_hard_boundaries, upsample=False)
    bar = final(foo, upsample=False)
    l = bar[0, i, :].sum() # We can safely sum as the 0. values in bar are disconnected from the computation graph.
    l.backward()
    grad_values.append(my_hard_boundaries.grad)
    my_hard_boundaries.grad = None

In [7]:
jacobian = torch.cat(grad_values, dim=0)
jacobian

tensor([[-0.2222, -0.1111,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.1111, -0.1111,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.1111,  0.2222,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.2500,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2500,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1875, -0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0625, -0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0625,  0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  

In [8]:
jacobian[:3, :3], jacobian[3:4, 3:4], jacobian[4:6, 4:6], jacobian[6:10, 6:10]

(tensor([[-0.2222, -0.1111,  0.0000],
         [ 0.1111, -0.1111,  0.0000],
         [ 0.1111,  0.2222,  0.0000]]),
 tensor([[0.]]),
 tensor([[-0.2500,  0.0000],
         [ 0.2500,  0.0000]]),
 tensor([[-0.1875, -0.1250, -0.0625,  0.0000],
         [ 0.0625, -0.1250, -0.0625,  0.0000],
         [ 0.0625,  0.1250, -0.0625,  0.0000],
         [ 0.0625,  0.1250,  0.1875,  0.0000]]))

In [9]:
my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
my_hard_boundaries.retain_grad()
my_hard_boundaries

tensor([[0., 0., 1., 1., 0., 1., 0., 0., 0., 1.]], requires_grad=True)

In [10]:
# Test that the effect of multiplying foo by -2. does not effect the forward pass:

foo_modified = common(my_hard_boundaries, upsample=False)
foo_modified = -2. * foo_modified
bar_modified = final(foo_modified, upsample=False)

foo_origonal = common(my_hard_boundaries, upsample=False)
bar_origonal = final(foo_origonal, upsample=False)

bar_modified


tensor([[[0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.5000, 0.0000],
         [0.0000, 0.0000, 0.5000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.2500],
         [0.0000, 0.0000, 0.0000, 0.2500],
         [0.0000, 0.0000, 0.0000, 0.2500],
         [0.0000, 0.0000, 0.0000, 0.2500]]], grad_fn=<DivBackward0>)

In [11]:
torch.allclose(bar_modified, bar_origonal)

True

In [12]:
grad_values_modified = []

for i in range(10):
    my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
    my_hard_boundaries.retain_grad()
    foo_modified = common(my_hard_boundaries, upsample=False)
    foo_modified = -2. * foo_modified
    bar_modified = final(foo_modified, upsample=False)
    l = bar_modified[0, i, :].sum() # We can safely sum as the 0. values in bar are disconnected from the computation graph.
    l.backward()
    grad_values_modified.append(my_hard_boundaries.grad)
    my_hard_boundaries.grad = None

In [13]:
jacobian_modified = torch.cat(grad_values_modified, dim=0)
jacobian_modified

tensor([[ 0.4444,  0.2222,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.2222,  0.2222,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.2222, -0.4444,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.5000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.3750,  0.2500,
          0.1250,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1250,  0.2500,
          0.1250,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1250, -0.2500,
          0.1250,  0.0000],
        [ 0.0000,  

In [14]:
jacobian_modified[:3, :3], jacobian_modified[3:4, 3:4], jacobian_modified[4:6, 4:6], jacobian_modified[6:10, 6:10]

(tensor([[ 0.4444,  0.2222,  0.0000],
         [-0.2222,  0.2222,  0.0000],
         [-0.2222, -0.4444,  0.0000]]),
 tensor([[0.]]),
 tensor([[ 0.5000,  0.0000],
         [-0.5000,  0.0000]]),
 tensor([[ 0.3750,  0.2500,  0.1250,  0.0000],
         [-0.1250,  0.2500,  0.1250,  0.0000],
         [-0.1250, -0.2500,  0.1250,  0.0000],
         [-0.1250, -0.2500, -0.3750,  0.0000]]))