Welcome to the official GitHub repository of RTFS-Net, accepted by ICLR 2024.
The 'cocktail party problem' highlights the difficulty machines face in isolating a single voice from overlapping conversations and background noise.
- T-domain Methods: These offer high-quality audio separation but suffer from high computational complexity and slow processing due to their extensive parameter count.
- TF-domain Methods: More efficient in computation but historically underperform compared to T-domain methods. They face three key challenges:
- Lack of independent modeling of time and frequency dimensions.
- Insufficient use of visual cues from multiple receptive fields for enhancing model performance.
- Poor handling of complex features, leading to loss of critical amplitude and phase information.
- Approach: Integrates audio and high-fidelity visual cues using a novel TF-domain method.
-
Innovations:
- RTFS Blocks: Compress and independently model acoustic dimensions (time and frequency), minimizing information loss while creating a low-complexity subspace.
- Cross-dimensional Attention Fusion (CAF) Block: Efficiently fuses audio and visual information for enhanced voice separation while utilizing 1.3% the computational complexity of the previous SOTA method.
-
Spectral Source Separation (
$S^3$ ) Block: Effectively extracts the target speaker's voice features using complex numbers.
Comparison of RTFS-Net with existing AVSS methods.
The dataflow of RTFS-Net is described below. See our paper for more details.
The red and blue solid lines signify the flow directions of auditory and visual features respectively. The snowflake indicates the weights are frozen and the component is not involved in training.
The core of RTFS-Net is the dual path core architecture shown below, named "RTFS Blocks". After compressing the data to a more efficient size, we process first the frequency dimension, then the time dimension, then both dimensions in tandem using TF-domain self-attention to capture inter-dependencies. We then carefully restore the data to its original dimensions using our Temporal-Frequency Attention Reconstruction units.
Audio-visual speech separation (AVSS) methods aim to integrate different modalities to generate high-quality separated speech, thereby enhancing the performance of downstream tasks such as speech recognition. Most existing state-of-the-art (SOTA) models operate in the time domain. However, their overly simplistic approach to modeling acoustic features often necessitates larger and more computationally intensive models in order to achieve SOTA performance. In this paper, we present a novel time-frequency domain AVSS method: Recurrent Time-Frequency Separation Network (RTFS-Net), which applies its algorithms on the complex time-frequency bins yielded by the Short-Time Fourier Transform. We model and capture the time and frequency dimensions of the audio independently using a multi-layered RNN along each dimension. Furthermore, we introduce a unique attention-based fusion technique for the efficient integration of audio and visual information, and a new mask separation approach that takes advantage of the intrinsic spectral nature of the acoustic features for a clearer separation. RTFS-Net outperforms the previous SOTA method using only 10% of the parameters and 18% of the MACs. This is the first time-frequency domain AVSS method to outperform all contemporary time-domain counterparts.
Before you begin, ensure you have Miniconda installed on your system. Then, follow these steps:
-
Clone this repository to your local machine:
git clone https://github.com/spkgyk/RTFS-Net.git cd RTFS-Net
-
Create a Conda environment with the required dependencies using the provided script:
source setup/conda_env.sh
Note: AVSS is a GPU-intensive task, so make sure you have access to a GPU for both installation and training.
Training the AVSS model is a straightforward process using the train.py
script. You can customize the training by specifying a configuration file and, if necessary, a checkpoint for resuming training. Here's how to get started:
-
Run the training script with a default configuration file:
python train.py --conf-dir config/lrs2_RTFSNet_4_layer.yaml
-
If you encounter unexpected interruptions during training and wish to resume from a checkpoint, use the following command (replace the checkpoint path with your specific checkpoint):
python train.py --conf-dir config/lrs2_RTFSNet_4_layer.yaml \ --checkpoint ../experiments/audio-visual/RTFS-Net/LRS2/4_layers/checkpoints/epoch=150-val_loss=-13.16.ckpt
Feel free to explore and fine-tune the configuration files in the config
directory to suit your specific requirements and dataset.
To train all the models, please use the run script provided:
bash run.sh
Use the test.py
script for evaluating your trained model on the test set. For examples, again see run.sh
.