In [1]:
import os 
import numpy as np
import pandas as pd
import librosa
import pyworld
import time
import shutil
import matplotlib.pyplot as plt

from tools import *
from model import *

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
data_dir = "../data/NTT_corevo"
figure_dir = "../figure/NTT_corevo/Classifier"
model_dir = "../model/NTT_corevo/Classifier"
model_name = [
    "Classifier_lr3-4_e10000_b4_label0",
    "Classifier_lr3-4_e10000_b4_label1",
    "Classifier_lr3-4_e10000_b4_label2",
    "Classifier_lr3-4_e10000_b4_label3",
    "Classifier_lr3-4_e10000_b4_label4",
    "Classifier_lr3-4_e10000_b4_label5",
    "Classifier_lr3-4_e10000_b4_child",
    "Classifier_lr3-4_e10000_b4_adult",
    "Classifier_lr3-4_e10000_b4_elder",
    "Classifier_lr3-4_e10000_b4_male",
    "Classifier_lr3-4_e10000_b4_female",
    
]
model_dir_vae = "../model/NTT_corevo/VAE"
model_name_vae = "VAE_lr3-4_e10000_b4"

In [3]:
seed_value = 0
np.random.seed(seed_value)
torch.manual_seed(seed_value)

<torch._C.Generator at 0x7f71c9cf88d0>

In [4]:
sampling_rate = 16000
num_mcep = 36
frame_period = 5.0
n_frames = 512
label_num = 6

In [5]:
def data_load(batch_size = 1, label = -1):
    data_list = []
    label_list = []
    
    if (label == -1):
        random_label = True 
    else:
        random_label =  False
        
    for i in range(batch_size):
        
        if random_label :
            label = np.random.randint(0, label_num)
            
        sample_data_dir = os.path.join(data_dir, "labeled/{:02}".format(label))
        file = np.random.choice(os.listdir(sample_data_dir))
        
        frames = 0
        count = 0
        while frames < n_frames:

            wav, _ = librosa.load(os.path.join(sample_data_dir, file), sr = sampling_rate, mono = True)
            wav = librosa.util.normalize(wav, norm=np.inf, axis=None)
            wav = wav_padding(wav = wav, sr = sampling_rate, frame_period = frame_period, multiple = 4)
            f0, timeaxis, sp, ap, mc = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period, num_mcep = num_mcep)

            if (count == 0):
                mc_transposed = np.array(mc).T
            else:
                mc_transposed = np.concatenate([mc_transposed, np.array(mc).T], axis =1)
            frames = np.shape(mc_transposed)[1]

            mean = np.mean(mc_transposed)
            std = np.std(mc_transposed)
            mc_norm = (mc_transposed - mean)/std

            count += 1

        start_ = np.random.randint(frames - n_frames + 1)
        end_ = start_ + n_frames
        
        data_list.append(mc_norm[:,start_:end_])
        label_list.append(label)

    return torch.Tensor(data_list).view(batch_size, 1, num_mcep, n_frames), torch.Tensor(label_list).view(batch_size, 1)


In [6]:
def save_figure(losses_list):
    if not os.path.exists(figure_dir):
            os.makedirs(figure_dir)
    losses_list = np.array(losses_list)
    plt.figure()
    for i in range(len(losses_list)):
        losses = losses_list[i]
        x = np.linspace(0, len(losses), len(losses))
        plt.plot(x, losses, label="label : {}".format(i))
        plt.legend(bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=0)
    plt.savefig(figure_dir + "/" + "_result.png")

In [7]:
def model_save(model, label):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    torch.save(model.state_dict(), os.path.join(model_dir, model_name[label]))
    
def model_load(label):
    model = Classifier()
    model.load_state_dict(torch.load(os.path.join(model_dir, model_name[label])))
    return model
    
def model_load_VAE():
    model = VAE()
    model.load_state_dict(torch.load(os.path.join(model_dir_vae, model_name_vae)))
    return model

In [8]:
learning_rate = 1e-3
learning_rate_ = 1e-4
num_epoch = 10000
batch_size = 4

num_label = 6

In [None]:
model_vae = model_load_VAE()
model_vae.eval()

VAE(
  (conv1): Conv2d(1, 8, kernel_size=(3, 9), stride=(1, 1), padding=(1, 4))
  (conv1_bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_gated): Conv2d(1, 8, kernel_size=(3, 9), stride=(1, 1), padding=(1, 4))
  (conv1_gated_bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_sigmoid): Sigmoid()
  (conv2): Conv2d(8, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv2_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_gated): Conv2d(8, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv2_gated_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_sigmoid): Sigmoid()
  (conv3): Conv2d(16, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv3_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3_gated): Conv2d(16, 16, kernel_size=(4, 8), stride=(2, 2),

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

losses_list = []

for label in range(num_label):

    model = Classifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    losses = []

    for epoch in range(num_epoch):
        epoch += 1
        lr = learning_rate_ *(1. / num_epoch) * (epoch) + learning_rate * (1. / num_epoch) * (num_epoch - epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        x_, label_ = data_load(batch_size)
        z_ = model_vae.predict(x_)
        
        optimizer.zero_grad()
        loss = model.calc_loss(z_, label, label_)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()

        print("Label {}  :  Epoch {}  :  Loss  {:.05}  :  LR {:.05}". format(label, epoch, loss.item(), lr))

    losses_list.append(losses)
    save_figure(losses_list)
    
    model_save(model, label)


cuda
Label 0  :  Epoch 1  :  Loss  0.71724  :  LR 0.00099991
Label 0  :  Epoch 2  :  Loss  0.65969  :  LR 0.00099982
Label 0  :  Epoch 3  :  Loss  0.66108  :  LR 0.00099973
Label 0  :  Epoch 4  :  Loss  0.64219  :  LR 0.00099964
Label 0  :  Epoch 5  :  Loss  0.68867  :  LR 0.00099955
Label 0  :  Epoch 6  :  Loss  0.83124  :  LR 0.00099946
Label 0  :  Epoch 7  :  Loss  0.70824  :  LR 0.00099937
Label 0  :  Epoch 8  :  Loss  0.67475  :  LR 0.00099928
Label 0  :  Epoch 9  :  Loss  0.68984  :  LR 0.00099919
Label 0  :  Epoch 10  :  Loss  0.63004  :  LR 0.0009991
Label 0  :  Epoch 11  :  Loss  0.71634  :  LR 0.00099901
Label 0  :  Epoch 12  :  Loss  0.63477  :  LR 0.00099892
Label 0  :  Epoch 13  :  Loss  0.63702  :  LR 0.00099883
Label 0  :  Epoch 14  :  Loss  0.65015  :  LR 0.00099874
Label 0  :  Epoch 15  :  Loss  0.96689  :  LR 0.00099865
Label 0  :  Epoch 16  :  Loss  0.82843  :  LR 0.00099856
Label 0  :  Epoch 17  :  Loss  0.62008  :  LR 0.00099847
Label 0  :  Epoch 18  :  Loss  0.853

Label 0  :  Epoch 145  :  Loss  0.40074  :  LR 0.00098695
Label 0  :  Epoch 146  :  Loss  0.3782  :  LR 0.00098686
Label 0  :  Epoch 147  :  Loss  0.53597  :  LR 0.00098677
Label 0  :  Epoch 148  :  Loss  0.89408  :  LR 0.00098668
Label 0  :  Epoch 149  :  Loss  0.39251  :  LR 0.00098659
Label 0  :  Epoch 150  :  Loss  0.61427  :  LR 0.0009865
Label 0  :  Epoch 151  :  Loss  0.384  :  LR 0.00098641
Label 0  :  Epoch 152  :  Loss  0.68266  :  LR 0.00098632
Label 0  :  Epoch 153  :  Loss  1.1504  :  LR 0.00098623
Label 0  :  Epoch 154  :  Loss  0.94472  :  LR 0.00098614
Label 0  :  Epoch 155  :  Loss  0.38117  :  LR 0.00098605
Label 0  :  Epoch 156  :  Loss  0.52416  :  LR 0.00098596
Label 0  :  Epoch 157  :  Loss  0.49876  :  LR 0.00098587
Label 0  :  Epoch 158  :  Loss  0.35558  :  LR 0.00098578
Label 0  :  Epoch 159  :  Loss  0.37203  :  LR 0.00098569
Label 0  :  Epoch 160  :  Loss  0.94648  :  LR 0.0009856
Label 0  :  Epoch 161  :  Loss  0.35661  :  LR 0.00098551
Label 0  :  Epoch 16

## the kind of classifiers
- label (1~6)
- female or male
- chilld or adult or elder

In [None]:
losses_list = []

for episode in range(5):
    
    # 0 : child
    # 1 : adult
    # 2 : elder
    # 3 : male
    # 4 : female

    model = Classifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    losses = []

    for epoch in range(num_epoch):
        epoch += 1
        lr = learning_rate_ *(1. / num_epoch) * (epoch) + learning_rate * (1. / num_epoch) * (num_epoch - epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        x_, label_ = data_load(batch_size)
        z_ = model_vae.predict(x_)
        for i in range(batch_size):
            # 0 : "FE_CH"
            # 1 : "FE_AD"
            # 2 : "FE_EL"
            # 3 : "MA_CH"
            # 4 : "MA_AD"
            # 5 : "MA_EL"
            if episode == 0:
                if label_[i] == 0 or label_[i] == 3:
                    label_[i] = 0
                else:
                    label_[i] = 1
            elif episode == 1:
                if label_[i] == 1 or label_[i] == 4:
                    label_[i] = 0
                else:
                    label_[i] = 1
            elif episode == 2:
                if label_[i] == 2 or label_[i] == 5:
                    label_[i] = 0
                else:
                    label_[i] = 1
            elif episode == 3:
                if label_ [i]== 3 or label_[i] == 4 or label_[i] == 5:
                    label_[i] = 0
                else:
                    label_[i] = 1
            elif episode == 4:
                if label_[i] == 0 or label_[i] == 1 or label_[i] == 2:
                    label_[i] = 0
                else:
                    label_[i] = 1
                    
        optimizer.zero_grad()
        loss = model.calc_loss(z_, 0, label_)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()

        print("Label {}  :  Epoch {}  :  Loss  {:.05}  :  LR {:.05}". format(label, epoch, loss.item(), lr))

    losses_list.append(losses)
    save_figure(losses_list)
    
    model_save(model, episode+6)