In [1]:
import os
import torch
from torch.utils.data import DataLoader 
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm

from utils.util import torch_fix_seed, get_video_name_list
from dataset import AUImageList
from preprocess import JAANet_ImageTransform
from networks import JAANet_networks

In [2]:
dataset = AUImageList(
        labels_path="/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/labels/PIMD_A/emo_and_au(video1-25).csv",
        video_name_list=get_video_name_list('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/labels/PIMD_A/emo_and_au(video1-25)-video_name_list.csv'),
        au_transform=JAANet_ImageTransform(phase='test')
    )
    
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2
)

In [3]:
torch_fix_seed()

device = (torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu"))
print(f"use device:{device}")

use device:cuda:0


In [4]:
#* JAANet(AU Estimator)
region_learning = JAANet_networks.network_dict['HMRegionLearning'](input_dim=3, unit_dim=8)
region_learning.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/region_learning.pth', map_location=device))

align_net = JAANet_networks.network_dict['AlignNet'](crop_size=176, map_size=44, au_num=12, land_num=49, input_dim=64, fill_coeff=0.56)
align_net.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/align_net.pth', map_location=device))

local_attention_refine = JAANet_networks.network_dict['LocalAttentionRefine'](au_num=12, unit_dim=8)
local_attention_refine.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/local_attention_refine.pth', map_location=device))
    
local_au_net = JAANet_networks.network_dict['LocalAUNetv1'](au_num=12, input_dim=64, unit_dim=8)
local_au_net.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/local_au_net.pth', map_location=device))
    
global_au_feat = JAANet_networks.network_dict['HLFeatExtractor'](input_dim=64, unit_dim=8)           
global_au_feat.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/global_au_feat.pth', map_location=device))

au_net = JAANet_networks.network_dict['AUNet'](au_num=12, input_dim=12000, unit_dim=8)
au_net.load_state_dict(torch.load('/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/params/JAANet-snapshots/au_net.pth', map_location=device))

au_net_layer0 = JAANet_networks.network_dict['AUNet_0'](input_dim=12000, unit_dim=8)
for param_t, param_s in zip(au_net_layer0.parameters(), au_net.au_output[0].parameters()):
    param_t.data = param_s.data
    
au_net_layer1 = JAANet_networks.network_dict['AUNet_1'](au_num=12, input_dim=512)
for param_t, param_s in zip(au_net_layer1.parameters(), au_net.au_output[1].parameters()):
    param_t.data = param_s.data


In [5]:
module_dict = {
    'region_learning': region_learning, 
    'align_net': align_net, 
    'local_attention_refine': local_attention_refine, 
    'local_au_net': local_au_net, 
    'global_au_feat': global_au_feat,
    'au_net_layer0': au_net_layer0,
    'au_net_layer1': au_net_layer1
}
    
training_module_list = []
       
for module_name, module in module_dict.items():
    if module_name not in training_module_list:
        for param in module.parameters():
            param.requires_grad = False
                    
            module.eval()

for module_name, module in module_dict.items():
        print(f"{module_name} is trainable: {module.training}")

region_learning = region_learning.to(device)
align_net = align_net.to(device)
local_attention_refine = local_attention_refine.to(device)
local_au_net = local_au_net.to(device)
global_au_feat = global_au_feat.to(device)
au_net_layer0 = au_net_layer0.to(device)
au_net_layer1 = au_net_layer1.to(device)

region_learning is trainable: False
align_net is trainable: False
local_attention_refine is trainable: False
local_au_net is trainable: False
global_au_feat is trainable: False
au_net_layer0 is trainable: False
au_net_layer1 is trainable: False


In [6]:
pbar = tqdm(total=len(dataloader))

middle_features = []
middle_features_512 = []
au_logits = []
au_posteriors = []
img_path_list = []
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        imgs, img_paths, emos, _ = batch
        imgs = imgs.to(device)

        # culc forward
        region_feat = region_learning(imgs)
        align_feat, align_output, aus_map = align_net(region_feat)
        aus_map = aus_map.to(device)
        output_aus_map = local_attention_refine(aus_map.detach())
        local_au_out_feat = local_au_net(region_feat, output_aus_map)
        global_au_out_feat = global_au_feat(region_feat)
        concat_au_feat = torch.cat((align_feat, global_au_out_feat, local_au_out_feat), dim=1)
        concat_au_feat = concat_au_feat.view(concat_au_feat.size(0), -1)
        middle_features += concat_au_feat.detach().cpu().numpy().tolist()
        
        au_net_layer0_outputs = au_net_layer0(concat_au_feat)
        middle_features_512 += au_net_layer0_outputs.detach().cpu().numpy().tolist()
        
        au_net_outputs = au_net_layer1(au_net_layer0_outputs)
        au_logits += au_net_outputs.detach().cpu().numpy().tolist()
        
        au_net_outputs = au_net_outputs.view(au_net_outputs.size(0), 2, int(au_net_outputs.size(1)/2))
        au_net__outputs = torch.softmax(au_net_outputs, dim=1)
        au_net_outputs = au_net_outputs[:,1,:]
        au_posteriors += au_net_outputs.detach().cpu().numpy().tolist()
        
        # save outputs
        img_path_list += img_paths
        
        # update tqdm bar
        pbar.update(1)

# close tqdm bar
pbar.close()

100%|██████████| 19931/19931 [30:55<00:00, 10.74it/s]


In [7]:
df_path = pd.DataFrame(img_path_list, columns=["img_path"])
df_au = pd.DataFrame(au_posteriors, columns=["AU01", "AU02", "AU04", "AU06", "AU07", "AU10", "AU12", "AU14", "AU15", "AU17", "AU23", "AU24"])
df_mid = pd.DataFrame(middle_features, columns=[i for i in range(12000)])
df_mid_512 = pd.DataFrame(middle_features_512, columns=[i for i in range(512)])
df_logits = pd.DataFrame(au_logits, columns=[i for i in range(24)])

au_list = pd.concat([df_path, df_au], axis=1)
mid_list = pd.concat([df_path, df_mid], axis=1)
mid512_list = pd.concat([df_path, df_mid_512], axis=1)
logits_list = pd.concat([df_path, df_logits], axis=1)

In [8]:
# mid_list.to_pickle("/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/processed/JAANet_feature.pkl")
au_list.to_pickle("/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/processed/PIMD_A/JAANet_posterior.pkl")
mid512_list.to_pickle("/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/processed/PIMD_A/JAANet_feature_512.pkl")
logits_list.to_pickle("/mnt/iot-qnap3/mochida/medical-care/emotionestimation/data/processed/PIMD_A/JAANet_logits.pkl")