This only works for non permutation invariant tasks. Permutation invariant tasks (such as speaker diarization) have the problem of mismatched classes between the windows outputted by the sliding window.

In [182]:
from pyannote.audio import Model
from pyannote.audio.models.segmentation import PyanNet
from pyannote.audio.tasks.segmentation.multilabel import MultiLabelSegmentation
from pyannote.database import registry, FileFinder
from pathlib import Path

registry.load_database("../_data/sample/sample.yaml")
protocol = registry.get_protocol("Sample.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()})

# authtoken = Path("../auth_token.txt").read_text().strip()
task = MultiLabelSegmentation(protocol, "vad")
model = PyanNet()

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../../.cache/torch/pyannote/models--pyannote--segmentation/snapshots/2ffce0501d0aecad81b43a06d538186e292d0070/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.0. Bad things might happen unless you revert torch to 1.x.
Specifications(problem=<Problem.MULTI_LABEL_CLASSIFICATION: 2>, resolution=<Resolution.FRAME: 1>, duration=5.0, min_duration=None, warm_up=(0.0, 0.0), classes=['speaker#1', 'speaker#2', 'speaker#3'], powerset_max_classes=None, permutation_invariant=True)
Specifications(problem=<Problem.MONO_LABEL_CLASSIFICATION: 1>, resolution=<Resolution.FRAME: 1>, duration=5, min_duration=5, warm_up=(0.0, 0.0), classes=['nothing', 'ov'], powerset_max_classes=None, permutation_invariant=False)


In [157]:
from pyannote.audio import Inference

f = next(protocol.train())
inf = Inference(model, batch_size=64, step=0.5, duration=model.specifications.duration)
output = inf(f)
output.data.shape

(51, 293, 3)

In [172]:
from pyannote.core import Annotation, Timeline, Segment, SlidingWindow
import torch
import itertools

reference: Annotation = f["annotation"]
uem: Timeline = f["annotated"]

# full file length
support = Segment(0.0, uem.extent().end)

# if not permutation invariant, the output is already (n_frames, n_classes)-shaped
if len(output.data.shape) == 2:
    pred = output.data
# if permutation invariant, the aggregation couldn't be done. The data is (n_chunks, n_frames_per_chunk, n_classes)-shaped
# however we dont know how to make the chunks (windows) match, so we cant aggregate
else:
    raise Exception()

# get the targets tensor
ref_t = reference.discretize(
    support=support,
    resolution=model.example_output.frames,
).data
ref_t = torch.from_numpy(ref_t).long()

# get the uem boolean tensor
uem_t = torch.from_numpy(uem.support().to_annotation().rename_labels(generator=itertools.cycle(["uem"])).discretize(
    support=support,
    resolution=model.example_output.frames,
    labels=["uem"],
).data).bool()

In [173]:
ref_t

tensor([[0, 0],
        [0, 0],
        [0, 0],
        ...,
        [1, 0],
        [1, 0],
        [1, 0]])

In [174]:
uem_t

tensor([[False],
        [False],
        [False],
        ...,
        [ True],
        [ True],
        [ True]])

In [175]:
pred.round(decimals=1)

array([[0. , 0. , 0. ],
       [0. , 0. , 0. ],
       [0. , 0. , 0. ],
       ...,
       [0.8, 0.8, 0. ],
       [0. , 0. , 0. ],
       [0. , 0. , 0. ]], dtype=float32)

In [176]:
(pred.shape, ref_t.shape, uem_t.shape)

((1759, 3), torch.Size([1758, 2]), torch.Size([1758, 1]))

As you can see, there might still be some frame count mismatches, due to float inaccuracies.

Also, even if we cut the last frame of pred, there is still one row of [0., 0. ,0.], which come from NaNs. 
This is because wasn't any data found at these frames, and empty frames are NaNs (but the aggregation replaces NaN with zeros so that the output is at least workable).

So here we might actually want to use (1757, x) shaped tensors.