In [52]:
import torch
import auraloss
import numpy as np
from tqdm import tqdm
import pyloudnorm as pyln

from dataset.dataset_ram import LA2A_Dataset
#from models.lstm_mlp_input import LSTMModelID
#from models.lstm_mlp import LSTMModel
from models.squeeze_lstm import Squeeze_LSTM

torch.backends.cudnn.benchmark = True

torch.set_float32_matmul_precision('highest')

In [53]:
location = torch.device('cpu')

In [54]:
config = {
    'dataset_dir': "path_to_dataset",
    'train_data_len_sec': 10,
    'batch_size': 50,
    'learning_rate': 1e-3,
    'num_workers': 10, 
}

model_config = {
    'input_size':  1,
    'output_size': 1, 
    'hidden_size': 64,
    'train_loss': 'l1+stft'
}

In [55]:
model = Squeeze_LSTM(input_s=model_config['input_size'], 
                  output_s=model_config['output_size'], 
                  hidden_size=model_config['hidden_size'],
                  train_loss=model_config['train_loss']).load_from_checkpoint(checkpoint_path='/home/abalykin/git/neural_la/artefacts/squeeze_lstm_v5/lstm_mlp_input-epoch=80-val_loss=1.78.ckpt', map_location=location)
model.eval()

Squeeze_LSTM(
  (l1): L1Loss()
  (esr): ESRLoss()
  (stft): STFTLoss(
    (spectralconv): SpectralConvergenceLoss()
    (logstft): STFTMagnitudeLoss(
      (distance): L1Loss()
    )
    (linstft): STFTMagnitudeLoss(
      (distance): L1Loss()
    )
  )
  (mel_stft): MelSTFTLoss(
    (spectralconv): SpectralConvergenceLoss()
    (logstft): STFTMagnitudeLoss(
      (distance): L1Loss()
    )
    (linstft): STFTMagnitudeLoss(
      (distance): L1Loss()
    )
  )
  (lstm): LSTM(3, 64)
  (linear): Linear(in_features=64, out_features=1, bias=True)
)

In [56]:
l1 = torch.nn.L1Loss()
esr = auraloss.time.ESRLoss()
stft = auraloss.freq.STFTLoss()
mel_stft = auraloss.freq.MelSTFTLoss(sample_rate=44100)
meter = pyln.Meter(rate=44100)

In [57]:
test_dataset = LA2A_Dataset(config['dataset_dir'],  config['train_data_len_sec'],  mode="test")

100%|██████████| 3/3 [00:03<00:00,  1.25s/it]


In [58]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=True,
                                                batch_size=1,
                                                num_workers=16,
                                                pin_memory=True)

In [59]:
l1_rec = []
stft_rec = []
lufs_rec = []

In [60]:
with torch.no_grad():
    for batch_index, batch in tqdm(enumerate(test_dataloader)):
        input, output, params = batch
        pred = model(input, params)
        output=output[:,None,:]
        l1_loss = l1(pred, output).cpu().detach().numpy().item()
        stft_loss = stft(pred, output).cpu().detach().numpy().item()
        target_lufs = meter.integrated_loudness(pred.squeeze().squeeze().cpu().numpy())
        output_lufs = meter.integrated_loudness(output.squeeze().squeeze().cpu().numpy())
        l1_lufs = np.abs(output_lufs - target_lufs)
        l1_rec.append(l1_loss)
        stft_rec.append(stft_loss)
        if not np.isnan(l1_lufs):
            lufs_rec.append(l1_lufs)

273it [00:09, 29.29it/s]


In [61]:
l1_rec = np.array(l1_rec) 
stft_rec = np.array(stft_rec) 
lufs_rec = np.array(lufs_rec) 

l1_mean = l1_rec.mean()
stft_mean = stft_rec.mean()
lufs_mean = lufs_rec.mean()

In [62]:
l1_mean

0.05864832728534882

In [63]:
stft_mean

1.8061781885859731

In [64]:
lufs_mean

5.23028657095154