In [1]:
import os, sys; sys.path.append(os.path.dirname(os.getcwd()))

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

import util.RAVDESS_dataset_util as Rd
import multimodal_vae
from train_mvae import build_model, train

from config_args import ConfigModelArgs, ConfigTrainArgs

In [None]:
cfg_model = ConfigModelArgs()
cfg_train = ConfigTrainArgs()

In [None]:
face_dataset = Rd.FaceEmotionDataset(
    root_dir=cfg_model.dataset_path,
    transform=transforms.Compose
    ([
        Rd.Rescale(cfg_model.img_size), 
        Rd.CenterCrop(cfg_model.img_size), 
        Rd.ToTensor()
    ]))

trainingset_len = len(face_dataset) // 100 * 90
testset_len = len(face_dataset) - trainingset_len
training_dataset, testing_dataset = torch.utils.data.random_split(
    face_dataset, 
    [trainingset_len, testset_len],
    #generator=torch.Generator().manual_seed(42)
)

dataset_loader = DataLoader(training_dataset, batch_size=cfg_train.batch_size,
                        shuffle=True, num_workers=cfg_train.num_workers)

testset_loader = DataLoader(testing_dataset, batch_size=cfg_train.batch_size,
                        shuffle=True, num_workers=cfg_train.num_workers)

print('training set size: ',trainingset_len,'\ntest set size: ',testset_len)

In [None]:
model: torch.nn.Module = build_model(
    cat_dim=cfg_model.cat_dim,
    latent_space_dim=cfg_model.z_dim,
    hidden_dim=cfg_model.hidden_dim,
    loss_weights=cfg_model.loss_weights,
    expert_type=cfg_model.expert_type,
    use_cuda=True
).double()

In [None]:
training_losses = train(
    mvae_model=model,
    dataset_loader=dataset_loader,
    learning_rate=cfg_train.learning_rate,
    optim_betas=cfg_train.optim_betas,
    num_epochs=cfg_train.num_epochs,
    batch_size=cfg_train.batch_size,
    checkpoint_every=cfg_train.checkpoint_every,
    checkpoint_path=cfg_train.checkpoint_path,
    save_model=cfg_train.save_model,
    seed=cfg_train.seed,
    use_cuda=cfg_train.use_cuda,
    cfg=cfg_train
)

In [None]:
plt.title("Total loss")
plt.ylabel("Loss ")
plt.xlabel("Epochs ")
plt.plot(training_losses['multimodal_loss'].total_loss[1:], color='red', label='multimodal')
plt.plot(training_losses['emotion_loss'].total_loss[1:], color='green', label='emotion')
plt.plot(training_losses['face_loss'].total_loss[1:], color='blue', label='face')
plt.legend(loc="upper right")

In [None]:
plt.title("Reconstruction loss")
plt.ylabel("Loss ")
plt.xlabel("Epochs ")

plt.plot(training_losses['multimodal_loss'].reconstruction_loss, color='red', label='multimodal')
plt.plot(training_losses['emotion_loss'].reconstruction_loss, color='green', label='emotion')
plt.plot(training_losses['face_loss'].reconstruction_loss, color='blue', label='face')
plt.legend(loc="upper right")

In [None]:
plt.title("KLD loss")
plt.ylabel("Loss ")
plt.xlabel("Epochs ")

plt.plot(training_losses['multimodal_loss'].kld_loss[1:], color='red', label='multimodal')
plt.plot(training_losses['emotion_loss'].kld_loss[1:], color='green', label='emotion')
plt.plot(training_losses['face_loss'].kld_loss[1:], color='blue', label='face')
plt.legend(loc="upper right")

In [None]:
plt.title("Face reconstruction loss")
plt.ylabel("Loss")
plt.xlabel("Epochs")

plt.plot(training_losses['multimodal_loss'].faces_reconstruction_loss, color='red', label='multimodal')
plt.plot(training_losses['emotion_loss'].faces_reconstruction_loss, color='green', label='emotion')
plt.plot(training_losses['face_loss'].faces_reconstruction_loss, color='blue', label='face')
plt.legend(loc="upper right")

In [None]:
plt.title("Emotion reconstruction loss")
plt.ylabel("Loss ")
plt.xlabel("Epochs ")

plt.plot(training_losses['multimodal_loss'].emotions_reconstruction_loss, color='red', label='multimodal')
plt.plot(training_losses['emotion_loss'].emotions_reconstruction_loss, color='green', label='emotion')
plt.plot(training_losses['face_loss'].emotions_reconstruction_loss, color='blue', label='face')
plt.legend(loc="upper right")

In [None]:
def testBatch(model, dataset_loader, img_size=64, use_cuda=True):
    sample = next(iter(dataset_loader))
    images = sample['image']
    labels = sample['cat']
    
    if use_cuda:
        images = images.cuda()
        labels = labels.cuda()
        
    input_array = np.zeros(shape=(img_size, 1, 3), dtype="uint8")
    reconstructed_array = np.zeros(shape=(img_size, 1, 3), dtype="uint8")
    reconstructed_emotions = []
    
    plt.figure(figsize = (40,10))
    
    batch_size = images.shape[0]
    
    reconstructed_images, reconstructed_emotions, _, _ = model(faces=images, emotions=labels)

    for idx in range(4):
        input_image = images[idx]
        
        # storing the input image
        input_image_display = np.array(input_image.cpu()*255., dtype='uint8').transpose((1, 2, 0))
        input_array = np.concatenate((input_array, input_image_display), axis=1)
        
        # generating the reconstructed image and adding to array
        input_image = input_image.view(1, 3, img_size, img_size)
        
        reconstructed_img = reconstructed_images[idx].cpu().view(3, img_size, img_size).detach().numpy()
        reconstructed_img = np.array(reconstructed_img*255., dtype='uint8').transpose((1, 2, 0))
        reconstructed_array = np.concatenate((reconstructed_array, reconstructed_img), axis=1)
        
    # remove first, blank column, and concatenate
    input_array = input_array[:,1:,:]
    reconstructed_array = reconstructed_array[:,1:,:]
    display_array = np.concatenate((input_array, reconstructed_array), axis=0)
    plt.imshow(display_array)
    
    print([Rd.emocat[label.item()] for label in labels[:4]])
    print([Rd.emocat[emo.item()] for emo in torch.argmax(reconstructed_emotions, 1)[:4]])

In [None]:
testBatch(model, testset_loader)

In [None]:
from tqdm import tqdm
def emotion_accuracy(model, dataset_loader):
    
    match = 0
    total = 0
    
    for sample in tqdm(iter(dataset_loader)):
        labels = sample['cat'].cuda()
                            
        _, reconstructed_emotions, _, _ = model(faces=None, emotions=labels)  
        emotion_cat = torch.argmax(reconstructed_emotions, 1)  
        
        for idx in range(len(labels)):
            total += 1
            if labels[idx] == emotion_cat[idx]:
                match += 1
    
    acc = match / total
    return acc

In [None]:
print(emotion_accuracy(model, testset_loader))

In [None]:
save_model = False

if save_model:
    # Do a global and a local save of the model (local to Hydra outputs)
    torch.save(mvae_model.state_dict(), cfg.model_save_path)
    torch.save(mvae_model.state_dict(), "ravdess_mvae_pretrained.pt")
    logger.info(f"Saved model to '{cfg.train.plain.model_save_path}', and also locally.")

    # Do a global and local save of the training stats (local to Hydra outputs)
    torch.save(training_losses, cfg.stats_save_path)
    torch.save(training_losses, "ravdess_mvae_pretrained_stats.pt")
    logger.info(f"Saved model to '{cfg.train.plain.stats_save_path}', and also locally.")