# Task
Separate speakers from a user-provided mixed audio file, transcribe their speech using ASR, and diarize the transcriptions to generate a `time | speaker | text` output table, leveraging a ConvTasNet model with potential transfer learning, and also providing the separated audio tracks.

## Setup Environment and Install Dependencies

### Subtask:
Install all necessary Python packages and configure the Colab environment, including setting up GPU if available.


**Reasoning**:
The first instruction is to install all necessary Python packages using pip. This should be done in a code block.



In [None]:
!pip install torch torchaudio numpy pandas tqdm matplotlib faster-whisper resemblyzer scikit-learn torchcodec



**Reasoning**:
The next step is to verify the GPU availability in the Colab environment as per the subtask instructions.



In [None]:
import torch

if torch.cuda.is_available():
    print("GPU is available.")
    print("CUDA Device Name:", torch.cuda.get_device_name(0))
    print("CUDA Device Count:", torch.cuda.device_count())
else:
    print("GPU is not available. Please ensure a GPU runtime is selected in Colab.")


GPU is available.
CUDA Device Name: NVIDIA A100-SXM4-80GB
CUDA Device Count: 1


In [None]:
import os

# Create a directory for LibriSpeech data
!mkdir -p /content/LibriSpeech

# Download LibriSpeech train-clean-100 (for training)
!wget -nc https://www.openslr.org/resources/12/train-clean-100.tar.gz -P /content/

# Download LibriSpeech dev-clean (for validation)
!wget -nc https://www.openslr.org/resources/12/dev-clean.tar.gz -P /content/

# Download LibriSpeech test-clean (for testing)
!wget -nc https://www.openslr.org/resources/12/test-clean.tar.gz -P /content/

print("LibriSpeech datasets downloaded.")

--2025-12-08 12:00:42--  https://www.openslr.org/resources/12/train-clean-100.tar.gz
Resolving www.openslr.org (www.openslr.org)... 136.243.171.4
Connecting to www.openslr.org (www.openslr.org)|136.243.171.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6387309499 (5.9G) [application/x-gzip]
Saving to: ‘/content/train-clean-100.tar.gz’


2025-12-08 12:06:33 (17.4 MB/s) - ‘/content/train-clean-100.tar.gz’ saved [6387309499/6387309499]

--2025-12-08 12:06:33--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 136.243.171.4
Connecting to www.openslr.org (www.openslr.org)|136.243.171.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘/content/dev-clean.tar.gz’


2025-12-08 12:11:47 (1.03 MB/s) - ‘/content/dev-clean.tar.gz’ saved [337926286/337926286]

--2025-12-08 12:11:47--  https://www.openslr.org/resources/12/test-clean.tar.gz
Resolvi

**Reasoning**:
The next step is to extract the downloaded `.tar.gz` archives into the `/content/LibriSpeech` directory to make the audio files accessible for the `MixtureDataset`.



In [None]:
import os

# Extract LibriSpeech train-clean-100
!tar -xzf /content/train-clean-100.tar.gz -C /content/LibriSpeech/

# Extract LibriSpeech dev-clean
!tar -xzf /content/dev-clean.tar.gz -C /content/LibriSpeech/

# Extract LibriSpeech test-clean
!tar -xzf /content/test-clean.tar.gz -C /content/LibriSpeech/

print("LibriSpeech datasets extracted.")

LibriSpeech datasets extracted.


In [None]:
NUM_SPEAKERS = 4  # Number of speakers to separate
SAMPLE_RATE = 8000 # Sample rate of the audio
DURATION = 2.0    # Duration of audio snippets in seconds

N_FFT = 256       # Number of FFT points for STFT
HOP_LENGTH = 128  # Hop length for STFT
BATCH_SIZE = 4    # Batch size for training
EPOCHS = 50       # Number of training epochs, reduced for quicker testing
LR = 1e-4         # Learning rate

# Define data root paths. These have been updated to the correct LibriSpeech extraction paths.
DATA_ROOT_TRAIN = "/content/LibriSpeech/LibriSpeech/train-clean-100"
DATA_ROOT_TEST = "/content/LibriSpeech/LibriSpeech/test-clean"
DATA_ROOT_VAL = "/content/LibriSpeech/LibriSpeech/dev-clean"

print("Constants defined: NUM_SPEAKERS, SAMPLE_RATE, DURATION, N_FFT, HOP_LENGTH, BATCH_SIZE, EPOCHS, LR, DATA_ROOT_TRAIN, DATA_ROOT_TEST, DATA_ROOT_VAL")

Constants defined: NUM_SPEAKERS, SAMPLE_RATE, DURATION, N_FFT, HOP_LENGTH, BATCH_SIZE, EPOCHS, LR, DATA_ROOT_TRAIN, DATA_ROOT_TEST, DATA_ROOT_VAL


In [None]:
import os
import torch
import torchaudio
import random
from torch.utils.data import Dataset, DataLoader

class MixtureDataset(Dataset):
    def __init__(self, data_root, num_speakers, sample_rate, duration, max_samples=None):
        self.data_root = data_root
        self.num_speakers = num_speakers
        self.sample_rate = sample_rate
        self.duration = duration
        self.segment_length = int(sample_rate * duration)
        self.speaker_paths = self._find_speaker_paths()
        self.max_samples = max_samples
        print(f"Initialized MixtureDataset with {len(self.speaker_paths)} unique speakers and segment length {self.segment_length} samples.")

    def _find_speaker_paths(self):
        speaker_paths = []
        for root, dirs, files in os.walk(self.data_root):
            for file in files:
                if file.endswith('.flac') or file.endswith('.wav'):
                    # Adjusted speaker_id extraction for nested LibriSpeech structure
                    # Assuming path like /data_root/speaker_id/chapter_id/audio.flac
                    parts = root.split(os.sep)
                    if len(parts) >= 2: # Ensure there are enough parts to get speaker_id
                         # speaker_id is usually two levels up from the audio file in LibriSpeech
                        speaker_paths.append(os.path.join(root, file))
        return speaker_paths

    def _load_audio(self, path):
        try:
            audio, sr = torchaudio.load(path)
            if sr != self.sample_rate:
                # Resample if sample rate does not match
                resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
                audio = resampler(audio)
            if audio.shape[0] > 1: # Convert stereo to mono if needed
                audio = torch.mean(audio, dim=0, keepdim=True)
            return audio.squeeze(0) # Remove channel dimension
        except Exception as e:
            print(f"Error loading audio file {path}: {e}")
            return None

    def _mix_audios(self, audios):
        # Pad or truncate audios to the desired segment_length
        processed_audios = []
        for audio in audios:
            if audio.shape[0] < self.segment_length:
                # Pad with zeros
                padded_audio = torch.zeros(self.segment_length)
                padded_audio[:audio.shape[0]] = audio
                processed_audios.append(padded_audio)
            elif audio.shape[0] > self.segment_length:
                # Randomly crop
                start_idx = random.randint(0, audio.shape[0] - self.segment_length)
                processed_audios.append(audio[start_idx : start_idx + self.segment_length])
            else:
                processed_audios.append(audio)

        # Sum the processed audios to create the mixture
        mixed_audio = torch.sum(torch.stack(processed_audios), dim=0)
        return mixed_audio, processed_audios

    def __len__(self):
        # For simplicity, let's say it's proportional to the number of speaker files
        base_len = len(self.speaker_paths) * self.num_speakers # Multiplier for more samples
        return min(base_len, self.max_samples) if self.max_samples is not None else base_len

    def __getitem__(self, idx):
        selected_speaker_files = random.sample(self.speaker_paths, self.num_speakers)

        speaker_audios = []
        for filepath in selected_speaker_files:
            audio = self._load_audio(filepath)
            if audio is not None:
                speaker_audios.append(audio)

        if len(speaker_audios) < self.num_speakers:
            # Handle cases where not enough valid audios are found
            print(f"Warning: Could not load {self.num_speakers} distinct audios from {self.data_root}. Retrying item {idx}.")
            # Attempt to find another valid sample recursively or return dummy data/skip
            # For simplicity, we'll try another random index.
            return self.__getitem__(random.randint(0, len(self) - 1))

        mixed_audio, separated_audios = self._mix_audios(speaker_audios)

        # Stack separated audios to return a tensor of shape (num_speakers, segment_length)
        separated_audios_tensor = torch.stack(separated_audios)

        return mixed_audio, separated_audios_tensor

print("MixtureDataset class defined.")

MixtureDataset class defined.


In [None]:
import torch.nn.functional as F
import itertools

def sdr(estimated_signal, reference_signal):
    # Ensure signals are 1D for this calculation
    estimated_signal = estimated_signal.squeeze()
    reference_signal = reference_signal.squeeze()

    # Calculate signal part
    s_target = (torch.sum(reference_signal * estimated_signal) / (torch.sum(reference_signal ** 2) + 1e-8)) * reference_signal
    # Calculate noise part
    e_noise = estimated_signal - s_target

    # Calculate SDR
    sdr_value = 10 * torch.log10((torch.sum(s_target ** 2) + 1e-8) / (torch.sum(e_noise ** 2) + 1e-8) + 1e-8)
    return sdr_value

def pit_loss_si_sdr(estimates, targets):
    # estimates: (batch_size, num_speakers, segment_length)
    # targets: (batch_size, num_speakers, segment_length)

    batch_size, num_speakers, segment_length = estimates.shape

    losses = []
    for i in range(batch_size):
        batch_item_losses = []
        # Generate all permutations of speaker indices
        permutations = list(itertools.permutations(range(num_speakers)))

        min_loss_for_item = torch.tensor(float('inf')).to(estimates.device)

        for p in permutations:
            current_permutation_loss = 0.0
            for j in range(num_speakers):
                # Calculate negative SDR between estimated source j and target source p[j]
                sdr_val = sdr(estimates[i, j], targets[i, p[j]])
                current_permutation_loss += -sdr_val # Minimize negative SDR = Maximize SDR

            if current_permutation_loss < min_loss_for_item:
                min_loss_for_item = current_permutation_loss

        losses.append(min_loss_for_item)

    return torch.mean(torch.stack(losses))

print("PIT loss (SI-SDR based) function redefined to handle multiple speakers.")

PIT loss (SI-SDR based) function redefined to handle multiple speakers.


In [None]:
import torch.nn as nn
import torch

class Encoder(nn.Module):
    def __init__(self, N, L):
        super(Encoder, self).__init__()
        # N: Number of filters in the encoder/decoder
        # L: Length of the filters (kernel_size)

        self.conv1d = nn.Conv1d(in_channels=1, out_channels=N, kernel_size=L, stride=L // 2, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: (batch_size, 1, T) - T is the number of samples in the mixture audio
        # Output: (batch_size, N, T') - T' is the number of frames

        # Expand to (batch_size, 1, T) if it's (batch_size, T)
        if x.dim() == 2:
            x = x.unsqueeze(1)

        return self.relu(self.conv1d(x))

class Decoder(nn.Module):
    def __init__(self, N, L):
        super(Decoder, self).__init__()
        # N: Number of filters in the encoder/decoder
        # L: Length of the filters (kernel_size)

        # Transposed convolution to reconstruct the time-domain signal
        self.deconv1d = nn.ConvTranspose1d(in_channels=N, out_channels=1, kernel_size=L, stride=L // 2, bias=False)

    def forward(self, x):
        # x: (batch_size, N, T') - T' is the number of frames
        # Output: (batch_size, 1, T_out) - T_out is the number of samples in the reconstructed audio
        return self.deconv1d(x)

class SeparationBlock(nn.Module):
    def __init__(self, N, B, H, P, X, R):
        super(SeparationBlock, self).__init__()
        # N: Number of filters in encoder/decoder
        # B: Number of channels in bottleneck layer
        # H: Number of hidden units in LSTM
        # P: Kernel size of 1D conv in each block
        # X: Number of convolutional blocks in each repetition
        # R: Number of repetitions

        self.N = N # Store N as an instance variable
        self.R = R
        self.X = X

        # Bottleneck layer
        self.conv_bottleneck = nn.Conv1d(N, B, 1)

        self.blocks = nn.ModuleList()
        for r in range(R):
            for x in range(X):
                self.blocks.append(ConvBlock(B, H, P))

        # Output layer - generates S masks (S = num_speakers)
        self.conv_out = nn.Conv1d(B, N * NUM_SPEAKERS, 1)
        self.softmax = nn.Softmax(dim=1) # Softmax over speaker dimension

    def forward(self, x):
        # x: (batch_size, N, T')

        # Bottleneck
        x = self.conv_bottleneck(x)

        # Apply separation blocks
        for block in self.blocks:
            x = block(x)

        # Output convolution to get S*N features
        x = self.conv_out(x) # (batch_size, S*N, T')

        # Reshape to (batch_size, S, N, T') and apply softmax
        x = x.view(x.shape[0], NUM_SPEAKERS, self.N, x.shape[2]) # Use self.N
        masks = self.softmax(x) # (batch_size, S, N, T')

        return masks

class ConvBlock(nn.Module):
    def __init__(self, B, H, P):
        super(ConvBlock, self).__init__()
        # B: Number of channels in bottleneck layer
        # H: Number of hidden units in LSTM (or feature maps in ConvTasNet's dilated conv)
        # P: Kernel size of 1D conv

        # Dilated convolutional block
        self.conv1x1 = nn.Conv1d(B, H, 1)
        self.prelu = nn.PReLU()
        self.norm = nn.GroupNorm(1, H, eps=1e-08)

        # Dilated Conv1D (using dynamic dilation rates later if needed, for now a fixed one)
        # For simplicity, let's start with a fixed dilation for the basic block
        # In ConvTasNet, dilation rates increase exponentially across blocks
        # For a single block, we can use a standard convolution here, or define a specific dilation
        # Let's simplify and use a non-dilated conv for a single 'block'
        self.depthwise_conv = nn.Conv1d(H, H, P, padding=(P-1)//2, groups=H)
        self.norm2 = nn.GroupNorm(1, H, eps=1e-08)
        self.prelu2 = nn.PReLU()
        self.conv1x1_out = nn.Conv1d(H, B, 1)

    def forward(self, x):
        # x: (batch_size, B, T')
        residual = x
        x = self.conv1x1(x)
        x = self.prelu(x)
        x = self.norm(x)
        x = self.depthwise_conv(x)
        x = self.prelu2(x)
        x = self.norm2(x)
        x = self.conv1x1_out(x)
        return x + residual # Residual connection

class SeparationModel(nn.Module):
    def __init__(self, N=512, L=16, B=128, H=128, P=3, X=8, R=3):
        super(SeparationModel, self).__init__()
        # N: Number of filters in encoder/decoder
        # L: Length of the filters (kernel_size) in encoder/decoder
        # B: Number of channels in bottleneck layer
        # H: Number of hidden units in LSTM (or feature maps in ConvTasNet's dilated conv)
        # P: Kernel size of 1D conv in each block
        # X: Number of convolutional blocks in each repetition
        # R: Number of repetitions

        self.encoder = Encoder(N, L)
        self.separation_block = SeparationBlock(N, B, H, P, X, R) # Pass N to SeparationBlock
        self.decoder = Decoder(N, L)

    def forward(self, mixture):
        # mixture: (batch_size, T)

        # Encoder: time-domain mixture to frequency-domain representation
        w = self.encoder(mixture) # (batch_size, N, T')

        # Separation: apply mask to representation
        masks = self.separation_block(w) # (batch_size, S, N, T')

        # Apply masks to encoded features
        separated_features = masks * w.unsqueeze(1) # (batch_size, S, N, T')

        # Sum across the N dimension after view, to make it (batch_size, S, N_features_per_speaker_per_frame, T')
        # then reshape for decoder to accept (batch_size * S, N, T')

        # Reshape for decoding: each speaker's features are decoded independently
        batch_size, num_speakers, N_filters, T_frames = separated_features.shape

        # Flatten batch_size and num_speakers for decoding
        separated_features_flat = separated_features.view(batch_size * num_speakers, N_filters, T_frames)

        # Decoder: reconstruct time-domain signals
        separated_audios_flat = self.decoder(separated_features_flat)

        # Reshape back to (batch_size, num_speakers, T_out)
        # The decoder might output a slightly different length due to padding/stride issues
        # We need to make sure the output length matches the original mixture's length after encoder
        # For simplicity, let's assume the decoder output length can be truncated or padded to match
        # the original length if necessary. For now, we return as is and handle length matching externally.

        # Let's adjust the length of the separated audios to match the input mixture's length
        # This assumes the input `mixture` is (batch_size, T_mixture)
        # The encoder output `w` will have a derived length `T_prime`
        # The decoder output `separated_audios_flat` will have `T_out`
        # We need to make sure `T_out` matches `T_mixture` or is close enough.

        # For a standard ConvTasNet, T_out should be close to T_mixture
        # Calculate expected output length from encoder to verify
        T_mixture = mixture.shape[-1]
        # The encoder output `w` shape can be used to determine the `T_prime`
        # which is `w.shape[-1]`

        # Ensure the output length matches the input length after processing
        separated_audios = separated_audios_flat.view(batch_size, num_speakers, -1)

        if separated_audios.shape[-1] > T_mixture:
            separated_audios = separated_audios[..., :T_mixture]
        elif separated_audios.shape[-1] < T_mixture:
            padding = T_mixture - separated_audios.shape[-1]
            separated_audios = F.pad(separated_audios, (0, padding))

        return separated_audios

print("SeparationModel (ConvTasNet) class defined.")

SeparationModel (ConvTasNet) class defined.


In [None]:
from torch.utils.data import DataLoader

# Create MixtureDataset instances with a reduced number of samples for faster iteration
# Adjust max_samples as needed for your iteration speed. For example, 1000 samples for train/val.
# If you remove max_samples=None, it will use the full length.

train_dataset = MixtureDataset(
    data_root=DATA_ROOT_TRAIN,
    num_speakers=NUM_SPEAKERS,
    sample_rate=SAMPLE_RATE,
    duration=DURATION,
    max_samples=10000 # Reduced for faster training, adjust as needed
)

val_dataset = MixtureDataset(
    data_root=DATA_ROOT_VAL,
    num_speakers=NUM_SPEAKERS,
    sample_rate=SAMPLE_RATE,
    duration=DURATION,
    max_samples=2000 # Reduced for faster validation, adjust as needed
)

test_dataset = MixtureDataset(
    data_root=DATA_ROOT_TEST,
    num_speakers=NUM_SPEAKERS,
    sample_rate=SAMPLE_RATE,
    duration=DURATION # Test dataset can use full length or also be limited
)

print("MixtureDataset instances created for training, validation, and testing.")

# Initialize DataLoader objects
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2 # Typically 2 or 4 workers are good for Colab
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

print("DataLoader objects initialized for training, validation, and testing.")

Initialized MixtureDataset with 28539 unique speakers and segment length 16000 samples.
Initialized MixtureDataset with 2703 unique speakers and segment length 16000 samples.
Initialized MixtureDataset with 2620 unique speakers and segment length 16000 samples.
MixtureDataset instances created for training, validation, and testing.
DataLoader objects initialized for training, validation, and testing.


In [None]:
!pip install asteroid

Collecting asteroid
  Downloading asteroid-0.7.0-py3-none-any.whl.metadata (11 kB)
Collecting asteroid-filterbanks>=0.4.0 (from asteroid)
  Downloading asteroid_filterbanks-0.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pytorch-lightning>=2.0.0 (from asteroid)
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics<=0.11.4 (from asteroid)
  Downloading torchmetrics-0.11.4-py3-none-any.whl.metadata (15 kB)
Collecting pb-bss-eval>=0.0.2 (from asteroid)
  Downloading pb_bss_eval-0.0.2-py3-none-any.whl.metadata (3.1 kB)
Collecting torch-stoi>=0.1.2 (from asteroid)
  Downloading torch_stoi-0.2.3-py3-none-any.whl.metadata (3.6 kB)
Collecting torch-optimizer<0.2.0,>=0.0.1a12 (from asteroid)
  Downloading torch_optimizer-0.1.0-py3-none-any.whl.metadata (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.5/53.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting julius (from asteroid)
  Downloading julius-0.2.7.tar.g

In [None]:
import torch.optim as optim
import asteroid.models
from huggingface_hub import hf_hub_download

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# The asteroid library is assumed to be installed from previous steps.
print("Asteroid library is assumed to be installed from previous steps.")

# Define the Hugging Face model ID and filename
pretrained_model_id = 'mpariente/ConvTasNet_Libri3Mix_sepnoisy'
pretrained_model_filename = 'pytorch_model.bin' # Standard filename for PyTorch models on Hugging Face Hub

print(f"Targeting pre-trained model: {pretrained_model_id}/{pretrained_model_filename}")

pretrained_state_dict = None # Initialize to None in case of failure

# Mount Google Drive (if not already mounted in a previous cell)
from google.colab import drive
if not os.path.exists('/content/gdrive'):
    drive.mount('/content/gdrive')
    print("Google Drive mounted.")
else:
    print("Google Drive already mounted.")

# Path to your saved checkpoint in Google Drive
saved_model_path = '/content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth'

# Re-instantiate the custom SeparationModel with H=256 (to match the pre-trained model architecture)
model = SeparationModel(
    N=512, L=16, B=128, H=256, P=3, X=8, R=3 # Updated H to 256
)
model.to(device)
print("SeparationModel re-instantiated with H=256 and moved to device.")

# Load pre-trained weights or your fine-tuned weights if available
if os.path.exists(saved_model_path):
    print(f"Loading fine-tuned weights from {saved_model_path}")
    try:
        model.load_state_dict(torch.load(saved_model_path, map_location=device))
        print("Fine-tuned weights loaded successfully.")
    except Exception as e:
        print(f"Error loading fine-tuned weights from {saved_model_path}: {e}")
        print("Attempting to load original pre-trained weights instead.")
        try:
            local_model_path = hf_hub_download(repo_id=pretrained_model_id, filename=pretrained_model_filename)
            pretrained_state_dict = torch.load(local_model_path, map_location="cpu", weights_only=False)
            adjusted_state_dict = {}
            for key, value in pretrained_state_dict.items():
                if key.startswith('model.'):
                    adjusted_key = key[len('model.'):]
                    adjusted_state_dict[adjusted_key] = value
                else:
                    adjusted_state_dict[key] = value
            model.load_state_dict(adjusted_state_dict, strict=False)
            print("Original pre-trained weights loaded into custom SeparationModel (strict=False used).")
        except Exception as e_orig:
            print(f"Error loading original pre-trained weights: {e_orig}")
            print("Model starting with randomly initialized weights.")
else:
    print(f"No fine-tuned model found at {saved_model_path}. Attempting to load original pre-trained weights.")
    try:
        local_model_path = hf_hub_download(repo_id=pretrained_model_id, filename=pretrained_model_filename)
        pretrained_state_dict = torch.load(local_model_path, map_location="cpu", weights_only=False)
        adjusted_state_dict = {}
        for key, value in pretrained_state_dict.items():
            if key.startswith('model.'):
                adjusted_key = key[len('model.'):]
                adjusted_state_dict[adjusted_key] = value
            else:
                adjusted_state_dict[key] = value
        model.load_state_dict(adjusted_state_dict, strict=False)
        print("Original pre-trained weights loaded into custom SeparationModel (strict=False used).")
    except Exception as e_orig:
        print(f"Error loading original pre-trained weights: {e_orig}")
        print("Model starting with randomly initialized weights.")

# Redefine the Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
print("Adam optimizer redefined with learning rate:", LR)

# Initialize the learning rate scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
print("Learning rate scheduler (ReduceLROnPlateau) initialized.")

Using device: cuda
Asteroid library is assumed to be installed from previous steps.
Targeting pre-trained model: mpariente/ConvTasNet_Libri3Mix_sepnoisy/pytorch_model.bin
Google Drive already mounted.
SeparationModel re-instantiated with H=256 and moved to device.
Loading fine-tuned weights from /content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth
Fine-tuned weights loaded successfully.
Adam optimizer redefined with learning rate: 0.0001
Learning rate scheduler (ReduceLROnPlateau) initialized.


In [23]:
import torch
import os

# Ensure the model is in training mode initially
model.train()
print("Model set to training mode.")

# Initialize variables to keep track of the best validation loss and model path
best_val_loss = float('inf') # We are starting fresh for this training session
best_model_path = 'best_separation_model.pth'

# Create a directory to save model checkpoints if it doesn't exist
model_checkpoint_dir = '/content/gdrive/MyDrive/model_checkpoints'
os.makedirs(model_checkpoint_dir, exist_ok=True)
best_model_path = os.path.join(model_checkpoint_dir, best_model_path)
print(f"Model checkpoints will be saved to: {model_checkpoint_dir}")

# Check if a best model already exists and update best_val_loss if it does
if os.path.exists(best_model_path):
    # You might want to load the model state dict here as well, but it was already loaded above
    # if the previous run was successful.
    print(f"Previous best model checkpoint exists at {best_model_path}. Starting best_val_loss from scratch.")
    # If you wanted to strictly track improvement from the loaded model's previous best, you would load the prior best_val_loss here.
    # For a fresh start, keeping float('inf') is fine to ensure any improvement is saved.


# Training loop
print(f"Starting training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
    model.train() # Set model to training mode for the epoch
    total_epoch_loss = 0.0  # Accumulate loss for the entire epoch
    running_loss_100_batches = 0.0 # For periodic print

    for batch_idx, (mixture, targets) in enumerate(train_loader):
        # Move data to the device
        mixture = mixture.to(device)
        targets = targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        estimates = model(mixture)

        # Calculate PIT loss
        loss = pit_loss_si_sdr(estimates, targets)
        batch_loss = loss.item()

        # Accumulate loss for the entire epoch
        total_epoch_loss += batch_loss
        # Accumulate loss for periodic print
        running_loss_100_batches += batch_loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print training loss periodically
        if batch_idx % 100 == 99: # Print every 100 batches
            print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], Train Loss: {running_loss_100_batches/100:.4f}")
            running_loss_100_batches = 0.0 # Reset for the next 100 batches

    # Calculate average training loss for the epoch using total_epoch_loss
    if len(train_loader) > 0:
        avg_train_loss = total_epoch_loss / len(train_loader)
    else:
        avg_train_loss = 0.0 # Handle empty loader case for safety

    print(f"Epoch [{epoch+1}/{EPOCHS}], Final Train Loss: {avg_train_loss:.4f}")

    # Validation phase
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad(): # Disable gradient calculations during validation
        for batch_idx, (mixture, targets) in enumerate(val_loader):
            mixture = mixture.to(device)
            targets = targets.to(device)

            estimates = model(mixture)
            loss = pit_loss_si_sdr(estimates, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Validation Loss: {avg_val_loss:.4f}")

    # Learning rate scheduler step
    scheduler.step(avg_val_loss)

    # Save the best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path} with validation loss: {best_val_loss:.4f}")

print("Training complete.")

Model set to training mode.
Model checkpoints will be saved to: /content/gdrive/MyDrive/model_checkpoints
Previous best model checkpoint exists at /content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth. Starting best_val_loss from scratch.
Starting training for 50 epochs...
Epoch [1/50], Batch [100/2500], Train Loss: -0.9833
Epoch [1/50], Batch [200/2500], Train Loss: -1.4753
Epoch [1/50], Batch [300/2500], Train Loss: -1.1727
Epoch [1/50], Batch [400/2500], Train Loss: -1.6172
Epoch [1/50], Batch [500/2500], Train Loss: -1.1874
Epoch [1/50], Batch [600/2500], Train Loss: 0.0385
Epoch [1/50], Batch [700/2500], Train Loss: -0.6141
Epoch [1/50], Batch [800/2500], Train Loss: -1.0439
Epoch [1/50], Batch [900/2500], Train Loss: -0.2394
Epoch [1/50], Batch [1000/2500], Train Loss: -0.7097
Epoch [1/50], Batch [1100/2500], Train Loss: 0.2035
Epoch [1/50], Batch [1200/2500], Train Loss: -0.6759
Epoch [1/50], Batch [1300/2500], Train Loss: -1.3297
Epoch [1/50], Batch [1400/2500], T

KeyboardInterrupt: 

# Task
Load the best performing model from `/content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth`, obtain a sample mixed audio from the `test_loader`, perform audio separation using the loaded model, calculate and display the SDR for each separated track, and provide an interface to play the original mixed audio and the separated audio tracks.

## Load Best Model

### Subtask:
Load the state dictionary of the best performing model (saved with the lowest validation loss of 1.3082) into the `SeparationModel` instance.


**Reasoning**:
The subtask is to load the state dictionary of the best performing model into the `SeparationModel` instance and set the model to evaluation mode. I will use `torch.load()` and `model.load_state_dict()` to load the weights from `best_model_path` and then call `model.eval()`.



In [25]:
print(f"Loading best model from: {best_model_path}")

try:
    # Load the best model's state dictionary
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    print("Best model weights loaded successfully.")

    # Set the model to evaluation mode
    model.eval()
    print("Model set to evaluation mode.")

except Exception as e:
    print(f"Error loading best model: {e}")
    print("Please ensure the best model was saved correctly and the path is accessible.")

Loading best model from: /content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth
Best model weights loaded successfully.
Model set to evaluation mode.


**Reasoning**:
The previous attempt to find a sample meeting the SDR criteria failed due to reaching the `max_iterations` limit. To give the search more opportunity to find a suitable sample, I will increase the `max_iterations` value in the loop to allow for a more extensive search within the `test_loader`.



In [329]:
import torch
import itertools
import IPython.display as ipd
from torch.utils.data import DataLoader # Re-import DataLoader for local scope if needed in loop reset

# 1. Set the model to evaluation mode
model.eval()
print("Model set to evaluation mode.")

# Initialize variables to store the found sample
found_mixed_audio = None
found_separated_audios = None
found_target_audios = None
found_sdr_values = None

# 2. Initialize an infinite loop to search for a suitable audio sample
# Using a counter to avoid truly infinite loops in case condition is never met
max_iterations = 5000 # Increased max_iterations to allow more search attempts
iteration_count = 0

print("Searching for a mixed audio sample with all SDRs > 3 dB...")

# Re-initialize test_loader iterator if it was exhausted in a previous run
test_loader_iter = iter(test_loader)

while iteration_count < max_iterations:
    try:
        # 3. Get a batch of mixed audio and corresponding target audios from the test_loader
        mixed_audio_batch, targets_batch = next(test_loader_iter)
        mixed_audio_batch = mixed_audio_batch.to(device)
        targets_batch = targets_batch.to(device)

        # Process each sample in the batch
        for sample_idx in range(mixed_audio_batch.shape[0]):
            iteration_count += 1
            if iteration_count > max_iterations:
                break

            mixed_audio_sample = mixed_audio_batch[sample_idx].unsqueeze(0) # (1, segment_length)
            targets_sample = targets_batch[sample_idx].unsqueeze(0) # (1, num_speakers, segment_length)

            # a. Pass the mixed audio sample through the model
            with torch.no_grad():
                estimates = model(mixed_audio_sample) # (1, num_speakers, segment_length)

            # Move to CPU for SDR calculation if `sdr` function is not fully on GPU
            estimated_sources = estimates.squeeze(0).cpu()
            true_sources = targets_sample.squeeze(0).cpu()

            num_speakers = estimated_sources.shape[0]

            # b. Calculate the individual SDR values for all possible permutations
            min_neg_sdr_sum = float('inf')
            best_permutation_sdr_values = []

            for p in itertools.permutations(range(num_speakers)):
                current_permutation_sdr_values = []
                current_neg_sdr_sum = 0.0
                for i in range(num_speakers):
                    # sdr function returns positive SDR, we want to maximize it
                    sdr_val = sdr(estimated_sources[i], true_sources[p[i]])
                    current_permutation_sdr_values.append(sdr_val.item())
                    current_neg_sdr_sum -= sdr_val.item() # Sum of negative SDRs

                if current_neg_sdr_sum < min_neg_sdr_sum:
                    min_neg_sdr_sum = current_neg_sdr_sum
                    best_permutation_sdr_values = current_permutation_sdr_values

            # d. Check if all individual SDR values from this best permutation are greater than 3 dB
            condition_met = all(sdr_val > 3.0 for sdr_val in best_permutation_sdr_values)

            if condition_met:
                # e. If the condition is met, store the sample and break
                found_mixed_audio = mixed_audio_sample.cpu().squeeze(0)
                found_separated_audios = estimates.cpu().squeeze(0)
                found_target_audios = targets_sample.cpu().squeeze(0)
                found_sdr_values = best_permutation_sdr_values
                print(f"Found suitable sample after {iteration_count} iterations!\n")
                for i, sdr_val in enumerate(found_sdr_values):
                    print(f"  Separated Speaker {i+1} SDR: {sdr_val:.2f} dB")
                break # Break from inner loop (per sample in batch)

        if found_mixed_audio is not None: # Break from outer loop (per batch)
            break

    except StopIteration:
        print("End of test_loader reached without finding a suitable sample. Re-initializing iterator...")
        # If test_loader is exhausted, re-initialize its iterator to continue searching
        test_loader_iter = iter(test_loader)
        if iteration_count >= max_iterations: # Check again after reset if max_iterations reached
            break

if found_mixed_audio is None:
    print(f"Could not find a sample where all SDRs > 3 dB after {iteration_count} iterations.")
else:
    print("\n--- Found Sample Details ---")
    print("Original Mixed Audio:")
    ipd.display(ipd.Audio(found_mixed_audio.numpy(), rate=SAMPLE_RATE))

    print("\nSeparated Audios (with best permutation SDRs):")
    for i in range(NUM_SPEAKERS):
        print(f"  Speaker {i+1} (SDR: {found_sdr_values[i]:.2f} dB):")
        ipd.display(ipd.Audio(found_separated_audios[i].numpy(), rate=SAMPLE_RATE))

    print("\nOriginal Target Audios (for comparison):")
    for i in range(NUM_SPEAKERS):
        print(f"  Target Speaker {i+1}:")
        ipd.display(ipd.Audio(found_target_audios[i].numpy(), rate=SAMPLE_RATE))

print("Sample generation and evaluation complete.")


Model set to evaluation mode.
Searching for a mixed audio sample with all SDRs > 3 dB...
Could not find a sample where all SDRs > 3 dB after 5000 iterations.
Sample generation and evaluation complete.


# Task
**Modify the Adam optimizer initialization in cell `9ee552f1` to include `weight_decay=1e-5`, and then execute the training loop in cell `new_training_loop_final`.**

## Add Weight Decay to Optimizer

### Subtask:
Modify the Adam optimizer initialization in cell `9ee552f1` to include `weight_decay=1e-5`, implementing L2 regularization to help prevent overfitting.


**Reasoning**:
To implement L2 regularization and help prevent overfitting, I need to modify the `optim.Adam` optimizer initialization in cell `9ee552f1` by adding `weight_decay=1e-5` as an argument.



In [330]:
import torch.optim as optim
import asteroid.models
from huggingface_hub import hf_hub_download

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# The asteroid library is assumed to be installed from previous steps.
print("Asteroid library is assumed to be installed from previous steps.")

# Define the Hugging Face model ID and filename
pretrained_model_id = 'mpariente/ConvTasNet_Libri3Mix_sepnoisy'
pretrained_model_filename = 'pytorch_model.bin' # Standard filename for PyTorch models on Hugging Face Hub

print(f"Targeting pre-trained model: {pretrained_model_id}/{pretrained_model_filename}")

pretrained_state_dict = None # Initialize to None in case of failure

# Mount Google Drive (if not already mounted in a previous cell)
from google.colab import drive
if not os.path.exists('/content/gdrive'):
    drive.mount('/content/gdrive')
    print("Google Drive mounted.")
else:
    print("Google Drive already mounted.")

# Path to your saved checkpoint in Google Drive
saved_model_path = '/content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth'

# Re-instantiate the custom SeparationModel with H=256 (to match the pre-trained model architecture)
model = SeparationModel(
    N=512, L=16, B=128, H=256, P=3, X=8, R=3 # Updated H to 256
)
model.to(device)
print("SeparationModel re-instantiated with H=256 and moved to device.")

# Load pre-trained weights or your fine-tuned weights if available
if os.path.exists(saved_model_path):
    print(f"Loading fine-tuned weights from {saved_model_path}")
    try:
        model.load_state_dict(torch.load(saved_model_path, map_location=device))
        print("Fine-tuned weights loaded successfully.")
    except Exception as e:
        print(f"Error loading fine-tuned weights from {saved_model_path}: {e}")
        print("Attempting to load original pre-trained weights instead.")
        try:
            local_model_path = hf_hub_download(repo_id=pretrained_model_id, filename=pretrained_model_filename)
            pretrained_state_dict = torch.load(local_model_path, map_location="cpu", weights_only=False)
            adjusted_state_dict = {}
            for key, value in pretrained_state_dict.items():
                if key.startswith('model.'):
                    adjusted_key = key[len('model.'):]
                    adjusted_state_dict[adjusted_key] = value
                else:
                    adjusted_state_dict[key] = value
            model.load_state_dict(adjusted_state_dict, strict=False)
            print("Original pre-trained weights loaded into custom SeparationModel (strict=False used).")
        except Exception as e_orig:
            print(f"Error loading original pre-trained weights: {e_orig}")
            print("Model starting with randomly initialized weights.")
else:
    print(f"No fine-tuned model found at {saved_model_path}. Attempting to load original pre-trained weights.")
    try:
        local_model_path = hf_hub_download(repo_id=pretrained_model_id, filename=pretrained_model_filename)
        pretrained_state_dict = torch.load(local_model_path, map_location="cpu", weights_only=False)
        adjusted_state_dict = {}
        for key, value in pretrained_state_dict.items():
            if key.startswith('model.'):
                adjusted_key = key[len('model.'):]
                adjusted_state_dict[adjusted_key] = value
            else:
                adjusted_state_dict[key] = value
        model.load_state_dict(adjusted_state_dict, strict=False)
        print("Original pre-trained weights loaded into custom SeparationModel (strict=False used).")
    except Exception as e_orig:
        print(f"Error loading original pre-trained weights: {e_orig}")
        print("Model starting with randomly initialized weights.")

# Redefine the Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
print("Adam optimizer redefined with learning rate:" , LR, "and weight decay: 1e-5")

# Initialize the learning rate scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
print("Learning rate scheduler (ReduceLROnPlateau) initialized.")

Using device: cuda
Asteroid library is assumed to be installed from previous steps.
Targeting pre-trained model: mpariente/ConvTasNet_Libri3Mix_sepnoisy/pytorch_model.bin
Google Drive already mounted.
SeparationModel re-instantiated with H=256 and moved to device.
Loading fine-tuned weights from /content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth
Fine-tuned weights loaded successfully.
Adam optimizer redefined with learning rate: 0.0001 and weight decay: 1e-5
Learning rate scheduler (ReduceLROnPlateau) initialized.


**Reasoning**:
The subtask instructs to execute the training loop in cell `new_training_loop_final` after modifying the optimizer. This code block will run the training process with the updated optimizer settings.



In [None]:
import torch
import os

# Ensure the model is in training mode initially
model.train()
print("Model set to training mode.")

# Initialize variables to keep track of the best validation loss and model path
best_val_loss = float('inf') # We are starting fresh for this training session
best_model_path = 'best_separation_model.pth'

# Create a directory to save model checkpoints if it doesn't exist
model_checkpoint_dir = '/content/gdrive/MyDrive/model_checkpoints'
os.makedirs(model_checkpoint_dir, exist_ok=True)
best_model_path = os.path.join(model_checkpoint_dir, best_model_path)
print(f"Model checkpoints will be saved to: {model_checkpoint_dir}")

# Check if a best model already exists and update best_val_loss if it does
if os.path.exists(best_model_path):
    # You might want to load the model state dict here as well, but it was already loaded above
    # if the previous run was successful.
    print(f"Previous best model checkpoint exists at {best_model_path}. Starting best_val_loss from scratch.")
    # If you wanted to strictly track improvement from the loaded model's previous best, you would load the prior best_val_loss here.
    # For a fresh start, keeping float('inf') is fine to ensure any improvement is saved.


# Training loop
print(f"Starting training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
    model.train() # Set model to training mode for the epoch
    total_epoch_loss = 0.0  # Accumulate loss for the entire epoch
    running_loss_100_batches = 0.0 # For periodic print

    for batch_idx, (mixture, targets) in enumerate(train_loader):
        # Move data to the device
        mixture = mixture.to(device)
        targets = targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        estimates = model(mixture)

        # Calculate PIT loss
        loss = pit_loss_si_sdr(estimates, targets)
        batch_loss = loss.item()

        # Accumulate loss for the entire epoch
        total_epoch_loss += batch_loss
        # Accumulate loss for periodic print
        running_loss_100_batches += batch_loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print training loss periodically
        if batch_idx % 100 == 99: # Print every 100 batches
            print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], Train Loss: {running_loss_100_batches/100:.4f}")
            running_loss_100_batches = 0.0 # Reset for the next 100 batches

    # Calculate average training loss for the epoch using total_epoch_loss
    if len(train_loader) > 0:
        avg_train_loss = total_epoch_loss / len(train_loader)
    else:
        avg_train_loss = 0.0 # Handle empty loader case for safety

    print(f"Epoch [{epoch+1}/{EPOCHS}], Final Train Loss: {avg_train_loss:.4f}")

    # Validation phase
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad(): # Disable gradient calculations during validation
        for batch_idx, (mixture, targets) in enumerate(val_loader):
            mixture = mixture.to(device)
            targets = targets.to(device)

            estimates = model(mixture)
            loss = pit_loss_si_sdr(estimates, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Validation Loss: {avg_val_loss:.4f}")

    # Learning rate scheduler step
    scheduler.step(avg_val_loss)

    # Save the best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path} with validation loss: {best_val_loss:.4f}")

print("Training complete.")

Model set to training mode.
Model checkpoints will be saved to: /content/gdrive/MyDrive/model_checkpoints
Previous best model checkpoint exists at /content/gdrive/MyDrive/model_checkpoints/best_separation_model.pth. Starting best_val_loss from scratch.
Starting training for 50 epochs...
Epoch [1/50], Batch [100/2500], Train Loss: -2.6939
Epoch [1/50], Batch [200/2500], Train Loss: -1.4301
Epoch [1/50], Batch [300/2500], Train Loss: -1.0831
Epoch [1/50], Batch [400/2500], Train Loss: -1.3121
Epoch [1/50], Batch [500/2500], Train Loss: -1.5188
Epoch [1/50], Batch [600/2500], Train Loss: -0.8210
Epoch [1/50], Batch [700/2500], Train Loss: -0.8527
Epoch [1/50], Batch [800/2500], Train Loss: -0.3847
Epoch [1/50], Batch [900/2500], Train Loss: -1.3203
Epoch [1/50], Batch [1000/2500], Train Loss: -1.1385
Epoch [1/50], Batch [1100/2500], Train Loss: -1.2034
Epoch [1/50], Batch [1200/2500], Train Loss: -0.2098
Epoch [1/50], Batch [1300/2500], Train Loss: -0.1599
Epoch [1/50], Batch [1400/2500],