In [1]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import random
from typing import Dict

# Import missing modules for optimization
import torch.optim as optim
from torch.optim import lr_scheduler

# Import the UNet model and the training function from the training module.

from train_sml import UNet, train_model_source_separation
import torch.nn as nn
# Import our custom dataset and augmentation pipeline.
from process_sml import AudioDatasetFolder, Compose, RandomTimeCrop, RandomTimeStretch, RandomPitchShift, RandomNoise, RandomDistortion, RandomVolume,compute_waveform,to_stereo


In [2]:
augmentation_pipeline = Compose([
    RandomTimeCrop(target_time=512),
    # RandomTimeStretch(factor_range=(0.9, 1.1)),
    RandomPitchShift(shift_range=(-1.0, 1.0)),
    RandomNoise(noise_std=0.05),
    RandomDistortion(gamma_range=(0.8, 1.2)),
    RandomVolume(volume_range=(0.8, 1.2))
])


In [3]:




# Define the component map for the dataset.
COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]
IS_TRACK_ID = True

# Set random seeds for reproducibility.
torch.manual_seed(42)
random.seed(42)

# Choose device early.
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the dataset.
dataset_multi = AudioDatasetFolder(
    csv_file='output_stems/musdb18_index_20250408_121813.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=5.0,
    # transform=augmentation_pipeline,  # list of transforms
    is_track_id=IS_TRACK_ID,
    input_name= "mixture"
)



In [4]:
x=dataset_multi.__getitem__(3)['mixture']
x.shape


torch.Size([2, 1025, 157])

In [13]:
import torch
import torchaudio
import sounddevice as sd

# Make sure wav is 1D (mono) or 2D (2, N) (stereo)
def prepare_for_playback(wav: torch.Tensor) -> torch.Tensor:
    if wav.dim() == 2:
        # shape: (channels, time)
        if wav.size(0) > 2:
            wav = wav[:2]  # take first two channels only
        return wav
    elif wav.dim() == 1:
        return wav
    else:
        raise ValueError("Unexpected waveform shape")

# Example
x = dataset_multi.__getitem__(3)
wav = compute_waveform(x['mixture'])

# Fix shape for playback
wav = prepare_for_playback(wav)

# Convert to numpy and transpose if stereo
wav_np = wav.cpu().numpy()
if wav_np.ndim == 2:
    wav_np = wav_np.T  # (channels, time) → (time, channels)

# Play
sd.play(wav_np, samplerate=16000)
sd.wait()


In [6]:

# Split dataset into train and validation (e.g., 80/20 split).
dataset_size = len(dataset_multi)
indices = list(range(dataset_size))
split = int(0.8 * dataset_size)
train_indices, val_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset_multi, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(dataset_multi, batch_size=32, sampler=val_sampler)
dataloaders: Dict[str, DataLoader] = {"train": train_loader, "val": val_loader}

# -------------------------------
# Model Integration
# -------------------------------
# For source separation, the model is expected to take the mixture spectrogram as input,
# and output separated source spectrograms corresponding to each target.
# --- Since the mixture is stereo, we initialize the UNet with in_channels=2 ---
model = UNet(in_channels=2)

# Define the label names (target keys) for source separation.
label_names = ["drums", "bass", "other_accompaniment", "vocals"]

# Prepare the final convolution layers for each target output.
# Here we assume that the decoder produces feature maps with 16 channels.
for key in label_names:
    model.final_convs[key] = nn.Conv2d(16, 2, kernel_size=1)

# IMPORTANT: Move the entire model to the device after adding the final conv layers.
model = model.to(device)


In [24]:
from torchsummary import summary
import torch

# Your input
input_shape = (2, 1025, 32)


# Summary with all output channels
summary(model=model, input_size=input_shape)


In [None]:

# -------------------------------
# Loss Function, Optimizer, Scheduler
# -------------------------------
# Use L1 loss for source separation.
criterion = nn.L1Loss()
# Create the optimizer using the model parameters.
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Create a learning rate scheduler.
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# -------------------------------
# Train the Model
# -------------------------------
# Here, the input key is "mixture" and label names are defined as above.
best_model = train_model_source_separation(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=3,
    device=device,
    log_dir='./logs',
    checkpoint_dir='./checkpoints',
    input_name="mixture",  # use "mixture" for the input spectrogram from the batch
    label_names=label_names,  # list of target keys for separated sources
    print_freq=10,
)

# (Optional) Test or visualize the best_model as needed.


Epoch 1/3
----------------------------------------
Train Epoch [1/3] Batch [0/4] Loss: 0.9941 LR: 0.001000


In [None]:

# Define an additional simple normalization transform:
def normalize_spec(spec: torch.Tensor) -> torch.Tensor:
    return (spec - spec.mean()) / (spec.std() + 1e-6)

COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]

IS_TRACK_ID = True
dataset_multi = AudioDatasetFolder(
    csv_file='output_stems/musdb18_index_20250408_121813.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=5.0,
    is_track_id=IS_TRACK_ID,
)

loader_multi = DataLoader(dataset_multi, batch_size=32, shuffle=False)
sample_multi = next(iter(loader_multi))

if IS_TRACK_ID:
    print("Loaded sample with track_id:", sample_multi['track_id'])

# Plot spectrogram for the 'mixture' component.
spec_multi = sample_multi['vocals'][0, 0]  # select first sample and first channel
plt.imshow(spec_multi.detach().numpy(), origin='lower', aspect='auto', cmap='Dark2_r')
plt.title("Spectrogram (mixture) with pitch_shift and normalization")
plt.xlabel("Time")
plt.ylabel("Frequency Bin")
plt.colorbar()
plt.show()


In [None]:
# Random crop
def random_noise_crop(tensor, crop_duration=5.0, sample_rate=16000):
    crop_size = int(crop_duration * sample_rate)
    max_start = tensor.shape[1] - crop_size
    start = random.randint(0, max_start)
    return tensor[:, start:start + crop_size]

noise_crop = random_noise_crop(big_tensor)
# Convert to spectrogram
spec = compute_spectrogram(noise_crop)
print(f"Shape of 5-second noise spectrogram: {spec.shape}")
