This notebook demonstrates how to run audio-visual speech separation (AVSS) inference using our pretrained RTFS-Net model.

Project code and description: https://github.com/avsshw/AVSS

# Installation

First, clone the repository:

In [None]:
!git clone https://github.com/avsshw/AVSS.git

In [None]:
%cd /content/AVSS

# Inference

We trained our model on the dataset of the following format:


```bash
NameOfTheDirectoryWithUtterances/
├── audio/
│   ├── train/
│   │   ├── mix/          # Mixed audio utterances (2-speaker mixtures)
│   │   ├── s1/          
│   │   └── s2/          
│   └── val/
│       ├── mix/
│       ├── s1/
│       └── s2/
└── mouths/
    ├── train/            
    │   ├── SpeakerID1.npz
    │   ├── SpeakerID2.npz
    │   └── ...
    └── val/
        ├── SpeakerID1.npz
        ├── SpeakerID2.npz
        └── ...
```

It can be evaluated with the following commands which you can follow if needed:

1.
```bash
gdown https://drive.google.com/uc?id=1t7FFsG3hPcgUYuitekSMpggYLvzV6SXW
unzip /content/AVSS-main/rtfs.zip
```

2. Assuming you have dla_dataset directory in the root of the project (or you can change it via hydra option datasets.test.data_dir=your/dataset):
```bash
python3 inference.py inferencer.from_pretrained=rtfs_improved/model_best.pth datasets=val_inference
```

In [None]:
import os
os.environ['MPLBACKEND'] = 'Agg'
!uv run python3 inference.py inferencer.from_pretrained=rtfs_improved/model_best.pth datasets=val_inference

## Custom inference

If you wish to run inference on your custom dataset, our model expects data in the following structure:


```bash
NameOfTheDirectoryWithUtterances
├── audio
│   ├── mix
│   │   ├── FirstSpeakerID1_SecondSpeakerID1.wav # also may be flac or mp3
│   │   ├── FirstSpeakerID2_SecondSpeakerID2.wav
│   │   .
│   │   .
│   │   .
│   │   └── FirstSpeakerIDn_SecondSpeakerIDn.wav
│   ├── s1 # ground truth for the speaker s1, may not be given
│   │   ├── FirstSpeakerID1_SecondSpeakerID1.wav # also may be flac or mp3
│   │   ├── FirstSpeakerID2_SecondSpeakerID2.wav
│   │   .
│   │   .
│   │   .
│   │   └── FirstSpeakerIDn_SecondSpeakerIDn.wav
│   └── s2 # ground truth for the speaker s2, may not be given
│       ├── FirstSpeakerID1_SecondSpeakerID1.wav # also may be flac or mp3
│       ├── FirstSpeakerID2_SecondSpeakerID2.wav
│       .
│       .
│       .
│       └── FirstSpeakerIDn_SecondSpeakerIDn.wav
└── mouths # contains video information for all speakers
    ├── FirstOrSecondSpeakerID1.npz # npz mouth-crop
    ├── FirstOrSecondSpeakerID2.npz
    .
    .
    .
    └── FirstOrSecondSpeakerIDn.npz
```

**We provide a small example dataset with ground truths. To run inference:**

1. Download the data (you can pass the link to your YandexDisk dataset in the .zip format here):

In [None]:
!uv run python3 scripts/download_inference_data.py --link https://disk.yandex.ru/d/h2t8ItWMdne2ZA --download_location .

Download complete.
Extracting ./inference_dataset.zip...
Extraction complete!


For our example dataset you can also use
```bash
sh scripts/inference.sh
```

but we provide you with a full comand for your own use above.

2. Download the pretrained model:

In [None]:
!gdown https://drive.google.com/uc?id=1l72LuBr_CQxaut6WUbyyPFJIRIbH8-68
!unzip /content/AVSS/rtfs.zip

Downloading...
From (original): https://drive.google.com/uc?id=1t7FFsG3hPcgUYuitekSMpggYLvzV6SXW
From (redirected): https://drive.google.com/uc?id=1t7FFsG3hPcgUYuitekSMpggYLvzV6SXW&confirm=t&uuid=0b4a367a-a576-459f-88ea-c43522f67ab7
To: /content/AVSS-main/rtfs.zip
100% 311M/311M [00:02<00:00, 125MB/s]


3. Set environment & run inference:

(Predictions will be saved to data/saved/inference_custom_dir/test, works even if you don't have ground truth like we do)

In [None]:
import os
os.environ['MPLBACKEND'] = 'Agg'
!sh scripts/inference.sh

RTFSNet(
  (encoder): STFT()
  (decoder): ISTFT()
  (blocks): ModuleList(
    (0): RTFSBlock(
      (input_proj): Linear(in_features=2, out_features=64, bias=True)
      (freq_rnn): FrequencyRNN(
        (rnn): LSTM(64, 64, num_layers=2, batch_first=True, bidirectional=True)
        (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (time_rnn): TimeRNN(
        (rnn): LSTM(128, 64, num_layers=2, batch_first=True, bidirectional=True)
        (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (tf_interaction): TFInteraction(
        (freq_conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), groups=128)
        (time_conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), groups=128)
        (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (ffn): Sequential(
          (0): Linear(in_features=128, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_f

4. Evaluate metrics (Optional)

If your dataset includes ground-truth clean sources (s1/, s2/), you can compute metrics separately:

In [None]:
!uv run python3 calc_metrics.py \
--predictions_dir data/saved/inference_custom_dir/test \
--ground_truth_dir inference_dataset/audio \
--mixture_dir inference_dataset/audio/mix

100% 3/3 [00:02<00:00,  1.30it/s]
SI-SNRi        : 12.0220
SDRi           : 12.4403
PESQ           : 2.0737
STOI           : 0.8946


This script assumes:

Predictions are .wav files with the same names as mixtures.

Ground truth is split into s1/ and s2/ subdirectories.

Mixtures are in mixture_dir.