In [11]:
from scripts.utils.utils import parse_param, get_device
from scripts.models.nn_auv_v2 import AUVTraj, AUVPROXYDeltaV
from scripts.training.loss_fct import TrajLoss
from scripts.training_v2.training_utils import get_dataloader_tensor
from scripts.utils.utils import read_files

import os

dir = "data/csv/training_csv"
data_params = {
    "samples": 5000,
    "steps": 15,
    "dir": dir,
    "frame": "body",
    "batch_size": 10,
    "shuffle": True,
    "num_workers": 8
}

In [12]:

files = [f for f in os.listdir(data_params["dir"]) if os.path.isfile(os.path.join(data_params["dir"], f))]

stats_file = os.path.join(data_params['dir'], "stats", "stats.yaml")
stats = parse_param(stats_file)

dfs = read_files(data_params["dir"], files, "train")
dl = get_dataloader_tensor(datafiles=dfs,
                           tau=data_params["steps"],
                           frame=data_params["frame"],
                           stats=stats,
                           batch_size=data_params["batch_size"],
                           shuffle=data_params["shuffle"],
                           num_workers=data_params["num_workers"])

In [17]:
tau = 50
X, Y = dl.dataset.get_trajs(tau)
device = get_device(0)
model = AUVTraj({}, dt=0.1, limMax=None, limMin=None).to(device)
dv_proxy = AUVPROXYDeltaV(Y[2].to(device)).to(device)
model.auv_step.dv_pred = dv_proxy
loss = TrajLoss(1., 10000., 10000.)

In [18]:
poses_pred, vels_pred, dvs_pred = model(X[0].to(device), X[1].to(device), X[2].to(device))
poses_pred = poses_pred.detach().cpu()
vels_pred = vels_pred.detach().cpu()
dvs_pred = dvs_pred.detach().cpu()
poses_target = Y[0]
vels_target = Y[1]
dvs_target = Y[2]

In [19]:
all_l = loss(poses_pred, poses_target,
             vels_pred, vels_target,
             dvs_pred, dvs_target)
print("-"*15)
print("All traj loss: {:.4e}".format(all_l))

---------------
All traj loss: 1.1420e-03
