In [None]:
import tensorflow as tf
import numpy as np
import os
import pickle
from tensorflow.python.keras import layers, Sequential, losses, metrics, optimizers, callbacks, models
from tensorflow.python.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.applications import vgg16, resnet
from tensorflow.python.keras.optimizer_v2 import adam
import gc
import tensorflow_addons as tfa
import random
import math

In [None]:
image_height = 48
image_width = 48
emotions_count = 8
emotion_labels = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt']

In [None]:
def Shuffle(img, emo):
    tmp = list(zip(img, emo))
    random.shuffle(tmp)
    img, emo = zip(*tmp)
    return img, emo

In [None]:
def load_data(i):
    image_path = "./dataset/images.npy"
    emotion_path = "./dataset/emotions_multi.npy"
    
    images = np.load(image_path)
    images = tf.convert_to_tensor(images)
    images = layers.Rescaling(1./127.5, offset= -1)(images)
    images = tf.image.grayscale_to_rgb(images)
    emotions = np.load(emotion_path)
    emotions = tf.convert_to_tensor(emotions)
    
    training_samples = 28317
    validation_samples = 3541
    training_size = training_samples + validation_samples
    
    training_images = images[:training_size]
    test_images = images[training_size:]
    training_emotions = emotions[:training_size]
    test_emotions = emotions[training_size:]

    if i==0:
        training_images_flip = tf.image.flip_left_right(training_images)
        training_images_flip_augmented = tf.concat([training_images,training_images_flip], 0)
        training_emotions_flip_augmented = tf.concat([training_emotions,training_emotions], 0)
        training_images, training_emotions = training_images_flip_augmented,training_emotions_flip_augmented

    if i==1:
        degree = [20,-20]
        
        training_images_rotate0 = tfa.image.rotate(images=training_images, angles=degree[0]*math.pi/180, fill_mode='constant', fill_value=0) 
        training_images_rotate1 = tfa.image.rotate(images=training_images, angles=degree[1]*math.pi/180, fill_mode='constant', fill_value=0)        
        training_images_rotate_augmented = tf.concat([training_images,training_images_rotate0,training_images_rotate1], 0)
        training_emotions_rotate_augmented = tf.concat([training_emotions,training_emotions,training_emotions], 0)
        training_images, training_emotions = training_images_rotate_augmented,training_emotions_rotate_augmented

    if i==2:
        brightness = [0.3,-0.3]
        
        training_images_brightness0 = tf.image.adjust_brightness(training_images, brightness[0])
        training_images_brightness1 = tf.image.adjust_brightness(training_images, brightness[1])        
        training_images_brightness_augmented = tf.concat([training_images,training_images_brightness0,training_images_brightness1], 0)          
        training_emotions_brightness_augmented = tf.concat([training_emotions,training_emotions,training_emotions], 0)
        training_images, training_emotions = training_images_brightness_augmented,training_emotions_brightness_augmented

    if i==3:
        contrast = [0.25+1,-0.25+1]
        
        training_images_contrast0 = tf.image.adjust_contrast(training_images,contrast[0]) 
        training_images_contrast1 = tf.image.adjust_contrast(training_images,contrast[1])        
        training_images_contrast_augmented = tf.concat([training_images,training_images_contrast0,training_images_contrast1], 0)
        training_emotions_contrast_augmented = tf.concat([training_emotions,training_emotions,training_emotions], 0)
        training_images, training_emotions = training_images_contrast_augmented,training_emotions_contrast_augmented
    
    print("aug index", i)
    print("training shape",training_images.shape, training_emotions.shape)
    return training_images, training_emotions, test_images, test_emotions

In [None]:
tf.config.run_functions_eagerly(True)
def model_acc(y_true, y_pred):
    size = y_true.shape[0]
    acc = 0
    for i in range(size):
        true = y_true[i]
        pred = y_pred[i]           
        index_max = tf.argmax(pred).numpy()
        if true[index_max].numpy()==tf.reduce_max(true).numpy():
            acc += 1
    return acc/size

In [None]:
def create_model():
    base_model = vgg16.VGG16(include_top=False, weights="imagenet", input_shape=(48,48,3))
    base_model.trainable=True
    return Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(4096, activation='relu'),
        layers.Dense(4096, activation='relu'),
        layers.Dense(emotions_count, activation='softmax'),])

In [None]:
def train(learning_rate, loss, num_epochs, batch_size, index):
    model = create_model()
    model.compile(optimizer=adam.Adam(learning_rate=learning_rate), 
                  loss=loss, 
                  metrics = [model_acc])
    training_images, training_emotions, test_images, test_emotions = load_data(index)
    history = model.fit(x=training_images,
                        y=training_emotions,
                        batch_size=batch_size,
                        epochs=num_epochs,
                        validation_data=(test_images, test_emotions),
                        shuffle=True,)
    del model, training_images, training_emotions, test_images, test_emotions
    gc.collect()
    return history

In [None]:
if not os.path.isdir('./history/'):
    os.mkdir('./history/')

learning_rate = 1e-4
num_epochs = 40
loss = losses.MeanSquaredError()
batch_size = 32

for aug in range(4):
    
    history = train(learning_rate, loss, num_epochs, batch_size, aug)
    if aug == 0:
        history_save_path = './history/aug_flip.txt'
    if aug == 1:
        history_save_path = './history/aug_rotate.txt'
    if aug == 2:
        history_save_path = './history/aug_brightness.txt'
    if aug == 3:
        history_save_path = './history/aug_contrast.txt'
    with open(history_save_path, 'wb') as file_pi:
        pickle.dump(history.history, file_pi)
    del history
    gc.collect()