In [1]:
import torch
from torch import nn

In [2]:
class BoundaryPredictor(nn.Module):
    def __init__(self, d_model, d_inner, activation_function,
                 temp, prior, bp_type, threshold=0.5):
        super().__init__()
        self.temp = temp
        self.prior = prior
        self.bp_type = bp_type
        self.threshold = threshold

        if activation_function == 'relu':
            activation_fn = nn.ReLU(inplace=True)
        elif activation_function == 'gelu':
            activation_fn = torch.nn.GELU()

        self.boundary_predictor = nn.Sequential(
            nn.Linear(d_model, d_inner),
            activation_fn,
            nn.Linear(d_inner, 1),
        )

        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, hidden):
        # Hidden is of shape [seq_len x bs x d_model]
        # Boundaries we return are [bs x seq_len]
        boundary_logits = self.boundary_predictor(hidden).squeeze(-1).transpose(0, 1)
        boundary_probs = torch.sigmoid(boundary_logits)

        if self.bp_type == 'gumbel':
            bernoulli = torch.distributions.relaxed_bernoulli.RelaxedBernoulli(
                temperature=self.temp,
                probs=boundary_probs,
            )

            soft_boundaries = bernoulli.rsample()

            hard_boundaries = (soft_boundaries > self.threshold).float()
            hard_boundaries = (
                hard_boundaries - soft_boundaries.detach() + soft_boundaries
            )
        elif self.bp_type in ['entropy', 'unigram']:
            soft_boundaries = boundary_probs
            hard_boundaries = (soft_boundaries > self.threshold).float()

        return soft_boundaries, hard_boundaries

    def calc_loss(self, preds, gt):
        # B x T
        if self.bp_type in ['entropy', 'unigram']:
            assert preds is not None and gt is not None
            return self.loss(preds, gt.float())
        elif self.bp_type in ['gumbel']:
            assert gt is None
            binomial = torch.distributions.binomial.Binomial(
                preds.size(-1),
                probs=torch.Tensor([self.prior]).to(preds.device)
            )
            loss_boundaries = -binomial.log_prob(
                preds.sum(dim=-1)
            ).mean() / preds.size(-1)

            return loss_boundaries

    def calc_stats(self, preds, gt):
        # B x T
        preds, gt = preds.bool(), gt.bool()
        TP = ((preds == gt) & preds).sum().item()
        FP = ((preds != gt) & preds).sum().item()
        FN = ((preds != gt) & (~preds)).sum().item()

        acc = (preds == gt).sum().item() / gt.numel()

        if TP == 0:
            precision, recall = 0, 0
        else:
            precision = TP / (TP + FP)
            recall = TP / (TP + FN)

        stats = {
            'acc': acc,
            'precision': precision,
            'recall': recall
        }

        return stats

In [3]:
class GumbelBoundaryPredictor(nn.Module):
    def __init__(self, d_model, d_inner, activation_function,
                 temp, prior, bp_type, threshold=0.5):
        super().__init__()
        self.temp = temp
        self.prior = prior
        self.bp_type = bp_type
        self.threshold = threshold

        if activation_function == 'relu':
            activation_fn = nn.ReLU(inplace=True)
        elif activation_function == 'gelu':
            activation_fn = torch.nn.GELU()

        self.boundary_predictor = nn.Sequential(
            nn.Linear(d_model, d_inner),
            activation_fn,
            nn.Linear(d_inner, 1),
        )

        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, hidden):
        # Hidden is of shape [seq_len x bs x d_model]
        # Boundaries we return are [bs x seq_len]
        boundary_logits = self.boundary_predictor(hidden).squeeze(-1).transpose(0, 1)
        boundary_probs = torch.sigmoid(boundary_logits)

        bernoulli = torch.distributions.relaxed_bernoulli.RelaxedBernoulli(
            temperature=self.temp,
            probs=boundary_probs,
        )

        soft_boundaries = bernoulli.rsample()

        hard_boundaries = (soft_boundaries > self.threshold).float()
        hard_boundaries = (
            hard_boundaries - soft_boundaries.detach() + soft_boundaries
        )

        return soft_boundaries, hard_boundaries

    def calc_loss(self, preds, gt):
        # B x T
        # Regularization: binomial log probability of the number of boundaries
        binomial = torch.distributions.binomial.Binomial(
            preds.size(-1),
            probs=torch.Tensor([self.prior]).to(preds.device)
        )
        loss_boundaries = -binomial.log_prob(
            preds.sum(dim=-1)
        ).mean() / preds.size(-1)

        return loss_boundaries

    def calc_stats(self, preds, gt):
        # B x T
        preds, gt = preds.bool(), gt.bool()
        TP = ((preds == gt) & preds).sum().item()
        FP = ((preds != gt) & preds).sum().item()
        FN = ((preds != gt) & (~preds)).sum().item()

        acc = (preds == gt).sum().item() / gt.numel()

        if TP == 0:
            precision, recall = 0, 0
        else:
            precision = TP / (TP + FP)
            recall = TP / (TP + FN)

        stats = {
            'acc': acc,
            'precision': precision,
            'recall': recall
        }

        return stats


In [4]:
import torch


def final(foo,
          upsample):
    """
        Input:
            B x L x S
    """
    autoregressive = foo != 0
    lel = 1 - foo

    lel[autoregressive] = 0

    dim = 2 if upsample else 1

    lel = lel / (lel.sum(dim=dim, keepdim=True) + 1e-9)

    return lel


def common(boundaries, upsample=False):
    """
    
    """
    boundaries = boundaries.clone()

    n_segments = boundaries.sum(dim=-1).max().item()

    if upsample:
        n_segments += 1

    if n_segments == 0:
        return None

    tmp = torch.zeros_like(
        boundaries
    ).unsqueeze(2) + torch.arange(
        start=0,
        end=n_segments,
        device=boundaries.device
    )

    hh1 = boundaries.cumsum(1)

    if not upsample:
        hh1 -= boundaries

    foo = tmp - hh1.unsqueeze(-1)

    return foo


def downsample(boundaries, hidden, null_group):
    """
        Downsampling

        - The first element of boundaries tensor is always 0 and doesn't matter
        - 1 starts a new group
        - We append an extra "null" group at the beginning
        - We discard last group because it won't be used (in terms of upsampling)

        Input:
            boundaries: B x L
            hidden: L x B x D
        Output:
            shortened_hidden: S x B x D
    """

    foo = common(boundaries, upsample=False)  # B x L x S

    if foo is None:
        return null_group.repeat(1, hidden.size(1), 1)
    else:
        bar = final(foo=foo, upsample=False)  # B x L x S

        shortened_hidden = torch.einsum('lbd,bls->sbd', hidden, bar)
        shortened_hidden = torch.cat(
            [null_group.repeat(1, hidden.size(1), 1), shortened_hidden], dim=0
        )

        return shortened_hidden


def upsample(boundaries, shortened_hidden):
    """
        Upsampling

        - The first element of boundaries tensor is always 0 and doesn't matter
        - 1 starts a new group
        - i-th group can be upsampled only to the tokens from (i+1)-th group, otherwise there's a leak

        Input:
            boundaries: B x L
            shortened_hidden: S x B x D
        Output:
            upsampled_hidden: L x B x D
    """

    foo = common(boundaries, upsample=True)  # B x L x S
    bar = final(foo, upsample=True)  # B x L x S

    return torch.einsum('sbd,bls->lbd', shortened_hidden, bar)


In [132]:
d_model = 3
d_inner = 3
my_bp = GumbelBoundaryPredictor(d_model=d_model, d_inner=d_inner, activation_function='gelu', temp=1.0, prior=0.5, bp_type='gumbel', threshold=0.5)
my_bp

GumbelBoundaryPredictor(
  (boundary_predictor): Sequential(
    (0): Linear(in_features=3, out_features=3, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=3, out_features=1, bias=True)
  )
  (loss): BCEWithLogitsLoss()
)

In [166]:
# my_hidden_states = torch.randn(10, 1, d_model)
# soft_boundaries, hard_boundaries = my_bp(my_hidden_states)
# hard_boundaries.shape, soft_boundaries.shape

(torch.Size([1, 10]), torch.Size([1, 10]))

In [167]:
# hard_boundaries.retain_grad()
# hard_boundaries


tensor([[1., 1., 1., 0., 0., 1., 1., 1., 0., 1.]], grad_fn=<AddBackward0>)

In [281]:
my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
my_hard_boundaries.retain_grad()
my_hard_boundaries

tensor([[0., 0., 1., 1., 0., 1., 0., 0., 0., 1.]], requires_grad=True)

In [282]:
foo = common(my_hard_boundaries, upsample=False)
bar = final(foo, upsample=False)
bar.shape, bar

(torch.Size([1, 10, 4]),
 tensor([[[0.3333, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.0000, 0.0000, 0.0000],
          [0.3333, 0.0000, 0.0000, 0.0000],
          [0.0000, 1.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5000, 0.0000],
          [0.0000, 0.0000, 0.5000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500],
          [0.0000, 0.0000, 0.0000, 0.2500]]], grad_fn=<DivBackward0>))

In [283]:
l = bar[0, 6, 3]
l

tensor(0.2500, grad_fn=<SelectBackward0>)

In [284]:
l.backward()
my_hard_boundaries, my_hard_boundaries.grad

(tensor([[0., 0., 1., 1., 0., 1., 0., 0., 0., 1.]], requires_grad=True),
 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1875, -0.1250,
          -0.0625,  0.0000]]))

In [286]:
grad_values = []

for i in range(10):
    my_hard_boundaries = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 0, 0, 1]], dtype=torch.float32, requires_grad=True)
    my_hard_boundaries.retain_grad()
    foo = common(my_hard_boundaries, upsample=False)
    bar = final(foo, upsample=False)
    l = bar[0, i, :].sum() # We can safely sum as the 0. values in bar are disconnected from the computation graph.count
    l.backward()
    grad_values.append(my_hard_boundaries.grad)
    my_hard_boundaries.grad = None

In [290]:
jacobian = torch.cat(grad_values, dim=0)
jacobian

tensor([[-0.2222, -0.1111,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.1111, -0.1111,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.1111,  0.2222,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.2500,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2500,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1875, -0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0625, -0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0625,  0.1250,
         -0.0625,  0.0000],
        [ 0.0000,  

In [293]:
jacobian[:3, :3], jacobian[3:4, 3:4], jacobian[4:6, 4:6], jacobian[6:10, 6:10]

(tensor([[-0.2222, -0.1111,  0.0000],
         [ 0.1111, -0.1111,  0.0000],
         [ 0.1111,  0.2222,  0.0000]]),
 tensor([[0.]]),
 tensor([[-0.2500,  0.0000],
         [ 0.2500,  0.0000]]),
 tensor([[-0.1875, -0.1250, -0.0625,  0.0000],
         [ 0.0625, -0.1250, -0.0625,  0.0000],
         [ 0.0625,  0.1250, -0.0625,  0.0000],
         [ 0.0625,  0.1250,  0.1875,  0.0000]]))

In [223]:
for n, param in my_bp.named_parameters():
    print(f"{n}: {param.grad=}")

boundary_predictor.0.weight: param.grad=tensor([[ 0.0017,  0.0109, -0.0136],
        [-0.0021, -0.0129,  0.0162],
        [ 0.0012,  0.0077, -0.0096]])
boundary_predictor.0.bias: param.grad=tensor([-0.0136,  0.0161, -0.0096])
boundary_predictor.2.weight: param.grad=tensor([[-0.0021,  0.0033,  0.0351]])
boundary_predictor.2.bias: param.grad=tensor([0.0616])


In [224]:
n_segments = my_hard_boundaries.sum(dim=-1).max().item()

n_segments

2.0

In [225]:
tmp = torch.zeros_like(
    my_hard_boundaries
).unsqueeze(2) + torch.arange(
    start=0,
    end=n_segments,
    device=my_hard_boundaries.device
)
tmp

tensor([[[0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.]]])

In [226]:

n_preceding_boundaries = my_hard_boundaries.cumsum(1)
n_preceding_boundaries -= my_hard_boundaries

n_preceding_boundaries

tensor([[0., 0., 1., 2.]], grad_fn=<SubBackward0>)

In [227]:
foo = tmp - n_preceding_boundaries.unsqueeze(-1)
foo

tensor([[[ 0.,  1.],
         [ 0.,  1.],
         [-1.,  0.],
         [-2., -1.]]], grad_fn=<SubBackward0>)

In [228]:
autoregressive = foo != 0
autoregressive

tensor([[[False,  True],
         [False,  True],
         [ True, False],
         [ True,  True]]])

In [229]:
lel = 1 - foo
lel[autoregressive] = 0
lel

tensor([[[1., 0.],
         [1., 0.],
         [0., 1.],
         [0., 0.]]], grad_fn=<IndexPutBackward0>)

In [235]:
lel = lel / (lel.sum(dim=1, keepdim=True) + 1e-9)
lel

tensor([[[0.5000, 0.0000],
         [0.5000, 0.0000],
         [0.0000, 1.0000],
         [0.0000, 0.0000]]], grad_fn=<DivBackward0>)