In [None]:
import os, sys
sys.path.append("../func")
sys.path.append("../autoencoder")
sys.path.append("../motion_generation_models")

from MoE import MoE
from MoE_Z import MoE as MoE_Z
import motion_generation
from GRU import GRU
from GRU_Z import GRU as GRU_Z
from LSTM import LSTM
from LSTM_Z import LSTM as LSTM_Z

from MotionGeneration import MotionGenerationModel as MoGen
from MotionGenerationR import MotionGenerationModel as MoGenR

from MotionGenerationEmbedd import MotionGenerationModel as MoGenZ
from MotionGenerationEmbeddR import MotionGenerationModel as MoGenZR

from MotionGenerationVAE import MotionGenerationModel as MoGenVAE
from MotionGenerationVAER import MotionGenerationModel as MoGenVAER

from MotionGenerationVAE_Embedd import MotionGenerationModel as MoGenVAE_Z
from MotionGenerationVAE_EmbeddR import MotionGenerationModel as MoGenVAE_ZR

from MLP import MLP
from MLP_Adversarial import MLP_ADV
from MLP_MIX import MLP_MIX
from MLP_MIX import MLP_layer
from RBF import RBF
from VAE import VAE
from DEC import DEC

from rig_agnostic_encoding.functions.DataProcessingFunctions import clean_checkpoints
from GlobalSettings import MODEL_PATH, RESULTS_PATH
import bz2
from cytoolz import concat, sliding_window, accumulate
from operator import add
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import ray
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune import CLIReporter

from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import func as F
import _pickle as pickle
import json as js
import importlib
import random
import traceback
import time
import Extract as ext
import plotly.graph_objs as go
import plotly.express as ex
from plotly.subplots import make_subplots
import scipy.signal as signal
from timeit import default_timer as timer
import importlib

In [None]:
import func
import MLP_Adversarial
import MotionGeneration
import MotionGenerationR
import MotionGenerationEmbedd
import MotionGenerationEmbeddR
import MotionGenerationVAE
import MotionGenerationVAER
import MotionGenerationVAE_Embedd
import MotionGenerationVAE_EmbeddR


In [None]:
config = {
    "hidden_dim": 256,                                  # dimension of the hidden layers
    "k": 256,                                           # input dimension of the cluster layer
    "z_dim": 128,                                       # dimension of the embeddings
    "lr": 1e-4,                                         # learning rate
    "batch_size": 32,                                   # batch size
    "keep_prob": 0,                                     # dropout probability
    "loss_fn":torch.nn.functional.mse_loss,             # use MSE as the default loss function
    "optimizer":torch.optim.AdamW,                      # use AdamW as the default optimizer
    "scheduler":torch.optim.lr_scheduler.StepLR,        # use StepLR as the default learning rate scheduler
    "scheduler_param": {"step_size":80, "gamma":.9},    # default schedule is to decrease the learning rate with 0.9x after each 80 iterations
    "basis_func":"gaussian",                            # the basis function for RBF layer is gaussian
    "n_centroid":64,                                    # number of clusters
    "k_experts": 4,                                     # number of experts
    "gate_size": 128,                                   # dimension of the gate module
    "g_hidden_dim": 512,                                # dimension of the experts
    "num_layers": 4,                                    # number of layers for the LSTM model
    "autoregress_prob":0,                               # teacher forced learning probability
    "autoregress_inc":.5,                               # how much the probability is increased after each period
    "autoregress_ep":10,                                # the period length
    "autoregress_max_prob":1,                           # specify the maximum teacher forced learning probability
    "cost_hidden_dim":128,                              # dimension of the hidden layer of the cost encoder
    "seq_len":13,                                       # batch sequence length, eg. a clip of 299 frames is divided into 23 chunks with 13 frames each. The chunks are composed into a matrix of ( batch_sizex23x13 )
    "device":"cpu",                                     # device on which the network is executed
    "use_label":False                                   # whether to use the pose label
    }

# Function definitions

In [None]:
def train(path):
    """
    The all-in-one template function for loading pre-processed data, construct the network, train and test it. Save the best 3 checkpoints.
    
    Parameters: 
        path (str): the path to the pre-processed dataset
    
    Returns:
        None
    """
    (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, cost_dim, target_dim], feature_dims = get_pose_datasets(path)
    upper_b = phase_dim+pose_dim
    te_x, te_y = [tset[0][:, phase_dim:upper_b] for tset in test_set_p], [tset[1][:, phase_dim:upper_b] for tset in test_set_p]

    h_dim = train_set_p[0][0].shape[0]
    w_dim = train_set_p[0][0].shape[1]


    ae_name = "AE_R1"
    if "R2" in path:    ae_name = ae_name.replace("R1", "R2")
    elif "R3" in path:    ae_name = ae_name.replace("R1", "R3")
    elif "R4" in path:    ae_name = ae_name.replace("R1", "R4")
    elif "R5" in path:    ae_name = ae_name.replace("R1", "R5")

    ae = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
                 pos_dim=feature_dims["pos"], rot_dim=feature_dims["rotMat2"], vel_dim=feature_dims["velocity"],
                 train_set=train_set_p, val_set=val_set_p, test_set=test_set_p, name=ae_name)
    logger=TensorBoardLogger(save_dir="RESULTS/", name=ae_name, version="0.11")

    trainer = pl.Trainer(
        default_root_dir=RESULTS_PATH,
        gpus=1, precision=16,
        min_epochs=20,
        logger=logger,
        max_epochs=200,
    )

    trainer.fit(ae)
    trainer.test(ae)
    p = ae.save_checkpoint(best_val_loss="final")
    clean_checkpoints(path=os.path.join(MODEL_PATH,ae_name))
    return ae

In [None]:
def get_datasets(path, train_prob=0.8):
    """
    loads and partitions the pre-processed dataset
    
    Parameters:
        path (str):         the path to the pre-processed dataset
        train_prob (float): % of dataset to be used as training set
    
    Returns:
        datasets (tuple):    (training_set, validation_set, test_set), where each of them is a TensorDataset (ie. list of (x, y))
        dims (list):         a list of dimensions of the components of an input vector
        feature_dims (dict): a dict of all dimensions of all components
    """
    phase_features = ["phase_vec_l2"]
    pose_features = ["pos", "rotMat2", "velocity"]
    cost_features = ["posCost", "rotCost"]
    target_features = ["targetPosition", "targetRotation"]

    path = path.replace("\\", "/")
    tokens = path.split("/")
    tokens = tokens[-1].split("_")
    level = tokens[0]
    name = tokens[1] + "_" + tokens[2]
    name = name.replace(".pbz2", "")

    obj = F.load(path)
    data = obj["data"]

    feature_dims = data[0][1]
    clips = [np.copy(d[0]) for d in data]

    phase_dim = sum([feature_dims[feature] for feature in phase_features])
    pose_dim = sum([feature_dims[feature] for feature in pose_features])
    cost_dim = sum([feature_dims[feature] for feature in cost_features])
    target_dim = sum([feature_dims[feature] for feature in target_features])
    x_tensors = torch.stack([F.normaliseT(torch.from_numpy(clip[:-1])).float() for clip in clips])
    y_tensors = torch.stack([torch.from_numpy(clip[1:]).float() for clip in clips])

    dataset_p = TensorDataset(x_tensors, y_tensors)
    N = len(x_tensors)

    train_ratio = int(train_prob * N)
    val_ratio = int((N - train_ratio) / 2.0)
    test_ratio = N - train_ratio - val_ratio
    train_set_p, val_set_p, test_set_p = random_split(dataset_p, [train_ratio, val_ratio, test_ratio],
                                                      generator=torch.Generator().manual_seed(2021))
    test_set_p += val_set_p
    return (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, cost_dim, target_dim], feature_dims

def get_pose_datasets(path):
    """
    Same as the above version but it designed for training autoencoders. Only uses pose data. Input and labels are same data.
    Loads and partitions the pre-processed pose dataset
    
    Parameters:
        path (str):         the path to the pre-processed dataset
    
    Returns:
        datasets (tuple):    (training_set, validation_set, test_set), where each of them is a TensorDataset (ie. list of (x, y))
        dims (list):         a list of dimensions of the components of an input vector
        feature_dims (dict): a dict of all dimensions of all components
    """

    phase_features = ["phase_vec_l2"]
    pose_features = ["pos", "rotMat2", "velocity"]
    cost_features = ["posCost", "rotCost"]
    target_features = ["targetPosition", "targetRotation"]

    path = path.replace("\\", "/")
    tokens = path.split("/")
    tokens = tokens[-1].split("_")
    level = tokens[0]
    name = tokens[1] + "_" + tokens[2]
    name = name.replace(".pbz2", "")

    obj = F.load(path)
    data = obj["data"]

    feature_dims = data[0][1]
    clips = [np.copy(d[0]) for d in data]

    phase_dim = sum([feature_dims[feature] for feature in phase_features])
    pose_dim = sum([feature_dims[feature] for feature in pose_features])
    cost_dim = sum([feature_dims[feature] for feature in cost_features])
    target_dim = sum([feature_dims[feature] for feature in target_features])
    x_tensors = torch.stack([F.normaliseT(torch.from_numpy(clip[:-1])).float() for clip in clips])
    y_tensors = torch.stack([torch.from_numpy(clip[1:]).float() for clip in clips])

    pose_data = x_tensors[:, :, phase_dim:phase_dim + pose_dim]
    dataset_p = TensorDataset(pose_data, pose_data)
    N = len(x_tensors)

    train_ratio = int(.8 * N)
    val_ratio = int((N - train_ratio) / 2.0)
    test_ratio = N - train_ratio - val_ratio
    train_set_p, val_set_p, test_set_p = random_split(dataset_p, [train_ratio, val_ratio, test_ratio],
                                                      generator=torch.Generator().manual_seed(2021))
    test_set_p += val_set_p
    return (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, cost_dim, target_dim], feature_dims

def get_datasets_reduc(path, train_prob=0.8):
    """
    Designed for training FS models.
    loads and partitions the pre-processed dataset
    
    Parameters:
        path (str):         the path to the pre-processed dataset
        train_prob (float): % of dataset to be used as training set
    
    Returns:
        datasets (tuple):        (training_set, validation_set, test_set), where each of them is a TensorDataset (ie. list of (x, y))
        dims (list):             a list of dimensions of the components of an input vector
        feature_dims (dict):     a dict of all dimensions of all components for the FE/FC models
        feature_dims2 (dict):    a dict of all dimensions of all components for the FS models
        ae (pl.LightningModule): a reference pre-trained autoencoder
    """

    phase_features = ["phase_vec_l2"]
    pose_features = ["pos", "rotMat2", "velocity"]
    cost_features = ["posCost", "rotCost"]
    target_features = ["targetPosition", "targetRotation"]

    path = path.replace("\\", "/")
    tokens = path.split("/")
    tokens = tokens[-1].split("_")
    level = tokens[0]
    name = tokens[1] + "_" + tokens[2]
    name = name.replace(".pbz2", "")

    path2 = path.replace(level + "_F", "0_F")
    obj = F.load(path)
    obj2 = F.load(path2)

    data = obj["data"]
    data2 = obj2["data"]

    feature_dims = data[0][1]
    feature_dims2 = data2[0][1]
    clips = [np.copy(d[0]) for d in data]
    clips2 = [np.copy(d[0]) for d in data2]

    phase_dim = sum([feature_dims[feature] for feature in phase_features])
    pose_dim = sum([feature_dims[feature] for feature in pose_features])
    pose_dim2 = sum([feature_dims2[feature] for feature in pose_features])
    cost_dim = sum([feature_dims[feature] for feature in cost_features])
    target_dim = sum([feature_dims[feature] for feature in target_features])

    x_tensors = torch.stack([F.normaliseT(torch.from_numpy(clip[:-1])).float() for clip in clips])
    y_tensors = torch.stack([torch.from_numpy(clip[1:]).float() for clip in clips2])

    ae_path = "../models/AE_0_F_R1/0.0001.256.pbz2"
    ae_path1 = ae_path.replace("0_F_R1", level + "_" + name)
    ae_path2 = ae_path.replace("0_F_R1", "0_" + name)
    ae = MLP_ADV.load_checkpoint(ae_path1)
    ae2 = MLP_ADV.load_checkpoint(ae_path2)

    ae.decoder = ae2.decoder

    dataset_p = TensorDataset(x_tensors, y_tensors)
    N = len(x_tensors)

    train_ratio = int(train_prob * N)
    val_ratio = int((N - train_ratio) / 2.0)
    test_ratio = N - train_ratio - val_ratio
    train_set_p, val_set_p, test_set_p = random_split(dataset_p, [train_ratio, val_ratio, test_ratio],
                                                      generator=torch.Generator().manual_seed(2021))
    test_set_p += val_set_p
    return (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, pose_dim2, cost_dim, target_dim], feature_dims, feature_dims2, ae

def get_keyJoints(clip_path):
    """
    returns a list of IDs of the keyjoints
    
    Parameters:
        clip_path (str): path to the a clip sample
    
    Returns:
        keyJoints (list): list of IDs of the keyjoints in the rig
    """
    df = pickle.loads(F.load(clip_path))
    keyJoints = []
    f0 = df["frames"][0]
    for j in range(len(f0)):
        if f0[j]["key"] and f0[j]["rotCost"].sum() > 0:
            keyJoints.append(j)
    return keyJoints


def compute_delta(pose_x, pose_y, feature_dims):
    """
    Given two clip (generated, ground_truth) with only pose data, computes the difference between frames for each of them separately.
    Euclidean distance between the joint positions and angular difference between the joint orientations 

    Parameters:
        pose_x (torch.nn.Tensor): the generated clip with only pose data containing the joint positions, rotations and linear velocity
        pose_y (torch.nn.Tensor): the ground truth clip with only pose data containing the joint positions, rotations and linear velocity
        feature_dims (dict):      a dict of dimensions of all components
    Returns:
        delta_px (torch.nn.Tensor): positional differences between the frames for pose_x 
        delta_py (torch.nn.Tensor): positional differences between the frames for pose_y
        delta_rx (torch.nn.Tensor): rotational differences between the frames for pose_x
        delta_ry (torch.nn.Tensor): rotational differences between the frames for pose_y
    """
    shape = pose_y.shape

    pos_dim = feature_dims["pos"]
    rot_dim = pos_dim + feature_dims["rotMat2"]

    px = pose_x[:, :pos_dim].reshape(shape[0], -1, 3)
    py = pose_y[:, :pos_dim].reshape(shape[0], -1, 3)

    rx = pose_x[:, pos_dim:rot_dim].reshape(shape[0], -1, 3, 2)
    ry = pose_y[:, pos_dim:rot_dim].reshape(shape[0], -1, 3, 2)

    delta_px = torch.cat([torch.mean(torch.sqrt(torch.sum((p1-p2)**2, dim=1))).unsqueeze(0) for p1, p2 in zip(px[:-1], px[1:])])
    delta_py = torch.cat([torch.mean(torch.sqrt(torch.sum((p1-p2)**2, dim=1))).unsqueeze(0) for p1, p2 in zip(py[:-1], py[1:])])

    delta_rx = torch.cat([torch.mean(torch.nan_to_num(
        torch.arccos(torch.nn.functional.cosine_similarity(r1, r2, dim=1)),0)).unsqueeze(0) for r1, r2 in zip(rx[:-1], rx[1:])])
    delta_ry = torch.cat([torch.mean(torch.nan_to_num(
        torch.arccos(torch.nn.functional.cosine_similarity(r1, r2, dim=1)),0)).unsqueeze(0) for r1, r2 in zip(ry[:-1], ry[1:])])
    return (delta_px, delta_py), (delta_rx, delta_ry)

def compute_acc_cost(out, y, keyJoints, feature_dims, phase_dim, pose_dim):
    """
    given the generate clip and targets, computes the accuracy, position cost, rotation cost 
    """
    pos_dim = phase_dim + feature_dims["pos"]
    rot_dim = pos_dim + feature_dims["rotMat2"]
    phase_pose_dim = phase_dim + pose_dim + feature_dims["posCost"] + feature_dims["rotCost"]
    p_cost_dim = phase_pose_dim + feature_dims["targetPosition"]
    r_cost_dim = p_cost_dim + feature_dims["targetRotation"]

    pos_x = out[:, phase_dim:pos_dim]
    pos_y = y[:, phase_dim:pos_dim]
    lenKJ = len(keyJoints)

    pos_xJ = [pos_x[:, 3*keyJoints[j]:3*keyJoints[j]+3] for j in range(lenKJ)]
    pos_yJ = [pos_y[:, 3*keyJoints[j]:3*keyJoints[j]+3] for j in range(lenKJ)]

    target_px = out[:, phase_pose_dim:p_cost_dim]
    target_py = y[:, phase_pose_dim:p_cost_dim]

    target_xJ = [target_px[:, 3*j:3*j+3] for j in range(lenKJ)]
    target_yJ = [target_py[:, 3*j:3*j+3] for j in range(lenKJ)]

    acc = []
    d_sum_x, d_sum_y = [], []
    pCost_x, pCost_y = [], []

    for xj, txj in zip(pos_xJ, target_xJ):
        dist = torch.sqrt(torch.sum((xj-txj)**2, dim=1))
        pCost_x.append(torch.sum(dist))
        dist[dist>=0.1] = 0
        dist[dist>0] = 1
        dist_sum = torch.sum(dist)
        d_sum_x.append(dist_sum)

    for xj, txj in zip(pos_yJ, target_yJ):
        dist = torch.sqrt(torch.sum((xj-txj)**2, dim=1))
        pCost_y.append(torch.sum(dist))
        dist[dist>=0.1] = 0
        dist[dist>0] = 1
        dist_sum = torch.sum(dist)
        d_sum_y.append(dist_sum)

    for i in range(lenKJ):
        dx, dy = d_sum_x[i], d_sum_y[i]
        a = 1 - abs(dx/(dy+1) - 1)
        acc.append(a)

    rot_x = out[:, pos_dim:rot_dim]
    rot_y = y[:, pos_dim:rot_dim]
    rot_xJ = [rot_x[:, 6*keyJoints[j]:6*keyJoints[j]+6].reshape(-1, 3, 2) for j in range(lenKJ)]
    rot_yJ = [rot_y[:, 6*keyJoints[j]:6*keyJoints[j]+6].reshape(-1, 3, 2) for j in range(lenKJ)]

    target_rx = out[:, p_cost_dim:r_cost_dim]
    target_ry = y[:, p_cost_dim:r_cost_dim]

    target_rxJ = [target_rx[:, 6*j:6*j+6].reshape(-1, 3, 2) for j in range(lenKJ)]
    target_ryJ = [target_ry[:, 6*j:6*j+6].reshape(-1, 3, 2) for j in range(lenKJ)]

    rCost_x, rCost_y = [], []
    for xj, txj in zip(rot_xJ, target_rxJ):
        delta_r = torch.nan_to_num(torch.arccos(torch.nn.functional.cosine_similarity(xj, txj, dim=1)),0)
        rCost_x.append(torch.sum(delta_r))

    for xj, txj in zip(rot_yJ, target_ryJ):
        delta_r = torch.nan_to_num(torch.arccos(torch.nn.functional.cosine_similarity(xj, txj, dim=1)),0)
        rCost_y.append(torch.sum(delta_r))

    return acc, pCost_x, pCost_y, rCost_x, rCost_y

In [None]:
def compute_ae_results(ae, data_path):
    """
a simple function for testing autoencoders
    """
    (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, cost_dim, target_dim], feature_dims = get_datasets(data_path)
    if "R2" in data_path:    ae_path = ae_path.replace("R1", "R2")
    elif "R3" in data_path:    ae_path = ae_path.replace("R1", "R3")
    elif "R4" in data_path:    ae_path = ae_path.replace("R1", "R4")
    elif "R5" in data_path:    ae_path = ae_path.replace("R1", "R5")

    ae = MLP_ADV.load_checkpoint(ae_path)


    summary = ae.summarize()
    upper_b = phase_dim+pose_dim
    te_x, te_y = [tset[0][:, phase_dim:upper_b] for tset in test_set_p], [tset[1][:, phase_dim:upper_b] for tset in test_set_p]
    recon_loss, adv_loss, pos_loss = [], [], []
    rot_loss, delta_pos, delta_rot = [], [], []
    elapsed_time = []
    with torch.no_grad():
        ae = ae.cpu()
        for x, y in zip(te_x, te_y):
            start = timer()
            out = ae(x)
            end = timer()

            recon_l, pl, rl = ae.loss(out, y)
            adv_l = 0.5 * torch.mean((ae.convDiscriminator(out.reshape(1,1,out.shape[0],-1)) - 1)** 2)
            delta_p, delta_r = compute_delta(out, y, feature_dims=feature_dims)

            elapsed_time.append(end-start)
            recon_loss.append(recon_l)
            adv_loss.append(adv_l)
            pos_loss.append(pl)
            rot_loss.append(rl)
            delta_pos.append(delta_p)
            delta_rot.append(delta_r)

    result = dict(name=ae.name, params=summary.total_parameters, mem=summary.model_size,
                  elapsed_times=elapsed_time, recon_error=recon_loss, adv_error=adv_loss,
                  pos_error=pos_loss, rot_error=rot_loss, delta_pos=delta_pos, delta_rot=delta_rot)
    return result

def get_G(path):
    """
    get the motion generator model (MoGenNet)
    """
    if "MoE" in path:
        if "ZCAT" in path:
            return MoE_Z
        else:
            return MoE
    elif "LSTM" in path:
        if "ZCAT" in path:
            return LSTM_Z
        else:
            return LSTM
def get_C(path, config, pose_dim):
    """
    get the cluster model
    """
    if "RBF" in path: return RBF(config=config, input_dims=[pose_dim]).cluster_model
    elif "VAE" in path: return VAE(config=config, input_dims=[pose_dim]).cluster_model
    elif "DEC" in path: return DEC(config=config, input_dims=[pose_dim]).cluster_model
    else: return MLP_layer()


def get_M(path):
    """
    get the MotionGenerationModel (OMG)
    """
    if "VAE" in path:
        if "ZCAT" in path:
            return MoGenVAE_Z
        else:
            return MoGenVAE
    else:
        if "ZCAT" in path:
            return MoGenZ
        else:
            return MoGen

def get_template(path):
    """
    Get a animation file template (.json). 
    """
    if "R1" in path:
        template = js.load(open("R1_template.json"))
    elif "R2" in path:
        template = js.load(open("R2_template.json"))
    elif "R3" in path:
        template = js.load(open("R3_template.json"))
    elif "R4" in path:
        template = js.load(open("R4_template.json"))
    elif "R5" in path:
        template = js.load(open("R5_template.json"))
    else:
        template = ""
    return template

def get_clip(dataPath):
    """
    Get animation clip
    """
    path = "G:/data/Dataset_R1_Two_1/False_1_0.pbz2"
    candidates = ["R2", "R3", "R4", "R5"]
    for candy in candidates:
        if candy in dataPath:
            path = path.replace("R1", candy)
            break
    return path

def compute_model_results(ae, dataset, data_path, model_path, model_name, template, clip_path_for_keyJ):
    """
    A simple function for testing the OMG model
    """
    (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, cost_dim, target_dim], feature_dims = dataset
    label = "R2"
    # if "R2" in data_path:    model_path = model_path.replace("R1", "R2"); label = "R2"
    if "R3" in data_path:    model_path = model_path.replace("R2", "R3");label = "R3"
    elif "R4" in data_path:    model_path = model_path.replace("R2", "R4");label = "R4"
    elif "R5" in data_path:    model_path = model_path.replace("R2", "R5");label = "R5"

    gModel = get_G(model_path)
    cModel = get_C(model_path, config=config, pose_dim=pose_dim)
    M = get_M(model_path)

    with bz2.BZ2File(model_path, "rb") as f:
        obj = pickle.load(f)
    pose_autoencoder = MLP.load_checkpoint(obj["pose_autoencoder_path"])
    cost_encoder = MLP.load_checkpoint(obj["cost_encoder_path"])
    generationModel = gModel.load_checkpoint(obj["motionGenerationModelPath"])

    model = M(config=obj["config"], feature_dims=obj["feature_dims"], Model=gModel, pose_autoencoder=pose_autoencoder,
                                      use_advLoss=obj["use_adv_loss"],
                                      input_slicers=obj["in_slices"], output_slicers=obj["out_slices"],
                                      name=obj["name"])

    cModel.load_state_dict(obj["middle_layer_dict"])
    pose_autoencoder.convDiscriminator = ae.convDiscriminator

    model.middle_layer = cModel
    model.in_slices = obj["in_slices"]
    model.out_slices = obj["out_slices"]
    model.pose_autoencoder = pose_autoencoder
    model.cost_encoder = cost_encoder
    model.generationModel = generationModel

    model = M.load_checkpoint(filename=model_path, Model=gModel, MiddleModel=cModel)
    summary = model.summarize()
    model = model.cpu()
    model.generationModel.device="cpu"
    model.generationModel = model.generationModel.cpu()

    recon_loss, adv_loss, pos_loss = [], [], []
    rot_loss, delta_pos, delta_rot = [], [], []
    acc, pCost, rCost, elapsed_time = [], [], [], []
    
    keyJoints = get_keyJoints(clip_path_for_keyJ)
    
    use_vae = "VAE" in model_path
    with torch.no_grad():
        for sample in test_set_p:
            x, y = sample[0], sample[1]
            shape = y.shape
    
            model.generationModel.reset_hidden(batch_size=y.shape[0])
    
            if use_vae:
                start = timer()
                out, z, mu, logvar = model(x)
                end = timer()
            else:
                start = timer()
                out, _ = model(x)
                end = timer()
    
            out = torch.cat(out, dim=1)
            pose_x = out[:, phase_dim:phase_dim+pose_dim]
            pose_y = y[:, phase_dim:phase_dim+pose_dim]
            recon_l, pl, rl = model.pose_autoencoder.loss(pose_x, pose_y)
            adv_l = 0.5 * torch.mean((model.pose_autoencoder.convDiscriminator(pose_x.reshape(1,1,shape[0],-1)) - 1)** 2)
    
            delta_p, delta_r = compute_delta(out, y, feature_dims=feature_dims)
            acc, pCost_x, pCost_y, rCost_x, rCost_y = compute_acc_cost(out, y, keyJoints, feature_dims, phase_dim, pose_dim)
    
            elapsed_time.append(end-start)
            recon_loss.append(recon_l)
            adv_loss.append(adv_l)
            pos_loss.append(pl)
            rot_loss.append(rl)
            delta_pos.append(delta_p)
            delta_rot.append(delta_r)
            acc.append(acc)
            pCost.append((pCost_x, pCost_y))
            rCost.append((rCost_x, rCost_y))

    F.local_generate_animation(model, test_set_p, feature_dims, template,target_dim, output_path="../animations/"+model_name+"_"+label, use_vae=use_vae, n=3)
    result = dict(name=model.name, params=summary.total_parameters, mem=summary.model_size,
                  elapsed_times=elapsed_time, recon_error=recon_loss, adv_error=adv_loss,
                  pos_error=pos_loss, rot_error=rot_loss, delta_pos=delta_pos, delta_rot=delta_rot,
                  accuracy=acc, potCost=pCost, rotCost=rCost)
    result = 0
    return result

def compute_model_results_reduc(ae, dataset, data_path, model_path, model_name, template, clip_path_for_keyJ):
    """
    A simple function for testing FS-OMG
    """
    (train_set_p, val_set_p, test_set_p), [phase_dim, pose_dim, pose_dim2, cost_dim, target_dim], feature_dims, feature_dims2, ae2 = dataset
    ae3 = MLP(config=config, dimensions=[pose_dim])
    pose_dim = pose_dim2
    feature_dims=feature_dims2
    label = "R2"
    # if "R2" in data_path:    model_path = model_path.replace(label, "R2"); label = "R2"
    if "R3" in data_path:    model_path = model_path.replace(label, "R3");label = "R3"
    elif "R4" in data_path:    model_path = model_path.replace(label, "R4");label = "R4"
    elif "R5" in data_path:    model_path = model_path.replace(label, "R5");label = "R5"

    gModel = get_G(model_path)
    cModel = get_C(model_path, config=config, pose_dim=pose_dim)
    M = get_M(model_path)

    with bz2.BZ2File(model_path, "rb") as f:
        obj = pickle.load(f)
    with bz2.BZ2File(obj["pose_autoencoder_path"], "rb") as f:
        mlp = pickle.load(f)
    print(model_path)
    # ae2.decoder = ae3.decoder
    ae2.encoder = ae3.encoder
    ae2.encoder.load_state_dict(mlp["encoder"])
    try:
        ae2.decoder.load_state_dict(mlp["decoder"])
    except:
        try:
            ae3 = MLP(config=config, dimensions=[pose_dim])
            ae2.decoder = ae3.decoder
            ae2.decoder.load_state_dict(mlp["decoder"])
        except:
            ae2.decoder = ae.decoder
            ae2.decoder.load_state_dict(mlp["decoder"])

    pose_autoencoder = MLP_ADV.load_checkpoint(obj["pose_autoencoder_path"])
    ae2.convDiscriminator = ae.convDiscriminator
    ae2 = pose_autoencoder
    cost_encoder = MLP.load_checkpoint(obj["cost_encoder_path"])
    generationModel = gModel.load_checkpoint(obj["motionGenerationModelPath"])

    model = M(config=obj["config"], feature_dims=obj["feature_dims"], Model=gModel, pose_autoencoder=ae2,
                                      use_advLoss=obj["use_adv_loss"],
                                      input_slicers=obj["in_slices"], output_slicers=obj["out_slices"],
                                      name=obj["name"])

    cModel.load_state_dict(obj["middle_layer_dict"])
    pose_autoencoder.convDiscriminator = ae.convDiscriminator

    model.middle_layer = cModel
    model.in_slices = obj["in_slices"]
    model.out_slices = obj["out_slices"]
    model.pose_autoencoder = ae2
    model.cost_encoder = cost_encoder
    model.generationModel = generationModel

    model = M.load_checkpoint(filename=model_path, Model=gModel, MiddleModel=cModel)
    summary = model.summarize()
    model = model.cpu()
    model.generationModel.device="cpu"
    model.generationModel = model.generationModel.cpu()

    recon_loss, adv_loss, pos_loss = [], [], []
    rot_loss, delta_pos, delta_rot = [], [], []
    acc, pCost, rCost, elapsed_time = [], [], [], []

    keyJoints = get_keyJoints(clip_path_for_keyJ)

    use_vae = "VAE" in model_path
    with torch.no_grad():
        for sample in test_set_p:
            x, y = sample[0], sample[1]
            shape = y.shape
    
            model.generationModel.reset_hidden(batch_size=y.shape[0])
    
            if use_vae:
                start = timer()
                out, z, mu, logvar = model(x)
                end = timer()
            else:
                start = timer()
                out, _ = model(x)
                end = timer()
    
            out = torch.cat(out, dim=1)
            pose_x = out[:, phase_dim:phase_dim+pose_dim]
            pose_y = y[:, phase_dim:phase_dim+pose_dim]
            recon_l, pl, rl = model.pose_autoencoder.loss(pose_x, pose_y)
            adv_l = 0.5 * torch.mean((model.pose_autoencoder.convDiscriminator(pose_x.reshape(1,1,shape[0],-1)) - 1)** 2)
    
            delta_p, delta_r = compute_delta(out, y, feature_dims=feature_dims)
            acc, pCost_x, pCost_y, rCost_x, rCost_y = compute_acc_cost(out, y, keyJoints, feature_dims, phase_dim, pose_dim)
    
            elapsed_time.append(end-start)
            recon_loss.append(recon_l)
            adv_loss.append(adv_l)
            pos_loss.append(pl)
            rot_loss.append(rl)
            delta_pos.append(delta_p)
            delta_rot.append(delta_r)
            acc.append(acc)
            pCost.append((pCost_x, pCost_y))
            rCost.append((rCost_x, rCost_y))

    F.local_generate_animation(model, test_set_p, feature_dims, template,target_dim, output_path="../animations/"+model_name+"_"+label, use_vae=use_vae, n=3)
    result = dict(name=model.name, params=summary.total_parameters, mem=summary.model_size,
                  elapsed_times=elapsed_time, recon_error=recon_loss, adv_error=adv_loss,
                  pos_error=pos_loss, rot_error=rot_loss, delta_pos=delta_pos, delta_rot=delta_rot,
                  accuracy=acc, potCost=pCost, rotCost=rCost)
    result = 0
    return result

# Training

In [None]:
# Paths to the pre-processed datasets 
data_paths = [
    "../datasets/0_F_R1.pbz2",
    "../datasets/0_F_R2.pbz2",
    "../datasets/0_F_R3.pbz2",
    "../datasets/0_F_R4.pbz2",
    "../datasets/0_F_R5.pbz2",
]

## Autoencoders

In [None]:
# Train autoencoders
aes = [train(path) for path in data_paths]
for ae in aes:
    ae.save_checkpoint(best_val_loss="final")   # save the trained version
results = [compute_ae_results(ae, path) for ae, path in zip(aes,data_paths)]    # test the trained AEs
F.save(results, "ae_results", "../results")

## Reference OMG models

In [None]:
ref_model_paths= [
    "../models/AE_MoE_256_ZIN0_F_R1/final.pbz2",            #AE+MoE
    "../models/AE_LSTM_256_ZIN0_F_R1_LSTM/final.pbz2",      #AE+LSTM
    "../models/RBF_LSTM_256_ZCAT0_F_R1_LSTM/final.pbz2",    #AE+RBF-CAT+LSTM
    "../models/RBF_LSTM_256_ZIN0_F_R1_LSTM/final.pbz2",     #AE+RBF-IN+LSTM,
    "../models/RBF_MoE_256_ZCAT0_F_R1_ZCAT/final.pbz2",     #AE+RBF-CAT+MoE
    "../models/RBF_MoE_256_ZINF_R1/final.pbz2",             #AE+RBF-IN-MoE
    "../models/VAE_LSTM_256_ZCAT0_F_R1_LSTM/final.pbz2",    #AE+VAE-CAT+LSTM
    "../models/VAE_LSTM_256_ZIN0_F_R1_LSTM/final.pbz2",     #AE+VAE-IN+LSTM
    "../models/VAE_MoE_256_ZCAT0_F_R1_ZCAT/final.pbz2",     #AE+VAE-CAT+MoE
    "../models/VAE_MoE_256_ZINF_R1/final.pbz2",             #AE+VAE-IN+MoE
    "../models/DEC_MoE_256_ZCAT0_F_R1_ZCAT/final.pbz2",     #AE+DEC-CAT+MoE
    "../models/DEC_MoE_256_ZINF_R1/final.pbz2",             #AE+DEC-IN+MoE
]

model_names = [
    "AE+MoE", "AE+LSTM", "RBF-CAT+LSTM", "RBF-IN+LSTM",
    "RBF-CAT+MoE", "RBF-IN+MoE", "VAE-CAT+LSTM", "VAE-IN+LSTM",
    "VAE-CAT+MoE", "VAE-IN+MoE", "DEC-CAT+MoE", "DEC-IN+MoE"
]

In [None]:
# Compute results of reference models
data = [get_datasets(path) for path in data_paths]
results = []
for i, dataPath in enumerate(data_paths):
    result = {}
    dataset = data[i]
    ae = aes[i]
    template = get_template(dataPath)
    clip_path = get_clip(dataPath)

    for model_path, model_name in zip(ref_model_paths, model_names):
        model_result = compute_model_results(ae, dataset, dataPath, model_path, model_name, template,clip_path)
        result[model_name] = model_result
    results.append(result)

F.save(results, "reference_results", "../results")

## Transferred OMG models

In [None]:
t_models = {
    "AE+LSTM_RAW":"../models/AE_LSTM_256_AE_0.10_RAW_F_R2_ZIN/final.pbz2",
    "AE+MoE_RAW":"../models/AE_MoE_256_AE_0.10_RAW_F_R2_ZIN/final.pbz2",
    "DEC-CAT+MoE_RAW":"../models/DEC_MoE_256_ZCAT_0.10_RAW_F_R2_ZCAT/final.pbz2",
    "DEC-IN+MoE_RAW":"../models/DEC_MoE_256_ZIN_0.10_RAW_F_R2_ZIN/final.pbz2",
    "RBF-IN+LSTM_RAW":"../models/RBF_LSTM_256_AE_0.10_RAW_F_R2_ZIN/final.pbz2",
    "RBF-CAT+LSTM_RAW":"../models/RBF_LSTM_256_ZCAT_0.10_RAW_F_R2_ZCAT/final.pbz2",
    "RBF-CAT+MoE_RAW":"../models/RBF_MoE_256_ZCAT_0.10_RAW_F_R2_ZCAT/final.pbz2",
    "RBF-IN+MoE_RAW": "../models/RBF_MoE_256_ZIN_0.10_RAW_F_R2_ZIN/final.pbz2",
    "VAE-IN+LSTM_RAW":"../models/VAE_LSTM_256_AE_0.10_RAW_F_R2_ZIN/final.pbz2",
    "VAE-CAT+LSTM_RAW":"../models/VAE_LSTM_256_ZCAT_0.10_RAW_F_R2_ZCAT/final.pbz2",
    "VAE-CAT+MoE_RAW":"../models/VAE_MoE_256_ZCAT_0.10_RAW_F_R2_ZCAT/final.pbz2",
    "VAE-IN+MoE_RAW":"../models/VAE_MoE_256_ZIN_0.10_RAW_F_R2_ZIN/final.pbz2",
    
    "AE+LSTM_F" : "../models/AE_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "AE+LSTM_T" : "../models/AE_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",
    "AE+MoE_F":"../models/AE_MoE_256_AE_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "AE+MoE_T":"../models/AE_MoE_256_AE_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",

    "DEC-CAT+MoE_F":"../models/DEC_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT/final.pbz2",
    "DEC-CAT+MoE_T":"../models/DEC_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT_trainable/final.pbz2",
    "DEC-IN+MoE_F":"../models/DEC_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "DEC-IN+MoE_T":"../models/DEC_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",

    "RBF-CAT+LSTM_T":"../models/RBF_LSTM_256_ZCAT_0.10_R1_to_F_R2_ZCAT/final.pbz2",
    "RBF-CAT+LSTM_F":"../models/RBF_LSTM_256_ZCAT_0.10_R1_to_F_R2_ZCAT_frozen/final.pbz2",
    "RBF-IN+LSTM_F":"../models/RBF_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "RBF-IN+LSTM_T":"../models/RBF_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",
    "RBF-CAT+MoE_F":"../models/RBF_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT/final.pbz2",
    "RBF-CAT+MoE_T":"../models/RBF_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT_trainable/final.pbz2",
    "RBF-IN+MoE_F":"../models/RBF_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "RBF-IN+MoE_T":"../models/RBF_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",

    "VAE-CAT+LSTM_T":"../models/VAE_LSTM_256_ZCAT_0.10_R1_to_F_R2_ZCAT/final.pbz2",
    "VAE-CAT+LSTM_F":"../models/VAE_LSTM_256_ZCAT_0.10_R1_to_F_R2_ZCAT_frozen/final.pbz2",
    "VAE-IN+LSTM_F":"../models/VAE_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "VAE-IN+LSTM_T":"../models/VAE_LSTM_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",
    "VAE-CAT+MoE_F":"../models/VAE_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT/final.pbz2",
    "VAE-CAT+MoE_T":"../models/VAE_MoE_256_ZCAT_0.10_R1_to_F_R2_ZCAT_trainable/final.pbz2",
    "VAE-IN+MoE_F":"../models/VAE_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN/final.pbz2",
    "VAE-IN+MoE_T":"../models/VAE_MoE_256_ZIN_0.10_R1_to_F_R2_ZIN_trainable/final.pbz2",
}


In [None]:
# Compute results of the transferred OMG models
data2 = [get_datasets(path, train_prob=0.16) for path in data_paths]
transfer_results = []
for i, dataPath in enumerate(data_paths):
    dataPath = data_paths[i]
    result = {}
    ae = aes[i]
    dataset = data2[i]
    template = get_template(dataPath)
    clip_path = get_clip(dataPath)

    for model_name, model_path in t_models.items():
        model_result = compute_model_results(ae, dataset, dataPath, model_path, model_name, template,clip_path)
        result[model_name] = model_result
    transfer_results.append(result)
F.save(transfer_results, "transfer_results", "../results")


## Reference FS-OMG models

In [None]:
aes_path = [
    "../models/AE_R1/final.256.pbz2",
    "../models/AE_R2/final.256.pbz2",
    "../models/AE_R3/final.256.pbz2",
    "../models/AE_R4/final.256.pbz2",
    "../models/AE_R5/final.256.pbz2",
]
aes = [MLP_ADV.load_checkpoint(path) for path in aes_path]

data_paths = [
    "../datasets/2_F_R1.pbz2",
    "../datasets/2_F_R2.pbz2",
    "../datasets/2_F_R3.pbz2",
    "../datasets/2_F_R4.pbz2",
    "../datasets/2_F_R5.pbz2",
]
data = [get_datasets_reduc(path, train_prob=0.8) for path in data_paths]

tr_models_ref = {
    "AE+MoE":"../models/AE_MoE_256_ZIN2_F_R1/final.pbz2",
    "VAE-CAT+MoE":"../models/VAE_MoE_256_ZCAT2_F_R1_ZCAT/final.pbz2",
    "DEC-CAT+MoE":"../models/DEC_MoE_256_ZCAT2_F_R1_ZCAT/final.pbz2",
    "RBF-CAT+MoE":"../models/RBF_MoE_256_ZCAT2_F_R1_ZCAT/final.pbz2",
}

transfer_results_reduc_ref = []

for i, dataPath in enumerate(data_paths):
    result = {}
    ae = aes[i]
    dataset = data3[i]
    template = get_template(dataPath)
    clip_path = get_clip(dataPath)

    for model_name, model_path in tr_models_ref.items():
        model_result = compute_model_results_reduc(ae, dataset, dataPath, model_path, model_name, template,clip_path)
        result[model_name] = model_result
    transfer_results_reduc_ref.append(result)
F.save(transfer_results_reduc_ref, "transfer_results_reduc_ref_2-5", "../results")

## Transferred FS-OMG models

In [None]:
tr_models = {
    "AE+MoE_T":"../models/AE_MoE_256_AE_0.12_R1_to_F_R2_ZIN_reduced_trainable/final.pbz2",
    "AE+MoE_F":"../models/AE_MoE_256_AE_0.12_R1_to_F_R2_ZIN_reduced/final.pbz2",
    "DEC-CAT+MoE_F":"../models/DEC_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced/final.pbz2",
    "DEC-CAT+MoE_T":"../models/DEC_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced_trainable/final.pbz2",
    "RBF-CAT+MoE_F":"../models/RBF_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced/final.pbz2",
    "RBF-CAT+MoE_T":"../models/RBF_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced_trainable/final.pbz2",
    "VAE-CAT+MoE_F":"../models/VAE_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced/final.pbz2",
    "VAE-CAT+MoE_T":"../models/VAE_MoE_256_ZCAT_0.12_R1_to_F_R2_ZCAT_reduced_trainable/final.pbz2"
}
data = [get_datasets_reduc(path, train_prob=0.16) for path in data_paths]
transfer_results_reduc = []
for i, dataPath in enumerate(data_paths):
    dataPath = data_paths[i]
    print(i)
    result = {}
    ae = aes[i]
    dataset = data4[i]
    template = get_template(dataPath)
    clip_path = get_clip(dataPath)

    for model_name, model_path in tr_models.items():
        model_result = compute_model_results_reduc(ae, dataset, dataPath, model_path, model_name, template,clip_path)
        result[model_name] = model_result
    transfer_results_reduc.append(result)
F.save(transfer_results_reduc, "transfer_results_reduc", "../results")