In [9]:
from torch.utils.data import DataLoader
import torch
import random
from typing import Dict

# Import missing modules for optimization
import torch.optim as optim
from torch.optim import lr_scheduler

# Import our custom dataset and augmentation pipeline.
from process_sml import (
    AudioDatasetFolder, Compose,
    RandomPitchShift_wav,RandomVolume_wav,RandomAbsoluteNoise_wav,RandomSpeed_wav,RandomFade_wav,RandomFrequencyMasking_spec,RandomTimeMasking_spec,RandomTimeStretch_spec,
    compute_waveform,reconstruct_waveform)
# Import the UNet model and the training function from the training module.
from train_sml import UNet, train_model_source_separation,LiteResUNet
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Define the component map for the dataset.
COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]
label_names = ["drums", "bass", "other_accompaniment", "vocals"]


argS = Compose([

    #spec transformation 
    RandomTimeStretch_spec(),
    #this two working properly together
    RandomFrequencyMasking_spec(),
    RandomTimeMasking_spec(),

])
argW = Compose(
 [
    # RandomPitchShift_wav(),
    RandomVolume_wav(),
    # RandomSpeed_wav(),
    RandomAbsoluteNoise_wav(),
    RandomFade_wav(),
 ]   
)


# Set random seeds for reproducibility.
torch.manual_seed(42)
random.seed(42)

# Choose device early.
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the dataset.
dataset_multi = AudioDatasetFolder(
    csv_file='output_stems/musdb18_index_20250408_121813.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=20.0,
    # spec_transform=argS,  # list of transforms
    # wav_transform=argW,
    is_track_id=True,
    # input_name= "mixture"
)

loader_multi = DataLoader(dataset_multi, batch_size=8, shuffle=False)
sample_multi = next(iter(loader_multi))


# Plot spectrogram for the 'mixture' component.
spec = sample_multi['mixture'][0]  # select first sample and first channel


In [3]:
drums_wav = sample_multi["drums"][0]

In [10]:
torch.is_complex(spec)

False

In [7]:
import torchaudio

# wav1 is shape [2, 64000], dtype=float
waveform = spec.detach().cpu()     # now shape [2, 64000]
torchaudio.save("direct.wav", waveform, sample_rate=16000)


In [None]:
import torchaudio

# Ensure waveform is a FloatTensor in [-1,1]
waveform_tensor = spec.detach().cpu()

# Write out
torchaudio.save("reconstructed_pre.wav", waveform_tensor, 16000)

print("Wrote WAV to reconstructed.wav — now open it with your favorite audio player!")


Wrote WAV to reconstructed.wav — now open it with your favorite audio player!


In [None]:

class ComputeSpectrogram:
    def __init__(self, n_fft: int = 2048, hop_length: int = 512, power: float = None, normalized: bool = False):
        """
        Compute the magnitude spectrogram using torchaudio.functional.spectrogram.
        
        Args:
            n_fft (int): The FFT window size.
            hop_length (int): The hop length for the window.
            power (float, optional): If None, returns a complex tensor and we take the magnitude.
                                      Otherwise, returns the power spectrogram.
            normalized (bool): If True, the spectrogram is normalized.
        """
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.power = power
        self.normalized = normalized

    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveform (Tensor): Audio tensor. Expected shapes are (time), (channel, time) or (batch, channel, time).
        
        Returns:
            Tensor: The computed magnitude spectrogram with shape (channel, freq, time) (or with batch dimension if provided).
        """

        spec = compute_spectrogram(waveform,self.n_fft,self.hop_length,self.power,self.normalized)
        

        if self.power is None:
            return spec.abs()
        else:
            return spec
