In [None]:
#%%

from train_test import train_model, test_model, plot_losses_and_metrics
from model.model import SepModel
from data.config import *
from data.dataset import PreComputedMixtureDataset
from torch.optim.adamw import AdamW
from torch.utils.data import DataLoader, SubsetRandomSampler
import os, pandas as pd, torch, numpy as np, torch.nn as nn

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# Load metadata
metadata = pd.read_csv(os.path.join(
    DATASET_MIX_AUDIO_PATH, "metadata.csv"))

dataset = PreComputedMixtureDataset(metadata_file=metadata)

# load data
train_indices = np.load('train_indices_new_last.npy')
val_indices = np.load('val_indices_new_last.npy')
test_indices = np.load('test_indices_new_last.npy')

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

# load laoder again
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=25)
val_loader = DataLoader(dataset, sampler=val_sampler, batch_size=25)
test_loader = DataLoader(dataset, sampler=test_sampler, batch_size=25)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = SepModel(in_c=1, out_c=32).to(device)
# criterion = nn.MSELoss()
criterion = nn.L1Loss()
# optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer = AdamW(model.parameters(), lr=1e-3, amsgrad=True, fused=True)

In [None]:
#%%

# load best modelfor testing

model = SepModel(in_c=1, out_c=32).to(device)
model.load_state_dict(torch.load('checkpoint\\best_model.pth', map_location=device, weights_only=True))
model.eval()
# Test model
test_model(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    device='cuda'
)

Testing Loss: 0.02364400, SI-SDR: 12.02264690, SDR Improvement: 0.89828867: 100%|[31m██████████[0m| 60/60 [00:20<00:00,  2.92it/s]


Testing Loss: 0.02364400, Final SI-SDR: 12.02264690, Final SDR Improvement: 0.89828867


{'test_loss': 0.02364400268221895,
 'final_si_sdr': tensor(12.0226),
 'final_sdr': tensor(0.8983),
 'output_waveform': array([[[ 2.2248041e-02,  1.7356899e-02, -2.6973069e-04, ...,
          -1.4096500e-01, -1.4332935e-01, -1.4488874e-01]],
 
        [[ 6.0201358e-02,  6.0275346e-02,  6.0584009e-02, ...,
          -2.2577846e-02, -2.1176932e-02, -1.8269006e-02]],
 
        [[-5.5870246e-02, -7.6062024e-02, -7.8470804e-02, ...,
          -3.6480825e-02, -3.6379941e-02, -3.9517965e-02]],
 
        ...,
 
        [[ 4.6286721e-02,  7.1489491e-02,  1.0525753e-01, ...,
          -4.0029919e-01, -4.0897611e-01, -4.1626316e-01]],
 
        [[-1.8645218e-02, -1.3576010e-02, -1.8294824e-02, ...,
          -4.2985016e-01, -4.6750450e-01, -4.8675954e-01]],
 
        [[-6.9839568e-03,  8.9819939e-04, -1.2718653e-02, ...,
           2.1192981e-02,  1.5420642e-02,  1.8379584e-02]]], dtype=float32),
 'true_percussion': array([[[-5.26011474e-02, -7.82273561e-02, -1.09248549e-01, ...,
          -1.0789