In [None]:
import torch
from dataset_tool import compute_loudness, compute_centroid
from IPython.display import Audio
import pickle
import librosa as li
from noisebandnet.model import NoiseBandNet
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import librosa.display

In [None]:
device = 'cuda'

TRAIN_PATH = 'trained_models/metal'
MODEL_PATH = f'{TRAIN_PATH}/model_10000.ckpt'
CONFIG_PATH = f'{TRAIN_PATH}/config.pickle'

CONTROL_PARAM_PATH = 'inference_controls/control_metal_1.npy'

In [None]:
with (open(CONFIG_PATH, "rb")) as f:
    config = pickle.load(f)

In [None]:
synth = NoiseBandNet(hidden_size=config.hidden_size, n_band=config.n_band, synth_window=config.synth_window, n_control_params=config.n_control_params).to(device).float()

In [None]:
synth.load_state_dict(torch.load(MODEL_PATH))

In [None]:
control_param = np.load(CONTROL_PARAM_PATH)
control_param = torch.from_numpy(control_param).unsqueeze(0).unsqueeze(0).float().to(device)
control_param = [control_param.float().permute(0,2,1)]

In [None]:
with torch.no_grad():
    y_audio = synth(control_params=control_param)
Audio(y_audio[0][0].detach().cpu().numpy(), rate=config.sampling_rate)

In [None]:
fig, ax = plt.subplots()
D = li.stft(y_audio[0][0].detach().cpu().numpy(), n_fft=1024, hop_length=256)
S_db = li.amplitude_to_db(np.abs(D), ref=np.max)
img = librosa.display.specshow(S_db, x_axis='time', y_axis='log', ax=ax, sr=config.sampling_rate, cmap='magma', hop_length=256)