## Basic imports

In [1]:
import sys 
import os
import numpy as np 
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

print("Python: %s" % sys.version)
print("Pytorch: %s" % torch.__version__)

# determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#audioviz
import librosa as libr
import librosa.display as display
import IPython.display

import pandas as pd

Python: 3.6.5 (default, Jun 21 2018, 23:07:39) 
[GCC 5.4.0 20160609]
Pytorch: 0.4.0


## Hyperparameters

In [2]:
n_seconds = 3
n_epochs = 50
sampling_rate = 16000
number_of_mels =128
training_set = ['train-clean-100']
lr = 0.001

## Speech preprocessing
Buidling tensorToMFCC transformation for learning

In [3]:
class tensorToMFCC:
    def __call__(self, y):
#         y = y.numpy()
        dims = y.shape
        y = libr.feature.melspectrogram(np.reshape(y, (dims[1],)), 16000, n_mels=number_of_mels,
                               fmax=8000)
        y = libr.feature.mfcc(S = libr.power_to_db(y))
        y = torch.from_numpy(y)                           
        return y.float()

In [4]:
transform  = tensorToMFCC()

## LibriSpeechDataSet
Load personalized data set, inspred by this [repository](https://github.com/oscarknagg/voicemap/tree/pytorch-python-3.6)

In [8]:
%load_ext autoreload
%autoreload 2
sys.path.insert(0, './../../Utils')
from datasets import LibriSpeechDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
path = '../../../Datasets/data/'
valid_sequence = LibriSpeechDataset(path, training_set, n_seconds, downsampling=1, 
                                    transform = transform, stochastic=False, )

Initialising LibriSpeechDataset with minimum length = 3s and subsets = ['train-clean-100']
Finished indexing data. 27853 usable files found.


In [11]:
train_loader = DataLoader(valid_sequence,
                      batch_size=32,
                      shuffle=True,
                      num_workers=8
                     # pin_memory=True # CUDA only
                     )

In [12]:
recording, speaker  = iter(train_loader).next()

In [13]:
print(recording.shape)
print(valid_sequence.datasetid_to_speaker_id)

torch.Size([32, 20, 94])
{0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 19, 8: 19, 9: 19, 10: 19, 11: 19, 12: 19, 13: 19, 14: 19, 15: 19, 16: 19, 17: 19, 18: 19, 19: 19, 20: 19, 21: 19, 22: 19, 23: 19, 24: 19, 25: 19, 26: 19, 27: 19, 28: 19, 29: 19, 30: 19, 31: 19, 32: 19, 33: 19, 34: 19, 35: 19, 36: 19, 37: 19, 38: 19, 39: 19, 40: 19, 41: 19, 42: 19, 43: 19, 44: 19, 45: 19, 46: 19, 47: 19, 48: 19, 49: 19, 50: 19, 51: 19, 52: 19, 53: 19, 54: 19, 55: 19, 56: 19, 57: 19, 58: 19, 59: 19, 60: 19, 61: 19, 62: 19, 63: 19, 64: 19, 65: 19, 66: 19, 67: 19, 68: 19, 69: 19, 70: 19, 71: 19, 72: 19, 73: 19, 74: 19, 75: 19, 76: 19, 77: 19, 78: 19, 79: 19, 80: 19, 81: 19, 82: 19, 83: 19, 84: 19, 85: 19, 86: 19, 87: 19, 88: 19, 89: 19, 90: 19, 91: 19, 92: 19, 93: 19, 94: 19, 95: 19, 96: 19, 97: 19, 98: 19, 99: 19, 100: 19, 101: 19, 102: 19, 103: 19, 104: 19, 105: 19, 106: 19, 107: 19, 108: 19, 109: 26, 110: 26, 111: 26, 112: 26, 113: 26, 114: 26, 115: 26, 116: 26, 117: 26, 118: 26, 119: 26, 120:

## Cyphercat utilities

In [29]:
sys.path.insert(0,'../cyphercat/Utils/')
from train import *
from metrics import * 
from data_downloaders import * 

## Models

In [117]:
class ConvBlock(nn.Module):
    def __init__(self, n_input, n_out, kernel_size):
        super(ConvBlock, self).__init__()
        self.cnn_block = nn.Sequential(
            nn.Conv1d(n_input, n_out, kernel_size, padding=1),
            nn.BatchNorm1d(n_out),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4)
        )
    
    def forward(self, x):
        return self.cnn_block(x)


class CNN_classifier(nn.Module):
    def __init__(self, in_size, n_hidden, n_classes):
        super(CNN_classifier, self).__init__()
        self.down_path = nn.ModuleList()
        self.down_path.append(ConvBlock(in_size, 2*in_size, 3))
        self.down_path.append(ConvBlock(2*in_size, 4*in_size, 3))
        self.down_path.append(ConvBlock(4*in_size, 8*in_size, 3))
        self.fc = nn.Sequential(
            nn.Linear(8*in_size, n_hidden),
            nn.ReLU()
        )
        self.out = nn.Linear(n_hidden, n_classes)
    def forward(self, x):
        for down in self.down_path:
            x = down(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.out(x)
        

In [118]:
test = ConvBlock(20, 40, 3)
aa = test(recording)
print(aa.shape)

torch.Size([32, 40, 23])


In [119]:
classifier = CNN_classifier(20, 512, 250)
classifier.to(device)

CNN_classifier(
  (down_path): ModuleList(
    (0): ConvBlock(
      (cnn_block): Sequential(
        (0): Conv1d(20, 40, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): ConvBlock(
      (cnn_block): Sequential(
        (0): Conv1d(40, 80, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): ConvBlock(
      (cnn_block): Sequential(
        (0): Conv1d(80, 160, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool1d(kernel_size=4, st

In [120]:
test = classifier(recording.to(device))
print(test.shape)

torch.Size([32, 250])


In [121]:
optimizer = optim.Adam(classifier.parameters(), lr)
criterion = nn.CrossEntropyLoss()

In [123]:
train(classifier, train_loader, train_loader, optimizer, criterion, 4, verbose = True)

[0/4][0/871] loss = 0.606142
[0/4][1/871] loss = 0.833479
[0/4][2/871] loss = 0.870916
[0/4][3/871] loss = 0.556991
[0/4][4/871] loss = 0.677827
[0/4][5/871] loss = 0.879278
[0/4][6/871] loss = 1.123237
[0/4][7/871] loss = 1.146893
[0/4][8/871] loss = 0.469473
[0/4][9/871] loss = 0.706817
[0/4][10/871] loss = 0.670719
[0/4][11/871] loss = 0.624397
[0/4][12/871] loss = 0.965856
[0/4][13/871] loss = 0.546783
[0/4][14/871] loss = 1.136407
[0/4][15/871] loss = 0.628688
[0/4][16/871] loss = 0.764205
[0/4][17/871] loss = 0.821027
[0/4][18/871] loss = 0.553897
[0/4][19/871] loss = 0.972488
[0/4][20/871] loss = 0.857228
[0/4][21/871] loss = 0.763063
[0/4][22/871] loss = 0.444977
[0/4][23/871] loss = 0.800634
[0/4][24/871] loss = 0.752950
[0/4][25/871] loss = 0.470171
[0/4][26/871] loss = 0.918299
[0/4][27/871] loss = 0.767410
[0/4][28/871] loss = 0.622329
[0/4][29/871] loss = 1.141006
[0/4][30/871] loss = 0.387066
[0/4][31/871] loss = 0.722553
[0/4][32/871] loss = 0.604469
[0/4][33/871] loss =

[0/4][272/871] loss = 0.500158
[0/4][273/871] loss = 0.095075
[0/4][274/871] loss = 0.441360
[0/4][275/871] loss = 0.731188
[0/4][276/871] loss = 0.506123
[0/4][277/871] loss = 0.913871
[0/4][278/871] loss = 0.763265
[0/4][279/871] loss = 0.567514
[0/4][280/871] loss = 0.840574
[0/4][281/871] loss = 0.487565
[0/4][282/871] loss = 0.686182
[0/4][283/871] loss = 0.521310
[0/4][284/871] loss = 0.552009
[0/4][285/871] loss = 0.559339
[0/4][286/871] loss = 0.823292
[0/4][287/871] loss = 0.521465
[0/4][288/871] loss = 0.736146
[0/4][289/871] loss = 0.447685
[0/4][290/871] loss = 0.555761
[0/4][291/871] loss = 0.474333
[0/4][292/871] loss = 0.382764
[0/4][293/871] loss = 0.874279
[0/4][294/871] loss = 0.623081
[0/4][295/871] loss = 0.482669
[0/4][296/871] loss = 0.704053
[0/4][297/871] loss = 0.604503
[0/4][298/871] loss = 1.219313
[0/4][299/871] loss = 0.665884
[0/4][300/871] loss = 0.744529
[0/4][301/871] loss = 0.447657
[0/4][302/871] loss = 0.465703
[0/4][303/871] loss = 0.882265
[0/4][30

[0/4][537/871] loss = 0.848495
[0/4][538/871] loss = 0.595367
[0/4][539/871] loss = 0.382260
[0/4][540/871] loss = 0.875940
[0/4][541/871] loss = 1.085059
[0/4][542/871] loss = 0.856184
[0/4][543/871] loss = 0.954956
[0/4][544/871] loss = 0.321076
[0/4][545/871] loss = 0.823223
[0/4][546/871] loss = 0.429673
[0/4][547/871] loss = 0.511656
[0/4][548/871] loss = 0.225589
[0/4][549/871] loss = 0.644203
[0/4][550/871] loss = 0.368731
[0/4][551/871] loss = 0.752738
[0/4][552/871] loss = 0.715618
[0/4][553/871] loss = 0.500563
[0/4][554/871] loss = 0.779583
[0/4][555/871] loss = 0.650141
[0/4][556/871] loss = 0.411600
[0/4][557/871] loss = 0.358216
[0/4][558/871] loss = 0.746149
[0/4][559/871] loss = 0.534022
[0/4][560/871] loss = 0.340913
[0/4][561/871] loss = 1.712141
[0/4][562/871] loss = 0.551709
[0/4][563/871] loss = 0.723158
[0/4][564/871] loss = 0.816231
[0/4][565/871] loss = 0.569625
[0/4][566/871] loss = 0.657191
[0/4][567/871] loss = 0.514804
[0/4][568/871] loss = 0.421452
[0/4][56

[0/4][805/871] loss = 0.373591
[0/4][806/871] loss = 0.648297
[0/4][807/871] loss = 0.944637
[0/4][808/871] loss = 0.569335
[0/4][809/871] loss = 0.417899
[0/4][810/871] loss = 0.377177
[0/4][811/871] loss = 0.325487
[0/4][812/871] loss = 0.702675
[0/4][813/871] loss = 0.513794
[0/4][814/871] loss = 0.896940
[0/4][815/871] loss = 0.404514
[0/4][816/871] loss = 0.354605
[0/4][817/871] loss = 0.546188
[0/4][818/871] loss = 0.444403
[0/4][819/871] loss = 0.654827
[0/4][820/871] loss = 0.366370
[0/4][821/871] loss = 0.568579
[0/4][822/871] loss = 0.404767
[0/4][823/871] loss = 0.737543
[0/4][824/871] loss = 0.740360
[0/4][825/871] loss = 0.420372
[0/4][826/871] loss = 0.472826
[0/4][827/871] loss = 0.483720
[0/4][828/871] loss = 0.518994
[0/4][829/871] loss = 0.575969
[0/4][830/871] loss = 0.626409
[0/4][831/871] loss = 0.644930
[0/4][832/871] loss = 0.243360
[0/4][833/871] loss = 0.454085
[0/4][834/871] loss = 0.557265
[0/4][835/871] loss = 0.434871
[0/4][836/871] loss = 0.290109
[0/4][83

[1/4][201/871] loss = 0.314550
[1/4][202/871] loss = 0.494019
[1/4][203/871] loss = 0.230780
[1/4][204/871] loss = 0.295290
[1/4][205/871] loss = 0.289952
[1/4][206/871] loss = 0.376882
[1/4][207/871] loss = 0.522677
[1/4][208/871] loss = 0.434045
[1/4][209/871] loss = 0.479376
[1/4][210/871] loss = 0.420578
[1/4][211/871] loss = 0.346392
[1/4][212/871] loss = 0.256478
[1/4][213/871] loss = 0.512533
[1/4][214/871] loss = 0.192977
[1/4][215/871] loss = 0.236449
[1/4][216/871] loss = 0.171530
[1/4][217/871] loss = 0.688737
[1/4][218/871] loss = 0.467985
[1/4][219/871] loss = 0.398610
[1/4][220/871] loss = 0.394310
[1/4][221/871] loss = 0.608255
[1/4][222/871] loss = 0.453971
[1/4][223/871] loss = 0.567919
[1/4][224/871] loss = 0.574748
[1/4][225/871] loss = 0.465908
[1/4][226/871] loss = 0.296976
[1/4][227/871] loss = 0.247070
[1/4][228/871] loss = 0.157715
[1/4][229/871] loss = 0.283309
[1/4][230/871] loss = 0.352053
[1/4][231/871] loss = 0.466960
[1/4][232/871] loss = 0.238433
[1/4][23

[1/4][470/871] loss = 0.289295
[1/4][471/871] loss = 0.463939
[1/4][472/871] loss = 0.527691
[1/4][473/871] loss = 0.349960
[1/4][474/871] loss = 0.840998
[1/4][475/871] loss = 0.211262
[1/4][476/871] loss = 0.202573
[1/4][477/871] loss = 0.348726
[1/4][478/871] loss = 0.900509
[1/4][479/871] loss = 0.163034
[1/4][480/871] loss = 0.346225
[1/4][481/871] loss = 0.334680
[1/4][482/871] loss = 0.364963
[1/4][483/871] loss = 0.346305
[1/4][484/871] loss = 0.416859
[1/4][485/871] loss = 0.475186
[1/4][486/871] loss = 0.345425
[1/4][487/871] loss = 0.202526
[1/4][488/871] loss = 0.281622
[1/4][489/871] loss = 0.421336
[1/4][490/871] loss = 0.463623
[1/4][491/871] loss = 0.274433
[1/4][492/871] loss = 0.213852
[1/4][493/871] loss = 0.178241
[1/4][494/871] loss = 0.057315
[1/4][495/871] loss = 0.457868
[1/4][496/871] loss = 0.241155
[1/4][497/871] loss = 0.232324
[1/4][498/871] loss = 0.254899
[1/4][499/871] loss = 0.172568
[1/4][500/871] loss = 0.329944
[1/4][501/871] loss = 0.239781
[1/4][50

[1/4][736/871] loss = 0.295902
[1/4][737/871] loss = 0.265833
[1/4][738/871] loss = 0.284050
[1/4][739/871] loss = 0.447272
[1/4][740/871] loss = 0.563723
[1/4][741/871] loss = 0.607190
[1/4][742/871] loss = 0.520896
[1/4][743/871] loss = 0.579802
[1/4][744/871] loss = 0.215276
[1/4][745/871] loss = 0.176168
[1/4][746/871] loss = 0.221822
[1/4][747/871] loss = 0.332656
[1/4][748/871] loss = 0.182341
[1/4][749/871] loss = 0.308443
[1/4][750/871] loss = 0.382852
[1/4][751/871] loss = 0.229601
[1/4][752/871] loss = 0.208584
[1/4][753/871] loss = 0.322607
[1/4][754/871] loss = 0.343681
[1/4][755/871] loss = 0.229473
[1/4][756/871] loss = 0.522615
[1/4][757/871] loss = 0.416594
[1/4][758/871] loss = 0.252918
[1/4][759/871] loss = 0.333833
[1/4][760/871] loss = 0.509518
[1/4][761/871] loss = 0.376913
[1/4][762/871] loss = 0.333532
[1/4][763/871] loss = 0.138522
[1/4][764/871] loss = 1.096651
[1/4][765/871] loss = 0.337384
[1/4][766/871] loss = 0.138108
[1/4][767/871] loss = 0.495101
[1/4][76

[2/4][133/871] loss = 0.229496
[2/4][134/871] loss = 0.089417
[2/4][135/871] loss = 0.163339
[2/4][136/871] loss = 0.163436
[2/4][137/871] loss = 0.094495
[2/4][138/871] loss = 0.932701
[2/4][139/871] loss = 0.258320
[2/4][140/871] loss = 0.302737
[2/4][141/871] loss = 0.253386
[2/4][142/871] loss = 0.184534
[2/4][143/871] loss = 0.125260
[2/4][144/871] loss = 0.059689
[2/4][145/871] loss = 0.099678
[2/4][146/871] loss = 0.317738
[2/4][147/871] loss = 0.237657
[2/4][148/871] loss = 0.121846
[2/4][149/871] loss = 0.172290
[2/4][150/871] loss = 0.149948
[2/4][151/871] loss = 0.334070
[2/4][152/871] loss = 0.232470
[2/4][153/871] loss = 0.139506
[2/4][154/871] loss = 0.697997
[2/4][155/871] loss = 0.151884
[2/4][156/871] loss = 0.272545
[2/4][157/871] loss = 0.324880
[2/4][158/871] loss = 0.158603
[2/4][159/871] loss = 0.322109
[2/4][160/871] loss = 0.451528
[2/4][161/871] loss = 0.235000
[2/4][162/871] loss = 0.115819
[2/4][163/871] loss = 0.376688
[2/4][164/871] loss = 0.157532
[2/4][16

[2/4][400/871] loss = 0.365092
[2/4][401/871] loss = 0.179471
[2/4][402/871] loss = 0.151550
[2/4][403/871] loss = 0.184824
[2/4][404/871] loss = 0.292933
[2/4][405/871] loss = 0.104059
[2/4][406/871] loss = 0.185899
[2/4][407/871] loss = 0.194509
[2/4][408/871] loss = 0.109068
[2/4][409/871] loss = 0.306228
[2/4][410/871] loss = 0.179180
[2/4][411/871] loss = 0.363191
[2/4][412/871] loss = 0.376004
[2/4][413/871] loss = 0.608052
[2/4][414/871] loss = 0.258195
[2/4][415/871] loss = 0.376156
[2/4][416/871] loss = 0.497192
[2/4][417/871] loss = 0.122754
[2/4][418/871] loss = 0.521612
[2/4][419/871] loss = 0.105423
[2/4][420/871] loss = 0.384308
[2/4][421/871] loss = 0.133348
[2/4][422/871] loss = 0.090323
[2/4][423/871] loss = 0.154178
[2/4][424/871] loss = 0.131603
[2/4][425/871] loss = 0.191704
[2/4][426/871] loss = 0.533507
[2/4][427/871] loss = 0.041117
[2/4][428/871] loss = 0.217448
[2/4][429/871] loss = 0.207607
[2/4][430/871] loss = 0.196674
[2/4][431/871] loss = 0.266562
[2/4][43

[2/4][670/871] loss = 0.472562
[2/4][671/871] loss = 0.143917
[2/4][672/871] loss = 0.334205
[2/4][673/871] loss = 0.562962
[2/4][674/871] loss = 0.118856
[2/4][675/871] loss = 0.242816
[2/4][676/871] loss = 0.326814
[2/4][677/871] loss = 0.248060
[2/4][678/871] loss = 0.116389
[2/4][679/871] loss = 0.215037
[2/4][680/871] loss = 0.225134
[2/4][681/871] loss = 0.120423
[2/4][682/871] loss = 0.141040
[2/4][683/871] loss = 0.567153
[2/4][684/871] loss = 0.282903
[2/4][685/871] loss = 0.454524
[2/4][686/871] loss = 0.196210
[2/4][687/871] loss = 0.380043
[2/4][688/871] loss = 0.221285
[2/4][689/871] loss = 0.205020
[2/4][690/871] loss = 0.231327
[2/4][691/871] loss = 0.114154
[2/4][692/871] loss = 0.211775
[2/4][693/871] loss = 0.306555
[2/4][694/871] loss = 0.297767
[2/4][695/871] loss = 0.276443
[2/4][696/871] loss = 0.214291
[2/4][697/871] loss = 0.426396
[2/4][698/871] loss = 0.294207
[2/4][699/871] loss = 0.233862
[2/4][700/871] loss = 0.233752
[2/4][701/871] loss = 0.030407
[2/4][70

[3/4][64/871] loss = 0.104583
[3/4][65/871] loss = 0.226405
[3/4][66/871] loss = 0.144496
[3/4][67/871] loss = 0.294540
[3/4][68/871] loss = 0.256875
[3/4][69/871] loss = 0.082528
[3/4][70/871] loss = 0.052861
[3/4][71/871] loss = 0.033519
[3/4][72/871] loss = 0.223120
[3/4][73/871] loss = 0.293435
[3/4][74/871] loss = 0.056356
[3/4][75/871] loss = 0.235800
[3/4][76/871] loss = 0.343708
[3/4][77/871] loss = 0.070114
[3/4][78/871] loss = 0.075021
[3/4][79/871] loss = 0.159329
[3/4][80/871] loss = 0.160795
[3/4][81/871] loss = 0.346019
[3/4][82/871] loss = 0.351157
[3/4][83/871] loss = 0.319737
[3/4][84/871] loss = 0.285064
[3/4][85/871] loss = 0.206135
[3/4][86/871] loss = 0.214927
[3/4][87/871] loss = 0.175868
[3/4][88/871] loss = 0.188105
[3/4][89/871] loss = 0.240414
[3/4][90/871] loss = 0.522637
[3/4][91/871] loss = 0.177558
[3/4][92/871] loss = 0.470265
[3/4][93/871] loss = 0.350297
[3/4][94/871] loss = 0.171116
[3/4][95/871] loss = 0.044652
[3/4][96/871] loss = 0.169817
[3/4][97/8

[3/4][335/871] loss = 0.135348
[3/4][336/871] loss = 0.018406
[3/4][337/871] loss = 0.095358
[3/4][338/871] loss = 0.028901
[3/4][339/871] loss = 0.147231
[3/4][340/871] loss = 0.249188
[3/4][341/871] loss = 0.427700
[3/4][342/871] loss = 0.119276
[3/4][343/871] loss = 0.350552
[3/4][344/871] loss = 0.092043
[3/4][345/871] loss = 0.120730
[3/4][346/871] loss = 0.123452
[3/4][347/871] loss = 0.527358
[3/4][348/871] loss = 0.229697
[3/4][349/871] loss = 0.182891
[3/4][350/871] loss = 0.240843
[3/4][351/871] loss = 0.384003
[3/4][352/871] loss = 0.280371
[3/4][353/871] loss = 0.113897
[3/4][354/871] loss = 0.106368
[3/4][355/871] loss = 0.073120
[3/4][356/871] loss = 0.090883
[3/4][357/871] loss = 0.078180
[3/4][358/871] loss = 0.239889
[3/4][359/871] loss = 0.147764
[3/4][360/871] loss = 0.106978
[3/4][361/871] loss = 0.130246
[3/4][362/871] loss = 0.420814
[3/4][363/871] loss = 0.405921
[3/4][364/871] loss = 0.188517
[3/4][365/871] loss = 0.260480
[3/4][366/871] loss = 0.040405
[3/4][36

[3/4][601/871] loss = 0.281836
[3/4][602/871] loss = 0.111690
[3/4][603/871] loss = 0.142812
[3/4][604/871] loss = 0.181896
[3/4][605/871] loss = 0.187117
[3/4][606/871] loss = 0.249034
[3/4][607/871] loss = 0.298755
[3/4][608/871] loss = 0.298142
[3/4][609/871] loss = 0.200690
[3/4][610/871] loss = 0.138644
[3/4][611/871] loss = 0.191181
[3/4][612/871] loss = 0.058599
[3/4][613/871] loss = 0.262951
[3/4][614/871] loss = 0.054302
[3/4][615/871] loss = 0.120169
[3/4][616/871] loss = 0.071738
[3/4][617/871] loss = 0.466623
[3/4][618/871] loss = 0.231511
[3/4][619/871] loss = 0.588600
[3/4][620/871] loss = 0.258341
[3/4][621/871] loss = 0.159001
[3/4][622/871] loss = 0.014511
[3/4][623/871] loss = 0.210985
[3/4][624/871] loss = 0.708283
[3/4][625/871] loss = 0.247127
[3/4][626/871] loss = 0.262272
[3/4][627/871] loss = 0.101579
[3/4][628/871] loss = 0.243825
[3/4][629/871] loss = 0.344876
[3/4][630/871] loss = 0.123957
[3/4][631/871] loss = 0.078010
[3/4][632/871] loss = 0.174808
[3/4][63


Accuracy = 95.71 %


Test:

Accuracy = 95.71 %




## Results
### Set-up
- Audio fetures MFCC
- 5 eposh training
- 3 second recordings
- Adam optimizer
- lr = 0.001
### Performance
- 95.71 accuracu traiing

In [124]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [125]:
save_checkpoint({
            'epoch': 5,
            'arch': 'CNN_voice_classifier',
            'state_dict': classifier.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, False, filename = 'model_weights/CNN_voice_classifier_5.pth')