In [3]:
# Load Libraries 

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.metrics import f1_score, precision_score, recall_score

# Data Preprocessing Pipeline

1. Train-Test Split 
2. Data Preprocessing (change to black and white, getting the facial shape etc) 
3. Data Augmentation etc ..

## Convert AFF-Wild2 videos into images
The focus is to convert selected frames of the video into images when there is a change in emotional label

In [None]:
def videoToImage(fileName):
    label_data_dir = os.getcwd()+"\\datasets\\AFF Wild\\annotations\\Train_Set\\"
    vid_data_dir = os.getcwd()+"\\datasets\\AFF Wild\\videos\\Train_Set\\"
    img_data_dir = os.getcwd()+"\\datasets\\AFF Wild\\images\\"

    file = open(label_data_dir+fileName, "r")
    imgLabels = []
    lastValue = 10
    # Get labels from .txt files. Specifically focusing on getting the frames of video where label changes
    for x in file:
        if(x[0]!='N' and x[0]!='-'):
            currentValue = int(x[0])
            if lastValue==10 or lastValue!=currentValue:
                imgLabels.append({'frame': count, 'label': currentValue})
                lastValue = currentValue
                count+=1

    # Get specific frames and store as images
    images = []
    cap = cv2.VideoCapture(vid_data_dir+fileName[:-4]+".mp4")
    for x in imgLabels:
        cap.set(1, x['frame'])
        ret, frame = cap.read()
        if(ret):
            images.append(frame)

    # Store images into a separate folder
    labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
    count = 0
    for x in images:
        img = Image.fromarray(x, 'RGB')
        img.save(img_data_dir+labels[imgLabels[count]['label']]+'\\'+fileName[:-4]+'-'+str(count)+'.png')
        count+=1

In [None]:
# Create necessary directories (comment out if already created)
os.mkdir(os.getcwd()+"\\datasets\\AFF Wild\\images")
labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
for label in labels:
    os.mkdir(os.getcwd()+"\\datasets\\AFF Wild\\images\\"+label)

data_dir = os.getcwd() + "\\datasets\\AFF Wild\\annotations\\Train_Set\\"
files = os.listdir(data_dir)
for file in files:
    videoToImage(file)

## Data Augmentation

### Creating custom transformers
This transformer uses the haarcascade classifier to identify facial features, and crop the image to only the facial features <br/>
This transformer also converts the image from RGB to a 3 channel grayscale image

In [None]:
class CustomTransform(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        
    def __call__(self, img):
        opencvImage = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        gray = cv2.cvtColor(opencvImage, cv2.COLOR_BGR2GRAY)
        tripleGray = np.stack((gray,)*3, axis=-1)
        faceCascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
        faces = faceCascade.detectMultiScale(
                tripleGray,
                scaleFactor=1.3,
                minNeighbors=3,
                minSize=(30, 30)
        )
        for (x, y, w, h) in faces:
            if len(faces)==1:
                tripleGray[y:y+h, x:x+w]
        if isinstance(self.output_size, int):
            resized = cv2.resize(tripleGray, (self.output_size, self.output_size))
        if isinstance(self.output_size, tuple):
            resized = cv2.resize(tripleGray, self.output_size)
        return resized

## Dataset Creation
There are 3 variations of the balanced datasets created here, the undersampled dataset, oversampled dataset and progressive dataset <br/>
The balanced datasets are then transformed and converted into tensor files

## Undersampled dataset
Create a balanced dataset by sampling image data to match the lowest label<br/>
Final tensor files are then stored together to form the dataset

### CK+

In [None]:
datadir = os.getcwd()+"\\datasets\\CK+\\sorted_dataset\\"

labels = ['anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
train_size =[]


for x in range(6):
    files = os.listdir(datadir+labels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.8))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    for file in files:
        if count in valSample:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\undersampled_dataset\\val\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\undersampled_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1

minLabel = min(train_size)

for x in range(6):
    count=0
    files = os.listdir(datadir+"\\undersampled_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=minLabel)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"\\CK_undersampled_dataset\\train\\"+labels[x]+"\\"+file, datadir+"\\CK_undersampled_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\CK_undersampled_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val']}

count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1
    
count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1

### FER2013

In [None]:
data_dir = "\\datasets\\FER2013\\train\\"
labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
FERlabels = ['neutral', 'angry', 'disgust', 'fear', 'happy', 'sad', 'surprise']
train_size=[]

for x in range(7):
    files = os.listdir(datadir+"train\\"+FERlabels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.8))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    
    for file in files:
        if count in valSample:
            copyfile(datadir+"train\\"+FERlabels[x]+"\\"+file, datadir+"\\FER_undersampled_dataset\\val\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+"train\\"+FERlabels[x]+"\\"+file, datadir+"\\FER_undersampled_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1
        
minLabel = min(train_size)

for x in range(7):
    count=0
    files = os.listdir(datadir+"\\FER_undersampled_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=minLabel)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"\\FER_undersampled_dataset\\train\\"+labels[x]+"\\"+file, datadir+"\\FER_undersampled_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\FER2013\\FER_undersampled_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val']}

count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1
    
count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1

### AffectNet

In [None]:
# Sort AffectNet images into folders of respective emotion labels
source_data_dir = os.getcwd() + "\\datasets\\AffectNet\\Automatically_Annotated_compressed\\"
target_data_dir = os.getcwd() + "\\datasets\\AffectNet_sorted\\"
df = pd.read_csv(source_data_dir+"automatically_annotated.csv")
for index, row in tqdm(df.iterrows()):

    imageName = row['subDirectory_filePath'].split("/")
    if row['expression']==0:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\neutral\\"+imageName[1])
    if row['expression']==1:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\happy\\"+imageName[1])
    if row['expression']==2:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\sadness\\"+imageName[1])
    if row['expression']==3:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\suprise\\"+imageName[1])
    if row['expression']==4:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\fear\\"+imageName[1])
    if row['expression']==5:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\disgust\\"+imageName[1])
    if row['expression']==6:
        copyfile(source_data_dir+"\\Automatically_Annotated\\Automatically_Annotated_images\\"+row['subDirectory_filePath'], target_data_dir+"\\anger\\"+imageName[1])

In [None]:
datadir = os.getcwd()+"\\datasets\\AffectNet_sorted\\"

labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
train_size=[]

for x in range(7):
    files = os.listdir(datadir+labels[x])
    count = 0
    trainCount = 0
    valCount=0
    maxInput = 0
    # Hardcoded minimum size of label
    if len(files)>890:
        maxInput = 890
    else:
        maxInput = len(files)
    train_size.append(math.floor(maxInput*0.8))
    valSample = np.random.choice(len(files), maxInput-train_size[x], replace=False)
    
    for file in files:
        if valCount>=len(valSample) and trainCount>=train_size[x]:
            break
        if count in valSample and valCount<len(valSample):
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\AffectNet_undersampled_dataset\\val\\"+labels[x]+"\\"+file)
            valCount+=1
        if trainCount<train_size[x]:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\AffectNet_undersampled_dataset\\train\\"+labels[x]+"\\"+file)
            trainCount+=1
        count+=1

In [None]:
data_transforms = {
    'train': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\AffectNet_sorted\\AffectNet_undersampled_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1
    
count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1

### Aff-Wild2

In [None]:
datadir = os.getcwd()+"\\datasets\\AFF Wild\\"

train_size=[]

for x in range(7):
    files = os.listdir(datadir+"images\\"+labels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.6))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    
    for file in files:
        if count in valSample:
            if count%2==0:
                copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"Aff_undersampled_dataset\\val\\"+labels[x]+"\\"+file)
            else:
                copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"Aff_undersampled_dataset\\test\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"Aff_undersampled_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1
        
minLabel = min(train_size)

for x in range(7):
    count=0
    files = os.listdir(datadir+"balanced_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=minLabel)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"Aff_undersampled_dataset\\train\\"+labels[x]+"\\"+file, datadir+"Aff_undersampled_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'test': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\AFF Wild\\Aff_undersampled_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val', 'test']}

count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1
    
count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1
    
count=0
data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\undersampled_dataset_tensors\\test\\"
for tensor in tqdm(image_datasets['test']):
    torch.save(tensor, data_dir+str(count)+'.pt')
    count+=1

### Oversampled dataset/Progressive dataset
Create a balanced dataset by sampling image data to match the median label<br/>
Final tensor files are either stored together (oversampled dataset) or separately (progressive dataset)

### CK+

In [None]:
datadir = os.getcwd()+"\\datasets\\CK+\\sorted_dataset\\"

labels = ['anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
train_size =[]


for x in range(6):
    files = os.listdir(datadir+labels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.8))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    for file in files:
        if count in valSample:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\CK_balanced_dataset\\val\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\CK_balanced_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1

median = math.ceil(statistics.median(train_size))
print("Median: "+str(median))

for x in range(6):
    count=0
    files = os.listdir(datadir+"\\CK_balanced_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=median)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"\\CK_balanced_dataset\\train\\"+labels[x]+"\\"+file, datadir+"\\CK_balanced_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\CK+\\sorted_dataset\\CK_balanced_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val']}

count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\CK_balanced_dataset_tensors\\val\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1
    
count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\CK_balanced_dataset_tensors\\train\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1

### FER2013
FER2013 has slightly different naming conventions, so slight adjustments need to be made

In [None]:
data_dir = "\\datasets\\FER2013\\train\\"
labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
FERlabels = ['neutral', 'angry', 'disgust', 'fear', 'happy', 'sad', 'surprise']
train_size=[]

for x in range(7):
    files = os.listdir(datadir+"train\\"+FERlabels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.8))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    
    for file in files:
        if count in valSample:
            copyfile(datadir+"train\\"+FERlabels[x]+"\\"+file, datadir+"\\FER_balanced_dataset\\val\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+"train\\"+FERlabels[x]+"\\"+file, datadir+"\\FER_balanced_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1
        
median = math.ceil(statistics.median(train_size))
print("Median: "+str(median))

for x in range(7):
    count=0
    files = os.listdir(datadir+"\\FER_balanced_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=median)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"\\FER_balanced_dataset\\train\\"+labels[x]+"\\"+file, datadir+"\\FER_balanced_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\FER2013\\FER_balanced_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val']}

count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\FER_balanced_dataset_tensors\\val\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1
    
count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\FER_balanced_dataset_tensors\\train\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1

### AffectNet

In [None]:
datadir = os.getcwd()+"\\datasets\\AffectNet_sorted\\"

labels = ['neutral', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'suprise']
train_size=[]

for x in range(7):
    files = os.listdir(datadir+labels[x])
    count = 0
    trainCount = 0
    valCount=0
    maxInput = 0
    # Hardcoded median size of label
    if len(files)>20854:
        maxInput = 20854
    else:
        maxInput = len(files)
    train_size.append(math.floor(maxInput*0.8))
    valSample = np.random.choice(len(files), maxInput-train_size[x], replace=False)
    
    for file in files:
        if valCount>=len(valSample) and trainCount>=train_size[x]:
            break
        if count in valSample and valCount<len(valSample):
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\AffectNet_balanced_dataset\\val\\"+labels[x]+"\\"+file)
            valCount+=1
        if trainCount<train_size[x]:
            copyfile(datadir+labels[x]+"\\"+file, datadir+"\\AffectNet_balanced_dataset\\train\\"+labels[x]+"\\"+file)
            trainCount+=1
        count+=1

In [None]:
data_transforms = {
    'train': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\AffectNet_sorted\\AffectNet_balanced_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\AffectNet_balanced_dataset_tensors\\val\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1
    
count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\AffectNet_balanced_dataset_tensors\\train\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1

### AFF-Wild2

In [None]:
datadir = os.getcwd()+"\\datasets\\AFF Wild\\"

train_size=[]

for x in range(7):
    files = os.listdir(datadir+"images\\"+labels[x])
    count = 0
    train_size.append(math.floor(len(files)*0.6))
    valSample = np.random.choice(len(files), len(files)-train_size[x], replace=False)
    
    for file in files:
        if count in valSample:
            if count%2==0:
                copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"AFF_balanced_dataset\\val\\"+labels[x]+"\\"+file)
            else:
                copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"AFF_balanced_dataset\\test\\"+labels[x]+"\\"+file)
        else:
            copyfile(datadir+"images\\"+labels[x]+"\\"+file, datadir+"AFF_balanced_dataset\\train\\"+labels[x]+"\\"+file)
        count+=1
        
median = math.ceil(statistics.median(train_size))
print("Median: "+str(median))

for x in range(7):
    count=0
    files = os.listdir(datadir+"AFF_balanced_dataset\\train\\"+labels[x])
    trainSample = np.random.randint(low=0, high=train_size[x], size=median)
    for file in files:
        freq = np.count_nonzero(trainSample == count)
        if freq>0:
            for i in range(freq):
                copyfile(datadir+"AFF_balanced_dataset\\train\\"+labels[x]+"\\"+file, datadir+"AFF_balanced_dataset\\train_sampled\\"+labels[x]+"\\"+str(i)+file)
        count+=1

In [None]:
data_transforms = {
    'train_sampled': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.2),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'test': transforms.Compose([
        CustomTransform(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}

data_dir = os.getcwd() + "\\datasets\\AFF Wild\\AFF_balanced_dataset\\"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train_sampled', 'val', 'test']}

count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\AFF_balanced_dataset_tensors\\val\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\val\\"
for tensor in tqdm(image_datasets['val']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1
    
count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\AFF_balanced_dataset_tensors\\train\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\train\\"
for tensor in tqdm(image_datasets['train_sampled']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1
    
count=0
progressive_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\AFF_balanced_dataset_tensors\\test\\"
oversampled_data_dir = os.getcwd() + "\\datasets\\sorted_dataset\\combined_balanced_dataset_tensors\\test\\"
for tensor in tqdm(image_datasets['test']):
    torch.save(tensor, progressive_data_dir+str(count)+'.pt')
    torch.save(tensor, oversampled_data_dir+str(count)+'.pt')
    count+=1

# Training Pipeline

To Do: 

1. Add in F1 Metrics 
2. Figure out how to add tensorboard to get training diagnostics 
3. Progressive is implemented but need to check

In [3]:
# ResNet-50 trained from scratch
class Resnet50_scratch_dag(nn.Module):

    def __init__(self):
        super(Resnet50_scratch_dag, self).__init__()
        self.meta = {'mean': [131.0912, 103.8827, 91.4953],
                     'std': [1, 1, 1],
                     'imageSize': [224, 224, 3]}
        self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=[7, 7], stride=(2, 2), padding=(3, 3), bias=False)
        self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv1_relu_7x7_s2 = nn.ReLU()
        self.pool1_3x3_s2 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=(0, 0), dilation=1, ceil_mode=True)
        self.conv2_1_1x1_reduce = nn.Conv2d(64, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_1_1x1_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_1_1x1_reduce_relu = nn.ReLU()
        self.conv2_1_3x3 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv2_1_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_1_3x3_relu = nn.ReLU()
        self.conv2_1_1x1_increase = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_1_1x1_increase_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_1_1x1_proj = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_1_1x1_proj_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_1_relu = nn.ReLU()
        self.conv2_2_1x1_reduce = nn.Conv2d(256, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_2_1x1_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_2_1x1_reduce_relu = nn.ReLU()
        self.conv2_2_3x3 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv2_2_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_2_3x3_relu = nn.ReLU()
        self.conv2_2_1x1_increase = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_2_1x1_increase_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_2_relu = nn.ReLU()
        self.conv2_3_1x1_reduce = nn.Conv2d(256, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_3_1x1_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_3_1x1_reduce_relu = nn.ReLU()
        self.conv2_3_3x3 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv2_3_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_3_3x3_relu = nn.ReLU()
        self.conv2_3_1x1_increase = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv2_3_1x1_increase_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2_3_relu = nn.ReLU()
        self.conv3_1_1x1_reduce = nn.Conv2d(256, 128, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv3_1_1x1_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_1_1x1_reduce_relu = nn.ReLU()
        self.conv3_1_3x3 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3_1_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_1_3x3_relu = nn.ReLU()
        self.conv3_1_1x1_increase = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_1_1x1_increase_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_1_1x1_proj = nn.Conv2d(256, 512, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv3_1_1x1_proj_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_1_relu = nn.ReLU()
        self.conv3_2_1x1_reduce = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_2_1x1_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_2_1x1_reduce_relu = nn.ReLU()
        self.conv3_2_3x3 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3_2_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_2_3x3_relu = nn.ReLU()
        self.conv3_2_1x1_increase = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_2_1x1_increase_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_2_relu = nn.ReLU()
        self.conv3_3_1x1_reduce = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_3_1x1_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_3_1x1_reduce_relu = nn.ReLU()
        self.conv3_3_3x3 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3_3_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_3_3x3_relu = nn.ReLU()
        self.conv3_3_1x1_increase = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_3_1x1_increase_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_3_relu = nn.ReLU()
        self.conv3_4_1x1_reduce = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_4_1x1_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_4_1x1_reduce_relu = nn.ReLU()
        self.conv3_4_3x3 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3_4_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_4_3x3_relu = nn.ReLU()
        self.conv3_4_1x1_increase = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv3_4_1x1_increase_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3_4_relu = nn.ReLU()
        self.conv4_1_1x1_reduce = nn.Conv2d(512, 256, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv4_1_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_1_1x1_reduce_relu = nn.ReLU()
        self.conv4_1_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_1_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_1_3x3_relu = nn.ReLU()
        self.conv4_1_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_1_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_1_1x1_proj = nn.Conv2d(512, 1024, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv4_1_1x1_proj_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_1_relu = nn.ReLU()
        self.conv4_2_1x1_reduce = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_2_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_2_1x1_reduce_relu = nn.ReLU()
        self.conv4_2_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_2_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_2_3x3_relu = nn.ReLU()
        self.conv4_2_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_2_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_2_relu = nn.ReLU()
        self.conv4_3_1x1_reduce = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_3_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_3_1x1_reduce_relu = nn.ReLU()
        self.conv4_3_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_3_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_3_3x3_relu = nn.ReLU()
        self.conv4_3_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_3_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_3_relu = nn.ReLU()
        self.conv4_4_1x1_reduce = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_4_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_4_1x1_reduce_relu = nn.ReLU()
        self.conv4_4_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_4_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_4_3x3_relu = nn.ReLU()
        self.conv4_4_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_4_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_4_relu = nn.ReLU()
        self.conv4_5_1x1_reduce = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_5_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_5_1x1_reduce_relu = nn.ReLU()
        self.conv4_5_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_5_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_5_3x3_relu = nn.ReLU()
        self.conv4_5_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_5_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_5_relu = nn.ReLU()
        self.conv4_6_1x1_reduce = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_6_1x1_reduce_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_6_1x1_reduce_relu = nn.ReLU()
        self.conv4_6_3x3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv4_6_3x3_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_6_3x3_relu = nn.ReLU()
        self.conv4_6_1x1_increase = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv4_6_1x1_increase_bn = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4_6_relu = nn.ReLU()
        self.conv5_1_1x1_reduce = nn.Conv2d(1024, 512, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv5_1_1x1_reduce_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_1_1x1_reduce_relu = nn.ReLU()
        self.conv5_1_3x3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv5_1_3x3_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_1_3x3_relu = nn.ReLU()
        self.conv5_1_1x1_increase = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv5_1_1x1_increase_bn = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_1_1x1_proj = nn.Conv2d(1024, 2048, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.conv5_1_1x1_proj_bn = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_1_relu = nn.ReLU()
        self.conv5_2_1x1_reduce = nn.Conv2d(2048, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv5_2_1x1_reduce_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_2_1x1_reduce_relu = nn.ReLU()
        self.conv5_2_3x3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv5_2_3x3_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_2_3x3_relu = nn.ReLU()
        self.conv5_2_1x1_increase = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv5_2_1x1_increase_bn = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_2_relu = nn.ReLU()
        self.conv5_3_1x1_reduce = nn.Conv2d(2048, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv5_3_1x1_reduce_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_3_1x1_reduce_relu = nn.ReLU()
        self.conv5_3_3x3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.conv5_3_3x3_bn = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_3_3x3_relu = nn.ReLU()
        self.conv5_3_1x1_increase = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.conv5_3_1x1_increase_bn = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv5_3_relu = nn.ReLU()
        self.pool5_7x7_s1 = nn.AvgPool2d(kernel_size=[7, 7], stride=[1, 1], padding=0)
        self.classifier = nn.Conv2d(2048, 8631, kernel_size=[1, 1], stride=(1, 1))

    def forward(self, data):
        conv1_7x7_s2 = self.conv1_7x7_s2(data)
        conv1_7x7_s2_bn = self.conv1_7x7_s2_bn(conv1_7x7_s2)
        conv1_7x7_s2_bnxx = self.conv1_relu_7x7_s2(conv1_7x7_s2_bn)
        pool1_3x3_s2 = self.pool1_3x3_s2(conv1_7x7_s2_bnxx)
        conv2_1_1x1_reduce = self.conv2_1_1x1_reduce(pool1_3x3_s2)
        conv2_1_1x1_reduce_bn = self.conv2_1_1x1_reduce_bn(conv2_1_1x1_reduce)
        conv2_1_1x1_reduce_bnxx = self.conv2_1_1x1_reduce_relu(conv2_1_1x1_reduce_bn)
        conv2_1_3x3 = self.conv2_1_3x3(conv2_1_1x1_reduce_bnxx)
        conv2_1_3x3_bn = self.conv2_1_3x3_bn(conv2_1_3x3)
        conv2_1_3x3_bnxx = self.conv2_1_3x3_relu(conv2_1_3x3_bn)
        conv2_1_1x1_increase = self.conv2_1_1x1_increase(conv2_1_3x3_bnxx)
        conv2_1_1x1_increase_bn = self.conv2_1_1x1_increase_bn(conv2_1_1x1_increase)
        conv2_1_1x1_proj = self.conv2_1_1x1_proj(pool1_3x3_s2)
        conv2_1_1x1_proj_bn = self.conv2_1_1x1_proj_bn(conv2_1_1x1_proj)
        conv2_1 = torch.add(conv2_1_1x1_proj_bn, 1, conv2_1_1x1_increase_bn)
        conv2_1x = self.conv2_1_relu(conv2_1)
        conv2_2_1x1_reduce = self.conv2_2_1x1_reduce(conv2_1x)
        conv2_2_1x1_reduce_bn = self.conv2_2_1x1_reduce_bn(conv2_2_1x1_reduce)
        conv2_2_1x1_reduce_bnxx = self.conv2_2_1x1_reduce_relu(conv2_2_1x1_reduce_bn)
        conv2_2_3x3 = self.conv2_2_3x3(conv2_2_1x1_reduce_bnxx)
        conv2_2_3x3_bn = self.conv2_2_3x3_bn(conv2_2_3x3)
        conv2_2_3x3_bnxx = self.conv2_2_3x3_relu(conv2_2_3x3_bn)
        conv2_2_1x1_increase = self.conv2_2_1x1_increase(conv2_2_3x3_bnxx)
        conv2_2_1x1_increase_bn = self.conv2_2_1x1_increase_bn(conv2_2_1x1_increase)
        conv2_2 = torch.add(conv2_1x, 1, conv2_2_1x1_increase_bn)
        conv2_2x = self.conv2_2_relu(conv2_2)
        conv2_3_1x1_reduce = self.conv2_3_1x1_reduce(conv2_2x)
        conv2_3_1x1_reduce_bn = self.conv2_3_1x1_reduce_bn(conv2_3_1x1_reduce)
        conv2_3_1x1_reduce_bnxx = self.conv2_3_1x1_reduce_relu(conv2_3_1x1_reduce_bn)
        conv2_3_3x3 = self.conv2_3_3x3(conv2_3_1x1_reduce_bnxx)
        conv2_3_3x3_bn = self.conv2_3_3x3_bn(conv2_3_3x3)
        conv2_3_3x3_bnxx = self.conv2_3_3x3_relu(conv2_3_3x3_bn)
        conv2_3_1x1_increase = self.conv2_3_1x1_increase(conv2_3_3x3_bnxx)
        conv2_3_1x1_increase_bn = self.conv2_3_1x1_increase_bn(conv2_3_1x1_increase)
        conv2_3 = torch.add(conv2_2x, 1, conv2_3_1x1_increase_bn)
        conv2_3x = self.conv2_3_relu(conv2_3)
        conv3_1_1x1_reduce = self.conv3_1_1x1_reduce(conv2_3x)
        conv3_1_1x1_reduce_bn = self.conv3_1_1x1_reduce_bn(conv3_1_1x1_reduce)
        conv3_1_1x1_reduce_bnxx = self.conv3_1_1x1_reduce_relu(conv3_1_1x1_reduce_bn)
        conv3_1_3x3 = self.conv3_1_3x3(conv3_1_1x1_reduce_bnxx)
        conv3_1_3x3_bn = self.conv3_1_3x3_bn(conv3_1_3x3)
        conv3_1_3x3_bnxx = self.conv3_1_3x3_relu(conv3_1_3x3_bn)
        conv3_1_1x1_increase = self.conv3_1_1x1_increase(conv3_1_3x3_bnxx)
        conv3_1_1x1_increase_bn = self.conv3_1_1x1_increase_bn(conv3_1_1x1_increase)
        conv3_1_1x1_proj = self.conv3_1_1x1_proj(conv2_3x)
        conv3_1_1x1_proj_bn = self.conv3_1_1x1_proj_bn(conv3_1_1x1_proj)
        conv3_1 = torch.add(conv3_1_1x1_proj_bn, 1, conv3_1_1x1_increase_bn)
        conv3_1x = self.conv3_1_relu(conv3_1)
        conv3_2_1x1_reduce = self.conv3_2_1x1_reduce(conv3_1x)
        conv3_2_1x1_reduce_bn = self.conv3_2_1x1_reduce_bn(conv3_2_1x1_reduce)
        conv3_2_1x1_reduce_bnxx = self.conv3_2_1x1_reduce_relu(conv3_2_1x1_reduce_bn)
        conv3_2_3x3 = self.conv3_2_3x3(conv3_2_1x1_reduce_bnxx)
        conv3_2_3x3_bn = self.conv3_2_3x3_bn(conv3_2_3x3)
        conv3_2_3x3_bnxx = self.conv3_2_3x3_relu(conv3_2_3x3_bn)
        conv3_2_1x1_increase = self.conv3_2_1x1_increase(conv3_2_3x3_bnxx)
        conv3_2_1x1_increase_bn = self.conv3_2_1x1_increase_bn(conv3_2_1x1_increase)
        conv3_2 = torch.add(conv3_1x, 1, conv3_2_1x1_increase_bn)
        conv3_2x = self.conv3_2_relu(conv3_2)
        conv3_3_1x1_reduce = self.conv3_3_1x1_reduce(conv3_2x)
        conv3_3_1x1_reduce_bn = self.conv3_3_1x1_reduce_bn(conv3_3_1x1_reduce)
        conv3_3_1x1_reduce_bnxx = self.conv3_3_1x1_reduce_relu(conv3_3_1x1_reduce_bn)
        conv3_3_3x3 = self.conv3_3_3x3(conv3_3_1x1_reduce_bnxx)
        conv3_3_3x3_bn = self.conv3_3_3x3_bn(conv3_3_3x3)
        conv3_3_3x3_bnxx = self.conv3_3_3x3_relu(conv3_3_3x3_bn)
        conv3_3_1x1_increase = self.conv3_3_1x1_increase(conv3_3_3x3_bnxx)
        conv3_3_1x1_increase_bn = self.conv3_3_1x1_increase_bn(conv3_3_1x1_increase)
        conv3_3 = torch.add(conv3_2x, 1, conv3_3_1x1_increase_bn)
        conv3_3x = self.conv3_3_relu(conv3_3)
        conv3_4_1x1_reduce = self.conv3_4_1x1_reduce(conv3_3x)
        conv3_4_1x1_reduce_bn = self.conv3_4_1x1_reduce_bn(conv3_4_1x1_reduce)
        conv3_4_1x1_reduce_bnxx = self.conv3_4_1x1_reduce_relu(conv3_4_1x1_reduce_bn)
        conv3_4_3x3 = self.conv3_4_3x3(conv3_4_1x1_reduce_bnxx)
        conv3_4_3x3_bn = self.conv3_4_3x3_bn(conv3_4_3x3)
        conv3_4_3x3_bnxx = self.conv3_4_3x3_relu(conv3_4_3x3_bn)
        conv3_4_1x1_increase = self.conv3_4_1x1_increase(conv3_4_3x3_bnxx)
        conv3_4_1x1_increase_bn = self.conv3_4_1x1_increase_bn(conv3_4_1x1_increase)
        conv3_4 = torch.add(conv3_3x, 1, conv3_4_1x1_increase_bn)
        conv3_4x = self.conv3_4_relu(conv3_4)
        conv4_1_1x1_reduce = self.conv4_1_1x1_reduce(conv3_4x)
        conv4_1_1x1_reduce_bn = self.conv4_1_1x1_reduce_bn(conv4_1_1x1_reduce)
        conv4_1_1x1_reduce_bnxx = self.conv4_1_1x1_reduce_relu(conv4_1_1x1_reduce_bn)
        conv4_1_3x3 = self.conv4_1_3x3(conv4_1_1x1_reduce_bnxx)
        conv4_1_3x3_bn = self.conv4_1_3x3_bn(conv4_1_3x3)
        conv4_1_3x3_bnxx = self.conv4_1_3x3_relu(conv4_1_3x3_bn)
        conv4_1_1x1_increase = self.conv4_1_1x1_increase(conv4_1_3x3_bnxx)
        conv4_1_1x1_increase_bn = self.conv4_1_1x1_increase_bn(conv4_1_1x1_increase)
        conv4_1_1x1_proj = self.conv4_1_1x1_proj(conv3_4x)
        conv4_1_1x1_proj_bn = self.conv4_1_1x1_proj_bn(conv4_1_1x1_proj)
        conv4_1 = torch.add(conv4_1_1x1_proj_bn, 1, conv4_1_1x1_increase_bn)
        conv4_1x = self.conv4_1_relu(conv4_1)
        conv4_2_1x1_reduce = self.conv4_2_1x1_reduce(conv4_1x)
        conv4_2_1x1_reduce_bn = self.conv4_2_1x1_reduce_bn(conv4_2_1x1_reduce)
        conv4_2_1x1_reduce_bnxx = self.conv4_2_1x1_reduce_relu(conv4_2_1x1_reduce_bn)
        conv4_2_3x3 = self.conv4_2_3x3(conv4_2_1x1_reduce_bnxx)
        conv4_2_3x3_bn = self.conv4_2_3x3_bn(conv4_2_3x3)
        conv4_2_3x3_bnxx = self.conv4_2_3x3_relu(conv4_2_3x3_bn)
        conv4_2_1x1_increase = self.conv4_2_1x1_increase(conv4_2_3x3_bnxx)
        conv4_2_1x1_increase_bn = self.conv4_2_1x1_increase_bn(conv4_2_1x1_increase)
        conv4_2 = torch.add(conv4_1x, 1, conv4_2_1x1_increase_bn)
        conv4_2x = self.conv4_2_relu(conv4_2)
        conv4_3_1x1_reduce = self.conv4_3_1x1_reduce(conv4_2x)
        conv4_3_1x1_reduce_bn = self.conv4_3_1x1_reduce_bn(conv4_3_1x1_reduce)
        conv4_3_1x1_reduce_bnxx = self.conv4_3_1x1_reduce_relu(conv4_3_1x1_reduce_bn)
        conv4_3_3x3 = self.conv4_3_3x3(conv4_3_1x1_reduce_bnxx)
        conv4_3_3x3_bn = self.conv4_3_3x3_bn(conv4_3_3x3)
        conv4_3_3x3_bnxx = self.conv4_3_3x3_relu(conv4_3_3x3_bn)
        conv4_3_1x1_increase = self.conv4_3_1x1_increase(conv4_3_3x3_bnxx)
        conv4_3_1x1_increase_bn = self.conv4_3_1x1_increase_bn(conv4_3_1x1_increase)
        conv4_3 = torch.add(conv4_2x, 1, conv4_3_1x1_increase_bn)
        conv4_3x = self.conv4_3_relu(conv4_3)
        conv4_4_1x1_reduce = self.conv4_4_1x1_reduce(conv4_3x)
        conv4_4_1x1_reduce_bn = self.conv4_4_1x1_reduce_bn(conv4_4_1x1_reduce)
        conv4_4_1x1_reduce_bnxx = self.conv4_4_1x1_reduce_relu(conv4_4_1x1_reduce_bn)
        conv4_4_3x3 = self.conv4_4_3x3(conv4_4_1x1_reduce_bnxx)
        conv4_4_3x3_bn = self.conv4_4_3x3_bn(conv4_4_3x3)
        conv4_4_3x3_bnxx = self.conv4_4_3x3_relu(conv4_4_3x3_bn)
        conv4_4_1x1_increase = self.conv4_4_1x1_increase(conv4_4_3x3_bnxx)
        conv4_4_1x1_increase_bn = self.conv4_4_1x1_increase_bn(conv4_4_1x1_increase)
        conv4_4 = torch.add(conv4_3x, 1, conv4_4_1x1_increase_bn)
        conv4_4x = self.conv4_4_relu(conv4_4)
        conv4_5_1x1_reduce = self.conv4_5_1x1_reduce(conv4_4x)
        conv4_5_1x1_reduce_bn = self.conv4_5_1x1_reduce_bn(conv4_5_1x1_reduce)
        conv4_5_1x1_reduce_bnxx = self.conv4_5_1x1_reduce_relu(conv4_5_1x1_reduce_bn)
        conv4_5_3x3 = self.conv4_5_3x3(conv4_5_1x1_reduce_bnxx)
        conv4_5_3x3_bn = self.conv4_5_3x3_bn(conv4_5_3x3)
        conv4_5_3x3_bnxx = self.conv4_5_3x3_relu(conv4_5_3x3_bn)
        conv4_5_1x1_increase = self.conv4_5_1x1_increase(conv4_5_3x3_bnxx)
        conv4_5_1x1_increase_bn = self.conv4_5_1x1_increase_bn(conv4_5_1x1_increase)
        conv4_5 = torch.add(conv4_4x, 1, conv4_5_1x1_increase_bn)
        conv4_5x = self.conv4_5_relu(conv4_5)
        conv4_6_1x1_reduce = self.conv4_6_1x1_reduce(conv4_5x)
        conv4_6_1x1_reduce_bn = self.conv4_6_1x1_reduce_bn(conv4_6_1x1_reduce)
        conv4_6_1x1_reduce_bnxx = self.conv4_6_1x1_reduce_relu(conv4_6_1x1_reduce_bn)
        conv4_6_3x3 = self.conv4_6_3x3(conv4_6_1x1_reduce_bnxx)
        conv4_6_3x3_bn = self.conv4_6_3x3_bn(conv4_6_3x3)
        conv4_6_3x3_bnxx = self.conv4_6_3x3_relu(conv4_6_3x3_bn)
        conv4_6_1x1_increase = self.conv4_6_1x1_increase(conv4_6_3x3_bnxx)
        conv4_6_1x1_increase_bn = self.conv4_6_1x1_increase_bn(conv4_6_1x1_increase)
        conv4_6 = torch.add(conv4_5x, 1, conv4_6_1x1_increase_bn)
        conv4_6x = self.conv4_6_relu(conv4_6)
        conv5_1_1x1_reduce = self.conv5_1_1x1_reduce(conv4_6x)
        conv5_1_1x1_reduce_bn = self.conv5_1_1x1_reduce_bn(conv5_1_1x1_reduce)
        conv5_1_1x1_reduce_bnxx = self.conv5_1_1x1_reduce_relu(conv5_1_1x1_reduce_bn)
        conv5_1_3x3 = self.conv5_1_3x3(conv5_1_1x1_reduce_bnxx)
        conv5_1_3x3_bn = self.conv5_1_3x3_bn(conv5_1_3x3)
        conv5_1_3x3_bnxx = self.conv5_1_3x3_relu(conv5_1_3x3_bn)
        conv5_1_1x1_increase = self.conv5_1_1x1_increase(conv5_1_3x3_bnxx)
        conv5_1_1x1_increase_bn = self.conv5_1_1x1_increase_bn(conv5_1_1x1_increase)
        conv5_1_1x1_proj = self.conv5_1_1x1_proj(conv4_6x)
        conv5_1_1x1_proj_bn = self.conv5_1_1x1_proj_bn(conv5_1_1x1_proj)
        conv5_1 = torch.add(conv5_1_1x1_proj_bn, 1, conv5_1_1x1_increase_bn)
        conv5_1x = self.conv5_1_relu(conv5_1)
        conv5_2_1x1_reduce = self.conv5_2_1x1_reduce(conv5_1x)
        conv5_2_1x1_reduce_bn = self.conv5_2_1x1_reduce_bn(conv5_2_1x1_reduce)
        conv5_2_1x1_reduce_bnxx = self.conv5_2_1x1_reduce_relu(conv5_2_1x1_reduce_bn)
        conv5_2_3x3 = self.conv5_2_3x3(conv5_2_1x1_reduce_bnxx)
        conv5_2_3x3_bn = self.conv5_2_3x3_bn(conv5_2_3x3)
        conv5_2_3x3_bnxx = self.conv5_2_3x3_relu(conv5_2_3x3_bn)
        conv5_2_1x1_increase = self.conv5_2_1x1_increase(conv5_2_3x3_bnxx)
        conv5_2_1x1_increase_bn = self.conv5_2_1x1_increase_bn(conv5_2_1x1_increase)
        conv5_2 = torch.add(conv5_1x, 1, conv5_2_1x1_increase_bn)
        conv5_2x = self.conv5_2_relu(conv5_2)
        conv5_3_1x1_reduce = self.conv5_3_1x1_reduce(conv5_2x)
        conv5_3_1x1_reduce_bn = self.conv5_3_1x1_reduce_bn(conv5_3_1x1_reduce)
        conv5_3_1x1_reduce_bnxx = self.conv5_3_1x1_reduce_relu(conv5_3_1x1_reduce_bn)
        conv5_3_3x3 = self.conv5_3_3x3(conv5_3_1x1_reduce_bnxx)
        conv5_3_3x3_bn = self.conv5_3_3x3_bn(conv5_3_3x3)
        conv5_3_3x3_bnxx = self.conv5_3_3x3_relu(conv5_3_3x3_bn)
        conv5_3_1x1_increase = self.conv5_3_1x1_increase(conv5_3_3x3_bnxx)
        conv5_3_1x1_increase_bn = self.conv5_3_1x1_increase_bn(conv5_3_1x1_increase)
        conv5_3 = torch.add(conv5_2x, 1, conv5_3_1x1_increase_bn)
        conv5_3x = self.conv5_3_relu(conv5_3)
        pool5_7x7_s1 = self.pool5_7x7_s1(conv5_3x)
        classifier_preflatten = self.classifier(pool5_7x7_s1)
        classifier = classifier_preflatten.view(classifier_preflatten.size(0), -1)
        return classifier, pool5_7x7_s1

def resnet50_scratch_dag(weights_path=None, **kwargs):
    """
    load imported model instance

    Args:
        weights_path (str): If set, loads model weights from the given path
    """
    model = Resnet50_scratch_dag()
    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)
    return model

In [4]:
class FacialExpressionModel:
    
    def __init__(self, model_name, num_classes,dataloaders,dataset_sizes,lr, momentum,
                 feature_extract=False, use_pretrained=True):
        
        self.model_name = model_name
        self.num_classes = num_classes
        self.feature_extract = feature_extract
        self.use_pretrained = use_pretrained
        self.learning_rate = lr
        self.momentum = momentum 
        self.model_ft = None
        self.input_size = None
        self.dataloaders = dataloaders
        self.dataset_sizes = dataset_sizes
    
    def train(self):
        
        # Initialise Model
        model_ft, input_size = self.initialise_model()
        
        # Set Device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model_ft.to(device)
        
        # Set Optimizer 
        parameters =self.params_to_update(model_ft)
        optimizer = optim.SGD(parameters, lr=self.learning_rate, momentum = self.momentum)
        
        # Set Criterion
        criterion = nn.CrossEntropyLoss()
        
        # Train Model
        model_ft, val_hist,f1 = self.train_model(model_ft, device, self.dataloaders, self.dataset_sizes, 
                                              criterion, optimizer, num_epochs=50)
        
        #Progressive version
#         list_val_hist = []
#         for dataloader in train_dataloader:
#             dataloaders = {"train": dataloader, "val": val_dataloader}
#             dataset_sizes = {"train": len(dataloader.dataset), "val": len(val_dataloader.dataset)}
#             model_ft, val_hist = self.train_model(model_ft, device, dataloaders, dataset_sizes,
#                                                  criterion, optimizer, num_epochs=5)
#             list_val_hist.append(val_hist)
        
        # Evaluate Model 
        accuracy = self.evaluate_model(model_ft, dataloaders, device)
        
        return model_ft, val_hist, accuracy
        
    
    def params_to_update(self,model_ft):
        if self.feature_extract: 
            params_to_update = []
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
        else:
            params_to_update = model_ft.parameters()
        
        return params_to_update
        
    def set_parameter_requires_grad(self, model, feature_extracting):
        """
        return: 
        
        """
        if feature_extracting: 
            for param in model.parameters(): 
                param.requires_grad = False 
    
    def initialise_model(self): 
        """
        return: model_ft, input_size
        
        """
        model_ft = None
        input_size = 0 
        
        if self.model_name == "resnet":
            model_ft = models.resnet50(pretrained=self.use_pretrained)
            self.set_parameter_requires_grad(model_ft, self.feature_extract)
            num_ftrs = model_ft.fc.in_features
            model_ft.fc = nn.Linear(num_ftrs, self.num_classes)
            input_size = 224
        
        elif self.model_name == "alexnet":
            model_ft = models.alexnet(pretrained=self.use_pretrained)
            self.set_parameter_requires_grad(model_ft, self.feature_extract)
            num_ftrs = model_ft.classifier[6].in_features
            model_ft.classifier[6] = nn.Linear(num_ftrs, self.num_classes)
            input_size = 224
        
        elif self.model_name == "vgg":
            model_ft = models.vgg11_bn(pretrained=self.use_pretrained)
            self.set_parameter_requires_grad(model_ft, self.feature_extract)
            num_ftrs = model_ft.classifier[6].in_features
            model_ft.classifier[6] = nn.Linear(num_ftrs, self.num_classes)
            input_size = 224
            
        elif self.model_name == 'vggface':
            model_ft = resnet50_scratch_dag(weights_path="resnet50_scratch_dag.pth")
            self.set_parameter_requires_grad(model_ft, self.feature_extract)
            num_ftrs = model_ft.classifier.in_channels
            model_ft.classifer = nn.Linear(num_ftrs, self.num_classes)
            input_size = 224
        
        self.model_ft = model_ft
        self.input_size = input_size
        
        return model_ft, input_size 
    
    
    #add dataset_size inside 

    def train_model(self,model,device,dataloaders, dataset_sizes, criterion, optimizer, num_epochs=25): 
        """
        return: model + val_acc_history
        
        """
        #start time 
        since = time.time()

        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0
        val_acc_history = []

        #for each epoch
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            #train and evaluate at current model at each epoch
            for phase in ["train","val"]:
                if phase == "train":
                    model.train() #set model to training mode
                else:
                    model.eval() #set model to evaluate mode

                #keep track of loss 
                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    #zero the parameter gradients
                    optimizer.zero_grad()

                    #forward prop 
                    #track history if only in train 
                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model(inputs) #predict
                        _,preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        #backward + optimize only if training phase 
                        if phase == "train":
                            loss.backward()
                            optimizer.step()

                    #stats 
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)


#                 if phase == "train":
#                     scheduler.step()

                epoch_loss = running_loss/dataset_sizes[phase]
                epoch_acc = running_corrects.double() /dataset_sizes[phase]
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
                
                # Store Losses and Accuracy in Tensorboard
                if phase == "train":
                    writer_training.add_scalar("loss", epoch_loss,epoch)
                else:
                    writer_validation.add_scalar("loss", epoch_loss,epoch)
                    writer_validation.add_scalar("accuracy", epoch_acc,epoch)

                if phase == "val" and epoch_acc > best_acc: 
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model.state_dict())
                if phase == "val":
                    val_acc_history.append(epoch_acc)

            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))

        model.load_state_dict(best_model)
        
        return model, val_acc_history 
    
    def evaluate_model(self, model_ft, dataloaders, device):
        correct = 0 
        total = 0 
        
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in dataloaders['test']: #change to test in real pipeline
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model_ft(inputs)
                _,predicted = torch.max(outputs.data,1)
#                 print(predicted)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
                all_preds.append(predicted.cpu().detach().numpy())
                all_labels.append(labels.cpu().detach().numpy())
                
        all_preds = np.concatenate(all_preds, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
                
        f1 = f1_score(all_labels, all_preds, average='macro' )
        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')

        #Accuracy 
        accuracy = correct/total * 100
        
        # Unweighted Average F1 Score 
        
        return correct/total * 100, f1

In [None]:
# Run experiment 
writer_training = SummaryWriter("./runs/ExperimentName/Training")
writer_validation = SummaryWriter("./runs/ExperimentName/Validation")

new_experiment = FacialExpressionModel("resnet", 2,dataloaders, dataset_sizes, lr=0.001, momentum=0.9,
                 feature_extract=False, use_pretrained=True) #add in epochs as well
model_ft, val_hist, accuracy = new_experiment.train()

writer_training.flush()
writer_validation.flush()
writer_training.close()
writer_training.close()

In [None]:
# Load the extension and start TensorBoard for kaggle

%load_ext tensorboard.notebook
%tensorboard --logdir logs