In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

import numpy as np
from itertools import groupby

torch.__version__

'2.2.1'

In [3]:
from torchcondirf.crf_head import CrfHead
from torchcondirf import util

In [4]:
def groupby_slices(labels, key=None):
    group_lengths = []
    values = []
    for value, group in groupby(labels, key):
        group_lengths.append(sum(1 for _ in group))
        values.append(value)
    ends = np.cumsum(group_lengths)
    starts = np.r_[0, np.roll(ends, 1)[1:]]

    return [*zip(starts, ends, values)]

In [5]:
def myget(list_, index, other=None): 
    """Like dict.get but works for lists and tuples."""
    try:
        return list_[index]
    except (KeyError, IndexError):
        return other

In [6]:
myget([1, 2], 3, -1), myget({"a": 1}, "b")

(-1, None)

In [7]:
from collections.abc import Iterable
from copy import deepcopy
from typing import List
import pandas as pd

## Getting an actual controlled test

First, some tooling for masking and computing scores:

In [8]:
def convert_constraints(batch_tag_observations):
    """
    Converts partial observations into the format for tag constraints, 
    namely (start, end, allowed_tag)
    """
    constraints = []
    for i, tag_observations in enumerate(batch_tag_observations):
        grouped_tags = groupby_slices(tag_observations)
        constraints.append(
            [(start, end, [tag]) for start, end, tag in grouped_tags if tag != -1]
        )
    return constraints

In [9]:
def get_mask_for_scores_by_hand(
    scores_by_hand: pd.DataFrame, constraints: List[List[tuple]]
):
    if constraints is None:
        return pd.Series(True, index=scores_by_hand.index)

    def row_satisfies_constraint(row):
        example_index = row["example_index"]
        constraints_for_element = constraints[example_index]
        tag_sequence = row["tag_sequence"]
        return all(
            tag_sequence[i] in tags
            for start, stop, tags in constraints_for_element
            for i in range(start, stop)
        )

    return scores_by_hand.apply(row_satisfies_constraint, axis=1)

In [10]:
def compute_log_partition_by_hand(scores_by_hand, constraints=None):
    mask_for_constraints = get_mask_for_scores_by_hand(scores_by_hand, constraints)
    scores_for_constrained_sequences = scores_by_hand.loc[mask_for_constraints]

    def log_partition_for_group(group):
        return torch.logsumexp(
            torch.tensor(group["score"].values, dtype=torch.float), dim=0
        )

    return scores_for_constrained_sequences.groupby("example_index").apply(
        log_partition_for_group
    )

In [11]:
def compute_marginal_by_hand(
    scores_by_hand,
    log_partition_by_hand,
    position,
    tag,
    example_index,
    constraints=None,
):
    mask_for_constraints = get_mask_for_scores_by_hand(scores_by_hand, constraints)
    scores_for_constrained_sequences = scores_by_hand.loc[mask_for_constraints]
    scores_for_example = scores_for_constrained_sequences.query(
        "example_index == @example_index"
    )
    mask_for_sequences_with_correct_tag = np.array(
        [seq[position] == tag for seq in scores_for_example["tag_sequence"]]
    )
    scores_for_sequences_with_correct_tag = torch.tensor(
        scores_for_example.loc[mask_for_sequences_with_correct_tag]["score"].values,
        dtype=torch.float,
    )
    return (
        torch.logsumexp(scores_for_sequences_with_correct_tag, dim=0)
        - log_partition_by_hand[example_index]
    )

### Set up our specific example

It will have 2 actual tags (1, 2 with 0 for pad). For now, we'll only have one sequence in the batch with length 4.

In [12]:
num_tags = 3

In [13]:
# The tags we have observed for the current batch. -1 means we haven't observed anything
batch_tag_observations = [[-1, -1, -1, -1], [-1, -1], [-1]]
lengths = torch.LongTensor([len(obs) for obs in batch_tag_observations])
max_sequence_length = lengths.max()
batch_size = len(batch_tag_observations)

In [14]:
length_mask = util.get_mask_from_sequence_lengths(lengths, max_sequence_length)
length_mask

tensor([[ True,  True,  True,  True],
        [ True,  True, False, False],
        [ True, False, False, False]])

In [15]:
mask = util.get_mask_for_tags(
    length_mask, num_tags, convert_constraints(batch_tag_observations)
)
mask

tensor([[[ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]],

        [[ True,  True,  True],
         [ True,  True,  True],
         [False, False, False],
         [False, False, False]],

        [[ True,  True,  True],
         [False, False, False],
         [False, False, False],
         [False, False, False]]])

#### Specify some particular log emissions:

In [16]:
# emissions are batch_size x max_sequence_length x num_tags
log_emissions = torch.Tensor(
    [
        [[1, 2, 3], [5, 6, 7], [12, 11, 10], [-1, -2, 0]],
        [
            [1, 4, 9],
            [5, 1, 5],
            [12, 121, 140],  # don't care, length is 2
            [-1, -200, 0],  # don't care, lenght is 2
        ],
        [
            [1, 15, 23],
            [5000, 3211, 2123],  # don't care, lenght is 1
            [12, 121, 140],  # don't care, length is 1
            [-1, -200, 0],  # don't care, lenght is 1
        ],
    ]
)
log_emissions

tensor([[[ 1.0000e+00,  2.0000e+00,  3.0000e+00],
         [ 5.0000e+00,  6.0000e+00,  7.0000e+00],
         [ 1.2000e+01,  1.1000e+01,  1.0000e+01],
         [-1.0000e+00, -2.0000e+00,  0.0000e+00]],

        [[ 1.0000e+00,  4.0000e+00,  9.0000e+00],
         [ 5.0000e+00,  1.0000e+00,  5.0000e+00],
         [ 1.2000e+01,  1.2100e+02,  1.4000e+02],
         [-1.0000e+00, -2.0000e+02,  0.0000e+00]],

        [[ 1.0000e+00,  1.5000e+01,  2.3000e+01],
         [ 5.0000e+03,  3.2110e+03,  2.1230e+03],
         [ 1.2000e+01,  1.2100e+02,  1.4000e+02],
         [-1.0000e+00, -2.0000e+02,  0.0000e+00]]])

In [17]:
crf_head = CrfHead(num_tags=num_tags, include_start_end_transitions=False)

#### Specify some particular log-transitions:

In [18]:
crf_head.log_transitions.requires_grad_(
    False
)  # disable grad to set desired transitions
crf_head.log_transitions.data = torch.tensor(
    [[-1e4, -1e4, -1e4], [-1e4, 5, 2], [-1e4, 4, 3]]
)
crf_head.log_transitions.requires_grad_()  # and now reenable it
crf_head.log_transitions

Parameter containing:
tensor([[-1.0000e+04, -1.0000e+04, -1.0000e+04],
        [-1.0000e+04,  5.0000e+00,  2.0000e+00],
        [-1.0000e+04,  4.0000e+00,  3.0000e+00]], requires_grad=True)

In [19]:
crf_head.log_emissions_scaling.data = torch.tensor(11.3)
log_emissions = log_emissions * torch.exp(-crf_head.log_emissions_scaling)

#### With these emissions and transitions, we can compute the scores for each sequence by hand:

In [20]:
import pandas as pd

expected_scores = pd.Series(
    {
        # first example of length 4
        "1111": 32,
        "1112": 33,
        "1121": 27,
        "1122": 30,
        "1211": 29,
        "1212": 30,
        "1221": 26,
        "1222": 29,
        "2111": 30,
        "2112": 31,
        "2121": 25,
        "2122": 28,
        "2211": 29,
        "2212": 30,
        "2221": 26,
        "2222": 29,
        # second example of length 2
        "11": 10,
        "12": 13,
        "21": 12,
        "22": 17,
        # third example of length 1
        "1": 15,
        "2": 23,
    }
).sort_values(ascending=False)

expected_scores.name = "score"
expected_scores = (
    expected_scores.reset_index()
    .rename(columns={"index": "tag_sequence"})
    .apply(
        lambda row: pd.Series(
            {"tag_sequence": [int(tag) for tag in row.tag_sequence], "score": row.score}
        ),
        axis=1,
    )
)
expected_scores["length"] = expected_scores.apply(
    lambda row: len(row["tag_sequence"]), axis=1
)
expected_scores = expected_scores.sort_values(
    ["length", "score"], ascending=[False, False], ignore_index=True
)

expected_scores["example_index"] = sum(
    [[i] * ((num_tags - 1) ** lengths[i]) for i in range(batch_size)], []
)
expected_scores

Unnamed: 0,tag_sequence,score,length,example_index
0,"[1, 1, 1, 2]",33,4,0
1,"[1, 1, 1, 1]",32,4,0
2,"[2, 1, 1, 2]",31,4,0
3,"[1, 1, 2, 2]",30,4,0
4,"[1, 2, 1, 2]",30,4,0
5,"[2, 2, 1, 2]",30,4,0
6,"[2, 1, 1, 1]",30,4,0
7,"[2, 2, 2, 2]",29,4,0
8,"[1, 2, 1, 1]",29,4,0
9,"[1, 2, 2, 2]",29,4,0


In [21]:
expected_scores.loc[
    np.array(myget(seq, 2, -1) == 2 for seq in expected_scores["tag_sequence"])
].sort_values("score", ascending=False)

Unnamed: 0,tag_sequence,score,length,example_index
3,"[1, 1, 2, 2]",30,4,0
7,"[2, 2, 2, 2]",29,4,0
9,"[1, 2, 2, 2]",29,4,0
11,"[2, 1, 2, 2]",28,4,0
12,"[1, 1, 2, 1]",27,4,0
13,"[1, 2, 2, 1]",26,4,0
14,"[2, 2, 2, 1]",26,4,0
15,"[2, 1, 2, 1]",25,4,0


## Begin testing

In [22]:
torch.set_grad_enabled(False)
torch.manual_seed(1)

<torch._C.Generator at 0x142e49ab0>

### Test that scores are computed correctly:

In [23]:
for _, row in expected_scores.iterrows():
    predicted_score = crf_head.forward(
        log_emissions[row.example_index : row.example_index + 1, : row.length],
        lengths=lengths[row.example_index : row.example_index + 1],
        tags=torch.tensor([[int(i) for i in row.tag_sequence]]),
    )["logits"].item()
    assert row.score == predicted_score

In [83]:
all_emissions = torch.stack(
    [log_emissions[0]] * 16 + [log_emissions[1]] * 4 + [log_emissions[2]] * 2
)

In [84]:
tags1 = torch.stack(
    [torch.tensor(t) for t in expected_scores["tag_sequence"][:16].values]
)

In [85]:
tags2 = torch.stack(
    [torch.tensor(t + [0, 0]) for t in expected_scores["tag_sequence"][16:20].values]
)

In [86]:
tags3 = torch.stack(
    [torch.tensor(t + [0, 0, 0]) for t in expected_scores["tag_sequence"][20:].values]
)

In [87]:
all_tags = torch.vstack([tags1, tags2, tags3])
all_tags

tensor([[1, 1, 1, 2],
        [1, 1, 1, 1],
        [2, 1, 1, 2],
        [1, 1, 2, 2],
        [1, 2, 1, 2],
        [2, 2, 1, 2],
        [2, 1, 1, 1],
        [2, 2, 2, 2],
        [1, 2, 1, 1],
        [1, 2, 2, 2],
        [2, 2, 1, 1],
        [2, 1, 2, 2],
        [1, 1, 2, 1],
        [1, 2, 2, 1],
        [2, 2, 2, 1],
        [2, 1, 2, 1],
        [2, 2, 0, 0],
        [1, 2, 0, 0],
        [2, 1, 0, 0],
        [1, 1, 0, 0],
        [2, 0, 0, 0],
        [1, 0, 0, 0]])

In [88]:
predicted_score = crf_head.forward(
    all_emissions, lengths=torch.LongTensor([4] * 16 + [2] * 4 + [1] * 2), tags=all_tags
)["logits"]

In [89]:
predicted_score

tensor([33., 32., 31., 30., 30., 30., 30., 29., 29., 29., 29., 28., 27., 26.,
        26., 25., 17., 13., 12., 10., 23., 15.])

In [31]:
expected_scores

Unnamed: 0,tag_sequence,score,length,example_index
0,"[1, 1, 1, 2]",33,4,0
1,"[1, 1, 1, 1]",32,4,0
2,"[2, 1, 1, 2]",31,4,0
3,"[1, 1, 2, 2]",30,4,0
4,"[1, 2, 1, 2]",30,4,0
5,"[2, 2, 1, 2]",30,4,0
6,"[2, 1, 1, 1]",30,4,0
7,"[2, 2, 2, 2]",29,4,0
8,"[1, 2, 1, 1]",29,4,0
9,"[1, 2, 2, 2]",29,4,0


In [32]:
assert torch.allclose(
    predicted_score, torch.tensor(expected_scores["score"], dtype=torch.float)
)

### Test Viterbi

In [33]:
top_k = 10

In [34]:
%%timeit
crf_head.viterbi_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=mask, top_k=top_k
)

208 µs ± 6.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [35]:
expected_scores  # .query("length < 4")

Unnamed: 0,tag_sequence,score,length,example_index
0,"[1, 1, 1, 2]",33,4,0
1,"[1, 1, 1, 1]",32,4,0
2,"[2, 1, 1, 2]",31,4,0
3,"[1, 1, 2, 2]",30,4,0
4,"[1, 2, 1, 2]",30,4,0
5,"[2, 2, 1, 2]",30,4,0
6,"[2, 1, 1, 1]",30,4,0
7,"[2, 2, 2, 2]",29,4,0
8,"[1, 2, 1, 1]",29,4,0
9,"[1, 2, 2, 2]",29,4,0


In [36]:
viterbi_nbest_predictions = crf_head.viterbi_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=mask, top_k=top_k
)
viterbi_nbest_predictions

(tensor([[[1, 1, 1, 2],
          [1, 1, 1, 1],
          [2, 1, 1, 2],
          [1, 1, 2, 2],
          [1, 2, 1, 2],
          [2, 2, 1, 2],
          [2, 1, 1, 1],
          [1, 2, 2, 2],
          [1, 2, 1, 1],
          [2, 2, 1, 1]],
 
         [[2, 2, 0, 0],
          [1, 2, 0, 0],
          [2, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0]],
 
         [[2, 0, 0, 0],
          [1, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0]]]),
 tensor([[    33.,     32.,     31.,     30.,     30.,     30.,     30.,     29.,
              29.,     29.],
         [    17.,     13.,     12.,     10.,  -9991.,  -9991.,  -9991.,  -9991.,
           -9991.,  -9991.],
         [    23.,     15., -10000., -10000., -10

In [37]:
from copy import deepcopy

In [38]:
def verify_viterbi_predictions(
    expected_scores, top_k, viterbi_nbest_predictions, mask_for_tags, padding_tag_id=0
):
    # mask out the padding tag, so be able to correctly 
    # calculate the max number of meaningful sequences
    mask_for_tags = deepcopy(mask_for_tags)
    mask_for_tags[..., padding_tag_id] = False 
    max_variations_per_example = mask_for_tags.sum(2)
    max_variations_per_example[max_variations_per_example == 0] = 1
    max_variations_per_example = max_variations_per_example.prod(1)
    num_examples = expected_scores["example_index"].max() + 1
    for i in range(num_examples):
        hand_predictions_for_example = expected_scores[expected_scores["example_index"] == i]
        score_to_tag_seqences = {
            score: group["tag_sequence"].tolist()
            for score, group in hand_predictions_for_example.groupby("score")
        }
        for j in range(top_k):
            if j < max_variations_per_example[i]: # the maximum number of tag-sequences for the given example
                model_prediction = viterbi_nbest_predictions[0][i, j].tolist()[:lengths[i]]
                model_score = viterbi_nbest_predictions[1][i, j].item()
                assert model_prediction in score_to_tag_seqences[model_score]

In [39]:
verify_viterbi_predictions(
    expected_scores, top_k, viterbi_nbest_predictions, mask, padding_tag_id=0
)

#### Test Viterbi with constraints

In [40]:
constraints = convert_constraints(
    [[-1, -1, 2, 1], [1, -1], [-1]]  # we require the sequence to end in ..., 2, 1
)

In [41]:
mask_for_lengths_and_tags = util.get_mask_for_tags(
    length_mask, num_tags, constraints=constraints
)

In [42]:
%%timeit
crf_head.viterbi_algorithm(
    log_emissions=log_emissions,
    lengths=lengths,
    mask=mask_for_lengths_and_tags,
    top_k=top_k,
)

208 µs ± 5.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [43]:
viterbi_nbest_predictions = crf_head.viterbi_algorithm(
    log_emissions=log_emissions,
    lengths=lengths,
    mask=mask_for_lengths_and_tags,
    top_k=top_k,
)
viterbi_nbest_predictions

(tensor([[[1, 1, 2, 1],
          [1, 2, 2, 1],
          [2, 2, 2, 1],
          [2, 1, 2, 1],
          [1, 1, 2, 2],
          [1, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 1, 2, 2],
          [1, 1, 2, 1],
          [1, 1, 2, 1]],
 
         [[1, 2, 0, 0],
          [1, 1, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0],
          [1, 2, 0, 0]],
 
         [[2, 0, 0, 0],
          [1, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [0, 0, 0, 0]]]),
 tensor([[    27.,     26.,     26.,     25.,  -9970.,  -9971.,  -9971.,  -9972.,
           -9975.,  -9975.],
         [    13.,     10.,  -9991.,  -9991.,  -9991.,  -9991.,  -9991.,  -9991.,
           -9991.,  -9991.],
         [    23.,     15., -10000., -10000., -10

In [44]:
expected_scores[
    get_mask_for_scores_by_hand(expected_scores, constraints=constraints)
].sort_values(["length", "score"], ascending=[False, False])

Unnamed: 0,tag_sequence,score,length,example_index
12,"[1, 1, 2, 1]",27,4,0
13,"[1, 2, 2, 1]",26,4,0
14,"[2, 2, 2, 1]",26,4,0
15,"[2, 1, 2, 1]",25,4,0
17,"[1, 2]",13,2,1
19,"[1, 1]",10,2,1
20,[2],23,1,2
21,[1],15,1,2


In [45]:
verify_viterbi_predictions(
    expected_scores, top_k, viterbi_nbest_predictions, mask_for_lengths_and_tags, padding_tag_id=0
)

### **Test the partition function:**

In [46]:
length_mask

tensor([[ True,  True,  True,  True],
        [ True,  True, False, False],
        [ True, False, False, False]])

In [47]:
%%timeit
crf_head._forward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)

257 µs ± 36.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [48]:
%%timeit
crf_head._backward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)

330 µs ± 46.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [49]:
# partition function computed bt the forward algorithm:
log_alpha, log_partition_forward = crf_head._forward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)
log_alpha, log_partition_forward

(tensor([[[-1.0000e+04,  2.0000e+00,  3.0000e+00],
          [-1.9997e+04,  1.3127e+01,  1.3693e+01],
          [-1.9986e+04,  2.9211e+01,  2.7627e+01],
          [-1.9971e+04,  3.2221e+01,  3.3284e+01]],
 
         [[-1.0000e+04,  4.0000e+00,  9.0000e+00],
          [-1.9991e+04,  1.2127e+01,  1.7018e+01],
          [-1.9983e+04, -9.9808e+03, -9.9800e+03],
          [-2.9980e+04, -1.9976e+04, -1.9976e+04]],
 
         [[-1.0000e+04,  1.5000e+01,  2.3000e+01],
          [-1.9977e+04, -9.9750e+03, -9.9740e+03],
          [-2.9974e+04, -1.9970e+04, -1.9970e+04],
          [-3.9969e+04, -2.9965e+04, -2.9966e+04]]]),
 tensor([33.5805, 17.0256, 23.0003]))

In [50]:
# perition functions computed by the backward algorithm
log_beta, log_partition_backward = crf_head._backward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)
log_beta, log_partition_backward

(tensor([[[-1.9973e+04,  3.3410e+01,  3.1725e+01],
          [-1.9985e+04,  2.6351e+01,  2.4562e+01],
          [-2.0000e+04,  1.5313e+01,  1.3049e+01],
          [-1.0000e+04, -2.0000e+00,  0.0000e+00]],
 
         [[-1.9995e+04,  1.3049e+01,  1.7007e+01],
          [-1.0000e+04,  1.0000e+00,  5.0000e+00],
          [-1.9983e+04, -9.9789e+03, -9.9800e+03],
          [-2.9979e+04, -1.9974e+04, -1.9976e+04]],
 
         [[-1.0000e+04,  1.5000e+01,  2.3000e+01],
          [-1.9977e+04, -9.9730e+03, -9.9740e+03],
          [-2.9973e+04, -1.9968e+04, -1.9970e+04],
          [-3.9968e+04, -2.9963e+04, -2.9966e+04]]]),
 tensor([33.5805, 17.0256, 23.0003]))

In [51]:
# partition function computed from the scores computed by hand:
log_partition_by_hand = compute_log_partition_by_hand(expected_scores)
log_partition_by_hand

  return scores_for_constrained_sequences.groupby("example_index").apply(


example_index
0    tensor(33.5805)
1    tensor(17.0256)
2    tensor(23.0003)
dtype: object

In [52]:
log_partition_forward

tensor([33.5805, 17.0256, 23.0003])

In [53]:
log_partition_backward

tensor([33.5805, 17.0256, 23.0003])

In [54]:
torch.stack(log_partition_by_hand.tolist())

tensor([33.5805, 17.0256, 23.0003])

In [55]:
assert (
    log_partition_forward.tolist()
    == log_partition_backward.tolist()
    == torch.stack(log_partition_by_hand.tolist()).tolist()
)

Run the "forward" method:

In [56]:
crf_head.forward(log_emissions=log_emissions, lengths=lengths, compute_log_beta=True)

{'log_partition': tensor([33.5805, 17.0256, 23.0003]),
 'log_alpha': tensor([[[-1.0000e+04,  2.0000e+00,  3.0000e+00],
          [-1.9997e+04,  1.3127e+01,  1.3693e+01],
          [-1.9986e+04,  2.9211e+01,  2.7627e+01],
          [-1.9971e+04,  3.2221e+01,  3.3284e+01]],
 
         [[-1.0000e+04,  4.0000e+00,  9.0000e+00],
          [-1.9991e+04,  1.2127e+01,  1.7018e+01],
          [-1.9983e+04, -9.9808e+03, -9.9800e+03],
          [-2.9980e+04, -1.9976e+04, -1.9976e+04]],
 
         [[-1.0000e+04,  1.5000e+01,  2.3000e+01],
          [-1.9977e+04, -9.9750e+03, -9.9740e+03],
          [-2.9974e+04, -1.9970e+04, -1.9970e+04],
          [-3.9969e+04, -2.9965e+04, -2.9966e+04]]]),
 'logits': None,
 'log_beta': tensor([[[-1.9973e+04,  3.3410e+01,  3.1725e+01],
          [-1.9985e+04,  2.6351e+01,  2.4562e+01],
          [-2.0000e+04,  1.5313e+01,  1.3049e+01],
          [-1.0000e+04, -2.0000e+00,  0.0000e+00]],
 
         [[-1.9995e+04,  1.3049e+01,  1.7007e+01],
          [-1.0000e+04, 

### Test constrained partition functions:

In [57]:
constraints = convert_constraints(
    [[-1, -1, 1, -1], [-1, 1], [-1]]  # we require the sequence to end in ..., 2, 1
)

In [58]:
mask_for_lengths_and_tags = util.get_mask_for_tags(
    length_mask, num_tags, constraints=constraints
)

In [59]:
constrained_log_alpha, constrained_log_partition_forward = crf_head._forward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=mask_for_lengths_and_tags
)
constrained_log_partition_forward

tensor([33.5243, 12.1269, 23.0003])

In [60]:
constrained_log_beta, constrained_log_partition_backward = crf_head._backward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=mask_for_lengths_and_tags
)
constrained_log_partition_backward

tensor([33.5243, 12.1269, 23.0003])

In [61]:
expected_scores.loc[
    get_mask_for_scores_by_hand(expected_scores, constraints=constraints)
]

Unnamed: 0,tag_sequence,score,length,example_index
0,"[1, 1, 1, 2]",33,4,0
1,"[1, 1, 1, 1]",32,4,0
2,"[2, 1, 1, 2]",31,4,0
4,"[1, 2, 1, 2]",30,4,0
5,"[2, 2, 1, 2]",30,4,0
6,"[2, 1, 1, 1]",30,4,0
8,"[1, 2, 1, 1]",29,4,0
10,"[2, 2, 1, 1]",29,4,0
18,"[2, 1]",12,2,1
19,"[1, 1]",10,2,1


In [62]:
constrained_log_partition_by_hand = compute_log_partition_by_hand(
    expected_scores, constraints=constraints
)
constrained_log_partition_by_hand = torch.stack(
    constrained_log_partition_by_hand.tolist()
)
constrained_log_partition_by_hand

  return scores_for_constrained_sequences.groupby("example_index").apply(


tensor([33.5243, 12.1269, 23.0003])

Run the "forward" method which will do all of the above:

In [63]:
crf_head.forward(
    log_emissions=log_emissions,
    lengths=lengths,
    mask=mask_for_lengths_and_tags,
    compute_log_beta=True,
)

{'log_partition': tensor([33.5243, 12.1269, 23.0003]),
 'log_alpha': tensor([[[-1.0000e+04,  2.0000e+00,  3.0000e+00],
          [-1.9997e+04,  1.3127e+01,  1.3693e+01],
          [-1.9986e+04,  2.9211e+01, -9.9824e+03],
          [-1.9971e+04,  3.2211e+01,  3.3211e+01]],
 
         [[-1.0000e+04,  4.0000e+00,  9.0000e+00],
          [-1.9991e+04,  1.2127e+01, -9.9880e+03],
          [-1.9988e+04, -9.9829e+03, -9.9839e+03],
          [-2.9983e+04, -1.9978e+04, -1.9979e+04]],
 
         [[-1.0000e+04,  1.5000e+01,  2.3000e+01],
          [-1.9977e+04, -9.9750e+03, -9.9740e+03],
          [-2.9974e+04, -1.9970e+04, -1.9970e+04],
          [-3.9969e+04, -2.9965e+04, -2.9966e+04]]]),
 'logits': None,
 'log_beta': tensor([[[-1.9974e+04,  3.3362e+01,  3.1627e+01],
          [-1.9985e+04,  2.6313e+01,  2.4313e+01],
          [-2.0000e+04,  1.5313e+01, -9.9970e+03],
          [-1.0000e+04, -2.0000e+00,  0.0000e+00]],
 
         [[-1.9999e+04,  1.0000e+01,  1.2000e+01],
          [-1.0000e+04, 

**We see a small discrepancy in partition functions when adding constraints to the backward algorithm...**

In [64]:
(
    constrained_log_partition_by_hand.tolist(),
    constrained_log_partition_forward.tolist(),
    constrained_log_partition_backward.tolist()
)

([33.52425765991211, 12.126928329467773, 23.000335693359375],
 [33.524261474609375, 12.126928329467773, 23.000335693359375],
 [33.524261474609375, 12.126928329467773, 23.000335693359375])

In [65]:
torch.testing.assert_close(
    constrained_log_partition_by_hand, constrained_log_partition_forward
)
torch.testing.assert_close(
    constrained_log_partition_by_hand, constrained_log_partition_backward
)

### Test that the loglik and weighted loglik match

In [66]:
all_forward_output = crf_head.forward(
    all_emissions, lengths=torch.LongTensor([4] * 16 + [2] * 4 + [1] * 2), tags=all_tags
)
weighted_log_likelihoods = crf_head.weighted_log_likelihood(
    log_emissions=all_emissions, 
    lengths=torch.LongTensor([4] * 16 + [2] * 4 + [1] * 2),
    tags=all_tags,
    log_alpha=all_forward_output["log_alpha"],
    weights=torch.tensor([1, 1, 1])
)
log_likelihoods = crf_head.log_likelihood(
    logits=all_forward_output["logits"],
    log_partition=all_forward_output["log_partition"]
)
torch.testing.assert_close(log_likelihoods, weighted_log_likelihoods)

#### Test that marginals (computed using log_alpha and log_beta) are correct:

In [67]:
log_marginals = crf_head.get_point_log_marginals(
    log_emissions=log_emissions,
    log_alpha=log_alpha,
    log_beta=log_beta,
    log_partition=log_partition_backward,
)
log_marginals  # first column is for the padding tag, so needs to be ignored

tensor([[[-3.0008e+04, -1.7010e-01, -1.8552e+00],
         [-4.0020e+04, -1.0285e-01, -2.3255e+00],
         [-4.0031e+04, -5.6278e-02, -2.9054e+00],
         [-3.0003e+04, -1.3594e+00, -2.9681e-01]],

        [[-3.0013e+04, -3.9770e+00, -1.8919e-02],
         [-3.0013e+04, -4.8987e+00, -7.4844e-03],
         [-3.9995e+04, -2.0098e+04, -2.0117e+04],
         [-5.9974e+04, -3.9767e+04, -3.9970e+04]],

        [[-2.0024e+04, -8.0003e+00, -3.3569e-04],
         [-4.4977e+04, -2.3182e+04, -2.2094e+04],
         [-5.9981e+04, -4.0082e+04, -4.0104e+04],
         [-7.9959e+04, -5.9751e+04, -5.9954e+04]]])

In [68]:
torch.exp(log_marginals)

tensor([[[0.0000e+00, 8.4358e-01, 1.5642e-01],
         [0.0000e+00, 9.0226e-01, 9.7739e-02],
         [0.0000e+00, 9.4528e-01, 5.4725e-02],
         [0.0000e+00, 2.5682e-01, 7.4318e-01]],

        [[0.0000e+00, 1.8741e-02, 9.8126e-01],
         [0.0000e+00, 7.4562e-03, 9.9254e-01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[0.0000e+00, 3.3535e-04, 9.9966e-01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00]]])

In [69]:
compute_marginal_by_hand(
    expected_scores,
    log_partition_by_hand=log_partition_by_hand,
    position=1,
    tag=1,
    example_index=0,
)

tensor(-0.1029)

In [70]:
# compute all marginals by hand
marginals_by_hand = []
for example_index in sorted(expected_scores["example_index"].unique()):
    seq_length = expected_scores.query("example_index == @example_index")[
        "length"
    ].tolist()[0]
    marginals_by_hand_for_example = torch.zeros(seq_length, 2)
    for position in range(seq_length):
        for j, tag in enumerate([1, 2]):
            marginals_by_hand_for_example[position, j] = compute_marginal_by_hand(
                expected_scores,
                log_partition_by_hand=log_partition_by_hand,
                position=position,
                tag=tag,
                example_index=example_index,
            )
    marginals_by_hand.append(marginals_by_hand_for_example)

**We don't quite get equality, but we're close**:

In [71]:
marginals_by_hand

[tensor([[-0.1701, -1.8552],
         [-0.1029, -2.3255],
         [-0.0563, -2.9054],
         [-1.3594, -0.2968]]),
 tensor([[-3.9770, -0.0189],
         [-4.8987, -0.0075]]),
 tensor([[-8.0003e+00, -3.3569e-04]])]

In [72]:
for i in range(batch_size):
    torch.testing.assert_close(marginals_by_hand[i], log_marginals[i, : lengths[i], 1:])

Now compute marginals using constraint forwards and constraint backwards:

In [73]:
marginals_by_constrained_forward = []
log_partition = crf_head._forward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)[1]

In [74]:
log_partition

tensor([33.5805, 17.0256, 23.0003])

In [75]:
for example_index in sorted(expected_scores["example_index"].unique()):
    seq_length = lengths[example_index]
    marginals_by_constrained_forward_for_example = torch.zeros(seq_length, 2)
    for position in range(seq_length):
        for j, tag in enumerate([1, 2]):
            one_tag_contraints_constraints = [[]] * batch_size
            one_tag_contraints_constraints[example_index] = [
                (position, position + 1, tag)
            ]
            marginals_by_constrained_forward_for_example[
                position, j
            ] = crf_head._forward_algorithm(
                log_emissions=log_emissions,
                lengths=lengths,
                mask=util.get_mask_for_tags(
                    length_mask=length_mask,
                    num_tags=3,  # adding the padding tag here
                    constraints=one_tag_contraints_constraints,
                ),
            )[
                1
            ][
                example_index
            ]  # take only the log_partition
    marginals_by_constrained_forward_for_example -= log_partition[example_index]
    marginals_by_constrained_forward.append(
        marginals_by_constrained_forward_for_example
    )

In [76]:
marginals_by_constrained_forward

[tensor([[-0.1701, -1.8552],
         [-0.1029, -2.3254],
         [-0.0563, -2.9054],
         [-1.3594, -0.2968]]),
 tensor([[-3.9770, -0.0189],
         [-4.8987, -0.0075]]),
 tensor([[-8.0003e+00, -3.3569e-04]])]

In [77]:
for i in range(batch_size):
    torch.testing.assert_close(
        marginals_by_hand[i], marginals_by_constrained_forward[i]
    )

In [78]:
marginals_by_constrained_backward = []
log_partition = crf_head._backward_algorithm(
    log_emissions=log_emissions, lengths=lengths, mask=length_mask
)[1]
for example_index in sorted(expected_scores["example_index"].unique()):
    seq_length = lengths[example_index]
    marginals_by_constrained_backward_for_example = torch.zeros(seq_length, 2)
    for position in range(seq_length):
        for j, tag in enumerate([1, 2]):
            one_tag_contraints_constraints = [[]] * batch_size
            one_tag_contraints_constraints[example_index] = [
                (position, position + 1, tag)
            ]
            marginals_by_constrained_backward_for_example[
                position, j
            ] = crf_head._backward_algorithm(
                log_emissions=log_emissions,
                lengths=lengths,
                mask=util.get_mask_for_tags(
                    length_mask=length_mask,
                    num_tags=3,  # adding the padding tag here
                    constraints=one_tag_contraints_constraints,
                ),
            )[
                1
            ][
                example_index
            ]  # take only the log_partition
    marginals_by_constrained_backward_for_example -= log_partition[example_index]
    marginals_by_constrained_backward.append(
        marginals_by_constrained_backward_for_example
    )

In [79]:
for i in range(batch_size):
    torch.testing.assert_close(
        marginals_by_hand[i], marginals_by_constrained_backward[i]
    )

**We see small differences in all ways for computing marginals.**

Uncomment the following cells to hit the assertion errors and see the discrepancy sizes.

In [80]:
# for i in range(batch_size):
#     np.testing.assert_equal(
#         np.array(marginals_by_hand[i]), np.array(log_marginals[i, : lengths[i], 1:])
#     )

In [81]:
# for i in range(batch_size):
#     np.testing.assert_equal(
#         np.array(marginals_by_constrained_forward[i]), np.array(marginals_by_hand[i])
#     )

In [82]:
# for i in range(batch_size):
#     np.testing.assert_equal(
#         np.array(marginals_by_constrained_backward[i]), np.array(marginals_by_hand[i])
#     )