In [1]:
from torch.utils.data import DataLoader
import torch
import random
from typing import Dict

# Import missing modules for optimization
import torch.optim as optim
from torch.optim import lr_scheduler
import torchaudio
# Import our custom dataset and augmentation pipeline.
from process_sml import (
    AudioDatasetFolder, Compose,compute_waveform_griffinlim,reconstruct_waveform,
    RandomPitchShift_wav,RandomVolume_wav,RandomAbsoluteNoise_wav,RandomSpeed_wav,RandomFade_wav,RandomFrequencyMasking_spec,RandomTimeMasking_spec,RandomTimeStretch_spec)
# Import the UNet model and the training function from the training module.
from train_sml import UNet, train_model_source_separation,LiteResUNet,infer_and_save,GLRUNET
import torch.nn as nn

# Define the component map for the dataset.
COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]
label_names = ["drums", "bass", "other_accompaniment", "vocals"]

dataset_val = AudioDatasetFolder(
    csv_file='output_stems/test_one.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=10.0,
    is_track_id=True,
    input_name= "mixture",

)
data_loader = DataLoader(dataset_val, batch_size=1)


In [2]:
model = LiteResUNet(backbone="resnet18",source_names=label_names,pretrained=True,in_channels=2)
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [13]:
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

# grab one batch
sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['vocals'][0]    # shape = (channels, freq, time)
spec = spec.to(device)


In [3]:
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

# grab one batch
sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['mixture'][0]    # shape = (channels, freq, time)
spec = spec.to(device)
print(spec.shape)
presaved_weights = torch.load(r"checkpoints\checkpoint_epoch_83.pth")
state_of_dict = presaved_weights['model_state_dict']

model.load_state_dict(state_dict=state_of_dict)
model = model.to(device)
# now you can e.g. plot it (if you move it back to CPU first) or feed it into a model:
model_input = spec.unsqueeze(0)      # add batch dim if needed
model_input = model_input.to(device)
out = model(model_input)

torch.Size([2, 129, 5001])


In [4]:

sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['vocals'][0]    # shape = (channels, freq, time)

In [3]:
spec.shape

torch.Size([2, 129, 5001])

In [6]:
C = spec.size(0) // 2           # 4 // 2 == 2
mag   = spec[0:C,  :,  :]       # → [2, 1025, 157]
phase = spec[C: ,  :,  :]       # → [2, 1025, 157]
complex_spec = torch.polar(mag, phase)
reconstruction = reconstruct_waveform(complex_spec)


In [7]:
reconstruction.shape

torch.Size([2, 160000])

In [4]:
out["vocals"].shape

torch.Size([1, 2, 129, 5001])

In [10]:
vocals =out["vocals"]

In [14]:

wav = compute_waveform_griffinlim(mag_spec=spec,n_fft=256,hop_length=32)

In [7]:
wav=wav.squeeze(dim=0)

In [13]:
import torch
from torcheval.metrics.functional import peak_signal_noise_ratio

peak_signal_noise_ratio(reconstruction.to(device="cpu"), wav.to(device="cpu"))

tensor(20.3448)

In [13]:
wav.shape

torch.Size([2, 160000])

In [8]:
vocals.to(device)

tensor([[[[3.3765e-01, 3.3765e-01, 3.0886e-01,  ..., 1.1192e-02,
           1.2870e-02, 1.2870e-02],
          [3.3765e-01, 3.3765e-01, 3.0886e-01,  ..., 1.1192e-02,
           1.2870e-02, 1.2870e-02],
          [2.9694e-01, 2.9694e-01, 2.7252e-01,  ..., 9.3838e-03,
           1.0822e-02, 1.0822e-02],
          ...,
          [1.0166e+01, 1.0166e+01, 9.1265e+00,  ..., 9.5611e-03,
           1.0430e-02, 1.0430e-02],
          [8.8001e+00, 8.8001e+00, 7.9124e+00,  ..., 1.1152e-02,
           1.2243e-02, 1.2243e-02],
          [8.8001e+00, 8.8001e+00, 7.9124e+00,  ..., 1.1152e-02,
           1.2243e-02, 1.2243e-02]],

         [[3.3652e-01, 3.3652e-01, 3.0712e-01,  ..., 2.0638e-02,
           2.3692e-02, 2.3692e-02],
          [3.3652e-01, 3.3652e-01, 3.0712e-01,  ..., 2.0638e-02,
           2.3692e-02, 2.3692e-02],
          [2.9528e-01, 2.9528e-01, 2.7038e-01,  ..., 1.8374e-02,
           2.0996e-02, 2.0996e-02],
          ...,
          [1.0129e+01, 1.0129e+01, 9.0952e+00,  ..., 7.5931

In [9]:

inverse_melscale_transform = torchaudio.transforms.InverseMelScale(n_stft=2048 // 2 + 1)
inverse_melscale_transform.to(device)
spectrogram = inverse_melscale_transform(vocals)

In [2]:
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

# grab one batch
sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['mixture'][0]    # shape = (channels, freq, time)

spec = spec.unsqueeze(0)
spec = spec.to(device)

model = GLRUNET(backbone="resnet18",source_names=label_names,pretrained=True,in_channels=2, use_griffinlim=True)  
model.to(device)

# 2) Load your checkpoint, but allow GL’s own parameters to stay at their defaults
presaved_weights = torch.load(r"ALL_CKP\semi_success_ckp\spec-griffinlim.pth")
state_of_dict = presaved_weights['model_state_dict']
model.load_state_dict(state_of_dict, strict=False)

# 3) Switch the model to eval mode
model.eval()

# 4) At inference time, you want two steps:
#    (a) run the UNet backbone + decoder to get your output spectrograms
#    (b) apply Griffin–Lim (either via your built‑in layer or directly) to get waveforms

# suppose `mixture_spec` is your input magnitude (B, C_in, F, T) torch‑Tensor on `device`
with torch.no_grad():
    # 4a) get your raw spectrogram predictions
    print(spec.shape)
    preds = model(spec)    # Dict[str, Tensor], each Tensor = (B, C_out, F, T)

    # if you coded your final convs to output just magnitudes,
    # you can call your GL layer like this:
    model.use_griffinlim = True    # turn it on
    waveforms = {}
    for name, spec in preds.items():
        # spec: (B, C_out, F, T)  —  ideally this is power spectrogram or magnitude
        # if your GL layer is stored in `model.griffinlim` you can call:
        wave = model(spec)         # (B, C_out, L) waveform tensor
        waveforms[name] = wave.cpu()          # or keep on device as you wish



torch.Size([1, 2, 1025, 313])
shape of the meg torch.Size([1, 2, 1025, 313])
shape of the flat torch.Size([2, 1025, 313])
shape of the meg torch.Size([1, 2, 1025, 313])
shape of the flat torch.Size([2, 1025, 313])
shape of the meg torch.Size([1, 2, 1025, 313])
shape of the flat torch.Size([2, 1025, 313])
shape of the meg torch.Size([1, 2, 1025, 313])
shape of the flat torch.Size([2, 1025, 313])


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 319488]

In [3]:
from torchsummary import summary

model.to(device)
summary(
    model, 
    input_size=(2, 128, 313),    # omit batch dim
    batch_size=1,                # optional, for FLOPs estimate
    device=str(device)
)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [1, 64, 64, 157]           6,272
       BatchNorm2d-2           [1, 64, 64, 157]             128
              ReLU-3           [1, 64, 64, 157]               0
         MaxPool2d-4            [1, 64, 32, 79]               0
            Conv2d-5            [1, 64, 32, 79]          36,864
       BatchNorm2d-6            [1, 64, 32, 79]             128
              ReLU-7            [1, 64, 32, 79]               0
            Conv2d-8            [1, 64, 32, 79]          36,864
       BatchNorm2d-9            [1, 64, 32, 79]             128
             ReLU-10            [1, 64, 32, 79]               0
       BasicBlock-11            [1, 64, 32, 79]               0
           Conv2d-12            [1, 64, 32, 79]          36,864
      BatchNorm2d-13            [1, 64, 32, 79]             128
             ReLU-14            [1, 64,

In [4]:
import torch

model.eval()
dummy = torch.randn(1, 2, 128, 313, device=device)
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
    _ = model(dummy)
print("Peak RAM:", torch.cuda.max_memory_allocated() / (1024**2), "MB")


Peak RAM: 67.6708984375 MB


<All keys matched successfully>

In [5]:
# grab one batch
sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['mixture'][0]    # shape = (channels, freq, time)
spec = spec.to(device)
model = model.to(device)
# now you can e.g. plot it (if you move it back to CPU first) or feed it into a model:
model_input = spec.unsqueeze(0)      # add batch dim if needed
model_input = model_input.to(device)
out = model(model_input)

In [7]:
vocals = out["vocals"].squeeze(0)

In [8]:
inverse_melscale_transform = torchaudio.transforms.InverseMelScale(n_stft=2048 // 2 + 1)
inverse_melscale_transform.to(device)
spectrogram = inverse_melscale_transform(vocals)

In [9]:
spectrogram.shape

torch.Size([2, 1025, 313])

In [4]:
spec.shape

torch.Size([2, 1025, 313])

In [11]:
wav_spec= compute_waveform_griffinlim(spectrogram)

In [13]:
wav_spec.shape

torch.Size([1, 2, 159744])

In [8]:
wav.shape

torch.Size([1, 2, 159744])

In [15]:
import torchaudio

# wav1 is shape [2, 64000], dtype=float
waveform = wav.squeeze().detach().cpu()     # now shape [2, 64000]
torchaudio.save("vocal_D-recon-f-256-h-32.wav", waveform, sample_rate=16000)


In [None]:

# after training:
infer_and_save(
    model=model,
    dataloader=data_loader,
    device=device,
    output_dir="./inference_outputs",
    input_name="mixture",
    label_names=["drums", "bass", "other_accompaniment", "vocals"],
    sample_rate=16000,
)



✅ All inference outputs saved to ./inference_outputs
