In [11]:
import os, sys
sys.path.append("motion_generation")
sys.path.append("rig_agnostic_encoding/functions")
sys.path.append("rig_agnostic_encoding/models")

from motion_generation.MoE import MoE
import motion_generation
from motion_generation.GRU import GRU
from motion_generation.LSTM import LSTM
from motion_generation.MotionGeneration import MotionGenerationModel
from motion_generation.MotionGeneration_v2 import MotionGenerationModel as MotionGenerationModel_v2
from motion_generation.MotionGeneration_v3 import MotionGenerationModel as MotionGenerationModel_v3
from motion_generation.MotionGenerationVAE import MotionGenerationModel as MotionGenerationModelVAE
from motion_generation.MotionGenerationRNN import MotionGenerationModelRNN
from motion_generation.MotionGenerationBatch import MotionGenerationModelBatch
from rig_agnostic_encoding.models.MLP import MLP
from rig_agnostic_encoding.models.RBF import RBF
import  rig_agnostic_encoding.models.RBF  as RB
from rig_agnostic_encoding.models.DEC import DEC
from rig_agnostic_encoding.models.MLP_MIX import MLP_MIX
from rig_agnostic_encoding.models.VAE import VAE
from rig_agnostic_encoding.models.VaDE import VaDE

from rig_agnostic_encoding.functions.DataProcessingFunctions import clean_checkpoints
from GlobalSettings import MODEL_PATH
import bz2
from cytoolz import sliding_window, accumulate
from operator import add
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import func as F
import _pickle as pickle
import json as js
import importlib

In [12]:
class RBF_Layer(nn.Module):
    """
       from JeremyLinux on GitHub {https://github.com/JeremyLinux/PyTorch-Radial-Basis-Function-Layer/blob/master/Torch%20RBF/torch_rbf.py}

       Transforms incoming data using a given radial basis function:
       u_{i} = rbf(||x - c_{i}|| / s_{i})
       Arguments:
           in_features: size of each input sample
           out_features: size of each output sample
       Shape:
           - Input: (N, in_features) where N is an arbitrary batch size
           - Output: (N, out_features) where N is an arbitrary batch size
       Attributes:
           centres: the learnable centres of shape (out_features, in_features).
               The values are initialised from a standard normal distribution.
               Normalising inputs to have mean 0 and standard deviation 1 is
               recommended.

           sigmas: the learnable scaling factors of shape (out_features).
               The values are initialised as ones.

           basis_func: the radial basis function used to transform the scaled
               distances.
       """

    def __init__(self, in_features: int=0, out_features: int=0, basis_func=None):
        super(RBF_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sigmas = nn.Parameter(torch.Tensor(out_features))
        self.out_layer = nn.Linear(in_features=in_features, out_features=out_features)
        self.basis_func = basis_func
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.centres, 0, 1)
        nn.init.constant_(self.sigmas, .01)

    def forward(self, x):
        size = (x.size(0), self.out_features, self.in_features)
        x = x.unsqueeze(1).expand(size)
        c = self.centres.unsqueeze(0).expand(size)
        distances = (x - c).pow(2).sum(-1).pow(0.5) * self.sigmas.unsqueeze(0)  # ALT. / (2*sigma**2)
        return self.out_layer(self.basis_func(distances))

    def freeze(self, flag=False):
        self.centres.requires_grad = flag
        self.sigmas.requires_grad = flag

In [13]:

class MotionGenerationModel(pl.LightningModule):
    def __init__(self, config:dict=None, Model=None, pose_autoencoder=None, middle_layer=None, feature_dims=None,
                 input_slicers:list=None, output_slicers:list=None,
                 train_set=None, val_set=None, test_set=None, name="MotionGeneration", workers=8):
        super().__init__()

        self.feature_dims = feature_dims
        self.config=config
        self.workers = workers

        self.loss_fn = config["loss_fn"] if "loss_fn" in config else nn.functional.mse_loss
        self.opt = config["optimizer"] if "optimizer" in config else torch.optim.Adam
        self.scheduler = config["scheduler"] if "scheduler" in config else None
        self.scheduler_param = config["scheduler_param"] if "scheduler_param" in config else None
        self.batch_size = config["batch_size"]
        self.learning_rate = config["lr"]
        self.seq_len = config["seq_len"]

        self.autoregress_prob = config["autoregress_prob"]
        self.autoregress_inc = config["autoregress_inc"]
        self.autoregress_ep = config["autoregress_ep"]
        self.autoregress_max_prob = config["autoregress_max_prob"]

        self.best_val_loss = np.inf
        self.phase_smooth_factor = 0.9

        self.pose_autoencoder = pose_autoencoder if pose_autoencoder is not None else \
            MLP(config=config, dimensions=[feature_dims["pose_dim"]], name="PoseAE")
        self.middle_layer = middle_layer if middle_layer is not None else \
            nn.Linear(in_features=self.pose_autoencoder.dimensions[-1], out_features=self.pose_autoencoder.dimensions[-1])
        self.use_label = pose_autoencoder is not None and pose_autoencoder.use_label

        cost_hidden_dim = config["cost_hidden_dim"]
        self.cost_encoder = MLP(config=config,
                            dimensions=[feature_dims["cost_dim"], cost_hidden_dim, cost_hidden_dim, cost_hidden_dim],
                            name="CostEncoder", single_module=-1)

        self.generationModel =  Model(config=config,
                                      dimensions=[feature_dims["g_input_dim"], feature_dims["g_output_dim"]],
                                      phase_input_dim = feature_dims["phase_dim"])

        self.input_dims = input_slicers
        self.output_dims = output_slicers
        self.in_slices = [0] + list(accumulate(add, input_slicers))
        self.out_slices = [0] + list(accumulate(add, output_slicers))

        self.train_set = train_set
        self.val_set = val_set
        self.test_set = test_set
        self.name = name

        self.automatic_optimization = False

    def forward(self, x):
        x_tensors = [x[:, d0:d1] for d0, d1 in zip(self.in_slices[:-1], self.in_slices[1:])]

        pose_h, mu, logvar = self.middle_layer(self.pose_autoencoder.encode(x_tensors[1]))
        embedding = torch.cat([pose_h, x_tensors[1], self.cost_encoder(x_tensors[2])], dim=1)
        out = self.generationModel(embedding, x_tensors[0])
        out_tensors = [out[:, d0:d1] for d0, d1 in zip(self.out_slices[:-1], self.out_slices[1:])]
        phase = self.update_phase(x_tensors[0], out_tensors[0])
        new_pose = self.pose_autoencoder.decode(out_tensors[1])

        return [phase, new_pose, out_tensors[-1]], pose_h, mu, logvar

    def computeCost(self, targets, trajectory):
        # targetPos = targets[:, :self.feature_dims["targetPosition"]]
        # targetRot = targets[:, self.feature_dims["targetPosition"]:]
        # posT = trajectory[:, :self.feature_dims["tPos"]]
        # rotT = trajectory[:, self.feature_dims["tPos"]:]
        #
        # targetPos = targetPos.reshape((-1, 12, 3))
        # posT = posT.reshape((-1, 12, 3))
        # # targetRot = targetRot.reshape((-1, 12, 3,3))
        # # rotT = rotT.reshape((-1, 12, 3, 3))
        #
        # posCost = torch.sum(((targetPos - posT)**2), axis=2).reshape((-1, self.feature_dims["posCost"]))
        # colLength = torch.sqrt(torch.clip(torch.sum(rotT**2, axis=2), 0))
        # rotT = rotT / colLength[:, :, :, None]

        # rotT = torch.transpose(rotT, dim0=2, dim1=3)
        # trace =torch.diagonal(targetRot @ rotT, offset=0, dim1=2, dim2=3).sum(dim=2)
        # rotCost = torch.abs(torch.arccos((torch.clamp( (trace - 1) / 2.0, -1, 1))))
        # torch.nan_to_num_(rotCost, 0)
        # rotCost = rotCost.reshape((-1, self.feature_dims["rotCost"]))

        return 0, 0

    def step(self, x, y):
        opt = self.optimizers()

        n = x.size()[1]
        tot_loss = 0
        tot_kl_loss = 0
        tot_recon_loss = 0

        x_c = x[:,0,:]
        autoregress_bools = torch.randn(n) < self.autoregress_prob
        for i in range(1, n):
            y_c = y[:,i-1,:]

            out, z, mu, logvar = self(x_c)
            recon = torch.cat(out, dim=1)
            loss = self.loss_fn(recon, y_c)
            kl_loss =  self.middle_layer.loss_function(z, mu, logvar)

            recon_loss = loss.detach()
            kl_loss = kl_loss.detach()

            tot_recon_loss += recon_loss
            tot_kl_loss += kl_loss
            tot_loss += recon_loss + kl_loss

            opt.zero_grad()
            self.manual_backward(loss+kl_loss)
            opt.step()

            if autoregress_bools[i]:
                x_c = torch.cat(out,dim=1).detach()
            else:
                x_c = x[:,i,:]

        tot_loss /= float(n)
        tot_recon_loss /= float(n)
        tot_kl_loss /= float(n)
        return tot_loss, tot_recon_loss, tot_kl_loss

    def step_eval(self, x, y):
        n = x.size()[1]
        tot_loss = 0
        tot_kl_loss = 0
        tot_recon_loss = 0

        x_c = x[:,0,:]
        autoregress_bools = torch.randn(n) < self.autoregress_prob
        for i in range(1, n):
            y_c = y[:,i-1,:]

            out, z, mu, logvar = self(x_c)
            recon = torch.cat(out, dim=1)
            loss = self.loss_fn(recon, y_c)
            kl_loss = self.middle_layer.loss_function(z, mu, logvar)

            recon_loss = loss.detach()
            kl_loss = kl_loss.detach()

            tot_recon_loss += recon_loss
            tot_kl_loss += kl_loss
            tot_loss += recon_loss + kl_loss

            if autoregress_bools[i]:
                x_c = torch.cat(out,dim=1).detach()
            else:
                x_c = x[:,i,:]

        tot_loss /= float(n)
        tot_recon_loss /= float(n)
        tot_kl_loss /= float(n)
        return tot_loss, tot_recon_loss, tot_kl_loss

    def training_step(self, batch, batch_idx):
        x, y = batch

        x = x.view(-1, self.seq_len, x.shape[-1])
        y = y.view(-1, self.seq_len, y.shape[-1])
        loss, recon_loss, kl_loss = self.step(x,y)
        self.log("ptl/train_loss", loss, prog_bar=True)
        # self.log("ptl/train_recon_loss", recon_loss, prog_bar=True)
        # self.log("ptl/train_kl_loss", kl_loss, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch

        x = x.view(-1, self.seq_len, x.shape[-1])
        y = y.view(-1, self.seq_len, y.shape[-1])
        loss, recon_loss, kl_loss = self.step_eval(x,y)
        self.log("ptl/val_loss", loss, prog_bar=True)
        self.log("ptl/val_recon_loss", recon_loss, prog_bar=True)
        self.log("ptl/val_kl_loss", kl_loss, prog_bar=True)
        return {"val_loss":loss}

    def test_step(self, batch, batch_idx):
        x, y = batch

        x = x.view(-1, self.seq_len, x.shape[-1])
        y = y.view(-1, self.seq_len, y.shape[-1])
        loss, recon_loss, kl_loss = self.step_eval(x,y)
        self.log("ptl/test_loss", loss, prog_bar=True)
        self.log("ptl/test_recon_loss", recon_loss, prog_bar=True)
        self.log("ptl/test_kl_loss", kl_loss, prog_bar=True)
        return {"test_loss":loss}

    def validation_epoch_end(self, outputs):
        if self.current_epoch > 0 and self.current_epoch % self.autoregress_ep == 0:
            self.autoregress_prob = min(self.autoregress_max_prob, self.autoregress_prob+self.autoregress_inc)

        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)
        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.save_checkpoint(best_val_loss=self.best_val_loss.item())

    def save_checkpoint(self, best_val_loss=np.inf, checkpoint_dir=MODEL_PATH):
        path = os.path.join(checkpoint_dir, self.name)

        config = {}
        for key, val in self.config.items():
            config[key] = val

        pose_autoencoder_path = self.pose_autoencoder.save_checkpoint(best_val_loss=best_val_loss, checkpoint_dir=path)
        cost_encoder_path = self.cost_encoder.save_checkpoint(best_val_loss=best_val_loss, checkpoint_dir=path)
        generationModel_path = self.generationModel.save_checkpoint(best_val_loss=best_val_loss, checkpoint_dir=path)

        model = {"name":self.name,
                 "config":config,
                 "pose_autoencoder_path":pose_autoencoder_path,
                 "cost_encoder_path": cost_encoder_path,
                 "motionGenerationModelPath":generationModel_path,
                 "in_slices":self.in_slices,
                 "out_slices":self.out_slices,
                 "feature_dims":self.feature_dims,
                 "middle_layer_dict":self.middle_layer.state_dict(),
                 }

        if not os.path.exists(path):
            os.mkdir(path)
        with bz2.BZ2File(os.path.join(path,
                                      str(best_val_loss)+".pbz2"), "w") as f:
            pickle.dump(model, f)

    @staticmethod
    def load_checkpoint(filename, Model, MiddleModel:nn.Module=None):
        with bz2.BZ2File(filename, "rb") as f:
            obj = pickle.load(f)

        pose_autoencoder = MLP.load_checkpoint(obj["pose_autoencoder_path"])
        cost_encoder = MLP.load_checkpoint(obj["cost_encoder_path"])
        generationModel = Model.load_checkpoint(obj["motionGenerationModelPath"])

        model = MotionGenerationModel(config=obj["config"], feature_dims=obj["feature_dims"], Model=Model,
                                      input_slicers=obj["in_slices"], output_slicers=obj["out_slices"],
                                      name=obj["name"])
        if MiddleModel is None:
            MiddleModel = nn.Linear(in_features=pose_autoencoder.dimensions[-1], out_features=pose_autoencoder.dimensions[-11])

        MiddleModel.load_state_dict(obj["middle_layer_dict"])
        model.in_slices = obj["in_slices"]
        model.out_slices = obj["out_slices"]
        model.pose_autoencoder = pose_autoencoder
        model.cost_encoder = cost_encoder
        model.generationModel = generationModel

        return model

    def update_phase(self, p1, p2):
        return self.phase_smooth_factor * p2 + (1-self.phase_smooth_factor)*p1

    def configure_optimizers(self):
        optimizer = self.opt(self.parameters(), lr=self.learning_rate)
        if self.scheduler is not None:
            scheduler = self.scheduler(optimizer, **self.scheduler_param)
            return [optimizer], [scheduler]
        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    def swap_pose_encoder(self, pose_encoder=None,
                          input_dim=None, output_dim=None,
                          feature_dims=None, freeze=False):
        self.pose_autoencoder = pose_encoder
        self.input_dims = input_dim
        self.output_dims = output_dim
        self.feature_dims = feature_dims
        self.in_slices = [0] + list(accumulate(add, self.input_dims))
        self.out_slices = [0] + list(accumulate(add, self.output_dims))

        if freeze:
            self.generationModel.freeze()
            self.cost_encoder.freeze()


In [17]:
config = {
    "hidden_dim": 512,
    "k": 32,
    "lr": 1e-4,
    "batch_size": 16,
    "keep_prob": .2,
    "loss_fn":torch.nn.functional.mse_loss,
    "optimizer":torch.optim.AdamW,
    "scheduler":torch.optim.lr_scheduler.StepLR,
    "scheduler_param": {"step_size":80, "gamma":.9},
    "basis_func":"gaussian",
    "n_centroid":128,
    "k_experts": 4,
    "gate_size": 128,
    "g_hidden_dim": 512,
    "num_layers": 4,
    "autoregress_prob":0,
    "autoregress_inc":.2,
    "autoregress_ep":50,
    "autoregress_max_prob":0,
    "cost_hidden_dim":128,
    "seq_len":15,
    "device":"cuda"
    }

In [4]:
MAX_FILES = -1
data_path = "/home/nuoc/Documents/MEX/data/Dataset_R1_One_1"
data_path2 = "/home/nuoc/Documents/MEX/data/data/Dataset2_R5_Two_1"
file_paths = []
file_paths2 = []
for dname, dirs, files in os.walk(data_path):
    for i, file in enumerate(files):
        file_paths.append(os.path.join(dname, file))
        if MAX_FILES > 0 and i >= MAX_FILES:
            break
for dname, dirs, files in os.walk(data_path2):
    for i, file in enumerate(files):
        file_paths2.append(os.path.join(dname, file))
        if MAX_FILES > 0 and i >= MAX_FILES:
            break




In [5]:
phase_features = ["phase_vec_l2"]
pose_features = ["pos", "rotMat2", "velocity"]
cost_features = ["posCost", "rotCost"]
features = phase_features + pose_features + cost_features
clips = []
feature_dims = {}

In [6]:
data = F.process_data_multithread(file_paths, features)
data2 = F.process_data_multithread(file_paths2, features)

feature_dims = data[0][1]
feature_dims2 = data2[0][1]

clips = [np.copy(d[0]) for d in data]
clips2 = [np.copy(d[0]) for d in data2]



2021-05-07 14:26:52,901	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-05-07 14:27:12,263	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


In [7]:
phase_dim = sum([feature_dims[feature] for feature in phase_features])
pose_dim = sum([feature_dims[feature] for feature in pose_features])
pose_dim2 = sum([feature_dims2[feature] for feature in pose_features])
cost_dim = sum([feature_dims[feature] for feature in cost_features])
print(phase_dim, " ", pose_dim, " ", cost_dim)
print(pose_dim2)

8   372   24
384


In [8]:
x_tensors = torch.stack([F.normaliseT(torch.from_numpy(clip)).float() for clip in clips])
y_tensors = torch.stack([torch.from_numpy(clip).float() for clip in clips])

x_tensors2 = torch.stack([F.normaliseT(torch.from_numpy(clip)).float() for clip in clips2])
y_tensors2 = torch.stack([torch.from_numpy(clip).float() for clip in clips2])

pose_data1 = x_tensors[:,  :,  phase_dim:phase_dim+pose_dim]
pose_data2 = x_tensors2[:, :, phase_dim:phase_dim+pose_dim2]
pose_data = torch.cat((pose_data1, pose_data2), dim=2)

dataset = TensorDataset(pose_data, pose_data)
datasetR1_R5 = TensorDataset(torch.Tensor(x_tensors), torch.Tensor(y_tensors2))
datasetR1 = TensorDataset(torch.Tensor(x_tensors), torch.Tensor(y_tensors))
datasetR5 = TensorDataset(torch.Tensor(x_tensors2), torch.Tensor(y_tensors2))

N = len(x_tensors)

train_ratio = int(.7*N)
val_ratio = int((N-train_ratio) / 2.0)
test_ratio = N - train_ratio - val_ratio
train_set, val_set, test_set = random_split(dataset, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_set2, val_set2, test_set2 = random_split(datasetR1_R5, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_setR1, val_setR1, test_setR1 = random_split(datasetR1, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_setR5, val_setR5, test_setR5 = random_split(datasetR5, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
print(len(train_set), len(val_set), len(test_set))
print(len(train_set2), len(val_set2), len(test_set2))

168 36 36
168 36 36


In [5]:
dataset_name = "R1_One_MoGenData"
dataset_name2 = "R5__Two_MoGenData"

In [10]:
data = {"data":[train_set, val_set, test_set],
        "feature_dims":feature_dims,
        "dims":[phase_dim, pose_dim, cost_dim]}

data2 = {"data":[train_set2, val_set2, test_set2],
        "feature_dims":feature_dims2,
        "dims":[phase_dim, pose_dim2, cost_dim]}
F.save(data, dataset_name, "/home/nuoc/Documents/MEX/data")
F.save(data2, dataset_name2, "/home/nuoc/Documents/MEX/data")

In [6]:
obj = F.load("/home/nuoc/Documents/MEX/data/"+dataset_name+".pbz2")
train_set, val_set, test_set = obj["data"]
feature_dims = obj["feature_dims"]
phase_dim, pose_dim, cost_dim= obj["dims"]

obj = F.load("/home/nuoc/Documents/MEX/data/"+dataset_name2+".pbz2")
train_set2, val_set2, test_set2 = obj["data"]
feature_dims2 = obj["feature_dims"]
phase_dim2, pose_dim2, cost_dim2= obj["dims"]

In [7]:
print(len(train_set), train_set[0][0].shape)
print(len(val_set), val_set[0][0].shape)
print(len(test_set), test_set[0][0].shape)
print(feature_dims)
print(feature_dims2)


168 torch.Size([300, 756])
36 torch.Size([300, 756])
36 torch.Size([300, 756])
{'phase_vec_l2': 8, 'pos': 93, 'rotMat2': 186, 'velocity': 93, 'posCost': 12, 'rotCost': 12}
{'phase_vec_l2': 8, 'pos': 96, 'rotMat2': 192, 'velocity': 96, 'posCost': 12, 'rotCost': 12}


In [18]:
model_name = "VAE_R1-R5"
AE = VAE(config=config, input_dims=[pose_dim, pose_dim2]
                 train_set=train_set, val_set=val_set, test_set=test_set)

In [19]:
MAX_EPOCHS = 100

checkpoint_callback = ModelCheckpoint(monitor="avg_val_loss", save_top_k=3)
earlystopping = EarlyStopping(monitor="avg_val_loss", patience=20)
logger=TensorBoardLogger(save_dir="logs/", name=model_name, version="0.1")

trainer = pl.Trainer(
    default_root_dir="/home/nuoc/Documents/MEX/src/motion_generation/checkpoints",
    gpus=1, precision=16,
    callbacks=[earlystopping],
    min_epochs=20,
    logger=logger,
    max_epochs=MAX_EPOCHS,
    stochastic_weight_avg=True
)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


In [21]:
trainer.fit(AE)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type      | Params
--------------------------------------------
0 | cluster_model | VAE_Layer | 2.1 K 
--------------------------------------------
2.1 K     Trainable params
0         Non-trainable params
2.1 K     Total params
0.008     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

1

In [38]:
model_name = "VAE_MOE_R1_One"

feature_dims3 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "g_input_dim": config["k"] + pose_dim + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim
    }

in_slice = [phase_dim, pose_dim, cost_dim]
out_slice = [phase_dim, config["k"], cost_dim]

temp = MLP(config=config, dimensions=[pose_dim])
temp2 = MLP(config=config, dimensions=[pose_dim2])

pose_encoder = temp
pose_encoder.encoder.load_state_dict(AE.active_models[0].encoder.state_dict())
pose_encoder.decoder = temp2.decoder
# middle_layer = torch.nn.Sequential()
pose_encoder.decoder.load_state_dict(AE.active_models[1].decoder.state_dict())

# pose_encoder = MLPMIX.active_models[1]
# pose_encoder.decoder = MLPMIX.active_models[1].decoder
# middle_layer = MLPMIX.cluster_model

tempVAE = VAE(config=config, input_dims=[pose_dim, pose_dim2],
                 train_set=train_set, val_set=val_set, test_set=test_set)

# middle_layer = RBF_Layer(in_features=config["k"], out_features=config["k"], basis_func=RB.gaussian)
# middle_layer.load_state_dict(model2.middle_layer.state_dict())
middle_layer = tempVAE.cluster_model
middle_layer.load_state_dict(AE.cluster_model.state_dict())
# cost_encoder = model2.cost_encoder

<All keys matched successfully>

In [51]:
# print(pose_encoder)
print(MLPMIX.cluster_model)

Sequential(
  (0): Linear(in_features=512, out_features=512, bias=True)
  (1): ELU(alpha=1.0)
)


In [39]:
# kk = int(len(train_setR5) / 4.0)
# train_setR52 = train_setR5[:kk]
# generation_model = model2.generationModel
model = MotionGenerationModel(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=feature_dims3,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_set2, val_set=val_set2, test_set=test_set2,
                                 name=model_name
                                   )
# model.generationModel.gate.load_state_dict(model2.generationModel.gate.state_dict())
# model.generationModel.load_state_dict(model2.generationModel.state_dict())
# model.cost_encoder.load_state_dict(model2.cost_encoder.state_dict())
#
# model.cost_encoder.freeze()
# model.generationModel.freeze()
# model.middle_layer.requires_grad_(False)



In [42]:
MAX_EPOCHS = 100

checkpoint_callback = ModelCheckpoint(monitor="avg_val_loss", save_top_k=3)
earlystopping = EarlyStopping(monitor="avg_val_loss", patience=20)
logger=TensorBoardLogger(save_dir="logs/", name=model_name, version="0.1")

trainer = pl.Trainer(
    default_root_dir="/home/nuoc/Documents/MEX/src/motion_generation/checkpoints",
    gpus=1, precision=16,
    # callbacks=[earlystopping],
    min_epochs=20,
    logger=logger,
    max_epochs=MAX_EPOCHS,
    stochastic_weight_avg=True
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


In [43]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP       | 946 K 
1 | middle_layer     | VAE_Layer | 2.1 K 
2 | cost_encoder     | MLP       | 36.2 K
3 | generationModel  | MoE       | 2.3 M 
-----------------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.107    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






1

In [41]:
trainer.test(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_kl_loss': 0.0,
 'ptl/test_loss': 0.4568394422531128,
 'ptl/test_recon_loss': 0.06087976694107056,
 'test_loss': 0.6854567527770996}
--------------------------------------------------------------------------------


[{'test_loss': 0.6854567527770996,
  'ptl/test_loss': 0.4568394422531128,
  'ptl/test_recon_loss': 0.06087976694107056,
  'ptl/test_kl_loss': 0.0}]

In [23]:
clean_checkpoints(path=os.path.join(MODEL_PATH,model_name))
# clean_checkpoints(path="/home/nuoc/Documents/MEX/models/version_0.2/MLP_MoE_R1_One_1_Full_v3_smooth_loss")

In [30]:
model2 = model

In [61]:
filename = "/home/nuoc/Documents/MEX/models/version_0.2/MLP_MoE_R1_One_1_Full_TUNE/0.06325527280569077.pbz2"
with bz2.BZ2File(filename, "rb") as f:
    obj = pickle.load(f)

pose_autoencoder = MLP.load_checkpoint(obj["pose_autoencoder_path"])
cost_encoder = MLP.load_checkpoint(obj["cost_encoder_path"])
generationModel = MoE.load_checkpoint(obj["motionGenerationModelPath"])

model = MotionGenerationModel_v2(config=obj["config"], feature_dims=obj["feature_dims"], Model=MoE,
                              input_slicers=obj["in_slices"], output_slicers=obj["out_slices"],
                              name=obj["name"])

model.in_slices = obj["in_slices"]
model.out_slices = obj["out_slices"]
model.pose_autoencoder = pose_autoencoder
model.cost_encoder = cost_encoder
model.generationModel = generationModel

In [64]:
model.test_set = test_set

[0, 62, 434, 458]
MLP(
  (encoder): Sequential(
    (fc0): Linear(in_features=372, out_features=512, bias=True)
    (act0): ELU(alpha=1.0)
    (drop): Dropout(p=0.2, inplace=False)
    (fc1): Linear(in_features=512, out_features=512, bias=True)
    (act1): ELU(alpha=1.0)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
  )
  (decoder): Sequential(
    (fc0): Linear(in_features=256, out_features=512, bias=True)
    (act0): ELU(alpha=1.0)
    (drop): Dropout(p=0.2, inplace=False)
    (fc1): Linear(in_features=512, out_features=512, bias=True)
    (act1): ELU(alpha=1.0)
    (fc2): Linear(in_features=512, out_features=372, bias=True)
  )
)
MLP(
  (encoder): Sequential(
    (fc0): Linear(in_features=24, out_features=128, bias=True)
    (act0): ELU(alpha=1.0)
    (drop): Dropout(p=0.2, inplace=False)
    (fc1): Linear(in_features=128, out_features=128, bias=True)
    (act1): ELU(alpha=1.0)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
  )
  (decoder): Sequent

In [29]:
model.autoregress_prob = 0
trainer.test(model)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_loss': 0.03257506340742111, 'test_loss': 0.07678437978029251}
--------------------------------------------------------------------------------


[{'test_loss': 0.07678437978029251, 'ptl/test_loss': 0.03257506340742111}]

In [30]:
model.autoregress_prob = 1
trainer.test(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_loss': 0.07573438435792923, 'test_loss': 0.10452431440353394}
--------------------------------------------------------------------------------


[{'test_loss': 0.10452431440353394, 'ptl/test_loss': 0.07573438435792923}]

In [28]:
# model.autoregress_prob = 0

n = 5
idx = np.random.randint(0, len(test_set), n)
original = []
generated = []
# pose_idx_upper = feature_dims2["phase_dim"] + feature_dims["pos"] + feature_dims["rotMat2"]
pose_idx_upper = model.in_slices[1] + feature_dims["pos"] + feature_dims["rotMat2"]
print(pose_idx_upper)


287


In [29]:
model.autoregress_prob = 1
with torch.no_grad():
    model.eval()
    model.cpu()
    # for i in range(1):
    # original =
    # x = x_tensors[idx]
    x = torch.stack([test_set2[i][0] for i in idx])
    y = torch.stack([test_set2[i][1] for i in idx])
    shape = x.shape
    # x = x.view(-1, shape[-1])

    x = x.view(-1, 15, shape[-1])
    # n = shape[1]
    # g_frames = [x_c.unsqueeze(1)]
    g_frames = []
    #
    for i in range(0, 15):
        x_c = x[:,i,:]
        out= model(x_c)
        x_c = torch.cat(out,dim=1).detach()
        g_frames.append(x_c.unsqueeze(1))

    # out = torch.cat(model(x), dim=1).view(shape)
            # x_c = torch.cat(out, dim=1)
    # g_frames.append(x_c.unsqueeze(1))
        # original.append(o_frames)
    # generated.append(torch.cat(g_frames, dim=1))
    # generated = out
    print(g_frames[0].shape, g_frames[1].shape, g_frames[2].shape)
    generated = torch.cat(g_frames, dim=1)
    generated = generated.view(y.shape)


torch.Size([100, 1, 404]) torch.Size([100, 1, 404]) torch.Size([100, 1, 404])


In [30]:
print(y.size())
# generated = generated.view(-1, 298, generated.shape[-1])
print(generated.size())

torch.Size([5, 300, 404])
torch.Size([5, 300, 404])


In [31]:
phase_dim = in_slice[0]
toPosDim = phase_dim+feature_dims["pos"]
toRotDim = toPosDim + feature_dims["rotMat2"]

gPos = generated[:, :, phase_dim:toPosDim]
gRot = generated[:, :, toPosDim:toRotDim]

oPos = y[:, :, phase_dim:toPosDim]
oRot = y[:, :, toPosDim:toRotDim]

print(gPos.shape, gRot.shape)
print(oPos.shape, oRot.shape)


torch.Size([5, 300, 93]) torch.Size([5, 300, 186])
torch.Size([5, 300, 93]) torch.Size([5, 300, 186])


In [32]:
n = 5
clip_length = gPos.shape[1]
gPos_r = gPos.reshape((n, clip_length, -1, 3))
gRot_r = gRot.reshape((n, clip_length, -1, 3, 2))
oPos_r = oPos.reshape((n, clip_length, -1, 3))
oRot_r = oRot.reshape((n, clip_length, -1, 3, 2))

print("Pos loss: ", torch.nn.functional.mse_loss(gPos_r, oPos_r))
print("Rot loss: ", torch.nn.functional.mse_loss(gRot_r, oRot_r))

Pos loss:  tensor(0.0090)
Rot loss:  tensor(0.0342)


In [33]:
template = js.load(open("/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json"))

In [34]:
def insert_pos(positions, rotations, name="Replay"):
    shape = positions.shape
    for c in range(shape[0]):
        for f in range(shape[1]):
            for j in range(shape[2]):
                template["frames"][f]["joints"][j]["position"]["x"] = positions[c,f,j,0].item()
                template["frames"][f]["joints"][j]["position"]["y"] = positions[c,f,j,1].item()
                template["frames"][f]["joints"][j]["position"]["z"] = positions[c,f,j,2].item()

                for r, cell in enumerate(["x", "y", "z"]):
                    for col, column in enumerate(["c0", "c1"]):
                        template["frames"][f]["joints"][j]["rotMat"][column][cell] = rotations[c,f,j,r, col].item()
        with open("{}_{}.json".format(name, c), "w") as f:
            js.dump(template, f)

insert_pos(oPos_r, oRot_r, "/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/Replay_original_6-31_R2")
insert_pos(gPos_r, gRot_r, "/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/Replay_generated_6-31_R2")
