In [3]:
import os, sys
sys.path.append("motion_generation")
sys.path.append("rig_agnostic_encoding/functions")
sys.path.append("rig_agnostic_encoding/models")

from motion_generation.MoE import MoE
from motion_generation.MoE_Z import MoE as MoE_Z
import motion_generation
from motion_generation.GRU import GRU
from motion_generation.GRU_Z import GRU as GRU_Z
from motion_generation.LSTM import LSTM
from motion_generation.LSTM_Z import LSTM as LSTM_Z
from motion_generation.MotionGeneration import MotionGenerationModel as MoGen
from motion_generation.MotionGenerationEmbedd import MotionGenerationModel as MoGenZ
from motion_generation.MotionGenerationVAE import MotionGenerationModel as MoGenVAE
from motion_generation.MotionGenerationVAE_Embedd import MotionGenerationModel as MoGenVAE_Z

from MLP import MLP
from MLP_Adversarial import MLP_ADV
from MLP_MIX import MLP_MIX
from MLP_MIX import MLP_layer

from RBF import RBF
from VAE import VAE
from DEC import DEC

In [9]:
from rig_agnostic_encoding.functions.DataProcessingFunctions import clean_checkpoints
from GlobalSettings import MODEL_PATH
import bz2
from cytoolz import concat, sliding_window, accumulate
from operator import add
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import func as F
import _pickle as pickle
import json as js
import importlib
import random
import Extract as ext
import ray

In [23]:
pose_label_feature = ["level"]
# ray.init(num_cpus=4)
r1 = "R1_template.json"
r2 = "R2_template.json"
r3 = "R3_template.json"
r4 = "R4_template.json"
r5 = "R5_template.json"

def parse_pose_label(path):
    pose_labels, _ = ray.get(ext.parse.remote(path, 0))
    pose_labels, _ = F.load_features(pickle.loads(pose_labels), pose_label_feature)
    pose_labels = pose_labels[0]
    pose_labels[pose_labels < 2] = 0
    pose_labels /= 2
    return pose_labels

pose_labels1 = parse_pose_label(r1)
pose_labels2 = parse_pose_label(r2)
pose_labels3 = parse_pose_label(r3)
pose_labels4 = parse_pose_label(r4)
pose_labels5 = parse_pose_label(r5)


In [24]:
F.save(pose_labels1, "pose_label1", "")
F.save(pose_labels2, "pose_label2", "")
F.save(pose_labels3, "pose_label3", "")
F.save(pose_labels4, "pose_label4", "")
F.save(pose_labels5, "pose_label5", "")

In [1]:
config = {
    "hidden_dim": 256,
    "k": 256,
    "z_dim": 256,
    "lr": 1e-4,
    "batch_size": 16,
    "keep_prob": 0,
    "loss_fn":torch.nn.functional.mse_loss,
    "optimizer":torch.optim.AdamW,
    "scheduler":torch.optim.lr_scheduler.StepLR,
    "scheduler_param": {"step_size":80, "gamma":.9},
    "basis_func":"gaussian",
    "n_centroid":64,
    "k_experts": 4,
    "gate_size": 128,
    "g_hidden_dim": 512,
    "num_layers": 4,
    "autoregress_prob":0,
    "autoregress_inc":.3,
    "autoregress_ep":20,
    "autoregress_max_prob":1,
    "cost_hidden_dim":128,
    "seq_len":13,
    "device":"cuda"
    }

NameError: name 'torch' is not defined

In [3]:
def getFilesNames(file_paths, data_path, MAX_FILES=-1):
    for dname, dirs, files in os.walk(data_path):
        for i, file in enumerate(files):
            file_paths.append(os.path.join(dname, file))
            if MAX_FILES > 0 and i >= MAX_FILES:
                break
    return file_paths

data_path = "/home/nuoc/Documents/MEX/data/data/Dataset_R1_Two_1"
data_path2 = "/home/nuoc/Documents/MEX/data/data/Dataset_R2_Two_1"
# data_path3 = "/home/nuoc/Documents/MEX/data/data/Dataset_R3_Two_1"
# data_path4 = "/home/nuoc/Documents/MEX/data/data/Dataset_R4_Two_1"
file_paths = getFilesNames([],data_path)
file_paths2 = getFilesNames([],data_path2)
# file_paths3 = getFilesNames([],data_path3)
# file_paths4 = getFilesNames([],data_path4)

print(len(file_paths))
print(len(file_paths2))





240
240


In [3]:
class MLP_ADV(pl.LightningModule):
    def __init__(self, config:dict=None, dimensions:list=None, pose_labels=None,
                 h_dim=0, w_dim=0,
                 train_set=None, val_set=None, test_set=None, pos_dim=0, rot_dim=0, vel_dim=0,
                 name:str="model", single_module:int=0, save_period=5,
                 workers=6):

        super(MLP_ADV, self).__init__()
        self.name = name
        self.dimensions = dimensions
        self.single_module = single_module
        self.h_dim = h_dim
        self.w_dim = w_dim
        # self.hparams = config
        self.act = nn.ELU
        self.save_period = save_period
        self.workers = workers
        self.config=config

        self.pos_dim = pos_dim
        self.rot_dim = pos_dim + rot_dim
        self.vel_dim = self.rot_dim + vel_dim

        self.hidden_dim = config["hidden_dim"]
        self.k = config["k"]
        self.learning_rate = config["lr"]
        self.batch_size = config["batch_size"]

        self.dimensions = dimensions if len(dimensions) > 1 else \
            [dimensions[0], self.hidden_dim, self.hidden_dim, self.k]

        self.loss_fn = config["loss_fn"] if "loss_fn" in config else nn.functional.mse_loss
        self.opt = config["optimizer"] if "optimizer" in config else torch.optim.AdamW
        self.scheduler = config["scheduler"] if "scheduler" in config else None
        self.scheduler_param = config["scheduler_param"] if "scheduler_param" in config else None

        self.pose_labels = pose_labels  # should be Tensor(1,63) for example
        self.use_label = pose_labels is not None

        self.train_set, self.val_set, self.test_set = train_set, val_set, test_set
        self.best_val_loss = np.inf

        self.encoder, self.decoder = nn.Module(), nn.Module()
        self.build()

        if "device" not in config:
            config["device"] = "cuda"

        self.encoder.to(config["device"])
        self.decoder.to(config["device"])

        self.encoder.apply(self.init_params)
        self.decoder.apply(self.init_params)

        h_out = int((((((h_dim-2) - 3) / 3 + 1) - 2) - 3 ) / 3 + 1)
        w_out = int((((((w_dim-2) - 3) / 3 + 1) - 2) - 3 ) / 3 + 1)

        self.convDiscriminator = nn.Sequential(
            nn.Conv2d(
                in_channels = 1,out_channels = 1,
                kernel_size = 3,stride = 1,padding = 0,
            ),
            nn.MaxPool2d(kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(1), nn.ELU(),
            nn.Conv2d(
                in_channels=1, out_channels=1,
                kernel_size=3, stride=1, padding=0,
            ),
            nn.MaxPool2d(kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(1), nn.ELU(),
            nn.Flatten(),
            nn.Linear(in_features=int(h_out*w_out), out_features=1)
        )
        # self.automatic_optimization = False
        self.save_hyperparameters()

    def build(self):
        layer_sizes = list(sliding_window(2, self.dimensions))
        if self.single_module == -1 or self.single_module == 0:
            layers = []
            for i, size in enumerate(layer_sizes):
                layers.append(("fc"+str(i), nn.Linear(size[0], size[1])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
            self.encoder = nn.Sequential(OrderedDict(layers))
        else:
            self.encoder = nn.Sequential()

        if self.single_module == 0 or self.single_module == 1:
            layers = []
            if self.pose_labels is not None:
                layer_sizes[-1] = (layer_sizes[-1][0], layer_sizes[-1][1] + self.pose_labels.numel())
            for i, size in enumerate(layer_sizes[-1::-1]):
                layers.append(("fc"+str(i), nn.Linear(size[1], size[0])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
            self.decoder = nn.Sequential(OrderedDict(layers))
        else:
            self.decoder = nn.Sequential()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))

    def encode(self, x):
        return self.encoder(x)

    def decode(self, h):
        if self.use_label:
            h = torch.cat((h, self.pose_labels.expand(h.shape[0],-1)), dim=1)
        return self.decoder(h)

    def decode_label(self, h):
        h = torch.cat((h, self.pose_labels.expand(h.shape[0],-1)), dim=1)
        return self.decoder(h)

    def loss(self, x, y):
        px, py = x[:, :self.pos_dim].detach(), y[:, :self.pos_dim].detach()
        rx, ry = x[:, self.pos_dim:self.rot_dim].detach() % 2 * np.pi, y[:, self.pos_dim:self.rot_dim].detach() % 2 * np.pi

        px_norm, py_norm = torch.sum(px ** 2), torch.sum(py ** 2)
        pos_loss = torch.mean((px - py) ** 2 / (px_norm * py_norm))

        rot_loss = nn.functional.mse_loss(rx, ry)
        recon_loss = self.loss_fn(x, y)
        return recon_loss, pos_loss, rot_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, y = batch
        prediction = self(x)
        recon_loss, pos_loss, rot_loss = self.loss(prediction, y)
        # d_opt, g_opt = self.optimizers()
        # sch1, sch2 = self.lr_schedulers()

        # d_opt.zero_grad()
        # g_opt.zero_grad()

        if optimizer_idx == 0:
            d_real = self.convDiscriminator(y.unsqueeze(1))
            d_fake = self.convDiscriminator(prediction.unsqueeze(1))
            d_loss = 0.5 * (torch.mean(d_real - 1)**2 + torch.mean(d_fake**2))
            loss = d_loss
            # self.manual_backward(d_loss)
            # d_opt.step()
            # sch1.step()
            self.log("ptl/train_d_loss", d_loss)
        else:

            d_fake = self.convDiscriminator(prediction.unsqueeze(1))
            g_loss = 0.5 * torch.mean((d_fake-1)**2) + recon_loss
            # self.manual_backward(g_loss+recon_loss)
            # g_opt.step()
            # sch2.step()
            loss = g_loss
            self.log("ptl/train_g_loss", g_loss)

        self.log("ptl/train_loss", recon_loss)
        self.log("ptl/train_pos_loss", pos_loss)
        self.log("ptl/train_rot_loss", rot_loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        recon_loss, pos_loss, rot_loss = self.loss(prediction, y)

        d_real = self.convDiscriminator(y.unsqueeze(1))
        d_fake = self.convDiscriminator(prediction.unsqueeze(1))
        d_loss = 0.5 * (torch.mean(d_real - 1)**2) + torch.mean(d_fake**2)

        d_fake = self.convDiscriminator(prediction.unsqueeze(1))
        g_loss = 0.5 * torch.mean((d_fake-1)**2)

        self.log("ptl/val_loss", recon_loss, prog_bar=True)
        self.log("ptl/val_d_loss", d_loss, prog_bar=True)
        self.log("ptl/val_g_loss", g_loss, prog_bar=True)
        self.log("ptl/val_pos_loss", pos_loss, prog_bar=True)
        self.log("ptl/val_rot_loss", rot_loss, prog_bar=True)
        return {"val_loss":recon_loss}

    def test_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        recon_loss, pos_loss, rot_loss = self.loss(prediction, y)

        d_real = self.convDiscriminator(y.unsqueeze(1))
        d_fake = self.convDiscriminator(prediction.unsqueeze(1))
        d_loss = 0.5 * (torch.mean(d_real - 1)**2) + torch.mean(d_fake**2)

        d_fake = self.convDiscriminator(prediction.unsqueeze(1))
        g_loss = 0.5 * torch.mean((d_fake-1)**2)

        self.log("ptl/test_loss", recon_loss, prog_bar=True)
        self.log("ptl/test_d_loss", d_loss, prog_bar=True)
        self.log("ptl/test_g_loss", g_loss, prog_bar=True)
        self.log("ptl/test_pos_loss", pos_loss)
        self.log("ptl/test_rot_loss", rot_loss)
        return {"test_loss":recon_loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)
        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            if self.current_epoch % self.save_period == 0:
                self.save_checkpoint(best_val_loss=self.best_val_loss.item())

    def save_checkpoint(self, best_val_loss:float=np.inf, checkpoint_dir=MODEL_PATH):
        config = {
            "hidden_dim":self.hidden_dim,
            "k":self.k,
            "lr":self.learning_rate,
            "batch_size":self.batch_size,
            "optimizer":self.opt,
            "scheduler":self.scheduler,
            "scheduler_param":self.scheduler_param,
            "device":self.config["device"],
        }
        model = {"config":config, "name":self.name,"dimensions":self.dimensions,
                 "pose_labels": self.pose_labels,
                 "single_module":self.single_module,
                 "dims": [self.pos_dim, self.rot_dim, self.vel_dim, self.h_dim, self.w_dim],
                 "encoder":self.encoder.state_dict(),
                 "decoder":self.decoder.state_dict(),
                 "discriminator":self.convDiscriminator.state_dict()}

        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        path = os.path.join(checkpoint_dir, self.name)
        if not os.path.exists(path):
            os.mkdir(path)

        filePath = os.path.join(path, str(best_val_loss)+"."+str(self.k)+".pbz2")
        with bz2.BZ2File(filePath, "w") as f:
            pickle.dump(model, f)
        return filePath

    @staticmethod
    def load_checkpoint(filePath):
        with bz2.BZ2File(filePath, "rb") as f:
            obj = pickle.load(f)
        model = MLP_ADV(config=obj["config"], single_module=obj["single_module"], pose_labels=obj["pose_labels"],
                        h_dim=obj["dims"][-2], w_dim=obj["dims"][-1],
                    name=obj["name"], dimensions=obj["dimensions"])

        model.encoder.load_state_dict(obj["encoder"])
        model.decoder.load_state_dict(obj["decoder"])
        model.convDiscriminator.load_state_dict(obj["discriminator"])
        model.pos_dim = obj["dims"][0]
        model.rot_dim = obj["dims"][1]
        model.vel_dim = obj["dims"][2]

        return model

    def freeze(self, flag=False):
        self.encoder.requires_grad_(flag)
        self.decoder.requires_grad_(flag)
        self.convDiscriminator.requires_grad_(flag)

    def configure_optimizers(self):
        optimizer_D = self.opt(self.convDiscriminator.parameters(), lr=self.learning_rate)
        optimizer_G = self.opt(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.learning_rate)
        if self.scheduler is not None:
            scheduler_D = self.scheduler(optimizer_D, **self.scheduler_param)
            scheduler_G = self.scheduler(optimizer_G, **self.scheduler_param)
            return [optimizer_D, optimizer_G], [scheduler_D, scheduler_G]
        else:
            return [optimizer_D, optimizer_G]

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, pin_memory=True, num_workers=self.workers)

    @staticmethod
    def init_params(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(.01)


In [5]:
phase_features = ["phase_vec_l2"]
pose_features = ["pos", "rotMat2", "velocity"]
cost_features = ["posCost", "rotCost"]
pose_label_feature = ["chainPos", "isLeft", "geoDistanceNormalised"]
target_features = ["targetPosition", "targetRotation"]
features = phase_features + pose_features + cost_features + target_features
clips = []
feature_dims = {}


In [5]:
data = F.process_data_multithread(file_paths, features)
data2 = F.process_data_multithread(file_paths2, features)

2021-05-14 14:18:09,737	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-05-14 14:18:31,022	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


In [6]:
pose_labels1 = F.process_data_multithread([file_paths[0]], pose_label_feature)
pose_labels2 = F.process_data_multithread([file_paths2[0]], pose_label_feature)

2021-05-14 14:18:52,963	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-05-14 14:18:58,016	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


In [7]:
obj = {"data":data, "data2":data2, "pose_label1":pose_labels1, "pose_label2":pose_labels2}
F.save(obj, filename="transfer_learning_set_R1-R2_Two_wTarget", path="/home/nuoc/Documents/MEX/data/")

In [3]:
obj = F.load("/home/nuoc/Documents/MEX/data/transfer_learning_set_R1-R2_Two_wTarget.pbz2")
data = obj["data"]
data2 = obj["data2"]

In [4]:
pose_labels1 = obj["pose_label1"]
pose_labels2 = obj["pose_label2"]

In [6]:
feature_dims = data[0][1]
feature_dims2 = data2[0][1]
clips = [np.copy(d[0]) for d in data]
clips2 = [np.copy(d[0]) for d in data2]

pose_labels1 = pose_labels1[0]
pose_labels2 = pose_labels2[0]

In [7]:
phase_dim = sum([feature_dims[feature] for feature in phase_features])
pose_dim = sum([feature_dims[feature] for feature in pose_features])
pose_dim2 = sum([feature_dims2[feature] for feature in pose_features])
cost_dim = sum([feature_dims[feature] for feature in cost_features])
target_dim = sum([feature_dims[feature] for feature in target_features])
print(phase_dim, " ", cost_dim, " ", target_dim)

8   24   48


In [8]:
x_tensors = torch.stack([F.normaliseT(torch.from_numpy(clip[:-1])).float() for clip in clips])
y_tensors = torch.stack([torch.from_numpy(clip[1:]).float() for clip in clips])

x_tensors2 = torch.stack([F.normaliseT(torch.from_numpy(clip[:-1])).float() for clip in clips2])
y_tensors2 = torch.stack([torch.from_numpy(clip[1:]).float() for clip in clips2])


pose_data1 = x_tensors[:,  :,  phase_dim:phase_dim+pose_dim]
pose_data2 = x_tensors2[:, :, phase_dim:phase_dim+pose_dim2]
pose_data = torch.cat((pose_data1, pose_data2), dim=2)

In [9]:
dataset_p = TensorDataset(pose_data, pose_data)
dataset_p1 = TensorDataset(pose_data1, pose_data1)
dataset_p2 = TensorDataset(pose_data2, pose_data2)
datasetR1 = TensorDataset(x_tensors, y_tensors)
datasetR2 = TensorDataset(x_tensors2, y_tensors2)

N = len(x_tensors)

train_ratio = int(.7*N)
val_ratio = int((N-train_ratio) / 2.0)
test_ratio = N - train_ratio - val_ratio

train_set_p, val_set_p, test_set_p = random_split(dataset_p, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_set_p1, val_set_p1, test_set_p1 = random_split(dataset_p1, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_set_p2, val_set_p2, test_set_p2 = random_split(dataset_p2, [val_ratio*2, val_ratio, train_ratio-test_ratio], generator=torch.Generator().manual_seed(2021))
train_set_p2_F, val_set_p2_F, test_set_p2_F = random_split(dataset_p2, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_setR1, val_setR1, test_setR1 = random_split(datasetR1, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
train_setR2, val_setR2, test_setR2 = random_split(datasetR2, [val_ratio, val_ratio, train_ratio], generator=torch.Generator().manual_seed(2021))
train_setR2_F, val_setR2_F, test_setR2_F = random_split(datasetR2, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))

In [None]:
def extract_targets(train_set, val_set, test_set, target_dim):
    t1, t2, t3, t33 = [], [], [], []
    for i in range(len(train_set)):
        x = train_set[i][0]
        y = train_set[i][1]
        t1.append((x[:, :-target_dim], y[:, :-target_dim]))
    for i in range(len(val_set)):
        x = val_set[i][0]
        y = val_set[i][1]
        t2.append((x[:, :-target_dim:], y[:, :-target_dim:]))
    for i in range(len(test_set)):
        x = test_set[i][0]
        y = test_set[i][1]
        t3.append((x[:, :-target_dim:], y[:, :-target_dim:]))
        t33.append((x[:, :-target_dim:], y))
    return t1, t2, t3, t33




In [None]:
t1, v1, te1, te11 = extract_targets(train_set=train_setR1, val_set=val_setR1, test_set=test_setR1, target_dim=target_dim)
t2, v2, te2, te22 = extract_targets(train_set=train_setR2, val_set=val_setR2, test_set=test_setR2, target_dim=target_dim)

# AE ONLY

In [10]:
h_dim = train_set_p1[0][0].shape[0]
w_dim = train_set_p1[0][0].shape[1]

ae_name1 = "AE_R1"
# ae1 = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
#                  pos_dim=feature_dims["pos"], rot_dim=feature_dims["rotMat2"], vel_dim=feature_dims["velocity"],
#                  train_set=train_set_p1, val_set=val_set_p1, test_set=test_set_p1, name=ae_name1)
#
# ae1 = MLP_ADV.load_checkpoint("/home/nuoc/Documents/MEX/models/version_0.3/AE_R1/0.001.256.pbz2")
# ae1.name = ae_name1
def load_ae(path):
    with bz2.BZ2File(path, "rb") as f:
        obj = pickle.load(f)
    model = MLP_ADV(config=obj["config"], single_module=obj["single_module"], pose_labels=obj["pose_labels"],
                h_dim=h_dim, w_dim=w_dim,
                name=obj["name"], dimensions=obj["dimensions"])

    model.encoder.load_state_dict(obj["encoder"])
    model.decoder.load_state_dict(obj["decoder"])
    model.convDiscriminator.load_state_dict(obj["discriminator"])
    model.pos_dim = obj["dims"][0]
    model.rot_dim = obj["dims"][1]
    model.vel_dim = obj["dims"][2]
    return model

ae1 = load_ae("/home/nuoc/Documents/MEX/models/version_0.3/AE_R1/0.001.256.pbz2")

In [13]:
ae_name2 = "AE_R2"
ae2 = MLP_ADV(config=config, dimensions=[pose_dim2],h_dim=h_dim, w_dim=w_dim,
                 pos_dim=feature_dims2["pos"], rot_dim=feature_dims2["rotMat2"], vel_dim=feature_dims2["velocity"],
                 train_set=train_set_p2_F, val_set=val_set_p2_F, test_set=test_set_p2_F, name=ae_name2)
#
# ae2 = MLP_ADV.load_checkpoint("/home/nuoc/Documents/MEX/models/version_0.3/AE_R2/0.002.256.pbz2")
# ae2 = load_ae("/home/nuoc/Documents/MEX/models/version_0.3/AE_R2/0.002.256.pbz2")
# ae2.name = ae_name2

In [5]:
ae_name2 = "AE_R2"
ae2 = MLP_ADV(config=config, dimensions=[100], w_dim=100, h_dim=100, name=ae_name2)


In [6]:
print(ae2.summarize())


  | Name              | Type       | Params
-------------------------------------------------
0 | encoder           | Sequential | 157 K 
1 | decoder           | Sequential | 157 K 
2 | convDiscriminator | Sequential | 125   
-------------------------------------------------
314 K     Trainable params
0         Non-trainable params
314 K     Total params
1.259     Total estimated model params size (MB)


  | Name              | Type       | Params
-------------------------------------------------
0 | encoder           | Sequential | 157 K 
1 | decoder           | Sequential | 157 K 
2 | convDiscriminator | Sequential | 125   
-------------------------------------------------
314 K     Trainable params
0         Non-trainable params
314 K     Total params
1.259     Total estimated model params size (MB)


In [20]:
ae1.test_set = test_set_p1
ae2.test_set = test_set_p2


In [21]:
trainer = pl.Trainer()
trainer.test(ae1)
trainer.test(ae2)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_d_loss': 0.39770177006721497,
 'ptl/test_g_loss': 0.11850268393754959,
 'ptl/test_loss': 0.060140326619148254,
 'ptl/test_pos_loss': 4.877923694308894e-13,
 'ptl/test_rot_loss': 1.749091386795044,
 'test_loss': 0.07371098548173904}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_d_loss': 0.3837208151817322,
 'ptl/test_g_loss': 0.12805208563804626,
 'ptl/test_loss': 0.051305778324604034,
 'ptl/test_pos_loss': 2.1457103671976285e-13,
 'ptl/test_rot_loss': 3.004026412963867,
 'test_loss': 0.06679833680391312}
--------------------------------------------------------------------------------


[{'test_loss': 0.06679833680391312,
  'ptl/test_loss': 0.051305778324604034,
  'ptl/test_d_loss': 0.3837208151817322,
  'ptl/test_g_loss': 0.12805208563804626,
  'ptl/test_pos_loss': 2.1457103671976285e-13,
  'ptl/test_rot_loss': 3.004026412963867}]

In [103]:
def fit(model, name, version="0.1", MIN_EPOCHS=20, MAX_EPOCHS=100, useEarlyStopping=False, patience=10):
    if useEarlyStopping:
        earlystopping = EarlyStopping(monitor="avg_val_loss",patience=patience)
        callbacks = [earlystopping]
    else:
        callbacks = []
    logger=TensorBoardLogger(save_dir="logs/", name=name, version=version)

    trainer = pl.Trainer(logger=logger, gpus=1, precision=16)
    model.autoregress_prob = 1
    trainer.test(model)
    model.autoregress_prob = .3

    trainer = pl.Trainer(
        default_root_dir="/home/nuoc/Documents/MEX/src/motion_generation/checkpoints",
        gpus=1, precision=16,
        callbacks= callbacks,
        min_epochs=MIN_EPOCHS,
        logger=logger,
        max_epochs=MAX_EPOCHS,
        stochastic_weight_avg=True
    )
    trainer.logger.log_hyperparams(model.hparams_initial)
    trainer.fit(model)
    model.autoregress_prob = 1
    trainer.test(model)

In [None]:
fit(ae1, ae_name1, MAX_EPOCHS=300, useEarlyStopping=True)

In [None]:
fit(ae2, ae_name2, version="0.2", MAX_EPOCHS=300, useEarlyStopping=True)

In [None]:
clean_checkpoints(os.path.join(MODEL_PATH,ae_name1))
clean_checkpoints(os.path.join(MODEL_PATH,ae_name2))

ae1.save_checkpoint(best_val_loss=0.001)
ae2.save_checkpoint(best_val_loss=0.002)

In [None]:
generate_animation_ae(model=ae1, test_set=test_set_p1, feature_dims=feature_dims,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Replay/"+ae_name1)
generate_animation_ae(model=ae2, test_set=test_set_p2_F, feature_dims=feature_dims2,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Replay/"+ae_name2)


In [37]:

class MLP_layer(nn.Module):
    def __init__(self, in_features=0, hidden_dim=0, out_features=0, device="cuda"):
        super().__init__()
        self.device=device
        if in_features==0:
            self.model = nn.Sequential()
        else:
            self.model = nn.Sequential(
                nn.Linear(in_features=in_features, out_features=hidden_dim),
                nn.ELU(),
                nn.Linear(in_features=hidden_dim, out_features=out_features)
            )
    def forward(self, x):
        return self.model(x)
    def loss(self, *args):
        return torch.zeros(1, device=self.device)

In [100]:
model_name = "MLP_MoE_R1_ADV"
model_name2 = "MLP_MoE_R2_ADV"

featureDim = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "target_dim":target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims["pos"],
    "rot_dim":feature_dims["rotMat2"],
    "vel_dim":feature_dims["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }
featureDim2 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim2,
    "cost_dim": cost_dim,
    "target_dim": target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims2["pos"],
    "rot_dim":feature_dims2["rotMat2"],
    "vel_dim":feature_dims2["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }

in_slice = [phase_dim, pose_dim, cost_dim, target_dim]
in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]

out_slice = [phase_dim, config["k"], cost_dim]

temp = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
               pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])
temp2 = MLP_ADV(config=config, dimensions=[pose_dim2],h_dim=h_dim, w_dim=w_dim,
               pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])

pose_encoder = temp
pose_encoder2 = temp2
pose_encoder.encoder.load_state_dict(ae1.encoder.state_dict())
pose_encoder.decoder.load_state_dict(ae1.decoder.state_dict())
pose_encoder.convDiscriminator.load_state_dict(ae1.convDiscriminator.state_dict())

pose_encoder2.encoder.load_state_dict(ae2.encoder.state_dict())
pose_encoder2.decoder.load_state_dict(ae2.decoder.state_dict())
pose_encoder2.convDiscriminator.load_state_dict(ae2.convDiscriminator.state_dict())

<All keys matched successfully>

In [39]:
middle_layer = MLP_layer(in_features=0, hidden_dim=0, out_features=0)
model1 = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim, use_advLoss=True,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_setR1, val_set=val_setR1, test_set=val_setR1+test_setR1,
                                 name=model_name
                                   )

In [101]:
model2 = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder2, middle_layer=middle_layer,
                                 feature_dims=featureDim2, use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2, val_set=val_setR2, test_set=val_setR2+test_setR2,
                                 name=model_name2
                                   )

In [26]:
model2_F = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder2, middle_layer=middle_layer,
                                 feature_dims=featureDim2, use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2_F, val_set=val_setR2_F, test_set=val_setR2_F+test_setR2_F,
                                 name=model_name2
                                   )

In [40]:
fit(model1, version="0.5", MAX_EPOCHS=120, name=model_name)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | MLP_layer | 0     
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
2.9 M     Trainable params
0         Non-trainable params
2.9 M     Total params
11.785    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.0,
 'ptl/test_adv_loss': 0.7820941805839539,
 'ptl/test_loss': 0.24763910472393036,
 'ptl/test_min_pos_cost': 0.8251054286956787,
 'ptl/test_min_rot_cost': 0.7447419762611389,
 'ptl/test_min_target_pos_cost': 0.3732735812664032,
 'ptl/test_min_target_rot_cost': 0.34755954146385193,
 'ptl/test_pos_loss': 4.514921556619811e-07,
 'ptl/test_rot_loss': 12.75084114074707,
 'ptl/test_sum_pos_cost': 9.909703254699707,
 'ptl/test_sum_rot_cost': 8.95154094696045,
 'ptl/test_sum_target_pos_cost': 4.481944561004639,
 'ptl/test_sum_target_rot_cost': 4.179253578186035,
 'test_loss': 0.24354878067970276}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…





HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.0,
 'ptl/test_adv_loss': 0.1426272690296173,
 'ptl/test_loss': 0.08220662921667099,
 'ptl/test_min_pos_cost': 0.8256732225418091,
 'ptl/test_min_rot_cost': 0.7462548017501831,
 'ptl/test_min_target_pos_cost': 0.3732735812664032,
 'ptl/test_min_target_rot_cost': 0.34755954146385193,
 'ptl/test_pos_loss': 5.829712823590683e-10,
 'ptl/test_rot_loss': 8.5796537399292,
 'ptl/test_sum_pos_cost': 9.911153793334961,
 'ptl/test_sum_rot_cost': 8.95975399017334,
 'ptl/test_sum_target_pos_cost': 4.481944561004639,
 'ptl/test_sum_target_rot_cost': 4.179253578186035,
 'test_loss': 0.09885984659194946}
--------------------------------------------------------------------------------


In [104]:
model2.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2.generationModel.load_state_dict(model1.generationModel.state_dict())
model2.generationModel.gate.requires_grad_(False)
model2.generationModel.requires_grad_(False)
model2.name="hp_param_test_model22222"
fit(model2, name=model2.name, version="0.2", MAX_EPOCHS=1)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | MLP_layer | 0     
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
497 K     Trainable params
2.4 M     Non-trainable params
2.9 M     Total params
11.785    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.0,
 'ptl/test_adv_loss': 0.15142472088336945,
 'ptl/test_loss': 0.17537879943847656,
 'ptl/test_min_pos_cost': 0.7270989418029785,
 'ptl/test_min_rot_cost': 0.645555317401886,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': 1.5835328603941434e-09,
 'ptl/test_rot_loss': 5.346451282501221,
 'ptl/test_sum_pos_cost': 8.732467651367188,
 'ptl/test_sum_rot_cost': 7.755195617675781,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': 0.16512751579284668}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.0,
 'ptl/test_adv_loss': 0.1578119844198227,
 'ptl/test_loss': 0.17361554503440857,
 'ptl/test_min_pos_cost': 0.7270669937133789,
 'ptl/test_min_rot_cost': 0.6457052826881409,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': 1.6307420969141617e-09,
 'ptl/test_rot_loss': 5.489764213562012,
 'ptl/test_sum_pos_cost': 8.729594230651855,
 'ptl/test_sum_rot_cost': 7.756972312927246,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': 0.16107825934886932}
--------------------------------------------------------------------------------


In [29]:
model2_F.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2_F.generationModel.load_state_dict(model1.generationModel.state_dict())

fit(model2_F, name=model_name2, version="full", MAX_EPOCHS=100)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type       | Params
------------------------------------------------
0 | pose_autoencoder | MLP_ADV    | 455 K 
1 | middle_layer     | Sequential | 0     
2 | cost_encoder     | MLP        | 42.4 K
3 | generationModel  | MoE        | 2.4 M 
------------------------------------------------
2.9 M     Trainable params
0         Non-trainable params
2.9 M     Total params
11.785    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.16165263950824738,
 'ptl/test_loss': 0.07438100874423981,
 'ptl/test_min_pos_cost': -0.2278854101896286,
 'ptl/test_min_rot_cost': 0.07796655595302582,
 'ptl/test_min_target_pos_cost': -0.07601615786552429,
 'ptl/test_min_target_rot_cost': 0.10112982988357544,
 'ptl/test_pos_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': -2.7257392406463623,
 'ptl/test_sum_rot_cost': 0.9410366415977478,
 'ptl/test_sum_target_pos_cost': -0.8902585506439209,
 'ptl/test_sum_target_rot_cost': 1.2244354486465454,
 'test_loss': 0.0652601420879364}
--------------------------------------------------------------------------------


In [75]:
# clean_checkpoints(os.path.join(MODEL_PATH,model_name))
clean_checkpoints(os.path.join(MODEL_PATH,model_name2))

# model1.save_checkpoint(best_val_loss=0.001)
model2.save_checkpoint(best_val_loss=0.001)
# model2_F.save_checkpoint(best_val_loss=0.001)

'/home/nuoc/Documents/MEX/models/version_0.3/MLP_MoE_R2_ADV/0.001.pbz2'

In [74]:
# generate_animation(model=model1, test_set=test_setR1, feature_dims=feature_dims,
#                    template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
#                    output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Replay/"+model_name)
generate_animation(model=model2, test_set=test_setR2, feature_dims=feature_dims2,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Replay/"+model_name2+"v2")
# generate_animation(model=model2_F, test_set=test_setR2, feature_dims=feature_dims2,
#                    template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
#                    output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Replay/"+model_name2+"F")


torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])
torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])


# MLP MIX - Z as Input

In [13]:
featureDim = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "target_dim":target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims["pos"],
    "rot_dim":feature_dims["rotMat2"],
    "vel_dim":feature_dims["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }
featureDim2 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim2,
    "cost_dim": cost_dim,
    "target_dim": target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims2["pos"],
    "rot_dim":feature_dims2["rotMat2"],
    "vel_dim":feature_dims2["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }


in_slice = [phase_dim, pose_dim, cost_dim, target_dim]
in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]

out_slice = [phase_dim, config["k"], cost_dim]

mlpmix_name1 = "MLPMIX-R1-Z-Input"
mlpmix1 = MLP_MIX(config=config, input_dims=[pose_dim])
temp = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
               pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"],)

mlpmix1.active_models[0] = temp
mlpmix1.active_models[0].encoder.load_state_dict(ae1.encoder.state_dict())
mlpmix1.active_models[0].decoder.load_state_dict(ae1.decoder.state_dict())

pose_encoder = mlpmix1.active_models[0]
middle_layer = mlpmix1.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae1.convDiscriminator.state_dict())

AttributeError: cannot assign module before Module.__init__() call

In [55]:
model1 = MoGenZ(config=config, Model=MoE_Z, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim, use_advLoss=True,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_setR1, val_set=val_setR1, test_set=test_setR1+val_setR1,
                                 name=mlpmix_name1
                                   )


In [23]:
fit(model1, mlpmix_name1, version="0.3",MAX_EPOCHS=150)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type       | Params
------------------------------------------------
0 | pose_autoencoder | MLP_ADV    | 455 K 
1 | middle_layer     | Sequential | 49.4 K
2 | cost_encoder     | MLP        | 42.4 K
3 | generationModel  | MoE        | 3.1 M 
------------------------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params
14.669    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.7822639346122742,
 'ptl/test_loss': 0.24796360731124878,
 'ptl/test_min_pos_cost': -0.506141722202301,
 'ptl/test_min_rot_cost': 0.3058273494243622,
 'ptl/test_min_target_pos_cost': -0.12436017394065857,
 'ptl/test_min_target_rot_cost': 0.3015017807483673,
 'ptl/test_pos_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': -6.050604820251465,
 'ptl/test_sum_rot_cost': 3.707205057144165,
 'ptl/test_sum_target_pos_cost': -1.4784694910049438,
 'ptl/test_sum_target_rot_cost': 3.63797664642334,
 'test_loss': 0.2492237240076065}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.1482972353696823,
 'ptl/test_loss': 0.07878115773200989,
 'ptl/test_min_pos_cost': -0.5057556629180908,
 'ptl/test_min_rot_cost': 0.3058273494243622,
 'ptl/test_min_target_pos_cost': -0.12436017394065857,
 'ptl/test_min_target_rot_cost': 0.3015017807483673,
 'ptl/test_pos_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': -6.061001300811768,
 'ptl/test_sum_rot_cost': 3.680527687072754,
 'ptl/test_sum_target_pos_cost': -1.4784694910049438,
 'ptl/test_sum_target_rot_cost': 3.63797664642334,
 'test_loss': 0.06701288372278214}
--------------------------------------------------------------------------------


In [51]:
mlpmix_name2 = "MLPMIX-R2-Z-Concat-Reduced"
mlpmix2 = MLP_MIX(config=config, input_dims=[pose_dim2])
temp = MLP_ADV(config=config, dimensions=[pose_dim2], h_dim=h_dim, w_dim=w_dim,)

in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

temp.encoder.load_state_dict(ae2.encoder.state_dict())
temp.decoder.load_state_dict(ae2.decoder.state_dict())

pose_encoder = temp
middle_layer = mlpmix2.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae2.convDiscriminator.state_dict())

middle_layer.load_state_dict(model1.middle_layer.state_dict())
middle_layer.requires_grad_(False)


model2 = MoGenZ(config=config, Model=MoE_Z, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim2,use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2, val_set=val_setR2, test_set=val_setR2+test_setR2,
                                 name=mlpmix_name2
                                   )

model2.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2.generationModel.load_state_dict(model1.generationModel.state_dict())

RuntimeError: Error(s) in loading state_dict for Sequential:
	Missing key(s) in state_dict: "0.weight", "0.bias", "2.weight", "2.bias". 
	Unexpected key(s) in state_dict: "centres", "sigmas". 

In [41]:
fit(model2, mlpmix_name2, version="0.2", MAX_EPOCHS=100)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type       | Params
------------------------------------------------
0 | pose_autoencoder | MLP_ADV    | 455 K 
1 | middle_layer     | Sequential | 49.4 K
2 | cost_encoder     | MLP        | 42.4 K
3 | generationModel  | MoE        | 3.1 M 
------------------------------------------------
3.6 M     Trainable params
49.4 K    Non-trainable params
3.7 M     Total params
14.669    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.7915453314781189,
 'ptl/test_loss': 0.2761686444282532,
 'ptl/test_min_pos_cost': -0.23226255178451538,
 'ptl/test_min_rot_cost': 0.08159427344799042,
 'ptl/test_min_target_pos_cost': -0.07874662429094315,
 'ptl/test_min_target_rot_cost': 0.10372575372457504,
 'ptl/test_pos_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': -2.763516664505005,
 'ptl/test_sum_rot_cost': 0.9894691705703735,
 'ptl/test_sum_target_pos_cost': -0.921808660030365,
 'ptl/test_sum_target_rot_cost': 1.2529796361923218,
 'test_loss': 0.274198442697525}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.16393332183361053,
 'ptl/test_loss': 0.1229998841881752,
 'ptl/test_min_pos_cost': -0.23197278380393982,
 'ptl/test_min_rot_cost': 0.08170578628778458,
 'ptl/test_min_target_pos_cost': -0.07874662429094315,
 'ptl/test_min_target_rot_cost': 0.10372575372457504,
 'ptl/test_pos_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': -2.7753686904907227,
 'ptl/test_sum_rot_cost': 0.9832243323326111,
 'ptl/test_sum_target_pos_cost': -0.921808660030365,
 'ptl/test_sum_target_rot_cost': 1.2529796361923218,
 'test_loss': 0.15640360116958618}
--------------------------------------------------------------------------------


In [42]:
clean_checkpoints(os.path.join(MODEL_PATH,mlpmix_name1))
clean_checkpoints(os.path.join(MODEL_PATH,mlpmix_name2))

model1.save_checkpoint(best_val_loss=0.001)
model2.save_checkpoint(best_val_loss=0.001)


'/home/nuoc/Documents/MEX/models/version_0.3/MLPMIX-R2-Z-Concat-Reduced/0.001.pbz2'

In [45]:
generate_animation(model=model1, test_set=test_setR1, feature_dims=feature_dims,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+mlpmix_name1)
generate_animation(model=model2, test_set=test_setR2, feature_dims=feature_dims2,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+mlpmix_name2)

torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])
torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])

torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])
torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])


# RBF -  Z as Input

In [22]:
featureDim = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "target_dim":target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims["pos"],
    "rot_dim":feature_dims["rotMat2"],
    "vel_dim":feature_dims["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }
featureDim2 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim2,
    "cost_dim": cost_dim,
    "target_dim": target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims2["pos"],
    "rot_dim":feature_dims2["rotMat2"],
    "vel_dim":feature_dims2["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }

rbf_name1 = "RBF-R1-Z-Input"
rbf1 = RBF(config=config, input_dims=[pose_dim],
           pos_dim=[featureDim["pos_dim"]], rot_dim=[featureDim["rot_dim"]], vel_dim=[featureDim["vel_dim"]],
           train_set=train_set_p1, val_set=val_set_p1, test_set=test_set_p1,
           name="RBF-z-Input")

temp = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
                 pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])

rbf1.active_models[0] = temp
rbf1.active_models[0].encoder.load_state_dict(ae1.encoder.state_dict())
rbf1.active_models[0].decoder.load_state_dict(ae1.decoder.state_dict())
rbf1.active_models[0].convDiscriminator.load_state_dict(ae1.convDiscriminator.state_dict())


<All keys matched successfully>

In [23]:
in_slice = [phase_dim, pose_dim, cost_dim, target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

pose_encoder = rbf1.active_models[0]
middle_layer = rbf1.cluster_model
# middle_layer.requires_grad_(False)

model1 = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim, use_advLoss=True,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_setR1, val_set=val_setR1, test_set=test_setR1+val_setR1,
                                 name=rbf_name1
                                   )

In [24]:
fit(model1, rbf_name1, version="0.2",MAX_EPOCHS=120)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | RBF_Layer | 65.8 K
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.048    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 1.2711893320083618,
 'ptl/test_adv_loss': 0.7819749116897583,
 'ptl/test_loss': 0.24783267080783844,
 'ptl/test_min_pos_cost': 0.825386643409729,
 'ptl/test_min_rot_cost': 0.7450899481773376,
 'ptl/test_min_target_pos_cost': 0.37322527170181274,
 'ptl/test_min_target_rot_cost': 0.34755951166152954,
 'ptl/test_pos_loss': 4.4907221763423877e-07,
 'ptl/test_rot_loss': 12.717830657958984,
 'ptl/test_sum_pos_cost': 9.910783767700195,
 'ptl/test_sum_rot_cost': 8.953682899475098,
 'ptl/test_sum_target_pos_cost': 4.4819440841674805,
 'ptl/test_sum_target_rot_cost': 4.179253101348877,
 'test_loss': 0.2490869164466858}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.07342609763145447,
 'ptl/test_adv_loss': 0.15071813762187958,
 'ptl/test_loss': 0.1079227402806282,
 'ptl/test_min_pos_cost': 0.8259875178337097,
 'ptl/test_min_rot_cost': 0.7457101941108704,
 'ptl/test_min_target_pos_cost': 0.37322527170181274,
 'ptl/test_min_target_rot_cost': 0.34755951166152954,
 'ptl/test_pos_loss': 5.331523000862148e-10,
 'ptl/test_rot_loss': 9.50162410736084,
 'ptl/test_sum_pos_cost': 9.912969589233398,
 'ptl/test_sum_rot_cost': 8.955196380615234,
 'ptl/test_sum_target_pos_cost': 4.4819440841674805,
 'ptl/test_sum_target_rot_cost': 4.179253101348877,
 'test_loss': 0.10524389147758484}
--------------------------------------------------------------------------------


In [52]:
rbf_name2 = "RBF-R2-Z-In-Reduced"
rbf2 = RBF(config=config, input_dims=[pose_dim2])
temp = MLP_ADV(config=config, dimensions=[pose_dim2], h_dim=h_dim, w_dim=w_dim,
                 pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])


in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

rbf2.active_models[0] = temp
rbf2.active_models[0].encoder.load_state_dict(ae2.encoder.state_dict())
rbf2.active_models[0].decoder.load_state_dict(ae2.decoder.state_dict())

pose_encoder = rbf2.active_models[0]
middle_layer = rbf2.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae2.convDiscriminator.state_dict())

middle_layer.load_state_dict(model1.middle_layer.state_dict())
middle_layer.requires_grad_(False)


model2 = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim2, use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2, val_set=val_setR2, test_set=val_setR2+test_setR2,
                                 name=rbf_name2
                                   )

model2.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2.generationModel.load_state_dict(model1.generationModel.state_dict())

<All keys matched successfully>

In [53]:
fit(model2, rbf_name2, version="0.1", MAX_EPOCHS=100)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | RBF_Layer | 65.8 K
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
2.9 M     Trainable params
65.8 K    Non-trainable params
3.0 M     Total params
12.048    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.07342610508203506,
 'ptl/test_adv_loss': 0.6219784021377563,
 'ptl/test_loss': 0.366985946893692,
 'ptl/test_min_pos_cost': 0.7270944714546204,
 'ptl/test_min_rot_cost': 0.6450060606002808,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': 1.990012599151214e-08,
 'ptl/test_rot_loss': 10.056227684020996,
 'ptl/test_sum_pos_cost': 8.735350608825684,
 'ptl/test_sum_rot_cost': 7.751901626586914,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': 0.3664238154888153}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 0.07342610508203506,
 'ptl/test_adv_loss': 0.17492277920246124,
 'ptl/test_loss': 0.1535314917564392,
 'ptl/test_min_pos_cost': 0.727071225643158,
 'ptl/test_min_rot_cost': 0.645881175994873,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': 8.506585436052205e-10,
 'ptl/test_rot_loss': 5.528582572937012,
 'ptl/test_sum_pos_cost': 8.73080825805664,
 'ptl/test_sum_rot_cost': 7.7568039894104,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': 0.15254905819892883}
--------------------------------------------------------------------------------


In [54]:
clean_checkpoints(os.path.join(MODEL_PATH,rbf_name1))
clean_checkpoints(os.path.join(MODEL_PATH,rbf_name2))

model1.save_checkpoint(best_val_loss=0.001)
model2.save_checkpoint(best_val_loss=0.001)





'/home/nuoc/Documents/MEX/models/version_0.3/RBF-R2-Z-In-Reduced/0.001.pbz2'

In [55]:
# generate_animation(model=model1, test_set=test_setR1, feature_dims=feature_dims,
#                    template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
#                    output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+rbf_name1)
generate_animation(model=model2, test_set=test_setR2, feature_dims=feature_dims2,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+rbf_name2+"_after")




torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])
torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])


# vae -  Z as embedding

In [56]:
featureDim = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "target_dim":target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims["pos"],
    "rot_dim":feature_dims["rotMat2"],
    "vel_dim":feature_dims["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }
featureDim2 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim2,
    "cost_dim": cost_dim,
    "target_dim": target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims2["pos"],
    "rot_dim":feature_dims2["rotMat2"],
    "vel_dim":feature_dims2["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }

vae_name1 = "VAE-R1-Z-In"
vae1 = VAE(config=config, input_dims=[pose_dim],
           pos_dim=[featureDim["pos_dim"]], rot_dim=[featureDim["rot_dim"]], vel_dim=[featureDim["vel_dim"]],
           train_set=train_set_p1, val_set=val_set_p1, test_set=test_set_p1,
           name="VAE-z-concat")

temp = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
                 pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])


vae1.active_models[0] = temp
vae1.active_models[0].encoder.load_state_dict(ae1.encoder.state_dict())
vae1.active_models[0].decoder.load_state_dict(ae1.decoder.state_dict())
vae1.active_models[0].convDiscriminator.load_state_dict(ae1.convDiscriminator.state_dict())


<All keys matched successfully>

In [57]:
in_slice = [phase_dim, pose_dim, cost_dim,target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

pose_encoder = vae1.active_models[0]
middle_layer = vae1.cluster_model
# middle_layer.requires_grad_(False)

model1 = MoGenVAE(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim, use_advLoss=True,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_setR1, val_set=val_setR1, test_set=test_setR1+val_setR1,
                                 name=vae_name1
                                   )

In [58]:
fit(model1, vae_name1, version="0.1",MAX_EPOCHS=120)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | VAE_Layer | 131 K 
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
3.1 M     Trainable params
0         Non-trainable params
3.1 M     Total params
12.311    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.7820942401885986,
 'ptl/test_kl_loss': 0.008792479522526264,
 'ptl/test_loss': 0.2564104497432709,
 'ptl/test_min_pos_cost': 0.8249645233154297,
 'ptl/test_min_rot_cost': 0.7449467182159424,
 'ptl/test_min_target_pos_cost': 0.37322527170181274,
 'ptl/test_min_target_rot_cost': 0.34755951166152954,
 'ptl/test_pos_loss': 4.4932636455996544e-07,
 'ptl/test_recon_loss': 0.2476179301738739,
 'ptl/test_rot_loss': 12.654857635498047,
 'ptl/test_sum_pos_cost': 9.90969467163086,
 'ptl/test_sum_rot_cost': 8.951436042785645,
 'ptl/test_sum_target_pos_cost': 4.4819440841674805,
 'ptl/test_sum_target_rot_cost': 4.179253101348877,
 'test_loss': 0.25784143805503845}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.15199275314807892,
 'ptl/test_kl_loss': 0.0041326903738081455,
 'ptl/test_loss': 0.09219834208488464,
 'ptl/test_min_pos_cost': 0.8258420825004578,
 'ptl/test_min_rot_cost': 0.7457008361816406,
 'ptl/test_min_target_pos_cost': 0.37322527170181274,
 'ptl/test_min_target_rot_cost': 0.34755951166152954,
 'ptl/test_pos_loss': 5.199253250154356e-10,
 'ptl/test_recon_loss': 0.08806566148996353,
 'ptl/test_rot_loss': 9.319856643676758,
 'ptl/test_sum_pos_cost': 9.911890029907227,
 'ptl/test_sum_rot_cost': 8.956314086914062,
 'ptl/test_sum_target_pos_cost': 4.4819440841674805,
 'ptl/test_sum_target_rot_cost': 4.179253101348877,
 'test_loss': 0.09882960468530655}
--------------------------------------------------------------------------------


In [61]:
vae_name2 = "VAE-R2-Z-In-Reduced"
vae2 = VAE(config=config, input_dims=[pose_dim2])
temp = MLP_ADV(config=config, dimensions=[pose_dim2], h_dim=h_dim, w_dim=w_dim,
                 pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"])

in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

vae2.active_models[0] = temp
vae2.active_models[0].encoder.load_state_dict(ae2.encoder.state_dict())
vae2.active_models[0].decoder.load_state_dict(ae2.decoder.state_dict())

pose_encoder = vae2.active_models[0]
middle_layer = vae2.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae2.convDiscriminator.state_dict())

middle_layer.load_state_dict(model1.middle_layer.state_dict())
middle_layer.requires_grad_(False)


model2 = MoGenVAE(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim2, use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2, val_set=val_setR2, test_set=val_setR2+test_setR2,
                                 name=vae_name2
                                   )

model2.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2.generationModel.load_state_dict(model1.generationModel.state_dict())

<All keys matched successfully>

In [64]:
fit(model2, vae_name2, version="0.1", MAX_EPOCHS=100)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | VAE_Layer | 131 K 
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.4 M 
-----------------------------------------------
2.9 M     Trainable params
131 K     Non-trainable params
3.1 M     Total params
12.311    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': nan,
 'ptl/test_kl_loss': nan,
 'ptl/test_loss': nan,
 'ptl/test_min_pos_cost': 0.7270957231521606,
 'ptl/test_min_rot_cost': 0.6450048685073853,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': nan,
 'ptl/test_recon_loss': nan,
 'ptl/test_rot_loss': nan,
 'ptl/test_sum_pos_cost': 8.735037803649902,
 'ptl/test_sum_rot_cost': 7.752207279205322,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': nan}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_adv_loss': 0.17783232033252716,
 'ptl/test_kl_loss': 0.005066482350230217,
 'ptl/test_loss': 0.13565075397491455,
 'ptl/test_min_pos_cost': 0.7270960211753845,
 'ptl/test_min_rot_cost': 0.6454120874404907,
 'ptl/test_min_target_pos_cost': 0.4006081521511078,
 'ptl/test_min_target_rot_cost': 0.3534235954284668,
 'ptl/test_pos_loss': 7.213670771832881e-10,
 'ptl/test_recon_loss': 0.13058426976203918,
 'ptl/test_rot_loss': 5.291316509246826,
 'ptl/test_sum_pos_cost': 8.731210708618164,
 'ptl/test_sum_rot_cost': 7.75564432144165,
 'ptl/test_sum_target_pos_cost': 4.815293312072754,
 'ptl/test_sum_target_rot_cost': 4.250393867492676,
 'test_loss': 0.14767085015773773}
--------------------------------------------------------------------------------


In [65]:
clean_checkpoints(os.path.join(MODEL_PATH,vae_name1))
clean_checkpoints(os.path.join(MODEL_PATH,vae_name2))

model1.save_checkpoint(best_val_loss=0.001)
model2.save_checkpoint(best_val_loss=0.001)



'/home/nuoc/Documents/MEX/models/version_0.3/VAE-R2-Z-In-Reduced/0.001.pbz2'

In [66]:
# generate_animation(model=model1, test_set=test_setR1, feature_dims=feature_dims, use_vae=True,
#                    template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
#                    output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+vae_name1)
generate_animation(model=model2, test_set=test_setR2, feature_dims=feature_dims2,use_vae=True,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+vae_name2+"after")






torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])
torch.Size([5, 299, 93]) torch.Size([5, 299, 186]) torch.Size([5, 299, 93])


# DEC - Z as input

In [52]:

config["z_dim"]=128
featureDim = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim,
    "cost_dim": cost_dim,
    "target_dim":target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims["pos"],
    "rot_dim":feature_dims["rotMat2"],
    "vel_dim":feature_dims["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }
featureDim2 = {
    "phase_dim": phase_dim,
    "pose_dim": pose_dim2,
    "cost_dim": cost_dim,
    "target_dim": target_dim,
    "g_input_dim": config["z_dim"] + config["cost_hidden_dim"],
    "g_output_dim":phase_dim + config["k"] + cost_dim,
    "pos_dim":feature_dims2["pos"],
    "rot_dim":feature_dims2["rotMat2"],
    "vel_dim":feature_dims2["velocity"],
    "posCost":feature_dims["posCost"],
    "rotCost":feature_dims["rotCost"]
    }


in_slice = [phase_dim, pose_dim, cost_dim, target_dim]
in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]

out_slice = [phase_dim, config["k"], cost_dim]

dec_name1 = "DEC-R1-Z-IN"
dec1 = DEC(config=config, input_dims=[pose_dim])
temp = MLP_ADV(config=config, dimensions=[pose_dim], h_dim=h_dim, w_dim=w_dim,
               pos_dim=featureDim["pos_dim"], rot_dim=featureDim["rot_dim"], vel_dim=featureDim["vel_dim"],)

dec1.active_models[0] = temp
dec1.active_models[0].encoder.load_state_dict(ae1.encoder.state_dict())
dec1.active_models[0].decoder.load_state_dict(ae1.decoder.state_dict())

pose_encoder = dec1.active_models[0]
middle_layer = dec1.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae1.convDiscriminator.state_dict())

<All keys matched successfully>

In [53]:
model1 = MoGen(config=config, Model=MoE, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim, use_advLoss=True,
                                 input_slicers=in_slice, output_slicers=out_slice,
                                 train_set=train_setR1, val_set=val_setR1, test_set=test_setR1+val_setR1,
                                 name=dec_name1
                                   )


In [54]:
fit(model1, dec_name1, version="0.1",MAX_EPOCHS=120)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | pose_autoencoder | MLP_ADV   | 455 K 
1 | middle_layer     | DEC_Layer | 32.8 K
2 | cost_encoder     | MLP       | 42.4 K
3 | generationModel  | MoE       | 2.2 M 
-----------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.867    Total estimated model params size (MB)



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 5.432858642961946e-07,
 'ptl/test_adv_loss': 0.7819070816040039,
 'ptl/test_loss': 0.2475130558013916,
 'ptl/test_min_pos_cost': -0.505712628364563,
 'ptl/test_min_rot_cost': 0.3058273494243622,
 'ptl/test_min_target_pos_cost': -0.12412314862012863,
 'ptl/test_min_target_rot_cost': 0.3015017807483673,
 'ptl/test_pos_loss': 2.2392487153410912e-07,
 'ptl/test_rot_loss': 12.80390453338623,
 'ptl/test_sum_pos_cost': -6.046413898468018,
 'ptl/test_sum_rot_cost': 3.7061212062835693,
 'ptl/test_sum_target_pos_cost': -1.4784694910049438,
 'ptl/test_sum_target_rot_cost': 3.63797664642334,
 'test_loss': 0.24882711470127106}
--------------------------------------------------------------------------------


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]





HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'ptl/test_C_loss': 4.594359779730439e-05,
 'ptl/test_adv_loss': 0.23049050569534302,
 'ptl/test_loss': 0.15637226402759552,
 'ptl/test_min_pos_cost': -0.5058517456054688,
 'ptl/test_min_rot_cost': 0.3058273494243622,
 'ptl/test_min_target_pos_cost': -0.12412314862012863,
 'ptl/test_min_target_rot_cost': 0.3015017807483673,
 'ptl/test_pos_loss': 1.6896911647634738e-09,
 'ptl/test_rot_loss': 10.55378246307373,
 'ptl/test_sum_pos_cost': -6.050246238708496,
 'ptl/test_sum_rot_cost': 3.7064802646636963,
 'ptl/test_sum_target_pos_cost': -1.4784694910049438,
 'ptl/test_sum_target_rot_cost': 3.63797664642334,
 'test_loss': 0.15836910903453827}
--------------------------------------------------------------------------------


In [None]:
dec_name2 = "dec-R2-Z-Concat-Reduced"
dec2 = MLP_MIX(config=config, input_dims=[pose_dim2])
temp = MLP_ADV(config=config, dimensions=[pose_dim2], h_dim=h_dim, w_dim=w_dim,)

in_slice2 = [phase_dim, pose_dim2, cost_dim,target_dim]
out_slice = [phase_dim, config["k"], cost_dim]

temp.encoder.load_state_dict(ae2.encoder.state_dict())
temp.decoder.load_state_dict(ae2.decoder.state_dict())

pose_encoder = temp
middle_layer = dec2.cluster_model
pose_encoder.convDiscriminator.load_state_dict(ae2.convDiscriminator.state_dict())

middle_layer.load_state_dict(model1.middle_layer.state_dict())
middle_layer.requires_grad_(False)


model2 = MoGenZ(config=config, Model=MoE_Z, pose_autoencoder=pose_encoder, middle_layer=middle_layer,
                                 feature_dims=featureDim2,use_advLoss=True,
                                 input_slicers=in_slice2, output_slicers=out_slice,
                                 train_set=train_setR2, val_set=val_setR2, test_set=val_setR2+test_setR2,
                                 name=dec_name2
                                   )

model2.generationModel.gate.load_state_dict(model1.generationModel.gate.state_dict())
model2.generationModel.load_state_dict(model1.generationModel.state_dict())

In [None]:
fit(model2, dec_name2, version="0.2", MAX_EPOCHS=100)

In [None]:
clean_checkpoints(os.path.join(MODEL_PATH,dec_name1))
clean_checkpoints(os.path.join(MODEL_PATH,dec_name2))

model1.save_checkpoint(best_val_loss=0.001)
model2.save_checkpoint(best_val_loss=0.001)


In [None]:
generate_animation(model=model1, test_set=test_setR1, feature_dims=feature_dims,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/TestAll_1_R1_One_1/False_2_0.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+dec_name1)
generate_animation(model=model2, test_set=test_setR2, feature_dims=feature_dims2,
                   template_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/Test/R2.json",
                   output_path="/home/nuoc/.config/unity3d/DefaultCompany/Procedural Animation/version0.3/"+dec_name2)


In [62]:
def setVec3(struct, vec):
    struct["x"] = vec[0].item()
    struct["y"] = vec[1].item()
    struct["z"] = vec[2].item()

def setVec6(struct, vec):
    for r, cell in enumerate(["x", "y", "z"]):
        for col, column in enumerate(["c0", "c1"]):
            struct[column][cell] = vec[r, col].item()

def insert_pos(template,
               positions=None, rotations=None, velocity=None,
               tPos=None, tRot=None, name="Replay"):
    shape = positions.shape
    for c in range(shape[0]):
        for f in range(shape[1]):
            t = 0
            for j in range(shape[2]):
                jo = template["frames"][f]["joints"][j]
                if positions is not None:
                    setVec3(jo["position"], positions[c,f,j])
                if rotations is not None:
                    setVec3(jo["velocity"], velocity[c,f,j])
                if velocity is not None:
                    setVec6(jo["rotMat"], rotations[c,f,j])


                if jo["key"]:
                    if tPos is not None:
                        setVec3(jo["cost"]["TargetPosition"], tPos[c,f,t])
                    if tRot is not None:
                        setVec6(jo["cost"]["TargetRotation"], tRot[c,f,t])
                    t+=1
        with open("{}_{}.json".format(name, c), "w") as f:
            js.dump(template, f)

def generate_animation(model, test_set, feature_dims, template_path, output_path, use_vae=False, n=5):
    idx = np.arange(n)
    with torch.no_grad():
        model.eval()

        x = torch.stack([test_set[i][0] for i in idx])
        y = torch.stack([test_set[i][1] for i in idx])
        shape = x.shape
        # x = x.view(-1, config["seq_len"], shape[-1])
        # x_c = x[:, 0, :]
        x = x.view(-1, shape[-1])
        generated = []
        # for i in range(config["seq_len"]):
        if use_vae:
            out, z, mu, logvar = model(x.cuda())
        else:
            out, _ = model(x.cuda())
        # x_c = torch.cat(out,dim=1).detach()
        # generated.append(x_c.unsqueeze(1).cpu())
        generated = out
        # generated = generated.view(shape).to("cpu")
        generated = torch.cat(generated, dim=1).reshape(shape).cpu()

    phase= feature_dims["phase_vec_l2"]
    toPosDim = phase+feature_dims["pos"]
    toRotDim = toPosDim + feature_dims["rotMat2"]
    toVelDim = toRotDim + feature_dims["velocity"]

    gPos = generated[:, :, phase:toPosDim]
    gRot = generated[:, :, toPosDim:toRotDim]
    gVel = generated[:, :, toRotDim:toVelDim]

    oPos = y[:, :, phase:toPosDim]
    oRot = y[:, :, toPosDim:toRotDim]
    oVel = y[:, :, toRotDim:toVelDim]

    tPos = y[:, :, -target_dim:-target_dim+3*4]
    tRot = y[:, :, -target_dim+3*4:]

    print(gPos.shape, gRot.shape, gVel.shape)
    print(oPos.shape, oRot.shape, oVel.shape)

    clip_length = gPos.shape[1]
    gPos_r = gPos.reshape((n, clip_length, -1, 3))
    gRot_r = gRot.reshape((n, clip_length, -1, 3, 2))
    gVel_r = gVel.reshape((n, clip_length, -1, 3))

    oPos_r = oPos.reshape((n, clip_length, -1, 3))
    oRot_r = oRot.reshape((n, clip_length, -1, 3, 2))
    oVel_r = oVel.reshape((n, clip_length, -1, 3))

    tPos_r = tPos.reshape((n, clip_length, -1, 3))
    tRot_r = tRot.reshape((n, clip_length, -1, 3, 3))

    template = js.load(open(template_path))

    insert_pos(template, oPos_r, oRot_r, oVel_r, tPos_r, tRot_r,output_path+"_O")
    insert_pos(template, gPos_r, gRot_r, gVel_r, tPos_r, tRot_r,output_path+"_G")

def generate_animation_ae(model, test_set, feature_dims, template_path, output_path, n=5):
    idx = np.arange(n)
    with torch.no_grad():
        model.eval()
        model.cpu()

        x = torch.stack([test_set[i][0] for i in idx])
        y = torch.stack([test_set[i][1] for i in idx])
        shape = x.shape
        x = x.view(-1, shape[-1])

        x.to("cpu")
        out = model(x)

        generated = out
        generated = generated.view(shape)

    toPosDim = feature_dims["pos"]
    toRotDim = toPosDim + feature_dims["rotMat2"]
    toVelDim = toRotDim + feature_dims["velocity"]

    gPos = generated[:, :, :toPosDim]
    gRot = generated[:, :, toPosDim:toRotDim]
    gVel = generated[:, :, toRotDim:toVelDim]

    oPos = y[:, :, :toPosDim]
    oRot = y[:, :, toPosDim:toRotDim]
    oVel = y[:, :, toRotDim:toVelDim]

    print(gPos.shape, gRot.shape, gVel.shape)
    print(oPos.shape, oRot.shape, oVel.shape)

    clip_length = gPos.shape[1]
    gPos_r = gPos.reshape((n, clip_length, -1, 3))
    gRot_r = gRot.reshape((n, clip_length, -1, 3, 2))
    gVel_r = gVel.reshape((n, clip_length, -1, 3))

    oPos_r = oPos.reshape((n, clip_length, -1, 3))
    oRot_r = oRot.reshape((n, clip_length, -1, 3, 2))
    oVel_r = oVel.reshape((n, clip_length, -1, 3))

    template = js.load(open(template_path))

    insert_pos(template, oPos_r, oRot_r, oVel_r, None, None,output_path+"_O")
    insert_pos(template, gPos_r, gRot_r, gVel_r, None, None, output_path+"_G")