In [None]:
from __future__ import print_function
import argparse
import numpy  as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms

from data_loaders import Plain_Dataset, eval_data_dataloader
from deep_emotion import Deep_Emotion
from generate_data import Generate_data

device = torch.device("cuda")

if not torch.cuda.is_available():
    raise Exception("CUDA is not available. Make sure you have a CUDA-enabled GPU.")

def Train(epochs,train_loader,val_loader,criterion,optmizer,device):
    '''
    Training Loop
    '''
    print("===================================Start Training===================================")
    for e in range(epochs):
        train_loss = 0
        validation_loss = 0
        train_correct = 0
        val_correct = 0
        # Train the model  #
        net.train()
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optmizer.zero_grad()
            outputs = net(data)
            loss = criterion(outputs,labels)
            loss.backward()
            optmizer.step()
            train_loss += loss.item()
            _, preds = torch.max(outputs,1)
            train_correct += torch.sum(preds == labels.data)

        #validate the model#
        net.eval()
        for data,labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            val_outputs = net(data)
            val_loss = criterion(val_outputs, labels)
            validation_loss += val_loss.item()
            _, val_preds = torch.max(val_outputs,1)
            val_correct += torch.sum(val_preds == labels.data)

        train_loss = train_loss/len(train_dataset)
        train_acc = train_correct.double() / len(train_dataset)
        validation_loss =  validation_loss / len(validation_dataset)
        val_acc = val_correct.double() / len(validation_dataset)
        print('Epoch: {} \tTraining Loss: {:.8f} \tValidation Loss {:.8f} \tTraining Acuuarcy {:.3f}% \tValidation Acuuarcy {:.3f}%'
                                                           .format(e+1, train_loss,validation_loss,train_acc * 100, val_acc*100))

    torch.save(net.state_dict(),'deep_emotion-{}-{}-{}.pt'.format(epochs,batchsize,lr))
    print("===================================Training Finished===================================")


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser(description="Configuration of setup and training process")
#     parser.add_argument('-s', '--setup', type=bool, help='setup the dataset for the first time')
#     parser.add_argument('-d', '--data', type=str,required= True,
#                                help='data folder that contains data files that downloaded from kaggle (train.csv and test.csv)')
#     parser.add_argument('-hparams', '--hyperparams', type=bool,
#                                help='True when changing the hyperparameters e.g (batch size, LR, num. of epochs)')
#     parser.add_argument('-e', '--epochs', type= int, help= 'number of epochs')
#     parser.add_argument('-lr', '--learning_rate', type= float, help= 'value of learning rate')
#     parser.add_argument('-bs', '--batch_size', type= int, help= 'training/validation batch size')
#     parser.add_argument('-t', '--train', type=bool, help='True when training')
#     args = parser.parse_args()

#     if args.setup :
#         generate_dataset = Generate_data(args.data)
#         generate_dataset.split_test()
#         generate_dataset.save_images('train')
#         generate_dataset.save_images('test')
#         generate_dataset.save_images('val')

#     if args.hyperparams:
#         epochs = args.epochs
#         lr = args.learning_rate
#         batchsize = args.batch_size
#     else :
epochs = 3000
lr = 0.001
batchsize = 128

#     if args.train:
net = Deep_Emotion()
net.to(device)
print("Model archticture: ", net)
traincsv_file = 'data'+'/'+'train.csv'
validationcsv_file = 'data'+'/'+'val.csv'
train_img_dir = 'data'+'/'+'train/'
validation_img_dir = 'data'+'/'+'val/'

transformation= transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_dataset= Plain_Dataset(csv_file=traincsv_file, img_dir = train_img_dir, datatype = 'train', transform = transformation)
validation_dataset= Plain_Dataset(csv_file=validationcsv_file, img_dir = validation_img_dir, datatype = 'val', transform = transformation)
train_loader= DataLoader(train_dataset,batch_size=batchsize,shuffle = True,num_workers=0)
val_loader=   DataLoader(validation_dataset,batch_size=batchsize,shuffle = True,num_workers=0)

criterion= nn.CrossEntropyLoss()
optmizer= optim.Adam(net.parameters(),lr= lr)
Train(epochs, train_loader, val_loader, criterion, optmizer, device)

Model archticture:  Deep_Emotion(
  (conv1): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (norm): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=810, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=7, bias=True)
  (localization): Sequential(
    (0): Conv2d(1, 8, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)




Epoch: 1 	Training Loss: 0.01434239 	Validation Loss 0.01477167 	Training Acuuarcy 23.693% 	Validation Acuuarcy 24.854%
Epoch: 2 	Training Loss: 0.01426444 	Validation Loss 0.01477285 	Training Acuuarcy 24.813% 	Validation Acuuarcy 24.742%
Epoch: 3 	Training Loss: 0.01423561 	Validation Loss 0.01485570 	Training Acuuarcy 24.914% 	Validation Acuuarcy 24.909%
Epoch: 4 	Training Loss: 0.01422313 	Validation Loss 0.01484282 	Training Acuuarcy 24.964% 	Validation Acuuarcy 25.132%
Epoch: 5 	Training Loss: 0.01422000 	Validation Loss 0.01476074 	Training Acuuarcy 24.914% 	Validation Acuuarcy 25.021%
Epoch: 6 	Training Loss: 0.01420103 	Validation Loss 0.01485341 	Training Acuuarcy 25.008% 	Validation Acuuarcy 25.160%
Epoch: 7 	Training Loss: 0.01419886 	Validation Loss 0.01477782 	Training Acuuarcy 25.064% 	Validation Acuuarcy 24.854%
Epoch: 8 	Training Loss: 0.01419717 	Validation Loss 0.01482774 	Training Acuuarcy 25.075% 	Validation Acuuarcy 25.160%
Epoch: 9 	Training Loss: 0.01418847 	Val

Epoch: 69 	Training Loss: 0.01288929 	Validation Loss 0.01615039 	Training Acuuarcy 33.709% 	Validation Acuuarcy 21.092%
Epoch: 70 	Training Loss: 0.01284300 	Validation Loss 0.01616588 	Training Acuuarcy 33.753% 	Validation Acuuarcy 19.588%
Epoch: 71 	Training Loss: 0.01283347 	Validation Loss 0.01630144 	Training Acuuarcy 33.959% 	Validation Acuuarcy 20.201%
Epoch: 72 	Training Loss: 0.01283085 	Validation Loss 0.01619857 	Training Acuuarcy 34.049% 	Validation Acuuarcy 19.588%
Epoch: 73 	Training Loss: 0.01278778 	Validation Loss 0.01646622 	Training Acuuarcy 34.361% 	Validation Acuuarcy 19.643%
Epoch: 74 	Training Loss: 0.01275657 	Validation Loss 0.01623159 	Training Acuuarcy 34.233% 	Validation Acuuarcy 19.560%
Epoch: 75 	Training Loss: 0.01278171 	Validation Loss 0.01623350 	Training Acuuarcy 34.054% 	Validation Acuuarcy 20.061%
Epoch: 76 	Training Loss: 0.01277722 	Validation Loss 0.01627964 	Training Acuuarcy 34.188% 	Validation Acuuarcy 20.869%
Epoch: 77 	Training Loss: 0.0127

Epoch: 137 	Training Loss: 0.01189171 	Validation Loss 0.01777796 	Training Acuuarcy 40.241% 	Validation Acuuarcy 19.058%
Epoch: 138 	Training Loss: 0.01194806 	Validation Loss 0.01760658 	Training Acuuarcy 39.851% 	Validation Acuuarcy 19.309%
Epoch: 139 	Training Loss: 0.01195893 	Validation Loss 0.01770244 	Training Acuuarcy 39.410% 	Validation Acuuarcy 19.030%
Epoch: 140 	Training Loss: 0.01189453 	Validation Loss 0.01755328 	Training Acuuarcy 40.269% 	Validation Acuuarcy 19.058%
Epoch: 141 	Training Loss: 0.01192066 	Validation Loss 0.01743640 	Training Acuuarcy 40.007% 	Validation Acuuarcy 19.699%
Epoch: 142 	Training Loss: 0.01193663 	Validation Loss 0.01788429 	Training Acuuarcy 40.018% 	Validation Acuuarcy 18.417%
Epoch: 143 	Training Loss: 0.01188673 	Validation Loss 0.01768414 	Training Acuuarcy 40.670% 	Validation Acuuarcy 19.225%
Epoch: 144 	Training Loss: 0.01190130 	Validation Loss 0.01765370 	Training Acuuarcy 40.046% 	Validation Acuuarcy 19.365%
Epoch: 145 	Training Los

Epoch: 205 	Training Loss: 0.01157606 	Validation Loss 0.01832162 	Training Acuuarcy 42.855% 	Validation Acuuarcy 19.337%
Epoch: 206 	Training Loss: 0.01157484 	Validation Loss 0.01790762 	Training Acuuarcy 42.403% 	Validation Acuuarcy 19.420%
Epoch: 207 	Training Loss: 0.01159391 	Validation Loss 0.01799803 	Training Acuuarcy 41.924% 	Validation Acuuarcy 19.727%
Epoch: 208 	Training Loss: 0.01146103 	Validation Loss 0.01790224 	Training Acuuarcy 43.178% 	Validation Acuuarcy 18.975%
Epoch: 209 	Training Loss: 0.01153431 	Validation Loss 0.01837702 	Training Acuuarcy 42.342% 	Validation Acuuarcy 19.615%
Epoch: 210 	Training Loss: 0.01156236 	Validation Loss 0.01804131 	Training Acuuarcy 42.147% 	Validation Acuuarcy 19.030%
Epoch: 211 	Training Loss: 0.01143129 	Validation Loss 0.01818089 	Training Acuuarcy 43.301% 	Validation Acuuarcy 18.501%
Epoch: 212 	Training Loss: 0.01151163 	Validation Loss 0.01784898 	Training Acuuarcy 42.676% 	Validation Acuuarcy 20.089%
Epoch: 213 	Training Los

Epoch: 273 	Training Loss: 0.01136142 	Validation Loss 0.01823808 	Training Acuuarcy 43.763% 	Validation Acuuarcy 18.501%
Epoch: 274 	Training Loss: 0.01126401 	Validation Loss 0.01851093 	Training Acuuarcy 44.276% 	Validation Acuuarcy 18.752%
Epoch: 275 	Training Loss: 0.01123280 	Validation Loss 0.01900932 	Training Acuuarcy 44.605% 	Validation Acuuarcy 18.334%
Epoch: 276 	Training Loss: 0.01128146 	Validation Loss 0.01865201 	Training Acuuarcy 44.276% 	Validation Acuuarcy 18.362%
Epoch: 277 	Training Loss: 0.01127232 	Validation Loss 0.01877111 	Training Acuuarcy 44.142% 	Validation Acuuarcy 18.557%
Epoch: 278 	Training Loss: 0.01130625 	Validation Loss 0.01816634 	Training Acuuarcy 43.797% 	Validation Acuuarcy 19.337%
Epoch: 279 	Training Loss: 0.01124650 	Validation Loss 0.01865871 	Training Acuuarcy 44.287% 	Validation Acuuarcy 18.947%
Epoch: 280 	Training Loss: 0.01124773 	Validation Loss 0.01872406 	Training Acuuarcy 44.471% 	Validation Acuuarcy 17.721%
Epoch: 281 	Training Los

Epoch: 341 	Training Loss: 0.01117300 	Validation Loss 0.01867572 	Training Acuuarcy 45.179% 	Validation Acuuarcy 19.337%
Epoch: 342 	Training Loss: 0.01110235 	Validation Loss 0.01878890 	Training Acuuarcy 45.357% 	Validation Acuuarcy 18.919%
Epoch: 343 	Training Loss: 0.01119114 	Validation Loss 0.01849151 	Training Acuuarcy 44.789% 	Validation Acuuarcy 19.142%
Epoch: 344 	Training Loss: 0.01108677 	Validation Loss 0.01835387 	Training Acuuarcy 45.140% 	Validation Acuuarcy 19.030%
Epoch: 345 	Training Loss: 0.01108973 	Validation Loss 0.01908164 	Training Acuuarcy 45.285% 	Validation Acuuarcy 17.888%
Epoch: 346 	Training Loss: 0.01106288 	Validation Loss 0.01900851 	Training Acuuarcy 45.853% 	Validation Acuuarcy 18.278%
Epoch: 347 	Training Loss: 0.01109544 	Validation Loss 0.01897831 	Training Acuuarcy 44.727% 	Validation Acuuarcy 18.055%
Epoch: 348 	Training Loss: 0.01116050 	Validation Loss 0.01918055 	Training Acuuarcy 44.555% 	Validation Acuuarcy 17.721%
Epoch: 349 	Training Los

Epoch: 409 	Training Loss: 0.01099192 	Validation Loss 0.01855077 	Training Acuuarcy 45.892% 	Validation Acuuarcy 17.749%
Epoch: 410 	Training Loss: 0.01096107 	Validation Loss 0.01917923 	Training Acuuarcy 45.898% 	Validation Acuuarcy 18.417%
Epoch: 411 	Training Loss: 0.01095473 	Validation Loss 0.01937273 	Training Acuuarcy 46.082% 	Validation Acuuarcy 18.250%
Epoch: 412 	Training Loss: 0.01096551 	Validation Loss 0.01906049 	Training Acuuarcy 46.082% 	Validation Acuuarcy 19.225%
Epoch: 413 	Training Loss: 0.01096688 	Validation Loss 0.01908977 	Training Acuuarcy 45.820% 	Validation Acuuarcy 18.278%
Epoch: 414 	Training Loss: 0.01091595 	Validation Loss 0.01907256 	Training Acuuarcy 46.009% 	Validation Acuuarcy 18.807%
Epoch: 415 	Training Loss: 0.01100245 	Validation Loss 0.01923163 	Training Acuuarcy 45.675% 	Validation Acuuarcy 19.170%
Epoch: 416 	Training Loss: 0.01099315 	Validation Loss 0.01907301 	Training Acuuarcy 45.513% 	Validation Acuuarcy 18.055%
Epoch: 417 	Training Los

Epoch: 477 	Training Loss: 0.01087277 	Validation Loss 0.01899285 	Training Acuuarcy 46.411% 	Validation Acuuarcy 18.334%
Epoch: 478 	Training Loss: 0.01088640 	Validation Loss 0.01914006 	Training Acuuarcy 45.959% 	Validation Acuuarcy 18.417%
Epoch: 479 	Training Loss: 0.01087552 	Validation Loss 0.01890095 	Training Acuuarcy 46.333% 	Validation Acuuarcy 19.615%
Epoch: 480 	Training Loss: 0.01091508 	Validation Loss 0.01897337 	Training Acuuarcy 45.954% 	Validation Acuuarcy 18.696%
Epoch: 481 	Training Loss: 0.01087423 	Validation Loss 0.01921058 	Training Acuuarcy 46.132% 	Validation Acuuarcy 19.225%
Epoch: 482 	Training Loss: 0.01087821 	Validation Loss 0.01910938 	Training Acuuarcy 46.355% 	Validation Acuuarcy 18.557%
Epoch: 483 	Training Loss: 0.01083295 	Validation Loss 0.01932383 	Training Acuuarcy 46.544% 	Validation Acuuarcy 19.198%
Epoch: 484 	Training Loss: 0.01087344 	Validation Loss 0.01921573 	Training Acuuarcy 46.528% 	Validation Acuuarcy 18.696%
Epoch: 485 	Training Los

Epoch: 545 	Training Loss: 0.01076961 	Validation Loss 0.01955666 	Training Acuuarcy 47.174% 	Validation Acuuarcy 19.365%
Epoch: 546 	Training Loss: 0.01084788 	Validation Loss 0.01912453 	Training Acuuarcy 46.884% 	Validation Acuuarcy 17.080%
Epoch: 547 	Training Loss: 0.01076193 	Validation Loss 0.01941112 	Training Acuuarcy 47.503% 	Validation Acuuarcy 18.167%
Epoch: 548 	Training Loss: 0.01079662 	Validation Loss 0.01932903 	Training Acuuarcy 47.297% 	Validation Acuuarcy 19.365%
Epoch: 549 	Training Loss: 0.01084393 	Validation Loss 0.01921526 	Training Acuuarcy 46.784% 	Validation Acuuarcy 18.612%
Epoch: 550 	Training Loss: 0.01078325 	Validation Loss 0.01899786 	Training Acuuarcy 46.940% 	Validation Acuuarcy 19.142%
Epoch: 551 	Training Loss: 0.01081010 	Validation Loss 0.01933681 	Training Acuuarcy 46.717% 	Validation Acuuarcy 18.947%
Epoch: 552 	Training Loss: 0.01084779 	Validation Loss 0.01938087 	Training Acuuarcy 46.974% 	Validation Acuuarcy 18.222%
Epoch: 553 	Training Los

Epoch: 613 	Training Loss: 0.01076108 	Validation Loss 0.01906734 	Training Acuuarcy 47.269% 	Validation Acuuarcy 18.612%
Epoch: 614 	Training Loss: 0.01069507 	Validation Loss 0.01939937 	Training Acuuarcy 47.615% 	Validation Acuuarcy 19.114%
Epoch: 615 	Training Loss: 0.01074278 	Validation Loss 0.01963635 	Training Acuuarcy 47.174% 	Validation Acuuarcy 18.139%
Epoch: 616 	Training Loss: 0.01077451 	Validation Loss 0.01918268 	Training Acuuarcy 46.901% 	Validation Acuuarcy 18.612%
Epoch: 617 	Training Loss: 0.01069868 	Validation Loss 0.01953503 	Training Acuuarcy 47.280% 	Validation Acuuarcy 18.027%
Epoch: 618 	Training Loss: 0.01079869 	Validation Loss 0.01935827 	Training Acuuarcy 47.113% 	Validation Acuuarcy 18.947%
Epoch: 619 	Training Loss: 0.01076704 	Validation Loss 0.01935539 	Training Acuuarcy 47.163% 	Validation Acuuarcy 18.780%
Epoch: 620 	Training Loss: 0.01071491 	Validation Loss 0.01949786 	Training Acuuarcy 47.447% 	Validation Acuuarcy 17.442%
Epoch: 621 	Training Los

Epoch: 681 	Training Loss: 0.01069572 	Validation Loss 0.01916050 	Training Acuuarcy 47.185% 	Validation Acuuarcy 18.780%
Epoch: 682 	Training Loss: 0.01066612 	Validation Loss 0.01950778 	Training Acuuarcy 47.620% 	Validation Acuuarcy 18.167%
Epoch: 683 	Training Loss: 0.01067473 	Validation Loss 0.01949719 	Training Acuuarcy 47.709% 	Validation Acuuarcy 18.640%
Epoch: 684 	Training Loss: 0.01067209 	Validation Loss 0.01973756 	Training Acuuarcy 48.300% 	Validation Acuuarcy 17.804%
Epoch: 685 	Training Loss: 0.01071301 	Validation Loss 0.01981688 	Training Acuuarcy 47.269% 	Validation Acuuarcy 17.888%
Epoch: 686 	Training Loss: 0.01059332 	Validation Loss 0.01934944 	Training Acuuarcy 48.222% 	Validation Acuuarcy 18.306%
Epoch: 687 	Training Loss: 0.01063787 	Validation Loss 0.01963438 	Training Acuuarcy 48.066% 	Validation Acuuarcy 17.860%
Epoch: 688 	Training Loss: 0.01078285 	Validation Loss 0.01941584 	Training Acuuarcy 46.533% 	Validation Acuuarcy 18.306%
Epoch: 689 	Training Los

Epoch: 749 	Training Loss: 0.01062721 	Validation Loss 0.01923768 	Training Acuuarcy 47.837% 	Validation Acuuarcy 18.250%
Epoch: 750 	Training Loss: 0.01056410 	Validation Loss 0.01958048 	Training Acuuarcy 48.228% 	Validation Acuuarcy 18.111%
Epoch: 751 	Training Loss: 0.01068697 	Validation Loss 0.01928557 	Training Acuuarcy 47.943% 	Validation Acuuarcy 17.888%
Epoch: 752 	Training Loss: 0.01056072 	Validation Loss 0.02011394 	Training Acuuarcy 48.395% 	Validation Acuuarcy 17.080%
Epoch: 753 	Training Loss: 0.01059985 	Validation Loss 0.01930442 	Training Acuuarcy 47.893% 	Validation Acuuarcy 18.863%
Epoch: 754 	Training Loss: 0.01067402 	Validation Loss 0.01965784 	Training Acuuarcy 47.726% 	Validation Acuuarcy 17.832%
Epoch: 755 	Training Loss: 0.01067223 	Validation Loss 0.01940678 	Training Acuuarcy 47.654% 	Validation Acuuarcy 18.640%
Epoch: 756 	Training Loss: 0.01068248 	Validation Loss 0.01956335 	Training Acuuarcy 47.369% 	Validation Acuuarcy 17.609%
Epoch: 757 	Training Los

Epoch: 817 	Training Loss: 0.01052277 	Validation Loss 0.01963200 	Training Acuuarcy 48.439% 	Validation Acuuarcy 17.832%
Epoch: 818 	Training Loss: 0.01060648 	Validation Loss 0.01954505 	Training Acuuarcy 47.609% 	Validation Acuuarcy 18.780%
Epoch: 819 	Training Loss: 0.01059892 	Validation Loss 0.01941986 	Training Acuuarcy 48.674% 	Validation Acuuarcy 17.916%


In [None]:
torch.save(net.state_dict(), 'Speaktrum_by_SOVA.pt')