In [1]:
from model.model import SpecVAE
from data_loader.data_loaders import EscDataLoader
from dataset import transformers
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F
from torch.optim import Adam

In [100]:
input_size = [64, 15]
latent_dim = 32
is_featExtract = False
n_convLayer = 3
n_convChannel = [32, 16, 8]
filter_size = [1, 3, 3]
stride = [1, 2, 2]
n_fcLayer = 1
n_fcChannel = [256]
activation = "leaky_relu"

model = SpecVAE(input_size=input_size, latent_dim=latent_dim, is_featExtract=is_featExtract, n_convLayer=n_convLayer,
                n_convChannel=n_convChannel, filter_size=filter_size, stride=stride, n_fcLayer=n_fcLayer,
                n_fcChannel=n_fcChannel, activation=activation)

data_dir = "~/data/esc/esc10-spectro1/data"
path_to_meta = "~/data/esc/meta/esc10.csv"
batch_size = 32
shuffle = True
validation_split = 0
num_workers = 0
folds = [1, 2, 3, 4]
samples = None

dataloader = EscDataLoader(data_dir=data_dir, path_to_meta=path_to_meta, batch_size=batch_size,
                           shuffle=shuffle, validation_split=validation_split, num_workers=num_workers,
                           folds=folds, samples=samples)
idx, label, data = next(iter(dataloader))

In [101]:
x = data.reshape(-1, data.size(2), data.size(3))[:2]
#model.encoder(x).size(), model.flat_size
mu, logvar, z = model.encode(x)
x_recon = model.decode(z)
z.size(), x_recon.size()

(torch.Size([2, 32]), torch.Size([2, 64, 15]))

In [102]:
dim_to_sum = list(range(1, len(x.size())))
loss = F.mse_loss(x, x_recon, reduction="none").sum(dim=dim_to_sum)
loss, loss.size()

(tensor([3521315.5000, 2430733.0000], grad_fn=<SumBackward1>), torch.Size([2]))

In [103]:
n_iter = 1000

optimizer = Adam(model.parameters(), lr=0.1)
for i in range(n_iter):
    optimizer.zero_grad()
    x_recon, *_ = model(x)
    loss = F.mse_loss(x, x_recon, reduction="none").sum(dim=dim_to_sum)
    loss = loss.mean()
    loss.backward()
    optimizer.step()
    if not i % 100:
        print(f"iter: {i}, loss: {loss.item():.2f}")

iter: 0, loss: 2975407.50
iter: 100, loss: 854617.19
iter: 200, loss: 123905.02
iter: 300, loss: 88247.94
iter: 400, loss: 56548.87
iter: 500, loss: 48453.43
iter: 600, loss: 30070.66
iter: 700, loss: 21510.47
iter: 800, loss: 18699.83
iter: 900, loss: 30537.47


In [116]:
arr = transformers.LoadNumpyAry()("/Users/Dovermore/data/esc/esc10-spectro1/data/1-4211-A-12.npy")
x = transformers.SpecChunking(duration=2.5, sr=22050, hop_size=735, reverse=False)(arr)

In [117]:
arr.shape, x.shape

((64, 151), (2, 64, 75))