<a href="https://colab.research.google.com/github/AyishaR/Spokendigit/blob/master/Spokendigit_Five_features.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
from torch.utils.data import Dataset, random_split, DataLoader, TensorDataset
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import tarfile
import os
import librosa
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import librosa.display
import sklearn
import matplotlib
import csv
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score

#Data

In [4]:
digit = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']

#Dataset 



Spoken_mnist_finalfts.csv is generated in 'Spokenmnist feature extraction.ipynb'.

In [9]:
finalfts = pd.read_csv("Spoken_digit_finalfts.csv")

In [6]:
from sklearn.preprocessing import StandardScaler
scale = StandardScaler()
finalfts[finalfts.columns[1:]] = scale.fit_transform(finalfts[finalfts.columns[1:]])

In [10]:
spokendset = TensorDataset(torch.tensor(np.array(finalfts[finalfts.columns[1:]].astype('float32'))),torch.tensor(finalfts['Label'])) 

In [11]:
# 90-10 split
size = len(spokendset)
val_size = int(0.1 * size)
train_size = size - val_size 

train_dset, val_dset = random_split(spokendset, [train_size, val_size])

train_size, val_size

(21298, 2366)

In [12]:
train_dl = DataLoader(train_dset, 512, True)
val_dl = DataLoader(val_dset, 512)

# Device


In [13]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

In [14]:
device = get_default_device()
device

device(type='cpu')

In [15]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)

#Train

In [16]:
class SpokenDigitModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(173, 1024)
        self.l2 = nn.Linear(1024, 512)
        self.l3 = nn.Linear(512, 64)
        self.l4 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = self.l4(x)
        return x

    def training_step(self, batch):
        inputs, labels = batch
        outputs = self(inputs)
        loss = F.cross_entropy(outputs, labels)
        return loss

    def validation_step(self, batch):
        inputs, labels = batch
        outputs = self(inputs)
        loss = F.cross_entropy(outputs, labels)
        _, pred = torch.max(outputs, 1)
        accuracy = torch.tensor(torch.sum(pred==labels).item()/len(pred))
        return [loss.detach(), accuracy.detach()] 

In [17]:
def evaluate(model, loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in tqdm(loader)]
    outputs = torch.tensor(outputs).T
    loss, accuracy = torch.mean(outputs, dim=1)
    return {"loss" : loss.item(), "accuracy" : accuracy.item()}

In [18]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [19]:
def fit(model, train_loader, val_loader, epochs, lr, optimizer_function = torch.optim.Adam):
    history = []
    optimizer = optimizer_function(model.parameters(), lr)
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=epochs, steps_per_epoch=len(train_loader))
    for epoch in range(epochs):
        print("Epoch ", epoch)
        #Train
        model.train()
        lrs = []
        tr_loss = []
        for batch in tqdm(train_loader):
            loss = model.training_step(batch)
            tr_loss.append(loss)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            lrs.append(get_lr(optimizer))
            sched.step()
        #Validate
        result = evaluate(model, val_loader)
        result["lrs"] = lrs
        result["train loss"] = torch.stack(tr_loss).mean().item()
 
        print("Last lr: ", lrs[-1]," Train_loss: ", result["train loss"], " Val_loss: ", result['loss'], " Accuracy: ", result['accuracy'])
        history.append(result)         
    return history

In [20]:
model = to_device(SpokenDigitModel(), device)
history = []
evaluate(model, val_dl)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




{'accuracy': 0.11331318318843842, 'loss': 2.9422481060028076}

In [22]:
history.append(fit(model, train_dl, val_dl, 64, 0.01))

Epoch  0


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00046125331558460707  Train_loss:  2.0962071418762207  Val_loss:  1.7959665060043335  Accuracy:  0.3839131295681
Epoch  1


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.000649371685235026  Train_loss:  1.6721926927566528  Val_loss:  1.5381710529327393  Accuracy:  0.4640895426273346
Epoch  2


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0009593538973555506  Train_loss:  1.477335810661316  Val_loss:  1.4922049045562744  Accuracy:  0.4857262074947357
Epoch  3


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0013828987634391283  Train_loss:  1.4045330286026  Val_loss:  1.3956390619277954  Accuracy:  0.5016951560974121
Epoch  4


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0019086639366120376  Train_loss:  1.307019591331482  Val_loss:  1.325338363647461  Accuracy:  0.5412465333938599
Epoch  5


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0025225696547385964  Train_loss:  1.2461926937103271  Val_loss:  1.2245386838912964  Accuracy:  0.5874459147453308
Epoch  6


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0032081757903132157  Train_loss:  1.1579763889312744  Val_loss:  1.1713407039642334  Accuracy:  0.6044811010360718
Epoch  7


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.003947122109908103  Train_loss:  1.0937646627426147  Val_loss:  1.3624060153961182  Accuracy:  0.5438949465751648
Epoch  8


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.004719619953220507  Train_loss:  1.0818008184432983  Val_loss:  1.1536433696746826  Accuracy:  0.6204500794410706
Epoch  9


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0055049821647689574  Train_loss:  1.045858383178711  Val_loss:  1.120049238204956  Accuracy:  0.6202707290649414
Epoch  10


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.006282177086898712  Train_loss:  1.0241625308990479  Val_loss:  1.11703360080719  Accuracy:  0.6232703924179077
Epoch  11


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0070303917784052174  Train_loss:  1.0066732168197632  Val_loss:  1.1656148433685303  Accuracy:  0.6003095507621765
Epoch  12


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.007729589376026378  Train_loss:  0.9818165898323059  Val_loss:  1.0709947347640991  Accuracy:  0.6334463357925415
Epoch  13


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.008361045672905848  Train_loss:  0.9537393450737  Val_loss:  1.1668928861618042  Accuracy:  0.5946098566055298
Epoch  14


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00890785054469004  Train_loss:  0.9436930418014526  Val_loss:  0.9845379590988159  Accuracy:  0.6786851286888123
Epoch  15


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009355360795286486  Train_loss:  0.8967756628990173  Val_loss:  1.148621916770935  Accuracy:  0.6161729097366333
Epoch  16


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009691592295271681  Train_loss:  0.9657419323921204  Val_loss:  1.1835858821868896  Accuracy:  0.6138954162597656
Epoch  17


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009907540911652659  Train_loss:  0.9208858013153076  Val_loss:  1.005327820777893  Accuracy:  0.6741303205490112
Epoch  18


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009997423634623305  Train_loss:  0.9113320708274841  Val_loss:  1.1275980472564697  Accuracy:  0.6135711669921875
Epoch  19


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009992134106552784  Train_loss:  0.8320357799530029  Val_loss:  0.9701324701309204  Accuracy:  0.6864116787910461
Epoch  20


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009960221326213456  Train_loss:  0.7584152817726135  Val_loss:  1.0560153722763062  Accuracy:  0.6812278628349304
Epoch  21


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009903926786310545  Train_loss:  0.7687380909919739  Val_loss:  0.9775892496109009  Accuracy:  0.6935485601425171
Epoch  22


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009823527201405897  Train_loss:  0.7155079245567322  Val_loss:  1.0294524431228638  Accuracy:  0.6616500020027161
Epoch  23


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009719417773875232  Train_loss:  0.7189133167266846  Val_loss:  1.0722204446792603  Accuracy:  0.6818960905075073
Epoch  24


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009592110251299626  Train_loss:  0.6910059452056885  Val_loss:  0.9910828471183777  Accuracy:  0.7036851644515991
Epoch  25


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009442230410981642  Train_loss:  0.6466673612594604  Val_loss:  0.9722771644592285  Accuracy:  0.706240177154541
Epoch  26


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009270514983950873  Train_loss:  0.6433963179588318  Val_loss:  1.0949640274047852  Accuracy:  0.691853404045105
Epoch  27


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.009077808033578922  Train_loss:  0.5948503613471985  Val_loss:  1.0835723876953125  Accuracy:  0.6901459097862244
Epoch  28


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.008865056806604618  Train_loss:  0.5879800915718079  Val_loss:  1.0072591304779053  Accuracy:  0.7080212831497192
Epoch  29


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.008633307076963694  Train_loss:  0.5506159663200378  Val_loss:  1.149868130683899  Accuracy:  0.6802279949188232
Epoch  30


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00838369800531021  Train_loss:  0.5240427851676941  Val_loss:  1.2850494384765625  Accuracy:  0.6902319192886353
Epoch  31


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00811745653949763  Train_loss:  0.5351822972297668  Val_loss:  1.0928733348846436  Accuracy:  0.702400267124176
Epoch  32


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.007835891383543871  Train_loss:  0.5026690363883972  Val_loss:  1.312723159790039  Accuracy:  0.694044828414917
Epoch  33


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.007540386564725742  Train_loss:  0.4993046224117279  Val_loss:  1.2389271259307861  Accuracy:  0.7184158563613892
Epoch  34


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.007232394630423575  Train_loss:  0.42036834359169006  Val_loss:  1.286938190460205  Accuracy:  0.7246192097663879
Epoch  35


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.006913429508156801  Train_loss:  0.3983043432235718  Val_loss:  1.347787618637085  Accuracy:  0.719978392124176
Epoch  36


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.006585059063906823  Train_loss:  0.38067197799682617  Val_loss:  1.080854892730713  Accuracy:  0.7255527377128601
Epoch  37


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.006248897395306571  Train_loss:  0.35344985127449036  Val_loss:  1.2585722208023071  Accuracy:  0.729240357875824
Epoch  38


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0059065968975793946  Train_loss:  0.3225083351135254  Val_loss:  1.268265962600708  Accuracy:  0.723447322845459
Epoch  39


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.005559840141227017  Train_loss:  0.3173750340938568  Val_loss:  1.3798859119415283  Accuracy:  0.7174626588821411
Epoch  40


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.005210331601391554  Train_loss:  0.3014935851097107  Val_loss:  1.3726149797439575  Accuracy:  0.7314785122871399
Epoch  41


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00485978927954578  Train_loss:  0.283986359834671  Val_loss:  1.4610297679901123  Accuracy:  0.7263143658638
Epoch  42


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.004509936258695003  Train_loss:  0.2430132031440735  Val_loss:  1.469723105430603  Accuracy:  0.739548921585083
Epoch  43


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.004162492233600785  Train_loss:  0.2118765115737915  Val_loss:  1.5703235864639282  Accuracy:  0.7378144860267639
Epoch  44


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0038191650576594933  Train_loss:  0.1854405403137207  Val_loss:  1.6355253458023071  Accuracy:  0.7459709048271179
Epoch  45


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.003481642347986829  Train_loss:  0.16453081369400024  Val_loss:  1.7091223001480103  Accuracy:  0.7364902496337891
Epoch  46


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0031515831899733408  Train_loss:  0.1514352560043335  Val_loss:  1.7436630725860596  Accuracy:  0.744560718536377
Epoch  47


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0028306099820869924  Train_loss:  0.13385431468486786  Val_loss:  1.7914905548095703  Accuracy:  0.7359277009963989
Epoch  48


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0025203004610094945  Train_loss:  0.1268480122089386  Val_loss:  1.821528434753418  Accuracy:  0.7440841197967529
Epoch  49


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.002222179946306651  Train_loss:  0.10821071267127991  Val_loss:  1.8669211864471436  Accuracy:  0.7497912049293518
Epoch  50


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0019377138427538756  Train_loss:  0.09666414558887482  Val_loss:  1.9420559406280518  Accuracy:  0.7512873411178589
Epoch  51


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0016683004371715396  Train_loss:  0.0888422355055809  Val_loss:  1.9657294750213623  Accuracy:  0.7541740536689758
Epoch  52


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0014152640251771198  Train_loss:  0.07715751230716705  Val_loss:  2.0424296855926514  Accuracy:  0.7541740536689758
Epoch  53


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.001179848401639479  Train_loss:  0.070724718272686  Val_loss:  2.082763195037842  Accuracy:  0.7553655505180359
Epoch  54


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.000963210746832791  Train_loss:  0.06397011876106262  Val_loss:  2.1146492958068848  Accuracy:  0.7591858506202698
Epoch  55


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0007664159383425639  Train_loss:  0.059719499200582504  Val_loss:  2.153053045272827  Accuracy:  0.7542600035667419
Epoch  56


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0005904313166835123  Train_loss:  0.0573914498090744  Val_loss:  2.1731905937194824  Accuracy:  0.7544123530387878
Epoch  57


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0004361219303587067  Train_loss:  0.05467969924211502  Val_loss:  2.189188003540039  Accuracy:  0.7542600035667419
Epoch  58


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.0003042462837328597  Train_loss:  0.052888400852680206  Val_loss:  2.199924945831299  Accuracy:  0.7527638673782349
Epoch  59


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00019545260862094147  Train_loss:  0.05168016627430916  Val_loss:  2.2133004665374756  Accuracy:  0.7537834048271179
Epoch  60


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  0.00011027567791908655  Train_loss:  0.05067480728030205  Val_loss:  2.2165989875793457  Accuracy:  0.7553459405899048
Epoch  61


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  4.913417694027851e-05  Train_loss:  0.05024070292711258  Val_loss:  2.2178587913513184  Accuracy:  0.7569084167480469
Epoch  62


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  1.232864537599174e-05  Train_loss:  0.05011792480945587  Val_loss:  2.2188334465026855  Accuracy:  0.7554982304573059
Epoch  63


HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Last lr:  4e-08  Train_loss:  0.04999082535505295  Val_loss:  2.2191309928894043  Accuracy:  0.7554982304573059


In [23]:
@torch.no_grad()
def predict_dl(model, dl):
    torch.cuda.empty_cache()
    batch_probs = []
    batch_targ = []
    for xb, yb in tqdm(dl):
        probs = model(xb)
        batch_probs.append(probs.cpu().detach())
        batch_targ.append(yb.cpu().detach())
    batch_probs = torch.cat(batch_probs)
    batch_targ = torch.cat(batch_targ)
    return [list(values).index(max(values)) for values in batch_probs], batch_targ

In [26]:
r = evaluate(model, val_dl)
yp, yt = predict_dl(model, val_dl)
print("Loss: ", r['loss'], "\nAccuracy: ", r['accuracy'], "\nF-score: ", f1_score(yt, yp, average='micro'))

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Loss:  2.2191309928894043 
Accuracy:  0.7554982304573059 
F-score:  0.7540152155536771


In [27]:
torch.save(model, '/content/spokendigit_lr_all.pth')

  "type " + obj.__name__ + ". It won't be checked "
