In [1]:
import os 
import numpy as np
import pandas as pd
import librosa
import pyworld
import time
import shutil
import matplotlib.pyplot as plt

from tools import *
from model import *

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
data_dir = "../data/NTT_corevo"
figure_dir = "../figure/NTT_corevo/VAE"
model_dir = "../model/NTT_corevo/VAE"
model_name = "VAE_lr3_e10000_b4"

In [3]:
seed_value = 0
np.random.seed(seed_value)
torch.manual_seed(seed_value)

<torch._C.Generator at 0x7f4f85a82530>

In [4]:
sampling_rate = 16000
num_mcep = 36
frame_period = 5.0
n_frames = 1024
label_num = 6

In [5]:
def data_load(batch_size = 1, label = -1):
    data_list = []
    label_list = []
    
    if (label == -1):
        random_label = True 
    else:
        random_label =  False
        
    for i in range(batch_size):
        
        if random_label :
            label = np.random.randint(0, label_num)
            
        sample_data_dir = os.path.join(data_dir, "labeled/{:02}".format(label))
        file = np.random.choice(os.listdir(sample_data_dir))
        
        frames = 0
        count = 0
        while frames < n_frames:

            wav, _ = librosa.load(os.path.join(sample_data_dir, file), sr = sampling_rate, mono = True)
            wav = librosa.util.normalize(wav, norm=np.inf, axis=None)
            wav = wav_padding(wav = wav, sr = sampling_rate, frame_period = frame_period, multiple = 4)
            f0, timeaxis, sp, ap, mc = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period, num_mcep = num_mcep)

            if (count == 0):
                mc_transposed = np.array(mc).T
            else:
                mc_transposed = np.concatenate([mc_transposed, np.array(mc).T], axis =1)
            frames = np.shape(mc_transposed)[1]

            mean = np.mean(mc_transposed)
            std = np.std(mc_transposed)
            mc_norm = (mc_transposed - mean)/std

            count += 1

        start_ = np.random.randint(frames - n_frames + 1)
        end_ = start_ + n_frames

        data_list.append(mc_norm[:,start_:end_])
        label_list.append(label)

    return torch.Tensor(data_list).view(batch_size, 1, num_mcep, n_frames), torch.Tensor(label_list).view(batch_size, 1)


In [6]:
def save_figure(losses, epoch):
    if not os.path.exists(figure_dir):
            os.makedirs(figure_dir)
    losses = np.array(losses)
    x = np.linspace(0, len(losses), len(losses))
    plt.figure()
    plt.plot(x, losses, label="vae")
    plt.legend(bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=0)
    plt.savefig(figure_dir + "/" + "epoch_{:05}".format(epoch) + ".png")
    plt.savefig(figure_dir + "/" + "result.png")

In [7]:
def model_save(model, model_dir, model_name):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    torch.save(model.state_dict(), os.path.join(model_dir, model_name))
    
def model_load():
    model = VAE()
    model.load_state_dict(torch.load(os.path.join(model_dir, model_name)))
    return model

In [8]:
learning_rate = 1e-3
num_epoch = 10000
batch_size = 4

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train()

losses = []

for epoch in range(num_epoch):
    epoch += 1
    
    x_, label_ = data_load(batch_size)
    optimizer.zero_grad()
    loss = model.calc_loss(x_)
    loss.backward()
    losses.append(loss.item())
    optimizer.step()
    
    if epoch % 500 == 0:
        save_figure(losses, epoch)
        model_save(model, model_dir, model_name + "_{}".format(epoch))
        model_save(model, model_dir, model_name)

    print("Epoch {}  :  Loss  {}". format(epoch, loss.item()))

'\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\nprint(device)\n\nmodel = VAE().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n\nmodel.train()\n\nlosses = []\n\nfor epoch in range(num_epoch):\n    epoch += 1\n    \n    x_, label_ = data_load(batch_size)\n    optimizer.zero_grad()\n    loss = model.calc_loss(x_)\n    loss.backward()\n    losses.append(loss.item())\n    optimizer.step()\n    \n    if epoch % 500 == 0:\n        save_figure(losses, epoch)\n        model_save(model, model_dir, model_name + "_{}".format(epoch))\n        model_save(model, model_dir, model_name)\n\n    print("Epoch {}  :  Loss  {}". format(epoch, loss.item()))\n'