<a href="https://colab.research.google.com/github/JerinJoe/Res-Unet-for-Audio-Mask-Prediction/blob/main/Res_UNet_for_audio_mask_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
!git clone https://github.com/fgnt/bsseval.git

Cloning into 'bsseval'...
fatal: could not read Username for 'https://github.com': No such device or address


In [2]:
!pip install bsseval -q

[31mERROR: Could not find a version that satisfies the requirement bsseval (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for bsseval[0m[31m
[0m

In [1]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import librosa
import librosa.display
import matplotlib.pyplot as plt
from bsseval import bss_eval_sources
import numpy as np
import zipfile
from torch.utils.data import Dataset, DataLoader

ModuleNotFoundError: No module named 'bsseval'

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out += residual
        out = self.relu(out)
        return out

class ResUNet(nn.Module):
    def __init__(self, num_sources):
        super(ResUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=3, padding=1),  # 4 input channels for MUSDB18
            ResidualBlock(64, 64),
            nn.MaxPool1d(2),
            ResidualBlock(64, 64),
            nn.MaxPool1d(2),
            ResidualBlock(64, 64)
        )
        # Masking layer added after the encoder
        self.mask_layer = nn.Conv1d(64, num_sources * 64, kernel_size=1)

    def forward(self, x):
        x = self.encoder(x)
        # Apply masking layer after encoder output
        masks = self.mask_layer(x).sigmoid()  # Sigmoid for masks between 0 and


In [None]:
class MUSDB18Dataset(Dataset):
    def __init__(self, zip_path, subset='train'):
        self.zip_path = zip_path
        self.subset = subset
        self.zip_ref = zipfile.ZipFile(zip_path, 'r')  # Open zip file

    def __len__(self):
        # Assuming 100 mixtures for training or test based on subset
        num_mixtures = 100 if self.subset == 'train' else 50  # Modify if different
        return num_mixtures

    def __getitem__(self, idx):
        with self.zip_ref.open(f'{self.subset}/mix_{idx+1}.wav', 'r') as mix_file:  # Open mixture from zip
            mix, _ = librosa.load(mix_file, sr=44100)

        # Assuming sources are stored in separate files within the zip (modify if structure differs)
        sources = []
        for i in range(1, 5):  # Assuming 4 sources (modify for different numbers)
            with self.zip_ref.open(f'{self.subset}/sources/s{i}.wav', 'r') as source_file:
                source, _ = librosa.load(source_file, sr=44100)
                sources.append(source)

        # Preprocess: Convert to spectrogram (or mel-spectrogram if desired)
        mix_spec = librosa.stft(mix, n_per_side=256, hop_length=128)
        source_specs = [librosa.stft(s, n_per_side=256, hop_length=128) for s in sources]

        # Stack mixture and source spectrograms (or mel-spectrograms)
        spec = np.concatenate([np.abs(mix_spec), np.abs(source_specs[0]), np.abs(source_specs[1]), np.abs(source_specs[2]), np.abs(source_specs[3])], axis=0)

        return spec, sources  # Return mixture spectrogram and source waveforms

In [None]:
# Assuming zip_path points to the downloaded MUSDB18 zip file in Google Drive
zip_path = '/content/drive/MyDrive/musdb18.zip'  # Replace with your path

In [None]:
# Define number of epochs

# Define hyperparameters (adjust as needed)
num_epochs = 10
learning_rate = 0.001
batch_size = 8

def train(model, train_loader, optimizer, loss_fn):
  model.train()  # Set model to training mode
  for epoch in range(num_epochs):
    print(f'Epoch: {epoch+1}/{num_epochs}')
    for data, _ in train_loader:  # Ignore target sources (we only need mixture)
      # Forward pass
      predicted_masks = model(data)  # Pass mixture spectrograms through the model

      # Calculate loss (assuming Mean Squared Error)
      loss = loss_fn(predicted_masks, torch.zeros_like(predicted_masks))

      # Backward pass and parameter update
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    # Print training loss (optional)
    print(f'Training Loss: {loss.item():.4f}')

In [None]:
def evaluate(model, test_loader):
  model.eval()  # Set model to evaluation mode
  predicted_masks = []
  with torch.no_grad():  # Disable gradient calculation for efficiency
    for data, _ in test_loader:
      predicted_batch_masks = model(data)  # Predict masks for test mixture spectrograms
      predicted_masks.extend(predicted_batch_masks.cpu().numpy())
  return predicted_masks

In [None]:
# Create datasets and loaders
train_dataset = MUSDB18Dataset(zip_path, subset='train')
test_dataset = MUSDB18Dataset(zip_path, subset='test')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
# Create model, optimizer, and loss function
model = ResUNet(num_sources=4)  # Assuming 4 sources
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

In [None]:
# Train the model
train(model, train_loader, optimizer, loss_fn)

In [None]:
# Evaluate on the test set and get predicted masks
predicted_masks = evaluate(model, test_loader)

In [None]:
# Load ground truth sources (assuming available)
ground_truth_sources = []  # List to store ground truth sources
for i in range(len(test_dataset)):
  source_paths = [os.path.join(test_dataset.root, f'sources/s{i+1}.wav')]  # Modify path if different
  ground_truth_sources.append([librosa.load(path, sr=44100)[0] for path in source_paths])

# Prepare predicted and ground truth sources for bsseval
predicted_sources = []
for batch_masks in predicted_masks:
  for mask in batch_masks:
    predicted_sources.append(mask * test_dataset[i][0])  # Assuming mixture in test_dataset

# Calculate SDR, SIR using bsseval
image_id = 0  # Assuming each test sample has a unique ID

audio_sr = 44100  # Assuming sampling rate of 44100 Hz
results = bss_eval_sources(ground_truth_sources, predicted_sources, eval_SDR=True, eval_SIR=True)

# Print SDR, SIR results
print(f'SDR results (dB): {results[0]}')
print(f'SIR results (dB): {results[1]}')

# Visualization (example for the first test sample)
mixture_spec = librosa.stft(test_dataset[0][0], n_per_side=256, hop_length=128)
predicted_mask_1 = predicted_masks[0][0]  # Assuming first mask in first batch
predicted_source_1 = predicted_mask_1 * mixture_spec

fig, axes = plt.subplots(3, 1, figsize=(10, 6))

# Mixture spectrogram
librosa.display.specshow(librosa.db(mixture_spec.T), x_axis='time', y_axis='log', sr=audio_sr, ax=axes[0])
axes[0].set_title('Mixture Spectrogram')
axes[0].set_ylim([None, 80])

# Predicted source 1 spectrogram
librosa.display.specshow(librosa.db(predicted_source_1.T), x_axis='time', y_axis='log', sr=audio_sr, ax=axes[1])
axes[1].set_title('Predicted Source 1 Spectrogram')
axes[1].set_ylim([None, 80])

# Ground truth source 1 spectrogram (if available)
if ground_truth_sources:
  ground_truth_spec_1 = librosa.stft(ground_truth_sources[0][0], n_per_side=256, hop_length=128)
  librosa.display.specshow(librosa.db(ground_truth_spec_1.T), x_axis='time', y_axis='log', sr=audio_sr, ax=axes[2])
  axes[2].set_title('Ground Truth Source 1 Spectrogram')
  axes[2].set_ylim([None, 80])

fig.tight_layout()
plt.show()

print(f'Visualizations created for test sample {image_id+1}')