In [4]:

import os, sys, math, time
import numpy as np
import numpy.linalg as la
import plotly.graph_objects as go
import plotly.express as ex
from plotly.subplots import make_subplots
import pandas as pd

import json as js
import _pickle as pickle
import bz2
import ray

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from collections import OrderedDict

from cytoolz import sliding_window, accumulate, get
import pytorch_lightning as pl
from operator import add
from tabulate import tabulate

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback

import ray
import ray.tune as tune


sys.path.append("../")
sys.path.append("../rig_agnostic_encoding")
sys.path.append("../rig_agnostic_encoding/models")
sys.path.append("../rig_agnostic_encoding/functions")
import func
# from MLP_withLabel import MLP_withLabel
# from MLP import MLP

In [5]:
DATA_PATH = "/home/nuoc/Documents/MEX/data"
MODEL_PATH = "/home/nuoc/Documents/MEX/models"
RESULTS_PATH = "/home/nuoc/Documents/MEX/results"


In [6]:
class MLP_withLabel(pl.LightningModule):
    def __init__(self, config:dict=None, dimensions:list=None, extra_feature_len:int=0,
                 train_set=None, val_set=None, test_set=None,
                 keep_prob:float=.2, name:str="model", load=False,
                 single_module:int=0):

        super(MLP_withLabel, self).__init__()
        self.name = name
        self.dimensions = dimensions
        self.keep_prob = keep_prob
        self.single_module = single_module
        self.extra_feature_len = extra_feature_len
        self.act = nn.ELU
        self.k = 0
        if load:
            self.build()
        else:
            self.hidden_dim = config["hidden_dim"]
            self.k = config["k"]
            self.learning_rate = config["lr"]
            self.act = config["activation"]
            self.loss_fn = config["ae_loss_fn"]
            self.batch_size = config["batch_size"]

            self.dimensions = [self.dimensions[0]-extra_feature_len, self.hidden_dim, self.k]
            self.train_set = train_set
            self.val_set = val_set
            self.test_set = test_set

            self.best_val_loss = np.inf

        self.build()
        self.encoder.apply(self.init_params)
        self.decoder.apply(self.init_params)


    def build(self):
        layer_sizes = list(sliding_window(2, self.dimensions))
        if self.single_module == -1 or self.single_module == 0:
            layers = []
            for i, size in enumerate(layer_sizes):
                layers.append(("fc"+str(i), nn.Linear(size[0], size[1])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
                    layers.append(("drop"+str(i+1), nn.Dropout(self.keep_prob)))
            self.encoder = nn.Sequential(OrderedDict(layers))
        else:
            self.encoder = nn.Sequential()

        if self.single_module == 0 or self.single_module == 1:
            layers = []
            layer_sizes[-1] = (layer_sizes[-1][0], layer_sizes[-1][1] + self.extra_feature_len)
            for i, size in enumerate(layer_sizes[-1::-1]):
                layers.append(("fc"+str(i), nn.Linear(size[1], size[0])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
                    layers.append(("drop"+str(i+1), nn.Dropout(self.keep_prob)))
            self.decoder = nn.Sequential(OrderedDict(layers))
        else:
            self.decoder = nn.Sequential()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.decode(*self.encode(x))

    def encode(self, x):
        _x, label = x[:, :-self.extra_feature_len], x[:, -self.extra_feature_len:]
        h = self.encoder(_x)
        return h, label

    def decode(self, h, label):
        hr = torch.cat((h, label), dim=1)
        return self.decoder(hr)

    def training_step(self, batch, batch_idx):
        x, y = batch
        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log("ptl/train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log('ptl/val_loss', loss, prog_bar=True)
        return {"val_loss":loss}

    def test_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log('ptl/test_loss', loss, prog_bar=True)
        return {"val_loss":loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)
        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.save_checkpoint(best_val_loss=self.best_val_loss.cpu().numpy())

    def save_checkpoint(self, best_val_loss:float=np.inf, checkpoint_dir=MODEL_PATH):

        model = {"k":self.k, "dimensions":self.dimensions,"keep_prob":self.keep_prob, "name":self.name,
                 "extra_feature_len" : self.extra_feature_len,
                 "encoder":self.encoder.state_dict(),
                 "decoder":self.decoder.state_dict()}

        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        path = os.path.join(checkpoint_dir, self.name)
        if not os.path.exists(path):
            os.mkdir(path)

        filePath = os.path.join(path, str(best_val_loss)+"."+str(self.k)+".pbz2")
        with bz2.BZ2File(filePath, "w") as f:
            pickle.dump(model, f)
        return filePath

    @staticmethod
    def load_checkpoint(filePath):
        with bz2.BZ2File(filePath, "rb") as f:
            obj = pickle.load(f)

        model = MLP_withLabel(name=obj["name"], dimensions=obj["dimensions"], extra_feature_len=obj["extra_feature_len"], keep_prob=obj["keep_prob"], load=True)
        model.encoder.load_state_dict(obj["encoder"])
        model.decoder.load_state_dict(obj["decoder"])
        return model

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def setup_data(self):
        pass
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, pin_memory=True)

    @staticmethod
    def init_params(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(.01)

In [7]:
class MLP(pl.LightningModule):
    def __init__(self, config:dict=None, dimensions:list=None,
                 train_set=None, val_set=None, test_set=None,
                 keep_prob:float=.2, name:str="model", load=False,
                 single_module:int=0):

        super(MLP, self).__init__()
        self.name = name
        self.dimensions = dimensions
        self.keep_prob = keep_prob
        self.single_module = single_module
        self.act = nn.ELU
        self.k = 0
        if load:
            self.build()
        else:
            self.hidden_dim = config["hidden_dim"]
            self.k = config["k"]
            self.learning_rate = config["lr"]
            self.act = config["activation"]
            self.loss_fn = config["loss_fn"]
            self.batch_size = config["batch_size"]

            self.dimensions = dimensions + [self.hidden_dim, self.k]
            self.train_set = train_set
            self.val_set = val_set
            self.test_set = test_set

            self.best_val_loss = np.inf

            self.build()
        self.encoder.apply(self.init_params)
        self.decoder.apply(self.init_params)


    def build(self):
        layer_sizes = list(sliding_window(2, self.dimensions))
        if self.single_module == -1 or self.single_module == 0:
            layers = []
            for i, size in enumerate(layer_sizes):
                layers.append(("fc"+str(i), nn.Linear(size[0], size[1])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
                    layers.append(("drop"+str(i+1), nn.Dropout(self.keep_prob)))
            self.encoder = nn.Sequential(OrderedDict(layers))
        else:
            self.encoder = nn.Sequential()

        if self.single_module == 0 or self.single_module == 1:
            layers = []
            for i, size in enumerate(layer_sizes[-1::-1]):
                layers.append(("fc"+str(i), nn.Linear(size[1], size[0])))
                if i < len(self.dimensions)-2:
                    layers.append(("act"+str(i), self.act()))
                    layers.append(("drop"+str(i+1), nn.Dropout(self.keep_prob)))
            self.decoder = nn.Sequential(OrderedDict(layers))
        else:
            self.decoder = nn.Sequential()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))

    def encode(self, x):
        return self.encoder(x)

    def decode(self, h):
        return self.decoder(h)

    def training_step(self, batch, batch_idx):
        x, y = batch
        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log("ptl/train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log('ptl/val_loss', loss, prog_bar=True)
        return {"val_loss":loss}

    def test_step(self, batch, batch_idx):
        x, y = batch

        prediction = self(x)
        loss = self.loss_fn(prediction, y)

        self.log('ptl/test_loss', loss, prog_bar=True)
        return {"val_loss":loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)
        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.save_checkpoint(best_val_loss=self.best_val_loss.cpu().numpy())

    def save_checkpoint(self, best_val_loss:float=np.inf, checkpoint_dir=MODEL_PATH):

        model = {"k":self.k, "dimensions":self.dimensions,"keep_prob":self.keep_prob, "name":self.name,
                 "single_module":self.single_module,
                 "encoder":self.encoder.state_dict(),
                 "decoder":self.decoder.state_dict()}

        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        path = os.path.join(checkpoint_dir, self.name)
        if not os.path.exists(path):
            os.mkdir(path)

        filePath = os.path.join(path, str(best_val_loss)+"."+str(self.k)+".pbz2")
        with bz2.BZ2File(filePath, "w") as f:
            pickle.dump(model, f)
        return filePath

    @staticmethod
    def load_checkpoint(filePath):
        with bz2.BZ2File(filePath, "rb") as f:
            obj = pickle.load(f)

        model = MLP(name=obj["name"], dimensions=obj["dimensions"], keep_prob=obj["keep_prob"],
                  load=True)
        model.encoder.load_state_dict(obj["encoder"])
        # model.decoder.load_state_dict(obj["decoder"])
        return model

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def setup_data(self):
        pass
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, pin_memory=True)

    @staticmethod
    def init_params(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(.01)

In [8]:
class MoE(nn.Module):
    def __init__(self, config=None, dimensions=None, phase_input_dim:int=0,
                 gate_size=0, k_experts=1, keep_prob=.2,
                 name="model", load=False):
        super().__init__()

        self.phase_input_dim = phase_input_dim
        self.dimensions = dimensions
        self.act_fn = nn.ELU
        self.name = name
        self.config=config
        self.gate_size=gate_size
        self.k_experts = k_experts
        self.keep_prob = .2
        if not load:
            self.k_experts = config["k_experts"]
            self.gate_size = config["gate_size"]
            self.keep_prob = config["keep_prob"]
            self.dimensions = [self.dimensions[0], config["hidden_dim"], config["hidden_dim"], self.dimensions[-1]]

        self.layers = []

        self.build()
        self.gate = nn.Sequential(
            nn.Linear(phase_input_dim, self.gate_size),
            nn.ELU(),
            nn.Linear(self.gate_size, self.gate_size),
            nn.ELU(),
            nn.Linear(self.gate_size, self.k_experts)
        )
        self.init_params()


    def forward(self, x:torch.Tensor, phase) -> torch.Tensor:
        coefficients = F.softmax(self.gate(phase), dim=1)

        layer_out = x
        for (weight, bias, activation) in self.layers:
            if weight is None:
                layer_out = activation(layer_out, p=self.keep_prob)
            else:
                flat_weight = weight.flatten(start_dim=1, end_dim=2)
                mixed_weight = torch.matmul(coefficients, flat_weight).view(
                    coefficients.shape[0], *weight.shape[1:3]
                )

                input = layer_out.unsqueeze(1)
                mixed_bias = torch.matmul(coefficients, bias).unsqueeze(1)
                out = torch.baddbmm(mixed_bias, input, mixed_weight).squeeze(1)
                layer_out = activation(out) if activation is not None else out

        return layer_out

    def build(self):
        layers = []
        for i, size in enumerate(zip(self.dimensions[0:], self.dimensions[1:])):
            if i < len(self.dimensions) - 2:
                layers.append(
                    (
                        nn.Parameter(torch.empty(self.k_experts, size[0], size[1])),
                        nn.Parameter(torch.empty(self.k_experts, size[1])),
                        self.act_fn()
                    )
                )
            else:
                layers.append(
                    (
                        nn.Parameter(torch.empty(self.k_experts, size[0], size[1])),
                        nn.Parameter(torch.empty(self.k_experts, size[1])),
                        None
                    )
                )

            if self.keep_prob > 0:
                layers.append((None, None, F.dropout))

        self.layers = layers

    def init_params(self):
        for i, (w, b, _) in enumerate(self.layers):
            if w is None:
                continue

            i = str(i)
            torch.nn.init.kaiming_uniform_(w)
            b.data.fill_(0.01)
            self.register_parameter("w" + i, w)
            self.register_parameter("b" + i, b)

    def save_checkpoint(self, best_val_loss:float=np.inf, checkpoint_dir=MODEL_PATH):

        model = {"dimensions":self.dimensions,
                 "name":self.name,
                 "gate":self.gate.state_dict(), "phase_input_dim":self.phase_input_dim,
                 "generationNetwork":self.state_dict(),
                 "gate_size":self.gate_size,
                 "k_experts":self.k_experts,
                 }

        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        path = os.path.join(checkpoint_dir, self.name)
        if not os.path.exists(path):
            os.mkdir(path)

        filePath = os.path.join(path, str(best_val_loss)+".pbz2")
        with bz2.BZ2File(filePath, "w") as f:
            pickle.dump(model, f)
        return filePath

    @staticmethod
    def load_checkpoint(filePath):
        with bz2.BZ2File(filePath, "rb") as f:
            obj = pickle.load(f)

        model = MoE(name=obj["name"], dimensions=obj["dimensions"], gate_size=obj["gate_size"],k_experts=obj["k_experts"],
                    phase_input_dim=obj["phase_input_dim"], load=True)
        model.gate.load_state_dict(obj["gate"])
        model.load_state_dict(obj["generationNetwork"])
        return model

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)
        return optimizer


In [9]:
class MotionGenerationModel(pl.LightningModule):
    def __init__(self, config:dict=None, pose_autoencoder=None, cost_input_dimension=None, feature_dims=None,
                 input_slicers:list=None, output_slicers:list=None, train_set=None, val_set=None, name="model", load=False):
        super().__init__()

        if not load:
            self.pose_autoencoder = pose_autoencoder # start with 3
            cost_hidden_dim = config["cost_hidden_dim"]
            cost_output_dim = config["cost_output_dim"]
            self.feature_dims = feature_dims
            self.phase_dim = feature_dims["phase_vec"]
            # self.trajectory_dim = feature_dims["tPos"]
            # self.cost_dim = trajectory_dim + feature_dims["posCost"]
            self.cost_encoder = MLP(dimensions=[cost_dim, cost_hidden_dim, cost_hidden_dim, cost_output_dim],
                                    name="CostEncoder", load=True, single_module=-1)

           # phase_input_dimension = input_slicers[0]
            moe_input_dim = pose_autoencoder.dimensions[-1] # + cost_output_dim
            moe_output_dim = pose_autoencoder.dimensions[-1] +  self.phase_dim # + self.trajectory_dim
            self.generationModel =  MoE(config=config, dimensions=[moe_input_dim, moe_output_dim], phase_input_dim=self.phase_dim + feature_dims["posCost"],
                                        name="MixtureOfExperts")

            self.in_slices = [0] + list(accumulate(add, input_slicers))
            self.out_slices = [0] + list(accumulate(add, output_slicers))
            # self.phase_dim = phase_dim

            self.config=config
            self.batch_size = config["batch_size"]
            self.learning_rate = config["lr"]
            self.loss_fn = config["loss_fn"]
            self.window_size = config["window_size"]
            self.autoregress_chunk_size = config["autoregress_chunk_size"]
            self.autoregress_prob = config["autoregress_prob"]
            self.autoregress_inc = config["autoregress_inc"]
            self.best_val_loss = np.inf
            self.phase_smooth_factor = 0.9

        self.train_set = train_set
        self.val_set = val_set
        self.name = name
        self.epochs = 0
        self.automatic_optimization = False
        self.left_id = 14*3
        self.right_id = 20*3



    def forward(self, x):
        x_tensors = [x[:, d0:d1] for d0, d1 in zip(self.in_slices[:-1], self.in_slices[1:])]
        pose_h, pose_label = self.pose_autoencoder.encode(x_tensors[1])
        phase = x_tensors[0][:, :self.phase_dim]
        targets = x_tensors[0][:, self.phase_dim:]
        # embedding = torch.cat([pose_h, (x_tensors[2])], dim=1)
        # embedding = torch.cat([pose_h)], dim=1)
        # embedding = torch.cat([pose_h, x_tensors[2]], dim=1)

        # posCost = self.computeCost(targets, x_tensors[-1][:, :self.trajectory_dim])
        # out = self.generationModel(embedding, torch.cat([phase, x_tensors[-1][:, self.trajectory_dim:]], dim=1))
        out = self.generationModel(pose_h, torch.cat([phase, x_tensors[-1]], dim=1))

        out_tensors = [out[:, d0:d1] for d0, d1 in zip(self.out_slices[:-1], self.out_slices[1:])] # phase_update, new_pose

        phase = self.update_phase(phase, out_tensors[0]) # phase_0, phase_1, phase_update

        new_pose = self.pose_autoencoder.decode(out_tensors[1], pose_label)

        # out_tensors[-1][:, 9:12] = new_pose[:, self.left_id:self.left_id+3]
        #
        # out_tensors[-1][:, 27:30] = new_pose[:, self.right_id:self.right_id+3]


        posCost = self.computeCost(targets, torch.cat([new_pose[:, self.left_id:self.left_id+3],new_pose[:, self.right_id:self.right_id+3]], dim=1))

        # print(phase.size(), targets.size(), new_pose.size(), pose_label.size(), out_tensors[-1].size(), posCost.size(), rotCost.size())
        # return [phase, targets, new_pose, pose_label, out_tensors[-1], posCost, rotCost]
        return [phase, targets, new_pose, pose_label, posCost]
        # return [phase, targets, new_pose, pose_label, out_tensors[-1]]

    def computeCost(self, targets, trajectory):
        targetPos = targets
        # targetRot = targets[:, self.feature_dims["targetPosition"]:]
        posT = trajectory
        # rotT = trajectory[:, self.feature_dims["tPos"]:]

        targetPos = targetPos.reshape((-1, 2, 3))
        posT = posT.reshape((-1, 2, 3))
        # targetRot = targetRot.reshape((-1, 12, 3,3))
        # rotT = rotT.reshape((-1, 12, 3, 3))

        posCost = torch.sum(((targetPos - posT)**2), axis=2).reshape((-1, self.feature_dims["posCost"]))
        # colLength = torch.sqrt(torch.clip(torch.sum(rotT**2, axis=2), 0))
        # rotT = rotT / colLength[:, :, :, None]

        # rotT = torch.transpose(rotT, dim0=2, dim1=3)
        # trace =torch.diagonal(targetRot @ rotT, offset=0, dim1=2, dim2=3).sum(dim=2)
        # rotCost = torch.abs(torch.arccos((torch.clamp( (trace - 1) / 2.0, -1, 1))))
        # torch.nan_to_num_(rotCost, 0)
        # rotCost = rotCost.reshape((-1, self.feature_dims["rotCost"]))

        return posCost



    def step(self, x, y, validation=False):
        if not validation:
           opt = self.optimizers()
        x = x.squeeze(dim=2)
        y = y.squeeze(dim=2)

        n = x.size()[1]
        tot_loss = 0
        tot_posLoss = 0
        tot_rotLoss = 0
        x_c = x[:,0,:]

        if self.autoregress_prob < 1:
            autoregress_bools = torch.randn(n) < self.autoregress_prob
            for i in range(1, n):
                y_c = y[:,i-1,:]
                # y_c.requires_grad_(True)
                if torch.sum(y_c) == 0:
                    break

                out= self(x_c)
                recon = torch.cat([out[0], out[2], out[4]], dim=1)
                loss = self.loss_fn(recon, y_c)
                posLoss = torch.mean(out[-1])
                # rotLoss = torch.mean(out[-1])
                # rotLoss = 0
                #
                tot_loss += loss.detach()
                tot_posLoss += posLoss.detach() * float(i)/float(n)
                # tot_rotLoss += rotLoss.detach()
                tot_rotLoss += 0

                # elif not recon.requires_grad:
                #     raise ValueError("recon no grad, i : ", i, " \n", recon, y_c)
                # elif not loss.requires_grad:
                #     raise ValueError("loss no grad")
                if not validation:
                    opt.zero_grad()
                    # self.optimizer.zero_grad()
                    self.manual_backward(loss)
                    # self.optimizer.step()
                    opt.step()

                if self.autoregress_prob > 0 and autoregress_bools[i]:
                    x_c = torch.cat(out, dim=1).detach()
                else:
                    x_c = x[:,i,:]

            tot_loss /= float(i+1)
            # tot_posLoss /= float(i+1)
            tot_rotLoss /= float(i+1)
        else:
            for i in range(1, n):
                y_c = y[:,i-1,:]
                if torch.sum(y_c) == 0:
                    break

                out= self(x_c)
                recon = torch.cat([out[0], out[2], out[4]], dim=1)
                loss = self.loss_fn(recon, y_c)
                # posLoss = torch.mean(out[-2])
                posLoss = torch.mean(out[-1])
                # rotLoss = torch.mean(out[-1])
                # rotLoss = 0

                tot_loss += loss.detach()
                tot_posLoss = posLoss.detach()* float(i)/float(n)
                # tot_rotLoss += rotLoss.detach()
                tot_rotLoss += 0
                # self.optimizer.zero_grad()
                # (loss + posLoss + rotLoss).backward()
                if not validation:
                    opt.zero_grad()
                    # self.optimizer.zero_grad()
                    self.manual_backward(loss)
                    # self.optimizer.step()
                    opt.step()

                x_c = torch.cat(out, dim=1).detach()

            tot_loss /= float(i+1)
            # tot_posLoss /= float(i+1)
            tot_rotLoss /= float(i+1)

        return tot_loss, tot_posLoss, tot_rotLoss
    def training_step(self, batch, batch_idx):
        x, y = batch

        loss, posLoss, rotLoss = self.step(x,y, False)

        self.log("ptl/train_loss", loss, prog_bar=True)
        self.log("ptl/train_posLoss", posLoss)
        self.log("ptl/train_rotLoss", rotLoss)
        # return loss#+posLoss+rotLoss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        loss, posLoss, rotLoss = self.step(x,y, True)
        self.log("ptl/val_loss", loss, prog_bar=True)
        self.log("ptl/val_posLoss", posLoss, prog_bar=True)
        self.log("ptl/val_rotLoss", rotLoss, prog_bar=True)
        return {"val_loss":loss}

    def validation_epoch_end(self, outputs):
        if self.epochs > 0 and self.epochs % 20==0:
            self.autoregress_prob = min(1, self.autoregress_prob+self.autoregress_inc)
            self.autoregress_chunk_size = min(120, self.autoregress_chunk_size+self.autoregress_inc)
        elif self.epochs > 0 and self.epochs % 10==0:
            self.scheduler.step()
        self.epochs += 1


        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)
        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.save_checkpoint()

    def save_checkpoint(self, checkpoint_dir=MODEL_PATH):
        path = os.path.join(checkpoint_dir, self.name)
        loss = self.best_val_loss.cpu().numpy()

        pose_autoencoder_path = self.pose_autoencoder.save_checkpoint(best_val_loss=loss, checkpoint_dir=path)
        cost_encoder_path = self.cost_encoder.save_checkpoint(best_val_loss=loss, checkpoint_dir=path)
        generationModel_path = self.generationModel.save_checkpoint(best_val_loss=loss, checkpoint_dir=path)

        model = {"name":self.name,
                 "pose_autoencoder_path":pose_autoencoder_path,
                 "cost_encoder_path": cost_encoder_path,
                 "motionGenerationModelPath":generationModel_path,
                 "in_slices":self.in_slices,
                 "out_slices":self.out_slices,
                 }

        if not os.path.exists(path):
            os.mkdir(path)
        with bz2.BZ2File(os.path.join(path,
                                      str(loss)+".pbz2"), "w") as f:
            pickle.dump(model, f)

    @staticmethod
    def load_checkpoint(filename, pose_ae_model, cost_encoder_model, generation_model):
        with bz2.BZ2File(filename, "rb") as f:
            obj = pickle.load(f)

        pose_autoencoder = pose_ae_model.load_checkpoint(obj["pose_autoencoder_path"])
        cost_encoder = cost_encoder_model.load_checkpoint(obj["cost_encoder_path"])
        generationModel = generation_model.load_checkpoint(obj["motionGenerationModelPath"])
        model = MotionGenerationModel(name=obj["name"])
        model.pose_autoencoder = pose_autoencoder
        model.cost_encoder = cost_encoder
        model.generationModel = generationModel
        model.in_slices = obj["in_slices"]
        model.out_slices = obj["out_slices"]

        return model

    def update_phase(self, p1, p2):
        return self.phase_smooth_factor * p2 + (1-self.phase_smooth_factor)*p1

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)
        self.scheduler = scheduler
        self.optimizer = optimizer
        return optimizer

    def train_dataloader(self):

        return DataLoader(self.train_set, batch_size=self.batch_size, pin_memory=True)

    def val_dataloader(self):

        return DataLoader(self.val_set, batch_size=self.batch_size, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, pin_memory=True)






In [10]:
data_path = [
             "/home/nuoc/Documents/MEX/data/TWO_R2-default-Two.pbz2",
            # "/home/nuoc/Documents/MEX/data/ONE_R2-default-One.pbz2",
            #  "/home/nuoc/Documents/MEX/data/ONE_R2-default-One-large.pbz2",
            #  "/home/nuoc/Documents/MEX/data/ONE_R2-default-One-small.pbz2",
            #  "/home/nuoc/Documents/MEX/data/TWO_R2-default-Two-small.pbz2",
            #  "/home/nuoc/Documents/MEX/data/TWO_R2-default-Two-large.pbz2",
            #  "/home/nuoc/Documents/MEX/data/TWO_ROT_R2-default-Two.pbz2",
            #  "/home/nuoc/Documents/MEX/data/TWO_ROT_R2-default-Two-large.pbz2",
            #  "/home/nuoc/Documents/MEX/data/TWO_ROT_R2-default-Two-small.pbz2",
             ]



# pose_features = ["pos", "rotMat", "velocity", "isLeft", "chainPos", "geoDistanceNormalised"]
# cost_features = ["tPos", "tRot", "posCost", "rotCost"]
# phase_features = ["phase_vec", "targetPosition", "targetRotation"]

pose_features = ["pos", "rotMat", "velocity", "isLeft", "chainPos", "geoDistanceNormalised"]
cost_features = ["posCost"]
phase_features = ["phase_vec", "targetPosition"]



In [11]:
def load(file_path):
    with bz2.BZ2File(file_path, "rb") as f:
        obj = pickle.load(f)
        return obj

data = [load(path) for path in data_path]

In [13]:
window_size = 3
frame_window = 15
sampling_step = frame_window / window_size

data_tensors = []

data_dims = []
feature_list = []
feature_data = [{}, {}]
features = phase_features + pose_features + cost_features
for f in features:
    feature_data[0][f] = []
    feature_data[1][f] = []

first_row = True
first_time = True
key_joints = []

a = []
b = []
for Data in data:
    for clip in Data:
        d = pickle.loads(clip)
        sequence = []
        n_frames = len(d["frames"])
        if first_time:
            key_joints = [i for i in range(len(d["frames"][0])) if d["frames"][0][i]["key"]]
            first_time = False

        idx = [int(n_frames/2)] * len(key_joints)
        for i, jo in enumerate(key_joints):
            for f_id, frame in enumerate(d["frames"]):
                if frame[jo]["contact"]:
                    idx[i] = f_id
                    break

        max_id = max(idx)
        max_id2 = max_id + 3
        min_id = min(idx)
        min_id2 = min(idx) - 3
        start1 = 0
        end1 = max_id2 if max_id2 <= n_frames else max_id
        end2 = min_id2 if min_id2 >= 0 else min_id
        start2 = n_frames-1
        clip_idx = [(start1, end1), (start2, end2)]


        for i, (start, end) in enumerate(clip_idx):
            frames = d["frames"]
            if start < end:
                idx = np.arange(start+1, end-1)
                intervals = [sorted(np.random.choice(idx, 28, replace=False).tolist()) for i in range(3)]
                # print(i, intervals, sorted(np.random.choice(idx, 28, replace=False).tolist()))
            else:
                idx = np.arange(end+1, start-1)
                intervals = [sorted(np.random.choice(idx, 28, replace=False).tolist(),reverse=True) for i in range(3)]
            # print(intervals[0])
            for interval in intervals:
                if start < end:
                    interval = [idx[0]-1] + interval + [idx[-1]+1]
                else:
                    interval = [idx[-1]+1] + interval + [idx[0]-1]
                for f in interval:
                    row_vec = []
            # if i == 1: continue
                # if i == 0:
                #     n = end-start
                # else:
                #     n = start
                #
                # if i == 0:
                #     f_idx = np.arange(f-frame_window, f+frame_window, sampling_step, dtype=int)
                # else:
                #     f_idx = np.arange(f+frame_window, f-frame_window, -sampling_step, dtype=int)
                #
                # f_idx[f_idx < 0] = 0
                # f_idx[f_idx >= n] = n
                # f_idx = f_idx.tolist()
                    f_idx = [f]
                    for feature in phase_features:
                        if feature == "phase_vec":
                            sin = np.asarray([frames[idx][jj]["phase_vec"] for jj in key_joints for idx in f_idx])
                            vel = np.concatenate([frames[idx][jj]["velocity"] for jj in key_joints for idx in f_idx])
                            vel = np.reshape(vel, (3,-1))
                            vel = np.sqrt(np.sum(vel**2, axis=0))
                            cos = np.cos(np.arcsin(np.asarray([frames[idx][jj]["sin_normalised_contact"] for jj in key_joints for idx in f_idx])))
                            cos = cos * vel
                            row_vec.append(np.concatenate([np.asarray([sin[i], cos[i]]) for i in range(len(sin))]))
                        elif feature == "targetRotation" or feature == "targetPosition":
                            row_vec.append(np.concatenate([frames[idx][jj][feature].ravel() for jj in key_joints for idx in f_idx]))
                        elif feature == "contact":
                            row_vec.append(np.asarray([frames[idx][jj]["contact"] for jj in key_joints for idx in f_idx]))
                        else:
                            row_vec.append(np.concatenate([frames[f][jj][feature] for jj in key_joints]))

                        # feature_data[i][feature].append(row_vec)
                        if first_row:
                            data_dims.append(row_vec[-1].shape)
                            feature_list.append(feature)

                    for feature in pose_features:
                        if feature=="rotMat":
                            joRot = np.concatenate([jo["rotMat"].ravel() for jo in frames[f]])
                            joRot = joRot.reshape((-1, 3, 3))
                            joRot_s = np.sqrt(np.sum(joRot**2, axis=1))
                            joRot = joRot / joRot_s[:, :, None]

                            if np.sum(np.isnan(joRot)) > 0:
                                raise ValueError("RotMat has nan")
                            row_vec.append(joRot.ravel())
                        elif feature == "isLeft" or feature == "chainPos" or feature == "geoDistanceNormalised":
                            row_vec.append(np.concatenate([[jo[feature]] for jo in frames[f]]))
                        else:
                            row_vec.append(np.concatenate([jo[feature] for jo in frames[f]]))

                        # feature_data[i][feature].append(row_vec)
                        if first_row:
                            data_dims.append(row_vec[-1].shape)
                            feature_list.append(feature)
                    for feature in cost_features:
                        if feature == "posCost":
                            targetPos = np.concatenate([frames[idx][jj]["targetPosition"].ravel() for jj in key_joints for idx in f_idx])
                            targetPos = targetPos.reshape((len(key_joints), len(f_idx), 3))
                            joPos = np.concatenate([frames[idx][jj]["pos"] for jj in key_joints for idx in f_idx])
                            joPos = joPos.reshape((len(key_joints), len(f_idx), 3))
                            posCost = np.sum(((targetPos - joPos)**2), axis=2).ravel()
                            row_vec.append(posCost)
                        elif feature == "rotCost":
                            targetRot = np.concatenate([frames[idx][jj]["targetRotation"].ravel() for jj in key_joints for idx in f_idx])
                            targetRot = targetRot.reshape((len(key_joints), len(f_idx), 3, 3))

                            joRot = np.concatenate([frames[idx][jj]["rotMat"].ravel() for jj in key_joints for idx in f_idx])
                            joRot = joRot.reshape((len(key_joints), len(f_idx), 3, 3))
                            joRot_s = np.sqrt(np.sum(joRot**2, axis=2))
                            joRot = joRot / joRot_s[:, :, :, None]

                            joRotT = np.transpose(joRot, (0, 1, 3, 2))
                            rotDiff = np.arccos(np.clip((np.trace(targetRot @ joRotT, axis1=2, axis2=3) - 1) / 2.0, -1, 1).ravel())

                            if np.sum(np.isnan(rotDiff)) > 0:
                                raise ValueError("RotCost has nan")
                            row_vec.append(rotDiff)
                        elif feature == "tPos" or feature == "tRot":
                            feature = "pos" if feature == "tPos" else "rotMat"
                            row_vec.append(np.concatenate([frames[idx][jj][feature].ravel() for jj in key_joints for idx in f_idx]))
                        else:
                            row_vec.append(np.concatenate([frames[f][jj][feature] for jj in key_joints]))

                        # feature_data[i][feature].append(row_vec)
                        if first_row:
                            data_dims.append(row_vec[-1].shape)
                            feature_list.append(feature)

                    if first_row: first_row = False
                # if i == 1:
                #     row_vec = np.flip(np.concatenate(row_vec))
                #     sequence.append(row_vec)
                # else:
                    sequence.append(np.concatenate(row_vec))
                data_tensors.append(np.vstack(sequence))
                sequence = []
        # break

In [14]:
print(key_joints)

[14, 20]


In [15]:
def loss_fn(x, y):
    return nn.functional.mse_loss(x,y)
def loss_fn2(x, y):
    return nn.functional.smooth_l1_loss(x,y)
def normalise(x):
    std = torch.std(x, dim=0)
    std[std==0] = 1
    return (x-torch.mean(x, dim=0)) / std



In [16]:
extra_feature_len = 21 * 3
n_phase_features = len(phase_features)
n_pose_features = len(pose_features)
phase_dim = np.sum(data_dims[0:n_phase_features])
pp_dim = data_dims[0][0]
pose_dim = np.sum(data_dims[n_phase_features:n_phase_features+n_pose_features])
cost_dim = np.sum(data_dims[n_phase_features+n_pose_features:])

table = [feature_list, data_dims]
print(tabulate(table))
print("phase dim: ",phase_dim)
print("pose dim: ", pose_dim)
print("cost dim: ", cost_dim)



---------  --------------  -----  ------  --------  ------  --------  ---------------------  -------
phase_vec  targetPosition  pos    rotMat  velocity  isLeft  chainPos  geoDistanceNormalised  posCost
(4,)       (6,)            (63,)  (189,)  (63,)     (21,)   (21,)     (21,)                  (2,)
---------  --------------  -----  ------  --------  ------  --------  ---------------------  -------
phase dim:  10
pose dim:  378
cost dim:  2


In [17]:
# x_tensors = torch.nn.utils.rnn.pad_sequence([normalise(torch.from_numpy(clip[:-1])).float().unsqueeze(dim=1) for clip in data_tensors], batch_first=True)
# y_tensors = torch.nn.utils.rnn.pad_sequence([normalise(torch.from_numpy(clip[1:])).float().unsqueeze(dim=1) for clip in data_tensors], batch_first=True)

x_tensors = torch.stack([normalise(torch.from_numpy(clip[:-1])).float() for clip in data_tensors])
y_tensors = torch.stack([torch.from_numpy(clip[1:]).float() for clip in data_tensors])


In [18]:
print(x_tensors.size())

torch.Size([384, 29, 390])


In [19]:
dataset = TensorDataset(torch.Tensor(x_tensors), torch.Tensor(y_tensors))
N = len(x_tensors)

train_ratio = int(.7*N)
val_ratio = int((N-train_ratio) / 2.0)
test_ratio = N - train_ratio - val_ratio
train_set, val_set, test_set = random_split(dataset, [train_ratio, val_ratio, test_ratio], generator=torch.Generator().manual_seed(2021))
print(len(train_set), len(val_set), len(test_set))

268 58 58


In [20]:
data_1 = {"data_sets":[train_set, val_set, test_set], "table":table}
with bz2.BZ2File("data_sets_8_30hz_phaes+pos+cost_single_frame.pbz2", "w") as f:
    pickle.dump(data_1, f)

In [21]:
# with bz2.BZ2File("data_sets_7_terminateAtContact_cost.pbz2", "rb") as f:
#     obj = pickle.load(f)
#
# train_set = obj["data_sets"][0]
# val_set = obj["data_sets"][1]
# test_set = obj["data_sets"][2]
# table = obj["table"]
feature_dims = {}
for feat, dim in zip(table[0], table[1]):
    if feat in feature_dims:
        if feat == "pos": feat = "tPos"
        elif feat == "rotMat": feat = "tRot"
    feature_dims[feat] = dim[0]


In [22]:
config = {
    "k_experts": 4,
    "gate_size": 128,
    "keep_prob": 0.2,
    "hidden_dim": 512,
    "cost_hidden_dim" : 256,
    "cost_output_dim" : 256,
    "batch_size": 32,
    "lr": 1e-4,
    "loss_fn": loss_fn2,
    "window_size":1,
    "autoregress_prob" : 0,
    "autoregress_inc" : .2,
    "autoregress_chunk_size": 120
}

model_name = "Test_11_phase+pose+cost_single"
epochs = 200

In [23]:
phase_dim = sum([feature_dims[feat] for feat in phase_features])
pose_dim = sum([feature_dims[feat] for feat in pose_features])
cost_dim = sum([feature_dims[feat] for feat in cost_features])
# trajectory_dim = feature_dims["tPos"]

pose_autoencoder = MLP_withLabel.load_checkpoint("/home/nuoc/Documents/MEX/models/MLP4_withLabel_best/M3/0.00324857.512.pbz2")
pose_encoder_out_dim = pose_autoencoder.dimensions[-1]

input_slices=[phase_dim, pose_dim, cost_dim]
output_slices=[feature_dims["phase_vec"], pose_encoder_out_dim]
print(input_slices)
print(output_slices)

[10, 378, 2]
[4, 512]


In [24]:
def reformLabel(dataset, slices, phase_dim, extra_feature_len):
    set_y = list(dataset[:][1])
    slices = [0] + list(accumulate(add, slices))
    for i, y in enumerate(set_y):
        y = y.squeeze(1)
        y_tensors = [y[:, d0:d1] for d0, d1 in zip(slices[:-1], slices[1:])]
        phase = y_tensors[0][:, :phase_dim]
        pose = y_tensors[1][:, :-extra_feature_len]
        trajectory = y_tensors[2]
        set_y[i] = torch.cat([phase, pose, trajectory], dim=1).unsqueeze(dim=1)
    return [(x, y) for x, y in zip(dataset[:][0], set_y)]

train_set = reformLabel(train_set, slices=input_slices, phase_dim=feature_dims["phase_vec"], extra_feature_len=21*3)
val_set = reformLabel(val_set, slices=input_slices, phase_dim=feature_dims["phase_vec"], extra_feature_len=21*3)
test_set = reformLabel(test_set, slices=input_slices, phase_dim=feature_dims["phase_vec"], extra_feature_len=21*3)


In [25]:
model = MotionGenerationModel(config=config, pose_autoencoder=pose_autoencoder, cost_input_dimension=cost_dim,
                          input_slicers=input_slices, output_slicers=output_slices,feature_dims=feature_dims,
                          train_set=train_set, val_set=val_set, name=model_name)




In [31]:
import pytorch_lightning as pl
# prob
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(monitor="avg_val_loss", save_top_k=3)
# earlystopping = EarlyStopping(monitor="avg_val_loss", patience=20)
logger=TensorBoardLogger(save_dir="logs/", name=model_name, version="0.0")

trainer = pl.Trainer(
    default_root_dir="/home/nuoc/Documents/MEX/src/motion_generation/checkpoints",
    gpus=1, precision=16,
    # callbacks=[checkpoint_callback],
    logger=logger,
    min_epochs=50,
    max_epochs=epochs,
    stochastic_weight_avg=True
)

# train_set = datasets[0][0]
# val_set = datasets[0][1]
train_loader = DataLoader(train_set, batch_size=config["batch_size"], pin_memory=True, num_workers=6)
val_loader = DataLoader(val_set, batch_size=config["batch_size"], pin_memory=True, num_workers=6)
trainer.fit(model,train_loader, val_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type          | Params
---------------------------------------------------
0 | pose_autoencoder | MLP_withLabel | 440 K 
1 | cost_encoder     | MLP           | 132 K 
2 | generationModel  | MoE           | 3.2 M 
---------------------------------------------------
3.8 M     Trainable params
0         Non-trainable params
3.8 M     Total params
15.005    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






1

In [30]:
# for i in range(10):
    # model.scheduler.step()
trainer.fit(model,train_loader, val_loader)

AssertionError: A GradScaler instance may only be pickled at the beginning of an iteration, or at the end after scaler.update().

In [73]:
def load(path):
    with bz2.BZ2File(path, "rb") as f:
        obj = pickle.load(f)
        return obj
mo = load("/home/nuoc/Documents/MEX/src/motion_generation/Test_9_phase+pose+cost+trajectory_autoregress_prob_curriculum/0.07476503.pbz2")
ae = load(mo["pose_autoencoder_path"])
cost = load(mo["cost_encoder_path"])
gene = load(mo["motionGenerationModelPath"])

In [29]:
# model.pose_autoencoder.encoder.load_state_dict(ae["encoder"])
# model.pose_autoencoder.decoder.load_state_dict(ae["decoder"])
# model.cost_encoder.encoder.load_state_dict(cost["encoder"])
# model.generationModel.gate.load_state_dict(gene["gate"])
# model.generationModel.load_state_dict(gene["generationNetwork"])
# model.name = ""
# pose_autoencoder = pose_ae_model.load_checkpoint(mo["pose_autoencoder_path"])
# cost_encoder = cost_encoder_model.load_checkpoint(obj["cost_encoder_path"])
# generationModel = generation_model.load_checkpoint(obj["motionGenerationModelPath"])
model.save_checkpoint("")
# model = MotionGenerationModel.load_checkpoint(
#     filename=",
# pose_ae_model=MLP_withLabel, cost_encoder_model=MLP, generation_model=MoE)

In [32]:
# test_set2 = reformLabel(test_set, slices=input_slices, phase_dim=feature_dims["phase_vec"], extra_feature_len=21*3, trajectory_dim=trajectory_dim)
model.autoregress_prob = 1
test_loader = DataLoader(test_set, batch_size=config["batch_size"], pin_memory=True, num_workers=6)
model.cpu()
with torch.no_grad():
    loss, posLoss = 0, 0
    for x,y in test_loader:
        x = x.to("cpu")
        y = y.to("cpu")
        l, pll, rl = model.step(x, y, validation=True)
        loss += l
        posLoss += pll
    loss /= float(len(test_loader))
    posLoss /= float(len(test_loader))
    print(loss.item())
    print(posLoss.item())

0.044012121856212616
0.24371576309204102
