In [None]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch_itl import model, sampler, cost, kernel, estimator

### Set paths and load model config / ckpt

In [None]:
# Set trained model paths
base_experiment_path = './LS_Experiments'
model_name = 'Rafd_itl_model_20201118-134437'

# get model config and ckpt
base_model_path = os.path.join(base_experiment_path, model_name, 'model/')
for fname in os.listdir(base_model_path):
    if ('config' in fname) and (fname.split('.')[-1] == 'json'):
        model_config_path = os.path.join(base_model_path, fname)
    elif ('ckpt' in fname) and (fname.split('.')[-1] == 'pt'):
        model_ckpt_path = os.path.join(base_model_path, fname)
    else:
        print(fname, 'does not exist')
print(model_config_path, model_ckpt_path)

# load ckpt and json
with open(model_config_path, 'r') as f:
    model_config = json.load(f)


### Read data

In [None]:
# ----------------------------------
# Reading input/output data
# ----------------------------------
dataset = model_config['Data']['dataset']  
theta_type = model_config['Data']['theta_type']  
inc_neutral = model_config['Data']['include_neutral']  
use_facealigner = True if model_config['Data']['input_data_version'] == 'facealigner' else False

data_path = './datasets/Rafd_Aligned/Rafd_LANDMARKS'  # set data path
if dataset == 'Rafd':
    # dirty hack only used to get Rafd speaker ids, not continuously ordered
    data_csv_path = '/home/mlpboon/Downloads/Rafd/Rafd.csv'

print('Reading data')
if use_facealigner:
    if dataset == 'KDEF':
        from datasets.datasets import kdef_landmarks_facealigner
        x_train, y_train, x_test, y_test, train_list, test_list = \
            kdef_landmarks_facealigner(data_path, inc_neutral=inc_neutral)
    elif dataset == 'Rafd':
        from datasets.datasets import rafd_landmarks_facealigner
        x_train, y_train, x_test, y_test, train_list, test_list = \
            rafd_landmarks_facealigner(data_path, data_csv_path, inc_neutral=inc_neutral)
else:
    from datasets.datasets import import_kdef_landmark_synthesis
    x_train, y_train, x_test, y_test = import_kdef_landmark_synthesis(dtype=input_data_version)

n = x_train.shape[0]
m = y_train.shape[1]
nf = y_train.shape[2]
print('data dimensions', n, m, nf)

In [None]:
# set ITL model
assert model_config['Kernels']['kernel_input_learnable'] == False
kernel_input = kernel.Gaussian(model_config['Kernels']['gamma_inp'])
kernel_output = kernel.Gaussian(model_config['Kernels']['gamma_out'])
kernel_freq = np.eye(nf) # can be added to ckpt or manually set as np.load(kernel_file)

# define emotion sampler - this can also be added to ckpt
if model_config['Data']['theta_type'] == 'aff':
    from datasets.datasets import import_affectnet_va_embedding
    affect_net_csv_path = './utils/landmark_utils/validation.csv'  # to be set if theta_type == 'aff'
    aff_emo_dict = import_affectnet_va_embedding(affect_net_csv_path)

    sampler_ = sampler.CircularSampler(data=dataset+theta_type,
                                       inc_neutral=inc_neutral,
                                       sample_dict=aff_emo_dict)
elif theta_type == '':
    sampler_ = sampler.CircularSampler(data=dataset,
                                       inc_neutral=inc_neutral)
sampler_.m = m

itl_model = model.SpeechSynthesisKernelModel(kernel_input, kernel_output,
                                             kernel_freq=torch.from_numpy(kernel_freq).float())

### Load model and predict

In [None]:
ckpt = torch.load(model_ckpt_path)
itl_model.test_mode(x_train=x_train, thetas=sampler_.sample(m), alpha=ckpt['itl_alpha'])
pred_test = itl_model.forward(x_test, sampler_.sample(m))

In [None]:
check_output = pred_test*128
check_output[0,0].reshape(68,2)

In [None]:
%matplotlib inline
plt_x = x_test[0].numpy().reshape(68, 2)
plt_xt = pred_test[1, 4].detach().numpy().reshape(68, 2)
if use_facealigner:
    plt_x = plt_x * 128
    plt_xt = plt_xt * 128
plt_uv = plt_xt - plt_x
plt.quiver(plt_x[:, 0], plt_x[:, 1], plt_uv[:, 0], plt_uv[:, 1], angles='xy')
ax = plt.gca()
ax.invert_yaxis()
plt.show()

### Continuous generation

In [None]:
def emotion_space_sampling(theta1, theta2, num_samples):
    angle1 = np.arctan2(theta1[1], theta1[0])
    angle2 = np.arctan2(theta2[1], theta2[0])
    angle1 = angle1 if angle1>=0 else angle1+(2*np.pi)
    angle2 = angle2 if angle2>=0 else angle2+(2*np.pi)
    
    reverse = False
    if angle1>angle2:
        start = angle2; end = angle1
        reverse = True
    else:
        start = angle1; end = angle2
        
    sampled_angles = np.linspace(start=start, stop=end, num=num_samples, endpoint=True)
    sample_coords = np.vstack((np.cos(sampled_angles), np.sin(sampled_angles))).T
    
    if reverse:
        return np.flipud(sample_coords)
    else:
        return sample_coords

class EdgeMap(object):
    def __init__(self, out_res, num_parts=3):
        self.out_res = out_res
        self.num_parts = num_parts
        self.groups = [
            [np.arange(0, 17, 1), 255],
            [np.arange(17, 22, 1), 255],
            [np.arange(22, 27, 1), 255],
            [np.arange(27, 31, 1), 255],
            [np.arange(31, 36, 1), 255],
            [list(np.arange(36, 42, 1)) + [36], 255],
            [list(np.arange(42, 48, 1)) + [42], 255],
            [list(np.arange(48, 60, 1)) + [48], 255],
            [list(np.arange(60, 68, 1)) + [60], 255]
        ]

    def __call__(self, shape):
        image = np.zeros((self.out_res, self.out_res, self.num_parts), dtype=np.float32)
        for g in self.groups:
            for i in range(len(g[0]) - 1):
                start = int(shape[g[0][i]][0]), int(shape[g[0][i]][1])
                end = int(shape[g[0][i + 1]][0]), int(shape[g[0][i + 1]][1])
                cv2.line(image, start, end, g[1], 1)
        return image

In [None]:
%matplotlib inline
import cv2
ckpt = torch.load(model_ckpt_path)
itl_model.test_mode(x_train=x_train, thetas=sampler_.sample(m), alpha=ckpt['itl_alpha'])
sampled_emotions = emotion_space_sampling(aff_emo_dict['Fear'], aff_emo_dict['Anger'], 10)
EM = EdgeMap(out_res=128, num_parts=1)
for i in range(len(sampled_emotions)):
    pred_test = itl_model.forward(x_test, torch.from_numpy(sampled_emotions[i][np.newaxis]).float())
    im_em = EM(pred_test[0, 0].detach().numpy().reshape(68,2)*128)
    plt.imshow(np.squeeze(im_em))
    plt.pause(0.5)