In [6]:
import torch
import torch.nn as nn
import librosa
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.model_selection import train_test_split
import soundfile as sf

In [None]:
# Define dataset directory
dataset_dir = "/path/to/dataset/"
batch_size = 3  # You can change this value as needed
segment_length =10*44100 # Fixed length for all audio samples (~1.5 sec at 44.1kHz)


In [31]:
# Get file paths for each category
def get_file_paths(folder_name):
    path = os.path.join(dataset_dir, folder_name)
    if not os.path.exists(path):
        print(f"Warning: {folder_name} folder not found!")
        return []
    return sorted([os.path.join(path, f) for f in os.listdir(path)])

In [32]:
song_paths = get_file_paths("song")
bass_paths = get_file_paths("bass")
vocal_paths = get_file_paths("vocal")
drum_paths = get_file_paths("drum")
music_paths = get_file_paths("music")

In [33]:
import librosa
import numpy as np

def load_audio(file_path, target_sr=44100, segment_length=441000):  
    audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
    
    # Pad if audio is shorter than segment length
    if len(audio) < segment_length:
        audio = np.pad(audio, (0, segment_length - len(audio)), mode='constant')

    # Number of segments (each 10 seconds long)
    num_segments = len(audio) // segment_length  
    segments = []

    for i in range(num_segments):
        start = i * segment_length
        end = start + segment_length
        segment = audio[start:end]

        # Get max amplitude of the segment
        original_max = np.max(np.abs(segment)) if np.max(np.abs(segment)) > 0 else 1.0  

        # Normalize the segment
        segment = segment / original_max if original_max > 0 else segment
        
        segments.append(segment)

    return np.array(segments)


In [36]:
# Ensure dataset length consistency
min_len = min(len(song_paths), len(bass_paths), len(vocal_paths), len(drum_paths), len(music_paths))
print(f"Using {min_len} samples for training.")
song_paths, bass_paths, vocal_paths, drum_paths, music_paths = (
    song_paths[:min_len], bass_paths[:min_len], vocal_paths[:min_len], drum_paths[:min_len], music_paths[:min_len]
)

# Prepare dataset
train_data = []  # Stores only the audio segments for training
max_amplitudes = []  # Stores original max amplitudes for reconstruction

print("Loading dataset...")
for i in range(min_len):
    song_segments = load_audio(song_paths[i])
    bass_segments = load_audio(bass_paths[i])
    vocal_segments = load_audio(vocal_paths[i])
    drum_segments= load_audio(drum_paths[i])
    music_segments = load_audio(music_paths[i])
    
    min_segments = min(len(song_segments), len(bass_segments), len(vocal_segments), len(drum_segments), len(music_segments))

    for j in range(min_segments):
      train_data.append((
        song_segments[j], 
        bass_segments[j], 
        vocal_segments[j], 
        drum_segments[j], 
        music_segments[j]
    ))  


       
print("Dataset loaded successfully!")

Using 60 samples for training.
Loading dataset...
Dataset loaded successfully!


In [37]:
print(f"Total samples in dataset: {len(train_data)}")   
print(f"Each sample should have 5 elements (song, bass, vocal, drum, music): {len(train_data[0])}")  

# Print the shape of a few segments
print("Example shapes:")
print(f"Song segment shape: {train_data[0][0].shape}")
print(f"Bass segment shape: {train_data[0][1].shape}")
print(f"Vocal segment shape: {train_data[0][2].shape}")
print(f"Drum segment shape: {train_data[0][3].shape}")
print(f"Music segment shape: {train_data[0][4].shape}")


Total samples in dataset: 216
Each sample should have 5 elements (song, bass, vocal, drum, music): 5
Example shapes:
Song segment shape: (441000,)
Bass segment shape: (441000,)
Vocal segment shape: (441000,)
Drum segment shape: (441000,)
Music segment shape: (441000,)


In [38]:
for i in range(5):  # Checking first 5 samples
    assert train_data[i][0].shape == train_data[i][1].shape == train_data[i][2].shape == train_data[i][3].shape == train_data[i][4].shape, f"Mismatch at index {i}"
print("All segments have matching shapes")


All segments have matching shapes


In [39]:
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Define split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# First split: Train (70%) and Temp (30%)
train_set, temp_set = train_test_split(train_data, test_size=(val_ratio + test_ratio), random_state=42)

# Second split: Validation (15%) and Test (15%)
val_set, test_set = train_test_split(temp_set, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=42)

# Convert lists to PyTorch tensors
train_dataset = TensorDataset(*[torch.tensor(np.array(d), dtype=torch.float32) for d in zip(*train_set)])
val_dataset = TensorDataset(*[torch.tensor(np.array(d), dtype=torch.float32) for d in zip(*val_set)])
test_dataset = TensorDataset(*[torch.tensor(np.array(d), dtype=torch.float32) for d in zip(*test_set)])

# Prepare DataLoaders
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Print dataset sizes
print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(val_dataset)}")
print(f"Testing Samples: {len(test_dataset)}")


Training Samples: 151
Validation Samples: 32
Testing Samples: 33


In [40]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")
print(f"Number of training batches: {len(train_data_loader)}")
print(f"Number of testing batches: {len(test_data_loader)}")



Number of training samples: 151
Number of testing samples: 33
Number of training batches: 51
Number of testing batches: 11


In [41]:
for batch in train_data_loader:
    batch_shapes = [tuple(b.shape) for b in batch]
    print(f"Batch shape: {batch_shapes}")
    break  # Print only the first batch


Batch shape: [(3, 441000), (3, 441000), (3, 441000), (3, 441000), (3, 441000)]


In [18]:
import torch
import torch.nn as nn

# Constants
num_sources = 4  # bass, vocal, drums, music

class DemucsModel(nn.Module):
    def __init__(self):
        super(DemucsModel, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=16, stride=4, padding=8),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=16, stride=4, padding=8),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=16, stride=4, padding=8),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=16, stride=8, padding=8),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )

        # Bidirectional LSTM for temporal modeling
        self.rnn = nn.LSTM(256, 256, batch_first=True, bidirectional=True)
        self.lstm_fc = nn.Linear(512, 256)  # Merge bidirectional outputs

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=8, output_padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=16, stride=4, padding=8, output_padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=16, stride=4, padding=8, output_padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.ConvTranspose1d(32, num_sources, kernel_size=16, stride=4, padding=8, output_padding=1),
            nn.Tanh(),  # Keeps output in range (-1, 1)
        )

    def forward(self, x):
        x = self.encoder(x)  
        x = x.permute(0, 2, 1)  # Change to (batch, time, channels) for LSTM
        x, _ = self.rnn(x)  
        x = self.lstm_fc(x)  
        x = self.decoder(x.permute(0, 2, 1))  # Change back to (batch, channels, time)
        return x

# Initialize Model
model = DemucsModel()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [19]:
print(model)


DemucsModel(
  (encoder): Sequential(
    (0): Conv1d(1, 32, kernel_size=(16,), stride=(4,), padding=(8,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(32, 64, kernel_size=(16,), stride=(4,), padding=(8,))
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv1d(64, 128, kernel_size=(16,), stride=(4,), padding=(8,))
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv1d(128, 256, kernel_size=(16,), stride=(8,), padding=(8,))
    (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
  (rnn): LSTM(256, 256, batch_first=True, bidirectional=True)
  (lstm_fc): Linear(in_features=512, out_features=256, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(8,), output_padding=(2,

In [20]:
import torch.nn.functional as F

def si_snr_loss(pred, target, eps=1e-8):
    target_energy = torch.sum(target**2, dim=-1, keepdim=True) + eps
    scale = torch.sum(target * pred, dim=-1, keepdim=True) / target_energy
    target_proj = scale * target
    noise = pred - target_proj

    si_snr = torch.sum(target_proj**2, dim=-1) / (torch.sum(noise**2, dim=-1) + eps)
    si_snr = 10 * torch.log10(si_snr + eps)
    
    return -si_snr.mean()

criterion = si_snr_loss


In [None]:
num_epochs=100

In [22]:
import torch.nn.functional as F
import csv

# Ensure target length consistency
target_length = 441000

# File to store loss and SDR values
csv_filename = "training_metrics.csv"
with open(csv_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch", "Train_Loss", "Val_Loss", "Avg_SDR"])

def calculate_sdr(predicted, target):
    # Avoid log issues with small values
    eps = 1e-8
    return 10 * torch.log10(torch.sum(target**2) / (torch.sum((target - predicted)**2) + eps))

print("Starting training...")

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = 0
    epoch_sdr_values = []
    
    for batch_idx, batch in enumerate(train_data_loader):
        inputs, target_bass, target_vocal, target_drum, target_music = batch
        inputs = inputs.unsqueeze(1)

        # Padding inputs and targets
        inputs = F.pad(inputs, (0, target_length - inputs.shape[-1]))  
        target_bass = F.pad(target_bass, (0, target_length - target_bass.shape[-1]))
        target_vocal = F.pad(target_vocal, (0, target_length - target_vocal.shape[-1]))
        target_drum = F.pad(target_drum, (0, target_length - target_drum.shape[-1]))
        target_music = F.pad(target_music, (0, target_length - target_music.shape[-1]))

        targets = torch.stack([target_bass, target_vocal, target_drum, target_music], dim=1)
        
        # Forward pass
        outputs = model(inputs)
        if outputs.shape[-1] < target_length:
            outputs = F.pad(outputs, (0, target_length - outputs.shape[-1]))
        
        # Compute loss
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        
        # Compute SDR for the batch
        batch_sdr = calculate_sdr(outputs, targets).item()
        epoch_sdr_values.append(batch_sdr)
    
    # Compute average train loss and SDR
    avg_train_loss = epoch_train_loss / len(train_data_loader)
    avg_sdr = sum(epoch_sdr_values) / len(epoch_sdr_values)
    
    # Validation phase
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for batch in val_data_loader:
            inputs, target_bass, target_vocal, target_drum, target_music = batch
            inputs = inputs.unsqueeze(1)
            
            inputs = F.pad(inputs, (0, target_length - inputs.shape[-1]))  
            target_bass = F.pad(target_bass, (0, target_length - target_bass.shape[-1]))
            target_vocal = F.pad(target_vocal, (0, target_length - target_vocal.shape[-1]))
            target_drum = F.pad(target_drum, (0, target_length - target_drum.shape[-1]))
            target_music = F.pad(target_music, (0, target_length - target_music.shape[-1]))

            targets = torch.stack([target_bass, target_vocal, target_drum, target_music], dim=1)
            outputs = model(inputs)
            if outputs.shape[-1] < target_length:
                outputs = F.pad(outputs, (0, target_length - outputs.shape[-1]))
            
            val_loss = criterion(outputs, targets)
            epoch_val_loss += val_loss.item()
    
    avg_val_loss = epoch_val_loss / len(val_data_loader)
    
    # Save to CSV
    with open(csv_filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch + 1, avg_train_loss, avg_val_loss, avg_sdr])
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Avg SDR: {avg_sdr:.4f}")


Starting training...
Epoch 1/100 - Train Loss: 42.8301, Val Loss: 53.4457, Avg SDR: -8.1676
Epoch 2/100 - Train Loss: 31.0846, Val Loss: 36.6078, Avg SDR: -8.0827
Epoch 3/100 - Train Loss: 23.6489, Val Loss: 28.7485, Avg SDR: -8.0204
Epoch 4/100 - Train Loss: 22.4305, Val Loss: 26.9486, Avg SDR: -8.1171
Epoch 5/100 - Train Loss: 21.1618, Val Loss: 29.3575, Avg SDR: -7.9763
Epoch 6/100 - Train Loss: 20.6698, Val Loss: 26.1854, Avg SDR: -8.1126
Epoch 7/100 - Train Loss: 17.0494, Val Loss: 25.2461, Avg SDR: -8.0441
Epoch 8/100 - Train Loss: 15.8840, Val Loss: 23.9656, Avg SDR: -7.9517
Epoch 9/100 - Train Loss: 15.0103, Val Loss: 23.0506, Avg SDR: -7.8187
Epoch 10/100 - Train Loss: 14.2911, Val Loss: 22.8091, Avg SDR: -7.6956
Epoch 11/100 - Train Loss: 13.6515, Val Loss: 22.5939, Avg SDR: -7.6596
Epoch 12/100 - Train Loss: 13.1466, Val Loss: 21.4359, Avg SDR: -7.5673
Epoch 13/100 - Train Loss: 12.6154, Val Loss: 21.1010, Avg SDR: -7.1917
Epoch 14/100 - Train Loss: 12.1697, Val Loss: 20.682

In [23]:
print(f"Predicted shape: {outputs.shape}")
print(f"Target shape: {target_bass.shape}")  # Target shape should be same as predictions


Predicted shape: torch.Size([3, 4, 441000])
Target shape: torch.Size([3, 441000])


In [None]:
# Save the trained model
torch.save(model.state_dict(), "audio_separation_model_2.pth")
print("Model saved successfully!")


Model saved successfully!


In [25]:
import numpy as np

def remove_low_amplitude_noise(audio, threshold_ratio=0.07):
    max_amplitude = np.max(np.abs(audio))
    threshold = max_amplitude * threshold_ratio
    audio_denoised = np.where(np.abs(audio) > threshold, audio, 0)
    return audio_denoised


In [None]:
import os
import torch
import soundfile as sf

print("Starting validation...")
model.eval()  # Set model to evaluation mode
val_loss = 0
output_dir = "/path/to/output_directory/"  # Output directory
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():  # No need to compute gradients during validation
    for batch_idx, batch in enumerate(test_data_loader):
        inputs, target_bass, target_vocal, target_drum, target_music = batch
        inputs = inputs.unsqueeze(1)  # Add channel dimension
        
        output = model(inputs)  # Forward pass

        # Ensure output is exactly 441000 samples
        target_length = 441000
        if output.shape[-1] < target_length:
            pad_size = target_length - output.shape[-1]
            output = F.pad(output, (0, pad_size))

        # Save outputs as audio files
        current_batch_size = inputs.shape[0]  # Get actual batch size
        for i in range(current_batch_size):
            output_bass = remove_low_amplitude_noise(output[i, 0, :].cpu().numpy())
            output_vocal = remove_low_amplitude_noise(output[i, 1, :].cpu().numpy())
            output_drum = remove_low_amplitude_noise(output[i, 2, :].cpu().numpy())
            output_music = remove_low_amplitude_noise(output[i, 3, :].cpu().numpy())
            
            sf.write(os.path.join(output_dir, f"test_output_bass_{batch_idx*current_batch_size + i}.wav"), output_bass, samplerate=44100)
            sf.write(os.path.join(output_dir, f"test_output_vocal_{batch_idx*current_batch_size + i}.wav"), output_vocal, samplerate=44100)
            sf.write(os.path.join(output_dir, f"test_output_drum_{batch_idx*current_batch_size + i}.wav"), output_drum, samplerate=44100)
            sf.write(os.path.join(output_dir, f"test_output_music_{batch_idx*current_batch_size + i}.wav"), output_music, samplerate=44100)

        # Compute batch loss and accumulate
        batch_loss = (criterion(output[:, 0, :], target_bass) +
                      criterion(output[:, 1, :], target_vocal) +
                      criterion(output[:, 2, :], target_drum) +
                      criterion(output[:, 3, :], target_music)).item()

        val_loss += batch_loss / len(test_data_loader)  # Average over dataset

print(f"Validation Loss: {val_loss:.4f}")


Starting validation...
Validation Loss: 61.3517
