In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
"""
# If you're using Google Colab and not running locally, run this cell.

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install text-unidecode
!pip install ipython

# ## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[asr]

## Install TorchAudio
!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html

# Streaming End-to-End Speaker Diarization 




## Streaming Diarization Inference with Sortformer

As explained in the [Sortformer Diarization Training](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb) tutorial, Sortformer is trained with Sort-Loss to generate speaker segments in arrival-time order. If a diarization model can generate speaker segments in a pre-defined manner or order, we do not need to match the permutations when we train diarization model with multi-speaker automatic speech recognition (ASR) models, nor do we need to match permutations from each window when a diarization model is running in streaming mode where audio chunk sequences are processed, creating a problem of permutation matching between inference windows. 

### Arrival-Order Speaker Cache

We propose the [Arrival-Order Speaker Cache (AOSC)](https://arxiv.org/pdf/2507.18446), which stores frame-level embeddings from the pre-encode NEST module. Unlike [speaker-tracing buffer](https://arxiv.org/pdf/2006.02616) in the previous [EEND-based online systems](https://arxiv.org/pdf/2101.08473), AOSC organizes embeddings by speaker index in arrival-time order. Combined with Sortformer's built-in arrival-ordering mechanism, this automatically resolves between-chunk permutations.

<img src="images/cache_fifo_chunk.png" alt="Cache, FIFO and Chunk" style="width: 800px;"/>

### Speaker cache updates in streaming Sortformer 

Short-chunk processing often reduces accuracy due to limited context. To address this, we combine AOSC with a FIFO queue, enhancing context utilization and enabling less frequent AOSC updates (rather than per-chunk), improving robustness and efficiency. As shown in the above figure, the system includes a speaker cache, FIFO queue, and input buffer (holding the current chunk and future context). Frames pushed out of the queue are processed by the speaker cache update mechanism.

<img src="images/streaming_steps.png" alt="Streaming steps" style="width: 1200px;"/>

The AOSC update acts as a no-op function if the input sequence is shorter than the maximum cache length `spkcache_len`. For longer sequences, it compresses the input to `spkcache_len` frames by keeping only the highest-scoring embeddings based on the model's frame-level predictions. You can find more details about the speaker cache update mechanism in the [Streaming Sortformer](https://arxiv.org/pdf/2507.18446) paper.

### A toy example for diarization with a streaming Sortformer model 

Download a toy example audio file (`an4_diarize_test.wav`) and its ground-truth label file (`an4_diarize_test.rttm`).

In [None]:
import os
import wget
ROOT = os.getcwd()
data_dir = os.path.join(ROOT,'data')
os.makedirs(data_dir, exist_ok=True)
an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')
an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')
if not os.path.exists(an4_audio):
    an4_audio_url = "https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav"
    an4_audio = wget.download(an4_audio_url, data_dir)
if not os.path.exists(an4_rttm):
    an4_rttm_url = "https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm"
    an4_rttm = wget.download(an4_rttm_url, data_dir)

Let's plot the waveform and listen to the audio. You'll notice that there are two speakers in the audio clip.

In [None]:
import IPython
import matplotlib.pyplot as plt
import numpy as np
import librosa

sr = 16000
signal, sr = librosa.load(an4_audio,sr=sr) 

fig,ax = plt.subplots(1,1)
fig.set_figwidth(20)
fig.set_figheight(2)
plt.plot(np.arange(len(signal)),signal,'gray')
fig.suptitle('Reference merged an4 audio', fontsize=16)
plt.xlabel('time (secs)', fontsize=18)
ax.margins(x=0)
plt.ylabel('signal strength', fontsize=16)
a,_ = plt.xticks();plt.xticks(a,a/sr)

IPython.display.Audio(signal, rate=sr)

Now that we have downloaded the example audio file contains multiple speakers, we can feed the audio clip into Sortformer diarizer and get the speaker diarization results.

### Download Sortformer diarizer model

To download streaming Sortformer diarizer from [HuggingFace model card](https://huggingface.co/nvidia) you need to get a [HuggingFace Acces Token](https://huggingface.co/docs/hub/en/security-tokens) and feed your access token in your python environment using [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli).

If you are having trouble getting a HuggingFace token, you can download Sortformer model from [Streaming Sortformer HuggingFace model card](https://huggingface.co/nvidia) and specify the `.nemo` file path to the downloaded model.

In [None]:
from nemo.collections.asr.models import SortformerEncLabelModel
from huggingface_hub import get_token as get_hf_token
import torch

if get_hf_token() is not None and get_hf_token().startswith("hf_"):
    # If you have logged into HuggingFace hub and have access token 
    diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
else:
    # You can downloaded ".nemo" file from https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 and specify the path.
    diar_model = SortformerEncLabelModel.restore_from(restore_path="/path/to/diar_streaming_sortformer_4spk-v2.nemo", map_location=torch.device('cuda'), strict=False)
diar_model.eval()

### Diarization output display function

To visualize the streaming diarization output, we use the same diarization output display function as in offline Sortformer diarizer tutorial. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_diarout(preds):
    preds_mat = preds.cpu().numpy().transpose()
    cmap_str, grid_color_p= 'viridis', 'gray'
    LW, FS = 0.4, 36

    yticklabels = ["spk0", "spk1", "spk2", "spk3"]
    yticks = np.arange(len(yticklabels))
    fig, axs = plt.subplots(1, 1, figsize=(30, 3)) 

    axs.imshow(preds_mat, cmap=cmap_str, interpolation='nearest') 
    axs.set_title('Predictions', fontsize=FS)
    axs.set_xticks(np.arange(-.5, preds_mat.shape[1], 1), minor=True)
    axs.set_yticks(yticks)
    axs.set_yticklabels(yticklabels)
    axs.set_xlabel(f"80 ms Frames", fontsize=FS)
    axs.grid(which='minor', color=grid_color_p, linestyle='-', linewidth=LW)

    plt.savefig('plot.png', dpi=300) 
    plt.show()

## Running Streaming Sortformer diarizer

### Parameter configuration for Streaming Sortformer 

Now it's time to setup the model with streaming parameters (all measured in 80ms frames). 

`chunk_len`: The number of frames in a processing chunk.  
`chunk_right_context`: The right context length.  
`fifo_len`: The number of previous frames attached before the chunk, from the FIFO queue.
`spkcache_update_period`: The number of frames extracted from the FIFO queue to update the speaker cache.
`spkcache_len`: The total number of frames in the speaker cache.

Note that the input buffer latency is determined by `chunk_len` + `chunk_right_context`.

The following restrictions apply to the Streaming Sortformer parameters:   
    * All streaming parameters must be non-negative integers.   
    * `chunk_len` and `spkcache_update_period` must be greater than 0.

In [None]:
import time
import math
import torch
import torch.amp
from tqdm import tqdm 

# If cuda is available, assign the model to cuda
if torch.cuda.is_available():
    diar_model.to(torch.device("cuda"))

global autocast
autocast = torch.amp.autocast(diar_model.device.type, enabled=True)

# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
diar_model.sortformer_modules.chunk_len = 6
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.chunk_right_context = 7
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False

# Validate that the streaming parameters are set correctly
diar_model.sortformer_modules._check_streaming_parameters()

### Feature extraction from an audio file  


Now, set up the input audio signal and convert it to log-mel features. Note that we are simulating the streaming scenario. In real life, we won't be able to access the whole utterance at once.

In [None]:
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)

### Run a streaming loop for streaming diarization

The following variables are needed to run the simulated streaming speaker diarization session. 

In [None]:
batch_size = 1
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)

streaming_state = diar_model.sortformer_modules.init_streaming_state(
        batch_size = batch_size,
        async_streaming = True,
        device = diar_model.device
    )
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)

streaming_loader = diar_model.sortformer_modules.streaming_feat_loader(
    feat_seq=processed_signal,
    feat_seq_length=processed_signal_length,
    feat_seq_offset=processed_signal_offset,
)

num_chunks = math.ceil(
    processed_signal.shape[2] / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor)
)

plot_preds = torch.zeros(
    (batch_size, num_chunks * diar_model.sortformer_modules.chunk_len, diar_model.sortformer_modules.n_spk), device=diar_model.device
)

Now we are ready to run streaming diarization step. Check out the output at the end of the streaming for loop. Note that this is a way to simulate the streaming input. In real-life setting,  `chunk_feat_seq_t` needs to be replaced with real-time streaming microphone audio stream.

In [None]:
# To simulate the real-time streaming, we will sleep for a chunk duration after each step
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
print(f"Chunk duration: {chunk_duration_seconds} seconds")

for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in tqdm(
    streaming_loader,
    total=num_chunks,
    desc="Streaming Steps",
    disable=False,
):
    loop_start_time = time.time()
    with torch.inference_mode():
        with autocast:
            streaming_state, total_preds = diar_model.forward_streaming_step(
                processed_signal=chunk_feat_seq_t,
                processed_signal_length=feat_lengths,
                streaming_state=streaming_state,
                total_preds=total_preds,
                left_offset=left_offset,
                right_offset=right_offset,
            )
            # plot the predictions
            plot_preds[:,:total_preds.shape[1]] = total_preds
            plot_diarout(plot_preds[0,:]) 
            time.sleep(chunk_duration_seconds)