In [1]:
import os 
import numpy as np
import pandas as pd
import librosa
import pyworld
import time
import shutil
import matplotlib.pyplot as plt

from tools import *
from model import *

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [None]:
data_dir = "../data/NTT_corevo"
figure_dir = "../figure/NTT_corevo/Classifier"
model_dir = "../model/NTT_corevo/Classifier"
model_name = [
    "Classifier_lr3_e10000_b16_label0",
    "Classifier_lr3_e10000_b16_label1",
    "Classifier_lr3_e10000_b16_label2",
    "Classifier_lr3_e10000_b16_label3",
    "Classifier_lr3_e10000_b16_label4",
    "Classifier_lr3_e10000_b16_label5"
]
model_dir_vae = "../model/NTT_corevo/VAE"
model_name_vae = "VAE_lr3_e10000_b4"

In [3]:
seed_value = 0
np.random.seed(seed_value)
torch.manual_seed(seed_value)

<torch._C.Generator at 0x7f0f77f424f0>

In [4]:
sampling_rate = 16000
num_mcep = 36
frame_period = 5.0
n_frames = 1024
label_num = 6

In [5]:
def data_load(batch_size = 1, label = -1):
    data_list = []
    label_list = []
    
    if (label == -1):
        random_label = True 
    else:
        random_label =  False
        
    for i in range(batch_size):
        
        if random_label :
            label = np.random.randint(0, label_num)
            
        sample_data_dir = os.path.join(data_dir, "labeled/{:02}".format(label))
        file = np.random.choice(os.listdir(sample_data_dir))
        
        frames = 0
        count = 0
        while frames < n_frames:

            wav, _ = librosa.load(os.path.join(sample_data_dir, file), sr = sampling_rate, mono = True)
            wav = librosa.util.normalize(wav, norm=np.inf, axis=None)
            wav = wav_padding(wav = wav, sr = sampling_rate, frame_period = frame_period, multiple = 4)
            f0, timeaxis, sp, ap, mc = world_decompose(wav = wav, fs = sampling_rate, frame_period = frame_period, num_mcep = num_mcep)

            if (count == 0):
                mc_transposed = np.array(mc).T
            else:
                mc_transposed = np.concatenate([mc_transposed, np.array(mc).T], axis =1)
            frames = np.shape(mc_transposed)[1]

            mean = np.mean(mc_transposed)
            std = np.std(mc_transposed)
            mc_norm = (mc_transposed - mean)/std

            count += 1

        start_ = np.random.randint(frames - n_frames + 1)
        end_ = start_ + n_frames
        
        data_list.append(mc_norm[:,start_:end_])
        label_list.append(label)

    return torch.Tensor(data_list).view(batch_size, 1, num_mcep, n_frames), torch.Tensor(label_list).view(batch_size, 1)


In [6]:
def save_figure(losses_list):
    if not os.path.exists(figure_dir):
            os.makedirs(figure_dir)
    losses_list = np.array(losses_list)
    plt.figure()
    for i in range(len(losses_list)):
        losses = losses_list[i]
        x = np.linspace(0, len(losses), len(losses))
        plt.plot(x, losses, label="label : {}".format(i))
        plt.legend(bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=0)
    plt.savefig(figure_dir + "/" + "result.png")

In [7]:
def model_save(model, label):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    torch.save(model.state_dict(), os.path.join(model_dir, model_name[label]))
    
def model_load(label):
    model = Classifier()
    model.load_state_dict(torch.load(os.path.join(model_dir, model_name[label])))
    return model
    
def model_load_VAE():
    model = VAE()
    model.load_state_dict(torch.load(os.path.join(model_dir_vae, model_name_vae)))
    return model

In [8]:
learning_rate = 1e-3
learning_rate_ = 1e-4
num_epoch = 1000
batch_size = 16

num_label = 6

In [None]:
model_vae = model_load_VAE()
model_vae.eval()

VAE(
  (conv1): Conv2d(1, 8, kernel_size=(3, 9), stride=(1, 1), padding=(1, 4))
  (conv1_bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_gated): Conv2d(1, 8, kernel_size=(3, 9), stride=(1, 1), padding=(1, 4))
  (conv1_gated_bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_sigmoid): Sigmoid()
  (conv2): Conv2d(8, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv2_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_gated): Conv2d(8, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv2_gated_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_sigmoid): Sigmoid()
  (conv3): Conv2d(16, 16, kernel_size=(4, 8), stride=(2, 2), padding=(1, 3))
  (conv3_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3_gated): Conv2d(16, 16, kernel_size=(4, 8), stride=(2, 2),

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

losses_list = []

for label in range(num_label):

    model = Classifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    losses = []

    for epoch in range(num_epoch):
        epoch += 1
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate_ + learning_rate * (1. / num_epoch) * (num_epoch - epoch)

        x_, label_ = data_load(batch_size)
        z_ = model_vae.predict(x_)
        
        optimizer.zero_grad()
        loss = model.calc_loss(z_, label, label_)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()

        print("Label {}  :  Epoch {}  :  Loss  {}". format(label, epoch, loss.item()))

    losses_list.append(losses)
    save_figure(losses_list)
    
    model_save(model, label)


cuda
Label 0  :  Epoch 1  :  Loss  5.726112365722656
Label 0  :  Epoch 2  :  Loss  5.656563758850098
Label 0  :  Epoch 3  :  Loss  5.455935478210449
Label 0  :  Epoch 4  :  Loss  5.747557640075684
Label 0  :  Epoch 5  :  Loss  5.487583160400391
Label 0  :  Epoch 6  :  Loss  5.47265625
Label 0  :  Epoch 7  :  Loss  5.477867126464844
Label 0  :  Epoch 8  :  Loss  5.433572769165039
Label 0  :  Epoch 9  :  Loss  5.340226650238037
Label 0  :  Epoch 10  :  Loss  5.566501140594482
Label 0  :  Epoch 11  :  Loss  5.3289031982421875
Label 0  :  Epoch 12  :  Loss  5.377022743225098
Label 0  :  Epoch 13  :  Loss  5.424824237823486
Label 0  :  Epoch 14  :  Loss  5.386098384857178
Label 0  :  Epoch 15  :  Loss  5.800251483917236
Label 0  :  Epoch 16  :  Loss  5.531680107116699
Label 0  :  Epoch 17  :  Loss  5.974213123321533
Label 0  :  Epoch 18  :  Loss  5.3442487716674805
Label 0  :  Epoch 19  :  Loss  5.42458963394165
Label 0  :  Epoch 20  :  Loss  5.545693397521973
Label 0  :  Epoch 21  :  Loss 

Label 0  :  Epoch 167  :  Loss  2.6357674598693848
Label 0  :  Epoch 168  :  Loss  2.5490782260894775
Label 0  :  Epoch 169  :  Loss  4.916698455810547
Label 0  :  Epoch 170  :  Loss  4.504040241241455
Label 0  :  Epoch 171  :  Loss  2.508470296859741
Label 0  :  Epoch 172  :  Loss  2.448788642883301
Label 0  :  Epoch 173  :  Loss  4.314709663391113
Label 0  :  Epoch 174  :  Loss  3.5073928833007812
Label 0  :  Epoch 175  :  Loss  4.941216468811035
Label 0  :  Epoch 176  :  Loss  2.4356436729431152
Label 0  :  Epoch 177  :  Loss  2.3745970726013184
Label 0  :  Epoch 178  :  Loss  4.580811500549316
Label 0  :  Epoch 179  :  Loss  2.395322799682617
Label 0  :  Epoch 180  :  Loss  4.769852638244629
Label 0  :  Epoch 181  :  Loss  6.544195652008057
Label 0  :  Epoch 182  :  Loss  4.526068687438965
Label 0  :  Epoch 183  :  Loss  2.3282999992370605
Label 0  :  Epoch 184  :  Loss  4.423346996307373
Label 0  :  Epoch 185  :  Loss  4.839785099029541
Label 0  :  Epoch 186  :  Loss  4.7160229682

Label 0  :  Epoch 330  :  Loss  8.161238670349121
Label 0  :  Epoch 331  :  Loss  1.3957847356796265
Label 0  :  Epoch 332  :  Loss  1.3797165155410767
Label 0  :  Epoch 333  :  Loss  8.073201179504395
Label 0  :  Epoch 334  :  Loss  5.511494159698486
Label 0  :  Epoch 335  :  Loss  4.324766635894775
Label 0  :  Epoch 336  :  Loss  3.6421592235565186
Label 0  :  Epoch 337  :  Loss  3.2816972732543945
Label 0  :  Epoch 338  :  Loss  1.4385432004928589
Label 0  :  Epoch 339  :  Loss  5.1154561042785645
Label 0  :  Epoch 340  :  Loss  7.971086502075195
Label 0  :  Epoch 341  :  Loss  6.504565238952637
Label 0  :  Epoch 342  :  Loss  1.4016271829605103
Label 0  :  Epoch 343  :  Loss  1.3987314701080322
Label 0  :  Epoch 344  :  Loss  4.992334365844727
Label 0  :  Epoch 345  :  Loss  5.412891864776611
Label 0  :  Epoch 346  :  Loss  3.944627046585083
Label 0  :  Epoch 347  :  Loss  6.969985485076904
Label 0  :  Epoch 348  :  Loss  3.544969320297241
Label 0  :  Epoch 349  :  Loss  5.66298198

Label 0  :  Epoch 493  :  Loss  10.356945037841797
Label 0  :  Epoch 494  :  Loss  0.7261342406272888
Label 0  :  Epoch 495  :  Loss  5.35597562789917
Label 0  :  Epoch 496  :  Loss  4.50029182434082
Label 0  :  Epoch 497  :  Loss  4.310047149658203
Label 0  :  Epoch 498  :  Loss  4.867672920227051
Label 0  :  Epoch 499  :  Loss  15.338064193725586
Label 0  :  Epoch 500  :  Loss  0.6720958948135376
Label 0  :  Epoch 501  :  Loss  5.04562520980835
Label 0  :  Epoch 502  :  Loss  5.194514751434326
Label 0  :  Epoch 503  :  Loss  4.296910762786865
Label 0  :  Epoch 504  :  Loss  5.742684841156006
Label 0  :  Epoch 505  :  Loss  5.26560640335083
Label 0  :  Epoch 506  :  Loss  5.352705955505371
Label 0  :  Epoch 507  :  Loss  0.7071807384490967
Label 0  :  Epoch 508  :  Loss  4.37840461730957
Label 0  :  Epoch 509  :  Loss  5.546141147613525
Label 0  :  Epoch 510  :  Loss  5.121591091156006
Label 0  :  Epoch 511  :  Loss  5.674420356750488
Label 0  :  Epoch 512  :  Loss  0.661842942237854


Label 0  :  Epoch 656  :  Loss  6.415811061859131
Label 0  :  Epoch 657  :  Loss  0.3830257058143616
Label 0  :  Epoch 658  :  Loss  0.392039030790329
Label 0  :  Epoch 659  :  Loss  6.373258590698242
Label 0  :  Epoch 660  :  Loss  6.800322532653809
Label 0  :  Epoch 661  :  Loss  6.9670562744140625
Label 0  :  Epoch 662  :  Loss  0.39922213554382324
Label 0  :  Epoch 663  :  Loss  0.38861849904060364
Label 0  :  Epoch 664  :  Loss  0.40824511647224426
Label 0  :  Epoch 665  :  Loss  0.4165562093257904
Label 0  :  Epoch 666  :  Loss  0.3818776607513428
Label 0  :  Epoch 667  :  Loss  5.457451820373535
Label 0  :  Epoch 668  :  Loss  5.948068618774414
Label 0  :  Epoch 669  :  Loss  5.833586692810059
Label 0  :  Epoch 670  :  Loss  0.37054890394210815
Label 0  :  Epoch 671  :  Loss  0.3622702956199646
Label 0  :  Epoch 672  :  Loss  12.747671127319336
Label 0  :  Epoch 673  :  Loss  6.105894088745117
Label 0  :  Epoch 674  :  Loss  6.3774566650390625
Label 0  :  Epoch 675  :  Loss  6.4

Label 0  :  Epoch 818  :  Loss  8.691385269165039
Label 0  :  Epoch 819  :  Loss  6.370518684387207
Label 0  :  Epoch 820  :  Loss  13.865407943725586
Label 0  :  Epoch 821  :  Loss  0.29654213786125183
Label 0  :  Epoch 822  :  Loss  6.257419109344482
Label 0  :  Epoch 823  :  Loss  0.29010042548179626
Label 0  :  Epoch 824  :  Loss  13.855255126953125
Label 0  :  Epoch 825  :  Loss  6.471115589141846
Label 0  :  Epoch 826  :  Loss  7.280145168304443
Label 0  :  Epoch 827  :  Loss  0.27961966395378113
Label 0  :  Epoch 828  :  Loss  0.2846851646900177
Label 0  :  Epoch 829  :  Loss  12.596653938293457
Label 0  :  Epoch 830  :  Loss  7.495057582855225
Label 0  :  Epoch 831  :  Loss  0.28800225257873535
Label 0  :  Epoch 832  :  Loss  0.26821646094322205
Label 0  :  Epoch 833  :  Loss  6.719986438751221
Label 0  :  Epoch 834  :  Loss  6.137880325317383
Label 0  :  Epoch 835  :  Loss  0.2768911123275757
Label 0  :  Epoch 836  :  Loss  6.002377033233643
Label 0  :  Epoch 837  :  Loss  0.3

Label 0  :  Epoch 980  :  Loss  0.2437860667705536
Label 0  :  Epoch 981  :  Loss  14.22875690460205
Label 0  :  Epoch 982  :  Loss  0.23630645871162415
Label 0  :  Epoch 983  :  Loss  0.2491975873708725
Label 0  :  Epoch 984  :  Loss  0.25381800532341003
Label 0  :  Epoch 985  :  Loss  6.0369086265563965
Label 0  :  Epoch 986  :  Loss  7.096216201782227
Label 0  :  Epoch 987  :  Loss  0.2752039134502411
Label 0  :  Epoch 988  :  Loss  0.23426982760429382
Label 0  :  Epoch 989  :  Loss  0.24291613698005676
Label 0  :  Epoch 990  :  Loss  6.7336649894714355
Label 0  :  Epoch 991  :  Loss  0.2654569149017334
Label 0  :  Epoch 992  :  Loss  13.331853866577148
Label 0  :  Epoch 993  :  Loss  13.876363754272461
Label 0  :  Epoch 994  :  Loss  14.177189826965332
Label 0  :  Epoch 995  :  Loss  0.23729640245437622
Label 0  :  Epoch 996  :  Loss  8.23703670501709
Label 0  :  Epoch 997  :  Loss  13.536689758300781
Label 0  :  Epoch 998  :  Loss  0.23415546119213104
Label 0  :  Epoch 999  :  Los

Label 1  :  Epoch 146  :  Loss  2.626923084259033
Label 1  :  Epoch 147  :  Loss  4.483407497406006
Label 1  :  Epoch 148  :  Loss  2.577331781387329
Label 1  :  Epoch 149  :  Loss  4.568768501281738
Label 1  :  Epoch 150  :  Loss  2.5444083213806152
Label 1  :  Epoch 151  :  Loss  2.516195297241211
Label 1  :  Epoch 152  :  Loss  4.452020645141602
Label 1  :  Epoch 153  :  Loss  4.924588203430176
Label 1  :  Epoch 154  :  Loss  2.472895860671997
Label 1  :  Epoch 155  :  Loss  2.414358139038086
Label 1  :  Epoch 156  :  Loss  4.64703369140625
Label 1  :  Epoch 157  :  Loss  6.376899719238281
Label 1  :  Epoch 158  :  Loss  4.670575141906738
Label 1  :  Epoch 159  :  Loss  2.408500909805298
Label 1  :  Epoch 160  :  Loss  2.364827871322632
Label 1  :  Epoch 161  :  Loss  4.830861568450928
Label 1  :  Epoch 162  :  Loss  4.539597511291504
Label 1  :  Epoch 163  :  Loss  2.341322898864746
Label 1  :  Epoch 164  :  Loss  2.306396484375
Label 1  :  Epoch 165  :  Loss  4.344466686248779
Lab

Label 1  :  Epoch 309  :  Loss  1.2064141035079956
Label 1  :  Epoch 310  :  Loss  4.231778621673584
Label 1  :  Epoch 311  :  Loss  5.515091419219971
Label 1  :  Epoch 312  :  Loss  1.2189085483551025
Label 1  :  Epoch 313  :  Loss  4.348423957824707
Label 1  :  Epoch 314  :  Loss  6.120331287384033
Label 1  :  Epoch 315  :  Loss  8.106607437133789
Label 1  :  Epoch 316  :  Loss  1.144545316696167
Label 1  :  Epoch 317  :  Loss  4.638182163238525
Label 1  :  Epoch 318  :  Loss  4.727034568786621
Label 1  :  Epoch 319  :  Loss  1.140649676322937
Label 1  :  Epoch 320  :  Loss  4.436498165130615
Label 1  :  Epoch 321  :  Loss  4.647326946258545
Label 1  :  Epoch 322  :  Loss  1.1699620485305786
Label 1  :  Epoch 323  :  Loss  1.082218050956726
Label 1  :  Epoch 324  :  Loss  4.287473678588867
Label 1  :  Epoch 325  :  Loss  1.1254078149795532
Label 1  :  Epoch 326  :  Loss  1.0305148363113403
Label 1  :  Epoch 327  :  Loss  1.038556456565857
Label 1  :  Epoch 328  :  Loss  8.94701385498

Label 1  :  Epoch 472  :  Loss  0.4243825674057007
Label 1  :  Epoch 473  :  Loss  6.456613063812256
Label 1  :  Epoch 474  :  Loss  11.888083457946777
Label 1  :  Epoch 475  :  Loss  0.4059038758277893
Label 1  :  Epoch 476  :  Loss  6.262466907501221
Label 1  :  Epoch 477  :  Loss  6.561830997467041
Label 1  :  Epoch 478  :  Loss  0.40719348192214966
Label 1  :  Epoch 479  :  Loss  0.4025326371192932
Label 1  :  Epoch 480  :  Loss  12.047168731689453
Label 1  :  Epoch 481  :  Loss  0.3935282230377197
Label 1  :  Epoch 482  :  Loss  6.787137031555176
Label 1  :  Epoch 483  :  Loss  6.586402416229248
Label 1  :  Epoch 484  :  Loss  6.593989849090576
Label 1  :  Epoch 485  :  Loss  12.724185943603516
Label 1  :  Epoch 486  :  Loss  0.3935880661010742
Label 1  :  Epoch 487  :  Loss  0.3935615122318268
Label 1  :  Epoch 488  :  Loss  12.56721305847168
Label 1  :  Epoch 489  :  Loss  6.408730506896973
Label 1  :  Epoch 490  :  Loss  12.50107192993164
Label 1  :  Epoch 491  :  Loss  5.99628

Label 1  :  Epoch 634  :  Loss  0.23744447529315948
Label 1  :  Epoch 635  :  Loss  0.24767519533634186
Label 1  :  Epoch 636  :  Loss  6.938794136047363
Label 1  :  Epoch 637  :  Loss  14.504412651062012
Label 1  :  Epoch 638  :  Loss  7.10319709777832
Label 1  :  Epoch 639  :  Loss  0.23862677812576294
Label 1  :  Epoch 640  :  Loss  14.505327224731445
Label 1  :  Epoch 641  :  Loss  7.192025184631348
Label 1  :  Epoch 642  :  Loss  7.324878215789795
Label 1  :  Epoch 643  :  Loss  0.2368135005235672
Label 1  :  Epoch 644  :  Loss  0.24456025660037994
Label 1  :  Epoch 645  :  Loss  0.22943617403507233
Label 1  :  Epoch 646  :  Loss  0.23483207821846008
Label 1  :  Epoch 647  :  Loss  7.886693000793457
Label 1  :  Epoch 648  :  Loss  7.156430244445801
Label 1  :  Epoch 649  :  Loss  0.23486104607582092
Label 1  :  Epoch 650  :  Loss  7.899627685546875
Label 1  :  Epoch 651  :  Loss  0.2345493584871292
Label 1  :  Epoch 652  :  Loss  0.22556549310684204
Label 1  :  Epoch 653  :  Loss 

Label 1  :  Epoch 795  :  Loss  7.776064395904541
Label 1  :  Epoch 796  :  Loss  0.17535719275474548
Label 1  :  Epoch 797  :  Loss  7.837337493896484
Label 1  :  Epoch 798  :  Loss  23.31252098083496
Label 1  :  Epoch 799  :  Loss  0.17196814715862274
Label 1  :  Epoch 800  :  Loss  0.1806044578552246
Label 1  :  Epoch 801  :  Loss  0.17221936583518982
Label 1  :  Epoch 802  :  Loss  7.511679649353027
Label 1  :  Epoch 803  :  Loss  0.17944841086864471
Label 1  :  Epoch 804  :  Loss  0.17557013034820557
Label 1  :  Epoch 805  :  Loss  0.17953097820281982
Label 1  :  Epoch 806  :  Loss  7.965255260467529
Label 1  :  Epoch 807  :  Loss  0.17466667294502258
Label 1  :  Epoch 808  :  Loss  7.177225112915039
Label 1  :  Epoch 809  :  Loss  15.312322616577148
Label 1  :  Epoch 810  :  Loss  7.406091213226318
Label 1  :  Epoch 811  :  Loss  15.26274299621582
Label 1  :  Epoch 812  :  Loss  0.17904207110404968
Label 1  :  Epoch 813  :  Loss  0.19386908411979675
Label 1  :  Epoch 814  :  Loss

## the kind of classifiers
- label (1~6)
- female or male
- chilld or adult or elder