In [None]:
import sys
import os
import math
import shutil
import random
import tempfile
import unittest
import traceback
import torch
import torch.utils.data
import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from tensorboardX import SummaryWriter
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import copy
from sklearn.decomposition import PCA
from sklearn import metrics
import pandas as pd
import argparse
import gc
import numpy as np
import time
from PIL import Image
from convlstm import *
use_cuda = True
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device) # change allocation of current GPU
print ('Current cuda device ', torch.cuda.current_device()) # check

fold = 'fold_0'
frame_len = 79

model_name = 'FacialCueNet(AU,sym,gaze,ME)'
train_dataset_name = 'CourtroomDB'
test_dataset_name = 'CourtroomDB(fold_0)'

# val_fold_dir is str !!! (EX)'fold_0'
def fold_to_trainset(fold_path, val_fold_dir):
    extracted_path_list = []
    au_path_list = []
    symmetry_path_list = []
    gaze_path_list = []
    ME_path_list = []
    for fold in sorted(os.listdir(fold_path)):
        if fold != val_fold_dir:
            for feature in sorted(os.listdir(fold_path+'/'+fold)):
                if feature == 'DeceptionDB_feature':
                    extracted_feature_path = fold+'/'+feature
                    extracted_path_list.append(extracted_feature_path)
                elif feature == 'ActionUnit(FPS15)':
                    au_feature_path = fold+'/'+feature
                    au_path_list.append(au_feature_path)    
                elif feature == 'NormalizedCrossCovariance(FPS15)':
                    symmetry_feature_path = fold+'/'+feature
                    symmetry_path_list.append(symmetry_feature_path) 
                elif feature == 'gaze_features(FPS15)':
                    gaze_feature_path = fold+'/'+feature
                    gaze_path_list.append(gaze_feature_path) 
                elif feature == 'MicroExpression(FPS15)':
                    ME_feature_path = fold+'/'+feature
                    ME_path_list.append(ME_feature_path) 
    return extracted_path_list, au_path_list, symmetry_path_list, gaze_path_list, ME_path_list


def data_to_dict_test(data_name):
    video_name_list = []
    class_dict = {}
    path = '/Real-life_Deception_Detection_2016/pruned_10fold/{0}/DeceptionDB_feature'.format(fold)
    for video in sorted(os.listdir(path)):
        video_name = video
        if video.split('_')[1] == 'truth':
            video_label = 1
        elif video.split('_')[1] == 'lie':
            video_label = 0
        else:
            print('Wrong video label was selected !')
        video_name_list.append(video_name)
        class_dict[video_name] = video_label
    
    au_file_list = []
    au_path = '/Real-life_Deception_Detection_2016/pruned_10fold/{0}/ActionUnit(FPS15)'.format(fold)
    for au_file in sorted(os.listdir(au_path)):
        au_file_list.append(au_file)
        
    symmetry_file_list = []
    symmetry_path = '/Real-life_Deception_Detection_2016/pruned_10fold/{0}/NormalizedCrossCovariance(FPS15)'.format(fold)
    for symmetry_file in sorted(os.listdir(symmetry_path)):
        symmetry_file_list.append(symmetry_file)
        
    gaze_file_list = []
    gaze_path = '/Real-life_Deception_Detection_2016/pruned_10fold/{0}/gaze_features(FPS15)'.format(fold)
    for gaze_file in sorted(os.listdir(gaze_path)):
        gaze_file_list.append(gaze_file)
        
    ME_file_list = []
    ME_path = '/Real-life_Deception_Detection_2016/pruned_10fold/{0}/MicroExpression(FPS15)'.format(fold)
    for ME_file in sorted(os.listdir(ME_path)):
        ME_file_list.append(ME_file)

    return path, class_dict, video_name_list, au_path, au_file_list, symmetry_path, symmetry_file_list, gaze_path, gaze_file_list, ME_path, ME_file_list

def data_to_dict_train(data_name):
    video_name_list = []
    class_dict = {}
    path = '/Real-life_Deception_Detection_2016/pruned_10fold'
    video_path, au_path, symmetry_path, gaze_path, ME_path = fold_to_trainset(path, fold)
    for v_path in video_path:
        for video in sorted(os.listdir(path+'/'+v_path)):
            video_name = video.split('/')[-1]
            if video_name.split('_')[1] == 'truth':
                video_label = 1
            elif video_name.split('_')[1] == 'lie':
                video_label = 0
            else:
                print('Wrong video label was selected !')
            video_name_list.append(v_path+'/'+video)
            class_dict[v_path+'/'+video] = video_label
    
    au_file_list = []
    for a_path in au_path:
        for au_file in sorted(os.listdir(path+'/'+a_path)):
            au_file_list.append(a_path+'/'+au_file)
    au_path = path
        
    symmetry_file_list = []
    for s_path in symmetry_path:
        for symmetry_file in sorted(os.listdir(path+'/'+s_path)):
            symmetry_file_list.append(s_path+'/'+symmetry_file)
    symmetry_path = path
    
    gaze_file_list = []
    for g_path in gaze_path:
        for gaze_file in sorted(os.listdir(path+'/'+g_path)):
            gaze_file_list.append(g_path+'/'+gaze_file)
            
    gaze_path = path
        
    ME_file_list = []
    for M_path in ME_path:
        for ME_file in sorted(os.listdir(path+'/'+M_path)):
            ME_file_list.append(M_path+'/'+ME_file)
            
    ME_path = path

    return path, class_dict, video_name_list, au_path, au_file_list, symmetry_path, symmetry_file_list, gaze_path, gaze_file_list, ME_path, ME_file_list

class DeceptionDB(Dataset):
    def __init__(self, category_dict, data_dir, video_names, actionunit_dirs, actionunit_files, symmetry_dirs, symmetry_files, gaze_dirs, gaze_files, ME_dirs, ME_files, transform=None):
        self.data_dir = data_dir # path just before clip npy file
        self.video_names = video_names # list of clip name ( npy file name  )
        self.category_dict = category_dict # class_dict[video_name]=class
        self.AU_dir = actionunit_dirs
        self.AU_file_names = actionunit_files
        self.symmetry_dir = symmetry_dirs
        self.symmetry_file_names = symmetry_files
        self.gaze_dir = gaze_dirs
        self.gaze_file_names = gaze_files
        self.ME_dir = ME_dirs
        self.ME_file_names = ME_files
      
    def __len__(self):
        return len(self.video_names)

    def __getitem__(self, idx):
        if isinstance(self.data_dir, list):
            feature_file = os.path.join(self.data_dir[idx],  self.video_names[idx])
        else:
            feature_file = os.path.join(self.data_dir,  self.video_names[idx])
       
        video_name = self.video_names[idx]
        
        feature_per_video = np.load(feature_file)
        
        label_per_video = np.expand_dims(int(self.category_dict[video_name]), axis=0)  # shape(1,n)
    
        
        if isinstance(self.AU_dir, list):
            au_file = os.path.join(self.AU_dir[idx],  self.AU_file_names[idx])
        else:
            au_file = os.path.join(self.AU_dir,  self.AU_file_names[idx])
        
        au_df = pd.read_csv(au_file,header=None,index_col=False)
        au_array = au_df.values
        real_use_AU = [0,1,2,4] # 0:AU15, 1:AU17, 2:AU20, 3:AU25, 4:AU45
        au_per_video = au_array[real_use_AU,1]
        au_per_video = np.array(au_per_video, dtype=np.float32)
        
        
        if isinstance(self.symmetry_dir, list):
            symmetry_file = os.path.join(self.symmetry_dir[idx],  self.symmetry_file_names[idx])
        else:
            symmetry_file = os.path.join(self.symmetry_dir,  self.symmetry_file_names[idx])
        
        symmetry_df = pd.read_csv(symmetry_file,header=None,index_col=False)
        symmetry_array = symmetry_df.values
        symmetry_array = symmetry_array[0]
        symmetry_per_video = np.array(symmetry_array, dtype=np.float32)
        
        
        if isinstance(self.gaze_dir, list):
            gaze_file = os.path.join(self.gaze_dir[idx],  self.gaze_file_names[idx])
        else:
            gaze_file = os.path.join(self.gaze_dir,  self.gaze_file_names[idx])
        
        gaze_df = pd.read_csv(gaze_file,header=None,index_col=False)
        gaze_array = gaze_df.values
        gaze_array = gaze_array[1:,1:]
        gaze_array = np.reshape(gaze_array,(36))
        gaze_per_video = np.array(gaze_array, dtype=np.float32)
        
        if isinstance(self.ME_dir, list):
            ME_file = os.path.join(self.ME_dir[idx],  self.ME_file_names[idx])
        else:
            ME_file = os.path.join(self.ME_dir,  self.ME_file_names[idx])
        
        ME_df = pd.read_csv(ME_file,header=None,index_col=False)
        ME_array = ME_df.values
        ME_array = ME_array[1:,1]
        ME_per_video = np.array(ME_array, dtype=np.float32)
        
        au_per_video = np.concatenate((au_per_video, symmetry_per_video), axis=0)
        au_per_video = np.concatenate((au_per_video, gaze_per_video), axis=0)
        au_per_video = np.concatenate((au_per_video, ME_per_video), axis=0)
        
       
        sample = {'feature': feature_per_video, 'label': label_per_video, 'AUs': au_per_video}
        
        
        return sample #, video_names


############### Select Input ###################
data_name = 'Real life trial dataset'
data_path_dir_test, label_dict_test, video_name_list_test, au_path_test, au_file_list_test, symmetry_path_test, symmetry_file_list_test, gaze_path_test, gaze_file_list_test, ME_path_test, ME_file_list_test = data_to_dict_test(data_name)
data_path_dir_train, label_dict_train, video_name_list_train, au_path_train, au_file_list_train, symmetry_path_train, symmetry_file_list_train, gaze_path_train, gaze_file_list_train, ME_path_train, ME_file_list_train = data_to_dict_train(data_name)
################################################

############## Load Data ######################
train_dataset = DeceptionDB(label_dict_train, data_path_dir_train, video_name_list_train, au_path_train, au_file_list_train, symmetry_path_train, symmetry_file_list_train, gaze_path_train, gaze_file_list_train, ME_path_train, ME_file_list_train)
test_dataset = DeceptionDB(label_dict_test, data_path_dir_test, video_name_list_test, au_path_test, au_file_list_test, symmetry_path_test, symmetry_file_list_test, gaze_path_test, gaze_file_list_test, ME_path_test, ME_file_list_test)
train_total_data_num = len(video_name_list_train)
test_total_data_num = len(video_name_list_test)
################################################

############### learning Parameters ###############
init_batch_size = 12 # or 6
max_epoch = 30 # or 60
num_segments = frame_len # video frames ****** Check ******
use_regularizer = True
hp_reg_factor = 1 # or 1
tv_reg_factor = 1e-5 # or 0.005
constrast_reg_factor = 1e-5 # or 1
init_lr = 0.005 # 1e-5
weight_decay = 1e-5  # or 1e-5
#lr_patience = 5 # or 3
dropout_ratio = 0.8 # or 0.4
class_num = 2
hidden = 2 #14
###################################################


class Action_Att_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, seq_len):
        super(Action_Att_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.fc = nn.Linear(hidden_size, output_size)
        self.fc_attention = nn.Linear(hidden_size, 1)
        self.fc_out = nn.Linear(hidden_size, output_size)
        self.fc_c0_0 = nn.Linear(1792, 896)
        self.fc_c0_1 = nn.Linear(896, 448)
        self.fc_h0_0 = nn.Linear(1792,896)
        self.fc_h0_1 = nn.Linear(896, 448)
        self.input_size = input_size
        self.fc_au_0 = nn.Linear(58, hidden_size) #### 4 AU durations + 1 symmetry correlation + 36 gaze_features + 17 Micro Expression
        self.fc_au_1 = nn.Linear(hidden_size, hidden_size)
        self.fc_fusion_out_0 = nn.Linear(hidden_size*2, hidden_size)
        self.fc_fusion_out_1 = nn.Linear(hidden_size, output_size)
        # Not all fc layer used

        self.mask_conv = nn.Sequential(
                nn.Conv2d(1792, 896, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(896),
                nn.ReLU(),
                nn.Conv2d(896, 448, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(448),
                nn.ReLU(),
                nn.Conv2d(448, 1, kernel_size=3, padding=1, bias=False),
                nn.Sigmoid(), #(bs*22, 1, 8, 8)
                )       
        self.batchnorm_2d = nn.BatchNorm2d(1792)
        self.batchnorm_3d_1 = nn.BatchNorm3d(num_segments)
        self.batchnorm_3d_2 = nn.BatchNorm3d(num_segments)
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.dropout_2d = nn.Dropout2d(p=dropout_ratio)
        self.dropout_3d = nn.Dropout3d(p=dropout_ratio)
        self.conv_lstm = ConvLSTM(input_size=(3, 3),
                input_dim=1792,
                hidden_dim=[hidden],
                kernel_size=(3, 3),
                num_layers=1,
                batch_first=True,
                bias=True,
                return_all_layers=True)      

    def forward(self, input_x, AU):

        batch_size = input_x.shape[0] 
        seq_len = input_x.shape[1]
        input_x = self.dropout_2d(input_x)
        input_x = input_x.view(-1, 1792, 3, 3) # num_segments*batch_size
        input_x = self.batchnorm_2d(input_x)
        
        mask = self.mask_conv(input_x)
        mask = mask.view(-1, num_segments, 1, 3, 3) #  batch_size
        input_x = input_x.view(-1, num_segments, 1792, 3, 3) # batch_size
        input_x = self.batchnorm_3d_1(input_x)
        diff_i = torch.sum(torch.abs(mask[:, :, :, :, 1:] - mask[:, :, :, :, :-1]))
        diff_j = torch.sum(torch.abs(mask[:, :, :, 1:, :] - mask[:, :, :, :-1, :]))
        tv_loss = tv_reg_factor*(diff_i + diff_j)
        mask_A = (mask > 0.5).type( torch.cuda.FloatTensor )
        mask_B = (mask < 0.5).type( torch.cuda.FloatTensor )
        contrast_loss = -(mask * mask_A).mean(0).sum() * constrast_reg_factor* 0.5 + (mask * mask_B).mean(0).sum() * constrast_reg_factor * 0.5

        mask_input_x = mask * input_x
        mask_input_x = self.batchnorm_3d_2(mask_input_x)
        
        del input_x
        output, hidden = self.conv_lstm(mask_input_x)
        del hidden
        del mask_input_x
        output = output[0]
        output = torch.mean(output,dim=4)
        output = torch.mean(output,dim=3)
        att_weight = self.fc_attention(output).view(-1, num_segments)
        att_weight = F.softmax(att_weight, dim =1)     
        weighted_output = torch.sum(output*att_weight.unsqueeze(dim=2),
                                        dim =1)
        
        au_feature_tmp = self.fc_au_0(AU)
        au_feature = self.fc_au_1(au_feature_tmp)
        fusion_input = torch.cat([weighted_output, au_feature], dim=1)
        semi_final_output = self.fc_fusion_out_0(fusion_input)
        final_output = self.fc_fusion_out_1(semi_final_output)
        
        del weighted_output
        del diff_i
        del diff_j
        gc.collect()
        torch.cuda.empty_cache()
        
        return final_output, att_weight, mask, tv_loss, contrast_loss
 
    def init_hidden(self, batch_size):
        result = Variable(torch.zeros(1, batch_size, self.hidden_size))
        if use_cuda:
            return result.cuda()
        else:
            return result

def train(batch_size,
        train_data,
        train_label,
        model,
        model_optimizer,
        criterion,
         train_au):
    loss = 0
    model_optimizer.zero_grad()

    logits, att_weight, mask, tv_loss, contrast_loss = model.forward(train_data,train_au)
    del train_data
    gc.collect()
    torch.cuda.empty_cache()
    loss += criterion(logits, torch.max(train_label,1)[1])
    att_reg = F.relu(att_weight[:, :-2] * att_weight[:, 2:] - att_weight[:, 1:-1].pow(2)).sqrt().mean()
    
    if use_regularizer:
        regularization_loss = hp_reg_factor*att_reg 
        loss += regularization_loss
        loss += tv_loss
        loss += contrast_loss

    loss.backward()
    
    model_optimizer.step()

    final_loss = loss.item()
    corrects = []
    correct = (torch.max(logits, 1)[1] == torch.max(train_label,1)[1])
    corrects.append(correct)
    corrects_sum = (torch.max(logits, 1)[1] == torch.max(train_label,1)[1]).sum()

    train_accuracy = 100.0 * corrects_sum/batch_size
    gc.collect()
    torch.cuda.empty_cache()

    return mask, final_loss, regularization_loss, tv_loss, contrast_loss, train_accuracy, att_weight, corrects, logits, train_label

def test_step(batch_size,
            batch_x,
            batch_y,
            model,
            criterion,
              test_au):
    test_loss=0
    test_logits, att_weight, mask, tv_loss, contrast_loss = model.forward(batch_x,test_au)
    del batch_x
    gc.collect()
    torch.cuda.empty_cache()
    test_loss += criterion(test_logits, torch.max(batch_y,1)[1])
    att_reg = F.relu(att_weight[:, :-2] * att_weight[:, 2:] - att_weight[:, 1:-1].pow(2)).sqrt().mean()
    
    if use_regularizer:
        test_reg_loss = hp_reg_factor*att_reg 
        test_loss += test_reg_loss
        test_loss += tv_loss
        test_loss += contrast_loss
            
    corrects = []
    correct = (torch.max(test_logits, 1)[1] == torch.max(batch_y,1)[1])
    corrects.append(correct)
    corrects_sum = (torch.max(test_logits, 1)[1] == torch.max(batch_y,1)[1]).sum()
    del correct
    test_accuracy = 100.0 * corrects_sum/batch_size
    gc.collect()
    torch.cuda.empty_cache()

    return mask, test_logits, test_loss, test_reg_loss, tv_loss, contrast_loss, test_accuracy, att_weight, corrects, test_logits, batch_y

################ main Start ! ##################

torch.cuda.manual_seed(1234)
maxEpoch = max_epoch
num_segments = num_segments
criterion = nn.CrossEntropyLoss()  
best_test_accuracy = 0
lstm_action = Action_Att_LSTM(input_size=1792, hidden_size=hidden, output_size=2, seq_len=num_segments).cuda()
model_optimizer = torch.optim.Adam(lstm_action.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=model_optimizer, lr_lambda= lambda epoch: 0.99 ** epoch)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=init_batch_size)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=init_batch_size)
num_step_per_epoch_train = train_total_data_num/init_batch_size
num_step_per_epoch_test = test_total_data_num/init_batch_size


for epoch_num in range(maxEpoch):
    lstm_action.train() # ex) test(input.to(device)) behind input
    avg_train_accuracy = 0
    train_name_list = video_name_list_train
    epoch_train_loss = 0 
    epoch_train_reg_loss = 0 
    epoch_train_tv_loss = 0
    epoch_train_contrast_loss = 0
    for i, train_sample in enumerate(train_data_loader):
        train_batch_feature = train_sample['feature']
        train_batch_label = train_sample['label']
        train_batch_au = train_sample['AUs']
        train_batch_label = train_batch_label.reshape(-1)
        one_hot_train_batch_label = np.eye(2)[train_batch_label] # 2 means num_class !
        one_hot_train_batch_label = torch.from_numpy(one_hot_train_batch_label)
        train_batch_feature = Variable(train_batch_feature).cuda().float()
        train_batch_label = Variable(one_hot_train_batch_label).cuda().long()
        train_batch_au = Variable(train_batch_au).cuda().float()
        train_mask, train_loss, train_reg_loss, train_tv_loss, train_contrast_loss, train_accuracy, train_tmp_att_weights, train_corrects, train_pred, train_label = train(init_batch_size, train_batch_feature, train_batch_label, lstm_action, model_optimizer, criterion, train_batch_au)
        avg_train_accuracy+=train_accuracy
        epoch_train_loss += train_loss
        epoch_train_reg_loss += train_reg_loss
        epoch_train_tv_loss += train_tv_loss
        epoch_train_contrast_loss += train_contrast_loss
        #print("batch {}, train_loss: {} ".format(i, train_loss))
        #print("batch {}, train_acc: {} ".format(i, train_accuracy))
        gc.collect()
        torch.cuda.empty_cache()

    epoch_train_loss = epoch_train_loss/num_step_per_epoch_train
    epoch_train_reg_loss = epoch_train_reg_loss/num_step_per_epoch_train
    epoch_train_tv_loss = epoch_train_tv_loss/num_step_per_epoch_train
    epoch_train_contrast_loss = epoch_train_contrast_loss/num_step_per_epoch_train
    final_train_accuracy = avg_train_accuracy/num_step_per_epoch_train
    print("epoch:{} ".format(epoch_num)+ " train loss: {}".format(epoch_train_loss))
    print("epoch:{} ".format(epoch_num)+ " train accuracy: {}".format(final_train_accuracy))
    gc.collect()
    torch.cuda.empty_cache()

    avg_test_accuracy = 0
    lstm_action.eval()
    test_name_list = video_name_list_test
    epoch_test_loss = 0
    epoch_test_reg_loss =0
    epoch_test_tv_loss =0 
    epoch_test_contrast_loss = 0
    for i, test_sample in enumerate(test_data_loader):
        test_batch_feature = test_sample['feature']
        test_batch_label = test_sample['label']
        test_batch_au = test_sample['AUs']
        test_batch_label = test_batch_label.reshape(-1)
        one_hot_test_batch_label = np.eye(2)[test_batch_label] # 2 means num_class !
        one_hot_test_batch_label = torch.from_numpy(one_hot_test_batch_label)
        with torch.no_grad():
            test_batch_feature = Variable(test_batch_feature).cuda().float()
            test_batch_label = Variable(one_hot_test_batch_label).cuda().long()
            test_batch_au = Variable(test_batch_au).cuda().float()
            test_mask, test_logits, test_loss, test_reg_loss, test_tv_loss, test_contrast_loss, test_accuracy, test_tmp_att_weights, test_corrects, test_pred, test_label = test_step(init_batch_size, test_batch_feature, test_batch_label, lstm_action, criterion, test_batch_au)

        #print("batch_{}, test_loss: {}".format(i, test_loss))
        #print("batch_{}, test_accuracy: {}".format(i, test_accuracy))
        gc.collect()
        torch.cuda.empty_cache()

        avg_test_accuracy+= test_accuracy
        epoch_test_loss += test_los
        epoch_test_reg_loss += test_reg_loss
        epoch_test_tv_loss += test_tv_loss
        epoch_test_contrast_loss += test_contrast_loss
        gc.collect()
        torch.cuda.empty_cache()

    torch.cuda.empty_cache()
    epoch_test_loss = epoch_test_loss/num_step_per_epoch_test
    epoch_test_reg_loss = epoch_test_reg_loss/num_step_per_epoch_test
    epoch_test_tv_loss = epoch_test_tv_loss/num_step_per_epoch_test
    epoch_test_contrast_loss = epoch_test_contrast_loss/num_step_per_epoch_test

    final_test_accuracy = avg_test_accuracy/num_step_per_epoch_test
    print("epoch: {} ".format(epoch_num)+ " test loss:{} ".format(epoch_test_loss))
    print("epoch: {} ".format(epoch_num)+ " test accuracy:{} ".format(final_test_accuracy))
                                             
    model_optimizer.step()
    scheduler.step()
    print("learning rate is : {}".format(model_optimizer.param_groups[0]['lr']))
