In [1]:
from torch.utils.data import DataLoader
import torch
import random
from typing import Dict

# Import missing modules for optimization
import torch.optim as optim
from torch.optim import lr_scheduler

# Import our custom dataset and augmentation pipeline.
from process_sml import (
    AudioDatasetFolder, Compose,compute_waveform_griffinlim,
    RandomPitchShift_wav,RandomVolume_wav,RandomAbsoluteNoise_wav,RandomSpeed_wav,RandomFade_wav,RandomFrequencyMasking_spec,RandomTimeMasking_spec,RandomTimeStretch_spec)
# Import the UNet model and the training function from the training module.
from train_sml import UNet, train_model_source_separation,LiteResUNet,infer_and_save
import torch.nn as nn

# Define the component map for the dataset.
COMPONENT_MAP = ["mixture", "drums", "bass", "other_accompaniment", "vocals"]
label_names = ["drums", "bass", "other_accompaniment", "vocals"]

dataset_val = AudioDatasetFolder(
    csv_file='output_stems/test_one.csv',
    audio_dir='.',  # adjust as needed
    components=COMPONENT_MAP,
    sample_rate=16000,
    duration=10.0,
    is_track_id=True,
    input_name= "mixture",

)
data_loader = DataLoader(dataset_val, batch_size=18)


In [2]:
model = LiteResUNet(backbone="resnet18",source_names=label_names,pretrained=True,in_channels=2)
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [3]:
presaved_weights = torch.load("checkpoints/checkpoint_epoch_115.pth")
state_of_dict = presaved_weights['model_state_dict']


In [4]:
model.load_state_dict(state_dict=state_of_dict)

<All keys matched successfully>

In [6]:
# grab one batch
sample_multi = next(iter(data_loader))

# Option A) move just the spectrogram you care about
spec = sample_multi['mixture'][0]    # shape = (channels, freq, time)
spec = spec.to(device)
model = model.to(device)
# now you can e.g. plot it (if you move it back to CPU first) or feed it into a model:
model_input = spec.unsqueeze(0)      # add batch dim if needed
model_input = model_input.to(device)
out = model(model_input)

In [7]:
vocals = out["vocals"].squeeze(0)

In [8]:
wav_spec= compute_waveform_griffinlim(vocals)

In [9]:
import torchaudio

# wav1 is shape [2, 64000], dtype=float
waveform = wav_spec.detach().cpu()     # now shape [2, 64000]
torchaudio.save("transformed.wav", waveform, sample_rate=16000)


In [None]:

# after training:
infer_and_save(
    model=model,
    dataloader=data_loader,
    device=device,
    output_dir="./inference_outputs",
    input_name="mixture",
    label_names=["drums", "bass", "other_accompaniment", "vocals"],
    sample_rate=16000,
)



✅ All inference outputs saved to ./inference_outputs
