In [2]:
"""Calculate loss on test set and change sampling rate"""
import torch
from model.mamba import Mamba
from config import ModelParams, HyperParams
from torch.utils.data import DataLoader
from utils import AudioSegmentDataset, eval

device="cuda"
trained_sr = 48000

model = Mamba(ModelParams).to(device)
checkpoint = torch.load(f"results/S5_mamba_model_{HyperParams.name}.pth", map_location=device)
model.load_state_dict(checkpoint)

loader = DataLoader(AudioSegmentDataset('dataset/test', sr=str(trained_sr), p_zero=0.0), batch_size=HyperParams.batch_size, shuffle=True)
print(f"Losses on {trained_sr} Hz, trained with {trained_sr} Hz")
eval(model, loader, ModelParams, HyperParams, device)

sr = 44100
model.change_scale(trained_sr/sr)
loader = DataLoader(AudioSegmentDataset('dataset/test', sr=str(sr), p_zero=0.0), batch_size=HyperParams.batch_size, shuffle=True)
print(f"Losses on {sr} Hz, trained with {trained_sr} Hz")
eval(model, loader, ModelParams, HyperParams, device)


Dataset initialized with 175 files. Zero-sample probability: 0.00
Losses on 48000 Hz, trained with 48000 Hz


Evaluating: 100%|██████████| 6/6 [00:03<00:00,  1.67it/s]


ESR value: 0.0019671963527798653
ESR_dB value: -27.06152289458973
MR STFT value: 0.2501034736633301
Dataset initialized with 175 files. Zero-sample probability: 0.00
Losses on 44100 Hz, trained with 48000 Hz


Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.54it/s]

ESR value: 0.002255569212138653
ESR_dB value: -26.467438420361347
MR STFT value: 0.24015912413597107





In [5]:
import torch
from model.mamba import Mamba
from config import HyperParams, ModelParams
from utils import init_hidden

device = "cpu"
sr = 44100
model = Mamba(ModelParams)
mamba_weights_path = f"results/S5_mamba_model_{HyperParams.name}.pth"
model.load_state_dict(torch.load(mamba_weights_path, map_location=device))

# Prepare inputs
batch_size = 2
N = 8
#x = torch.randn(batch_size, N, 1, dtype=torch.float32)
#c = torch.randn(batch_size, ModelParams.c_dim, dtype=torch.float32)
x = torch.tensor([[0.0000000000, 0.3535531759, 0.5000000000, 0.3535540700, 0.0000012676, -0.3535522819, -0.5000000000, -0.3535549641],
                  [0.5000000000, 0.3535558581, 0.0000038028, -0.3535504639, -0.5000000000, -0.3535567522, -0.0000050704, 0.3535495698]], dtype=torch.float32).unsqueeze(-1)
c = torch.tensor([[0.8, -0.6], [0.8, -0.6]], dtype=torch.float32)
c = c.unsqueeze(1).expand(-1, N, -1)
h1, h2 = init_hidden(ModelParams.n_layers, batch_size, ModelParams.ssm_size, device)
step_rescale = 48000 / sr
model.change_scale(step_rescale)

with torch.no_grad():
    h1, h2 = init_hidden(ModelParams.n_layers, batch_size, ModelParams.ssm_size, device)
    y, (h1, h2) = model(x, (h1, h2), c)

print(f"\n8. Sample Output (first 8 timesteps, first batch, can compare to C++ implementation):")
print("-" * 25)
print(f"{'Timestep':<10} {'Parallel':<15}")
print("-" * 25)
for i in range(min(8, N)):
    p = y[0, i, 0].item()
    print(f"{i:<10} {p:<15.6f}")


8. Sample Output (first 8 timesteps, first batch, can compare to C++ implementation):
-------------------------
Timestep   Parallel       
-------------------------
0          0.000072       
1          0.011912       
2          0.030952       
3          0.056277       
4          0.062342       
5          0.054146       
6          0.023015       
7          -0.008738      


In [1]:
"""Get model weights"""
import torch
from model2json import model_2_json
from model.mamba import Mamba
from config import ModelParams, HyperParams

model = Mamba(ModelParams)
model.load_state_dict(torch.load(f"results/S5_mamba_model_{HyperParams.name}.pth", map_location='cpu'))
model_2_json(model)
print("Model weights saved as model_weights.json") # run "xxd -i model_weights.json > model_weights.h" to embed the json file in the plugin project

Model weights saved as model_weights.json
