In [9]:
from torch.utils.data import DataLoader
import torch
import random
from typing import Dict

# Import missing modules for optimization
import torch.optim as optim
from torch.optim import lr_scheduler

# Import our custom dataset and augmentation pipeline.
from process_sml import (
    AudioDatasetFolder, Compose,
    RandomPitchShift_wav,RandomVolume_wav,RandomAbsoluteNoise_wav,RandomSpeed_wav,RandomFade_wav,RandomFrequencyMasking_spec,RandomTimeMasking_spec,RandomTimeStretch_spec,
    compute_waveform,reconstruct_waveform)
# Import the UNet model and the training function from the training module.
from train_sml import UNet, train_model_source_separation,LiteResUNet
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Define the component map for the dataset.
COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]
label_names = ["drums", "bass", "other_accompaniment", "vocals"]


argS = Compose([

    #spec transformation 
    RandomTimeStretch_spec(),
    #this two working properly together
    RandomFrequencyMasking_spec(),
    RandomTimeMasking_spec(),

])
argW = Compose(
 [
    # RandomPitchShift_wav(),
    RandomVolume_wav(),
    # RandomSpeed_wav(),
    RandomAbsoluteNoise_wav(),
    RandomFade_wav(),
 ]   
)


# Set random seeds for reproducibility.
torch.manual_seed(42)
random.seed(42)

# Choose device early.
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the dataset.
dataset_multi = AudioDatasetFolder(
    csv_file='output_stems/musdb18_index_20250408_121813.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=20.0,
    # spec_transform=argS,  # list of transforms
    # wav_transform=argW,
    is_track_id=True,
    # input_name= "mixture"
)

loader_multi = DataLoader(dataset_multi, batch_size=8, shuffle=False)
sample_multi = next(iter(loader_multi))


# Plot spectrogram for the 'mixture' component.
spec = sample_multi['mixture'][0]  # select first sample and first channel


In [3]:
drums_wav = sample_multi["drums"][0]

In [10]:
torch.is_complex(spec)

False

In [5]:
drums_wav.shape

torch.Size([2, 320000])

In [2]:
spec.shape

torch.Size([4, 1025, 626])

In [9]:
spec_multi.shape

torch.Size([2, 1025, 626])

In [10]:
phase_multi

tensor([[[ 0.0000,  0.0000,  0.0000,  ..., -2.2632, -3.1416, -3.1416],
         [ 0.1314,  1.6906, -1.5155,  ..., -1.6732, -0.7753,  0.0839],
         [-2.9658, -0.5233, -1.7235,  ...,  0.7273, -0.8559,  1.5805],
         ...,
         [ 0.1550,  2.1844,  0.1257,  ...,  0.8283, -1.7063,  1.7371],
         [-0.0819, -1.7108, -1.6442,  ...,  0.8761,  0.1350, -2.5726],
         [-2.9659, -0.5269, -1.3851,  ..., -3.1369, -3.1344, -3.1319]],

        [[ 0.0000,  0.0000,  0.0000,  ..., -2.2632, -3.1416, -3.1416],
         [ 3.1126,  2.5359,  2.0272,  ..., -2.2548, -1.2661, -0.2525],
         [-2.9839,  0.1515,  1.8673,  ..., -1.4304,  0.8589, -0.7995],
         ...,
         [ 0.1540,  2.1629,  0.0480,  ...,  0.3488, -0.4115, -2.8410],
         [-0.0829, -1.7318, -1.6658,  ...,  0.8868, -1.6112,  1.6722],
         [-2.9659, -0.5269, -1.3851,  ..., -3.1076, -3.1051, -3.1026]]])

In [7]:
phase_multi.shape

torch.Size([2, 1025, 626])

In [16]:
combined = torch.cat((spec_multi, phase_multi), dim=0)  # (4, H, W)
combined.shape

torch.Size([4, 1025, 626])

In [17]:
spec_multi_recovered, phase_multi_recovered = torch.split(combined, 2, dim=0)


In [18]:
spec_multi_recovered

tensor([[[0.0031, 0.0060, 0.0123,  ..., 0.0566, 0.0644, 0.0167],
         [0.0011, 0.0029, 0.0067,  ..., 0.0533, 0.0508, 0.0133],
         [0.0013, 0.0053, 0.0081,  ..., 0.0481, 0.0362, 0.0091],
         ...,
         [0.0027, 0.0074, 0.0087,  ..., 0.0212, 0.0211, 0.0047],
         [0.0005, 0.0061, 0.0101,  ..., 0.0303, 0.0192, 0.0038],
         [0.0016, 0.0049, 0.0077,  ..., 0.0194, 0.0082, 0.0010]],

        [[0.0029, 0.0054, 0.0118,  ..., 0.0552, 0.0650, 0.0171],
         [0.0006, 0.0034, 0.0071,  ..., 0.0549, 0.0517, 0.0133],
         [0.0027, 0.0017, 0.0045,  ..., 0.0398, 0.0507, 0.0136],
         ...,
         [0.0027, 0.0074, 0.0086,  ..., 0.0219, 0.0206, 0.0045],
         [0.0005, 0.0059, 0.0088,  ..., 0.0303, 0.0199, 0.0040],
         [0.0016, 0.0049, 0.0076,  ..., 0.0190, 0.0072, 0.0007]]])

In [2]:
complex_x = torch.polar(spec_multi, phase_multi)

In [3]:
wav = reconstruct_waveform(complex_x)

In [8]:
x = compute_waveform(spec_multi)

In [7]:
import torchaudio

# wav1 is shape [2, 64000], dtype=float
waveform = spec.detach().cpu()     # now shape [2, 64000]
torchaudio.save("direct.wav", waveform, sample_rate=16000)


In [7]:
wav1 = compute_waveform(spec_multi)

In [9]:
import torchaudio

# Ensure waveform is a FloatTensor in [-1,1]
waveform_tensor = wav1.detach().cpu()

# Write out
torchaudio.save("reconstructed_pre.wav", waveform_tensor, 16000)

print("Wrote WAV to reconstructed.wav — now open it with your favorite audio player!")


Wrote WAV to reconstructed.wav — now open it with your favorite audio player!


In [None]:

class ComputeSpectrogram:
    def __init__(self, n_fft: int = 2048, hop_length: int = 512, power: float = None, normalized: bool = False):
        """
        Compute the magnitude spectrogram using torchaudio.functional.spectrogram.
        
        Args:
            n_fft (int): The FFT window size.
            hop_length (int): The hop length for the window.
            power (float, optional): If None, returns a complex tensor and we take the magnitude.
                                      Otherwise, returns the power spectrogram.
            normalized (bool): If True, the spectrogram is normalized.
        """
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.power = power
        self.normalized = normalized

    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveform (Tensor): Audio tensor. Expected shapes are (time), (channel, time) or (batch, channel, time).
        
        Returns:
            Tensor: The computed magnitude spectrogram with shape (channel, freq, time) (or with batch dimension if provided).
        """

        spec = compute_spectrogram(waveform,self.n_fft,self.hop_length,self.power,self.normalized)
        

        if self.power is None:
            return spec.abs()
        else:
            return spec
