In [16]:
import torch
from scripts.utils.utils import parse_param, get_device
from scripts.models.nn_auv_v2 import AUVTraj, AUVPROXYDeltaV
from scripts.training.loss_fct import TrajLoss
from scripts.training_v2.training_utils import get_dataloader_tensor
from scripts.utils.utils import read_files, to_euler, gen_img_3D_v2

import matplotlib.pyplot as plt

import os

dir = "data/csv/tests"
data_params = {
    "samples": 10,
    "steps": 15,
    "dir": dir,
    "frame": "body",
    "batch_size": 10,
    "shuffle": True,
    "num_workers": 8
}

In [17]:

files = [f for f in os.listdir(data_params["dir"]) if os.path.isfile(os.path.join(data_params["dir"], f))]

stats_file = os.path.join(data_params['dir'], "stats", "stats.yaml")
stats = parse_param(stats_file)

dfs = read_files(data_params["dir"], files, "train")
dl = get_dataloader_tensor(datafiles=dfs,
                           tau=data_params["steps"],
                           frame=data_params["frame"],
                           stats=stats,
                           batch_size=data_params["batch_size"],
                           shuffle=data_params["shuffle"],
                           num_workers=data_params["num_workers"])

In [33]:
tau = 50
X, Y = dl.dataset.get_trajs(tau)
device = get_device(0)
model = AUVTraj({}, dt=0.1, limMax=None, limMin=None).to(device)
dv_proxy = AUVPROXYDeltaV(Y[2].to(device)).to(device)
model.auv_step.dv_pred = dv_proxy
loss = TrajLoss(1., 0., 0.)

In [34]:
poses_pred, vels_pred, dvs_pred = model(X[0].to(device), X[1].to(device), X[2].to(device))
poses_pred = poses_pred.detach().cpu()
vels_pred = vels_pred.detach().cpu()
dvs_pred = dvs_pred.detach().cpu()
poses_target = Y[0]
vels_target = Y[1]
dvs_target = Y[2]

In [35]:
for i in range(X[0].shape[0]):
    # p_dict = {
    #     "model": to_euler(poses_pred[i].data),
    #     "gt": to_euler(poses_target[i].data)
    # }

    # v_dict = {
    #     "model": vels_pred[i],
    #     "gt": vels_target[i]
    # }

    # dv_dict = {
    #     "model": dvs_pred[i],
    #     "gt": dvs_target[i]
    # }

    # p_img, v_img, dv_img = gen_img_3D_v2(p_dict, v_dict, dv_dict, tau=tau)
    l = loss(poses_pred[i], poses_target[i],
             vels_pred[i], vels_target[i],
             dvs_pred[i], dvs_target[i])

    print("-"*15)
    print("loss traj {}: {:.4e}".format(i, l))
    # plt.figure()
    # plt.imshow(p_img)
    # plt.figure()
    # plt.imshow(v_img)
    # plt.figure()
    # plt.imshow(dv_img)
    # plt.show()
    # plt.close("all")

all_l = loss(poses_pred, poses_target,
             vels_pred, vels_target,
             dvs_pred, dvs_target)
print("-"*15)
print("All traj loss: {:.4e}".format(all_l))

---------------
loss traj 0: 8.9813e-04
---------------
loss traj 1: 2.0443e-04
---------------
loss traj 2: 7.5247e-04
---------------
All traj loss: 6.1834e-04
