"""
Google Colab Setup for RE-SepFormer CI Project
Run this notebook in Google Colab to set up the environment and download data
"""

RE-SepFormer CI Project Workflow ===

1. DATA PREPARATION:
   - LibriSpeech train-clean-100 (100 hours)
   - WHAM! noise dataset (or synthetic noise)
   - Create multi-talker babble (1, 2, 4 speakers)
   - Mix at SNRs 1-10 dB

2. MODEL ARCHITECTURE (RE-SepFormer):
   - Uses SpeechBrain's ResourceEfficientSeparator
   - Key features: Non-overlapping chunks, memory averaging
   - 8M parameters (vs 26M for SepFormer)
   
3. TRAINING:
   - Loss: SI-SNR (Scale-Invariant SNR)
   - Optimizer: Adam with learning rate scheduling
   - 100 epochs (reduced to 5 for demo)
   
4. EVALUATION:
   - Metrics: SI-SDR improvement, PESQ, STOI
   - Test on different noise types and SNRs
   - Compare with Paper 1 results
   
5. CI TESTING:
   - Export processed audio files
   - 20 sentences per condition
   - Ready for behavioral testing

To run the complete pipeline:
1. Execute all cells in order
2. Monitor training progress
3. Review evaluation metrics
4. Export audio for CI testing

In [1]:
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH

# Clone SpeechBrain repository
!git clone https://github.com/speechbrain/speechbrain/
%cd /content/speechbrain/templates/speech_recognition/

In [2]:
!python -m pip install git+https://github.com/speechbrain/speechbrain.git

Collecting git+https://github.com/speechbrain/speechbrain.git
  Cloning https://github.com/speechbrain/speechbrain.git to /tmp/pip-req-build-ijbiqpuv
  Running command git clone --filter=blob:none --quiet https://github.com/speechbrain/speechbrain.git /tmp/pip-req-build-ijbiqpuv
  Resolved https://github.com/speechbrain/speechbrain.git to commit c75ab5489431fd0a2a7d21160bc37677801cb506
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install speechbrain
!pip install pesq
!pip install pystoi
!pip install mir_eval
!pip install hyperpyyaml
!pip install soundfile librosa
!pip install pandas numpy scipy matplotlib seaborn tqdm

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting pesq
  Downloading pesq-0.0.4.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pesq
  Building wheel for pesq (setup.py) ... [?25l[?25hdone
  Created wheel for pesq: filename=pesq-0.0.4-cp311-cp311-linux_x86_64.whl size=274952 sha256=5ba6189616eb10c5dcd89ed72136e49d147997cafd60227dd16b2799b5ca8fd0
  Stored in directory: /root/.cache/pip/wheels/ae/f1/23/2698d0bf31eec2b2aa50623b5d93b6206c49c7155d0e31345d
Successfully built pesq
Installing collected packages: pesq
Successfully installed pesq-0.0.4
Collecting pystoi
  Downloading pystoi-0.4.1-py2.py3-none-any.whl.metadata (4.0 kB)
Downloading pystoi-0.4.1-py2.py3-none-any.whl (8.2 kB)
Installing collected packages: pystoi
Successfully installed pystoi-0.4.1
Collecting mir_eval
  Downloading mir_eval-0.8.2-py3-none-any.whl.metadata (3.0 kB)
Downloading mir_eval-0.8.2-py3-none-any.whl (102 kB)
[2K   [90m━━━━━━━

In [4]:
import os

In [5]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: False


In [6]:
import urllib.request
import tarfile
from tqdm import tqdm
import shutil

In [8]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [29]:
os.chdir("/content/speechbrain/")

In [30]:
os.getcwd()

'/content/speechbrain'

In [31]:
#%cd speechbrain/
!pip install -r requirements.txt

Ignoring SoundFile: markers 'sys_platform == "win32"' don't match your environment
Collecting black==24.3.0 (from -r lint-requirements.txt (line 1))
  Downloading black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting click==8.1.7 (from -r lint-requirements.txt (line 2))
  Downloading click-8.1.7-py3-none-any.whl.metadata (3.0 kB)
Collecting flake8==7.0.0 (from -r lint-requirements.txt (line 3))
  Downloading flake8-7.0.0-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting isort==5.13.2 (from -r lint-requirements.txt (line 4))
  Downloading isort-5.13.2-py3-none-any.whl.metadata (12 kB)
Collecting pycodestyle==2.11.0 (from -r lint-requirements.txt (line 5))
  Downloading pycodestyle-2.11.0-py2.py3-none-any.whl.metadata (4.5 kB)
Collecting pydoclint==0.4.1 (from -r lint-requirements.txt (line 6))
  Downloading pydoclint-

In [9]:
def download_with_progress(url, filename):
    """Download file with progress bar"""
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)

    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
        urllib.request.urlretrieve(url, filename, reporthook=t.update_to)

In [10]:
!df -h /content

Filesystem      Size  Used Avail Use% Mounted on
overlay         226G   42G  184G  19% /


In [38]:
os.getcwd()

'/content/speechbrain'

In [13]:
os.chdir('/content')

In [15]:
os.getcwd()


'/content'

In [11]:
os.makedirs('data/raw', exist_ok=True)


In [20]:
# Create directory structure
directories = [
    'data/raw/librispeech',
    'data/raw/wham',
    'data/raw/ieee',
    'data/processed/train',
    'data/processed/val',
    'data/processed/test',
    'models/checkpoints',
    'scripts',
    'configs',
    'utils',
    'results/plots',
    'notebooks'
]


for dir_path in directories:
    os.makedirs(dir_path, exist_ok=True)

In [21]:
if not os.path.exists('data/raw/librispeech/train-clean-100'):
    print("Downloading LibriSpeech train-clean-100...")
    download_with_progress(
        'https://www.openslr.org/resources/12/train-clean-100.tar.gz',
        'train-clean-100.tar.gz'
    )

    print("Extracting...")
    with tarfile.open('train-clean-100.tar.gz', 'r:gz') as tar:
        tar.extractall('data/raw/librispeech/')

    # Clean up
    os.remove('train-clean-100.tar.gz')
    print("LibriSpeech downloaded and extracted!")
else:
    print("LibriSpeech already exists!")

Downloading LibriSpeech train-clean-100...


train-clean-100.tar.gz: 6.39GB [04:27, 23.9MB/s]                            


Extracting...
LibriSpeech downloaded and extracted!


In [23]:
!curl -L -o wham_noise.zip https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/wham_noise.zip
!unzip -q wham_noise.zip -d data/raw/wham/
!rm wham_noise.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 16.9G  100 16.9G    0     0  45.3M      0  0:06:22  0:06:22 --:--:-- 45.2M


In [24]:
import soundfile as sf
import torch
import torchaudio
from tqdm import tqdm
import random
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [25]:
class CIDataPreparer:
    def __init__(self, librispeech_path, wham_path, output_path, sample_rate=16000):
        """Initialize data preparer"""
        self.librispeech_path = Path(librispeech_path)
        self.wham_path = Path(wham_path)
        self.output_path = Path(output_path)
        self.sample_rate = sample_rate

        # Create output directories
        self.output_path.mkdir(parents=True, exist_ok=True)
        (self.output_path / 'train').mkdir(exist_ok=True)
        (self.output_path / 'val').mkdir(exist_ok=True)
        (self.output_path / 'test').mkdir(exist_ok=True)

    def load_audio(self, path, target_sr=16000):
        """Load audio file and resample if necessary"""
        try:
            audio, sr = torchaudio.load(path)
        except:
            # Fallback to soundfile
            audio, sr = sf.read(path)
            audio = torch.FloatTensor(audio).unsqueeze(0)

        # Convert to mono if stereo
        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim=0, keepdim=True)

        # Resample if necessary
        if sr != target_sr:
            resampler = torchaudio.transforms.Resample(sr, target_sr)
            audio = resampler(audio)

        return audio.numpy()[0], target_sr

    def create_multi_talker_babble(self, speech_files, num_talkers):
        """Create multi-talker babble from speech files"""
        # Randomly select speakers
        selected_files = random.sample(speech_files, min(num_talkers, len(speech_files)))

        # Load first speaker
        babble, sr = self.load_audio(selected_files[0])

        # Mix additional speakers
        for file in selected_files[1:]:
            audio, _ = self.load_audio(file)

            # Match lengths
            min_len = min(len(babble), len(audio))
            babble = babble[:min_len] + audio[:min_len]

        # Normalize
        babble = babble / (np.max(np.abs(babble)) + 1e-8)

        return babble

    def mix_audio_at_snr(self, speech, noise, snr_db):
        """Mix speech and noise at specified SNR"""
        # Match lengths
        min_len = min(len(speech), len(noise))
        speech = speech[:min_len]
        noise = noise[:min_len]

        # Calculate power
        speech_power = np.mean(speech ** 2) + 1e-8
        noise_power = np.mean(noise ** 2) + 1e-8

        # Calculate scaling factor
        snr_linear = 10 ** (snr_db / 10)
        noise_scale = np.sqrt(speech_power / (noise_power * snr_linear))

        # Scale noise and mix
        noise_scaled = noise * noise_scale
        mixture = speech + noise_scaled

        # Prevent clipping
        max_val = np.max(np.abs(mixture))
        if max_val > 0.95:
            scale = 0.95 / max_val
            mixture *= scale
            speech *= scale
            noise_scaled *= scale

        return mixture, speech, noise_scaled

    def prepare_training_data(self, num_mixtures=100):  # Reduced for Colab
        """Prepare training data (reduced for Colab demo)"""
        print("Preparing training data...")

        # Get LibriSpeech file list
        speech_files = list(self.librispeech_path.glob('**/*.flac'))
        print(f"Found {len(speech_files)} speech files")

        # Get WHAM! noise files
        noise_files = list(self.wham_path.glob('**/*.wav'))
        print(f"Found {len(noise_files)} noise files")

        if len(speech_files) == 0:
            raise ValueError("No speech files found!")

        # SNR range as per Paper 1
        snr_range = range(1, 11)  # 1 to 10 dB

        # Create mixtures for each noise type
        noise_types = ['wham', '1talker', '2talker', '4talker']

        metadata = []

        for noise_type in noise_types:
            print(f"\\nCreating {noise_type} mixtures...")

            for i in tqdm(range(min(num_mixtures, len(speech_files)))):
                try:
                    # Select random speech file
                    speech_file = random.choice(speech_files)
                    speech, sr = self.load_audio(speech_file)

                    # Skip if too short
                    if len(speech) < sr:  # Less than 1 second
                        continue

                    # Select random SNR
                    snr = random.choice(snr_range)

                    # Create or select noise
                    if noise_type == 'wham' and len(noise_files) > 0:
                        noise_file = random.choice(noise_files)
                        noise, _ = self.load_audio(noise_file)
                    else:
                        num_talkers = int(noise_type[0]) if noise_type != 'wham' else 1
                        # Exclude current speech file from babble
                        other_files = [f for f in speech_files if f != speech_file]
                        if len(other_files) >= num_talkers:
                            noise = self.create_multi_talker_babble(other_files, num_talkers)
                        else:
                            # Fallback to synthetic noise
                            noise = np.random.randn(len(speech)) * 0.1

                    # Mix at specified SNR
                    mixture, clean, noise_scaled = self.mix_audio_at_snr(speech, noise, snr)

                    # Save files
                    mixture_name = f'{noise_type}_{i:05d}_snr{snr}.wav'
                    clean_name = f'{noise_type}_{i:05d}_clean.wav'
                    noise_name = f'{noise_type}_{i:05d}_noise.wav'

                    sf.write(self.output_path / 'train' / mixture_name, mixture, sr)
                    sf.write(self.output_path / 'train' / clean_name, clean, sr)
                    sf.write(self.output_path / 'train' / noise_name, noise_scaled, sr)

                    # Store metadata
                    metadata.append({
                        'mixture_path': str(self.output_path / 'train' / mixture_name),
                        'clean_path': str(self.output_path / 'train' / clean_name),
                        'noise_path': str(self.output_path / 'train' / noise_name),
                        'noise_type': noise_type,
                        'snr': snr,
                        'speech_file': str(speech_file),
                        'duration': len(mixture) / sr
                    })
                except Exception as e:
                    print(f"Error processing file: {e}")
                    continue

        # Save metadata
        df = pd.DataFrame(metadata)
        df.to_csv(self.output_path / 'train_metadata.csv', index=False)
        print(f"\\nCreated {len(metadata)} training mixtures")
        print(f"Total duration: {df['duration'].sum() / 3600:.1f} hours")

        return df

In [16]:
def main():
    """Main function to run data preparation"""
    # Paths for Colab
    librispeech_path = "/content/data/raw/librispeech/train-clean-100"
    wham_path = ""
    output_path = "/content/data/processed"

    # Initialize preparer
    preparer = CIDataPreparer(librispeech_path, wham_path, output_path)

    # Prepare training data (reduced for Colab)
    train_df = preparer.prepare_training_data(num_mixtures=100)  # Reduced from 5590

    print("\\nData preparation complete!")

"""

with open('scripts/data_preparation.py', 'w') as f:
    f.write(data_prep_code)

print("Data preparation script saved!")"""

'\n\nwith open(\'scripts/data_preparation.py\', \'w\') as f:\n    f.write(data_prep_code)\n\nprint("Data preparation script saved!")'

In [None]:
with open('configs/resepformer_config.yaml', 'w') as f:
    f.write(config_yaml)

print("Configuration file saved!")

In [39]:
os.chdir("/content/models")
os.getcwd()

'/content/models'

In [40]:
from speechbrain.inference.separation import SepformerSeparation as separator

model = separator.from_hparams(source= "speechbrain/resepformer-wsj02mix", savedir="pretrained_models/resepformer-wsj02mix")


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/resepformer-wsj02mix' if not cached


hyperparams.yaml: 0.00B [00:00, ?B/s]

DEBUG:speechbrain.utils.fetching:Fetch: Local file found, creating symlink '/root/.cache/huggingface/hub/models--speechbrain--resepformer-wsj02mix/snapshots/b8e127bf2b3585c95eebbe7b786e9d3f16675156/hyperparams.yaml' -> '/content/models/pretrained_models/resepformer-wsj02mix/hyperparams.yaml'
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in pretrained_models/resepformer-wsj02mix.
INFO:speechbrain.utils.fetching:Fetch encoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/resepformer-wsj02mix' if not cached


encoder.ckpt:   0%|          | 0.00/9.07k [00:00<?, ?B/s]

DEBUG:speechbrain.utils.fetching:Fetch: Local file found, creating symlink '/root/.cache/huggingface/hub/models--speechbrain--resepformer-wsj02mix/snapshots/b8e127bf2b3585c95eebbe7b786e9d3f16675156/encoder.ckpt' -> '/content/models/pretrained_models/resepformer-wsj02mix/encoder.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["encoder"] = /content/models/pretrained_models/resepformer-wsj02mix/encoder.ckpt
INFO:speechbrain.utils.fetching:Fetch masknet.ckpt: Fetching from HuggingFace Hub 'speechbrain/resepformer-wsj02mix' if not cached


masknet.ckpt:   0%|          | 0.00/186M [00:00<?, ?B/s]

DEBUG:speechbrain.utils.fetching:Fetch: Local file found, creating symlink '/root/.cache/huggingface/hub/models--speechbrain--resepformer-wsj02mix/snapshots/b8e127bf2b3585c95eebbe7b786e9d3f16675156/masknet.ckpt' -> '/content/models/pretrained_models/resepformer-wsj02mix/masknet.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["masknet"] = /content/models/pretrained_models/resepformer-wsj02mix/masknet.ckpt
INFO:speechbrain.utils.fetching:Fetch decoder.ckpt: Fetching from HuggingFace Hub 'speechbrain/resepformer-wsj02mix' if not cached


decoder.ckpt:   0%|          | 0.00/9.00k [00:00<?, ?B/s]

DEBUG:speechbrain.utils.fetching:Fetch: Local file found, creating symlink '/root/.cache/huggingface/hub/models--speechbrain--resepformer-wsj02mix/snapshots/b8e127bf2b3585c95eebbe7b786e9d3f16675156/decoder.ckpt' -> '/content/models/pretrained_models/resepformer-wsj02mix/decoder.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["decoder"] = /content/models/pretrained_models/resepformer-wsj02mix/decoder.ckpt
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: encoder, masknet, decoder
DEBUG:speechbrain.utils.parameter_transfer:Redirecting (loading from local path): encoder -> /content/models/pretrained_models/resepformer-wsj02mix/encoder.ckpt
DEBUG:speechbrain.utils.parameter_transfer:Redirecting (loading from local path): masknet -> /content/models/pretrained_models/resepformer-wsj02mix/masknet.ckpt
DEBUG:speechbrain.utils.parameter_transfer:Redirecting (loading from local path): decoder -> /content/models/pretrained_models/resepformer-wsj

In [106]:
import torch
import torch.nn as nn
import speechbrain as sb
from speechbrain.lobes.models.dual_path import Encoder, Decoder
from speechbrain.lobes.models.resepformer import ResourceEfficientSeparator
# Naming change in newer speechbrain versions might be nnet.normalization
from speechbrain.nnet.normalization import LayerNorm
import torch.nn.functional as F

class RESepFormerModel(nn.Module):
    """
    RE-SepFormer implementation using SpeechBrain's ResourceEfficientSeparator
    Configured to match Paper 2's specifications
    """
    def __init__(self,
                 n_src=2,
                 n_filters=128,
                 kernel_size=16,
                 stride=8,
                 segment_size=150,
                 dropout=0.0,
                 bidirectional=True,
                 mem_type='av',
                 norm_type='gln'):
        super().__init__()

        self.n_src = n_src
        self.n_filters = n_filters

        # Encoder - converts waveform to latent representation
        # CORRECTED: Removed the invalid 'stride' argument.
        # The stride is automatically set to kernel_size // 2 (which is 8 here).
        self.encoder = Encoder(
            kernel_size=kernel_size,
            out_channels=n_filters,
        )

        # Normalization layer
        # CORRECTED: Initialized with `n_filters` instead of the undefined 'normalized_shape'.
        self.norm = LayerNorm(n_filters)

        # RE-SepFormer separator - the key component from Paper 2
        self.separator = ResourceEfficientSeparator(
            input_dim=n_filters,
            causal=False,
            num_spk=n_src,
            nonlinear='relu',
            layer=8,
            unit=512,
            segment_size=segment_size,
            dropout=dropout,
            # Removed mem_type as it might be causing the "Unsupported segment model class" error
            # mem_type=mem_type
        )

        # Decoder - converts back to waveform
        self.decoder = Decoder(
            in_channels=n_filters,
            out_channels=1,
            kernel_size=kernel_size,
            stride=stride,
            bias=False
        )

    def forward(self, mixture):
        """
        Args:
            mixture: [batch, time]
        Returns:
            separated: [batch, n_src, time]
        """
        # Encode the mixture: [B, time] -> [B, N, L]
        mixture_encoded = self.encoder(mixture)

        # CORRECTED: Transpose for LayerNorm, apply norm, and transpose back.
        # LayerNorm expects features in the last dimension.
        # [B, N, L] -> [B, L, N]
        mixture_transposed = mixture_encoded.transpose(1, 2)
        # [B, L, N] -> [B, L, N]
        mixture_normalized_transposed = self.norm(mixture_transposed)
        # [B, L, N] -> [B, N, L]
        mixture_normalized = mixture_normalized_transposed.transpose(1, 2)

        # Separate sources. The separator expects input of shape [B, N, L]
        # and returns masks of shape [B, n_src, N, L]
        masks = self.separator(mixture_normalized)

        # Apply masks to get separated encodings
        # mixture_encoded [B, N, L] -> unsqueeze to [1, B, N, L] for broadcasting
        # masks          [B, n_src, N, L] -> permute to [n_src, B, N, L]
        masked = mixture_encoded.unsqueeze(1) * masks.permute(1, 0, 2, 3)
        masked = masked.permute(1, 0, 2, 3) # [n_src, B, N, L]

        # Decode each source
        separated_sources = []
        for i in range(self.n_src):
            # [B, N, L] -> [B, 1, time]
            decoded = self.decoder(masked[i])
            separated_sources.append(decoded.squeeze(1))

        # Stack sources: list of [B, time] -> [B, n_src, time]
        separated = torch.stack(separated_sources, dim=1)

        return separated

In [107]:
model = RESepFormerModel(n_src=1)  # Single source for enhancement
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

Model created with 20,865 parameters




```

Data Preparation Following Paper 1


Create CI-specific Dataset
```



In [67]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
from pathlib import Path

In [108]:
class CochlearImplantDataset(Dataset):
    """
    Dataset following Paper 1's methodology:
    - LibriSpeech + WHAM! noise
    - Multi-talker babble creation
    - SNRs from 1-10 dB
    """
    def __init__(self,
                 speech_dir,
                 noise_dir,
                 sample_rate=16000,
                 segment_duration=4.0,
                 snr_range=(1, 10),
                 noise_types=['wham', '1talker', '2talker', '4talker']):

        self.speech_dir = Path(speech_dir)
        self.noise_dir = Path(noise_dir)
        self.sample_rate = sample_rate
        self.segment_len = int(segment_duration * sample_rate)
        self.snr_range = snr_range
        self.noise_types = noise_types

        # Collect audio files
        self.speech_files = list(self.speech_dir.glob('**/*.flac'))
        self.noise_files = list(self.noise_dir.glob('**/*.wav'))

        print(f"Found {len(self.speech_files)} speech files")
        print(f"Found {len(self.noise_files)} noise files")

        # Create mixture combinations
        self._create_mixtures()

    def _create_mixtures(self):
        """Create mixture combinations as per Paper 1"""
        self.mixtures = []

        # Following Paper 1: 5590 mixtures per noise type
        mixtures_per_type = min(100, len(self.speech_files))  # Reduced for demo

        for noise_type in self.noise_types:
            for _ in range(mixtures_per_type):
                speech_file = random.choice(self.speech_files)
                snr = random.randint(*self.snr_range)

                if noise_type == 'wham' and self.noise_files:
                    noise_file = random.choice(self.noise_files)
                else:
                    # For multi-talker babble
                    noise_file = None

                self.mixtures.append({
                    'speech_file': speech_file,
                    'noise_file': noise_file,
                    'noise_type': noise_type,
                    'snr': snr
                })

    def __len__(self):
        return len(self.mixtures)

    def _load_audio(self, path):
        """Load and resample audio if needed"""
        import librosa
        audio, sr = librosa.load(path, sr=self.sample_rate, mono=True)
        return audio

    def _create_babble(self, num_talkers):
        """Create multi-talker babble"""
        babble_files = random.sample(self.speech_files, min(num_talkers, len(self.speech_files)))

        babble = None
        for file in babble_files:
            audio = self._load_audio(file)
            if babble is None:
                babble = audio
            else:
                # Mix at equal levels
                min_len = min(len(babble), len(audio))
                babble[:min_len] += audio[:min_len]

        # Normalize
        babble = babble / (np.max(np.abs(babble)) + 1e-8)
        return babble

    def _mix_at_snr(self, speech, noise, snr_db):
        """Mix speech and noise at specified SNR"""
        # Match lengths
        if len(noise) < len(speech):
            noise = np.tile(noise, (len(speech) // len(noise) + 1))
        noise = noise[:len(speech)]

        # Calculate gains
        speech_power = np.mean(speech ** 2)
        noise_power = np.mean(noise ** 2)

        noise_gain = np.sqrt(speech_power / (noise_power * (10 ** (snr_db / 10))))

        # Mix
        mixture = speech + noise * noise_gain

        # Prevent clipping
        max_val = np.max(np.abs(mixture))
        if max_val > 0.95:
            scale = 0.95 / max_val
            mixture *= scale
            speech *= scale

        return mixture, speech

    def __getitem__(self, idx):
        """Get a mixture sample"""
        item = self.mixtures[idx]

        # Load speech
        speech = self._load_audio(item['speech_file'])

        # Create or load noise
        if item['noise_type'] == 'wham' and item['noise_file']:
            noise = self._load_audio(item['noise_file'])
        else:
            # Create multi-talker babble
            num_talkers = int(item['noise_type'][0]) if item['noise_type'] != 'wham' else 1
            noise = self._create_babble(num_talkers)

        # Mix at specified SNR
        mixture, clean_speech = self._mix_at_snr(speech, noise, item['snr'])

        # Truncate or pad to fixed length
        if len(mixture) > self.segment_len:
            start = random.randint(0, len(mixture) - self.segment_len)
            mixture = mixture[start:start + self.segment_len]
            clean_speech = clean_speech[start:start + self.segment_len]
        else:
            pad_len = self.segment_len - len(mixture)
            mixture = np.pad(mixture, (0, pad_len))
            clean_speech = np.pad(clean_speech, (0, pad_len))

        return {
            'mixture': torch.FloatTensor(mixture),
            'clean': torch.FloatTensor(clean_speech),
            'snr': item['snr'],
            'noise_type': item['noise_type']
        }

In [109]:
os.chdir("/content")
os.getcwd()

'/content'

In [90]:
# Create dataset
dataset = CochlearImplantDataset(
    speech_dir='data/raw/librispeech/LibriSpeech/train-clean-100',
    noise_dir='data/raw/wham/wham_noise/tr',
    segment_duration=4.0  # 4 seconds for manageable training
)

print(f"Created dataset with {len(dataset)} mixtures")

Found 28539 speech files
Found 20000 noise files
Created dataset with 400 mixtures




```
Training Pipeline
```



In [110]:
###Loss Functions and Training Setup


def si_snr_loss(estimate, target, eps=1e-8):
    """
    Scale-Invariant Signal-to-Noise Ratio loss
    Used in both Paper 1 and Paper 2
    """
    # Ensure same length
    min_len = min(estimate.shape[-1], target.shape[-1])
    estimate = estimate[..., :min_len]
    target = target[..., :min_len]

    # Remove mean
    estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True)
    target = target - torch.mean(target, dim=-1, keepdim=True)

    # Compute SI-SNR
    dot = torch.sum(estimate * target, dim=-1, keepdim=True)
    target_energy = torch.sum(target ** 2, dim=-1, keepdim=True) + eps

    projection = dot * target / target_energy
    noise = estimate - projection

    si_snr = 10 * torch.log10(
        torch.sum(projection ** 2, dim=-1) / (torch.sum(noise ** 2, dim=-1) + eps)
    )

    return -torch.mean(si_snr)  # Negative for minimization


In [111]:
class Trainer:
    """Training pipeline for RE-SepFormer"""
    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, patience=3, factor=0.5
        )

    def train_epoch(self, dataloader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0

        for batch in tqdm(dataloader, desc="Training"):
            mixture = batch['mixture'].to(self.device)
            clean = batch['clean'].to(self.device)

            # Forward pass
            separated = self.model(mixture)

            # For single source enhancement, use first output
            if separated.shape[1] > 1:
                estimate = separated[:, 0, :]
            else:
                estimate = separated.squeeze(1)

            # Compute loss
            loss = si_snr_loss(estimate, clean)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
            self.optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        return avg_loss

    def evaluate(self, dataloader):
        """Evaluate model"""
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                mixture = batch['mixture'].to(self.device)
                clean = batch['clean'].to(self.device)

                # Forward pass
                separated = self.model(mixture)

                if separated.shape[1] > 1:
                    estimate = separated[:, 0, :]
                else:
                    estimate = separated.squeeze(1)

                # Compute loss
                loss = si_snr_loss(estimate, clean)
                total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        return avg_loss

Train the Model

In [112]:
# Create data loaders
train_loader = DataLoader(
    dataset,
    batch_size=4,  # Small batch size for Colab
    shuffle=True,
    num_workers=2
)

In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RESepFormerModel(n_src=1)  # Single source for CI enhancement
trainer = Trainer(model, device)

In [115]:
# Training loop (reduced for demo)
num_epochs = 70  # Paper 1 used 100 epochs
best_loss = float('inf')

for epoch in range(num_epochs):
    print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

    # Train
    train_loss = trainer.train_epoch(train_loader)
    print(f"Training loss: {train_loss:.4f}")

    # Save checkpoint
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'loss': train_loss
        }, 'best_model.pth')
        print("Saved best model!")


--- Epoch 1/70 ---


Training:   0%|          | 0/100 [00:02<?, ?it/s]


ValueError: Unsupported segment model class