In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import sys
from torch.utils.data import DataLoader
import argparse

from social_utils import *
import yaml

from model_sdd import Goal_example_model
import numpy as np
import pdb
from gmm2d import *

from metrics import *
from utils import *

In [4]:
from dataclasses import dataclass
@dataclass
class GoalExample:
    num_workers: int 
    gpu_index: int
    config_filename: str 
    save_file: str 
    verbose: bool
    lr: float
    input_feat: int
    output_feat: int 
    checkpoint: str 

In [7]:
args = GoalExample(0, 0, "optimal.yaml", "PECNET_social_model.pt", True, 0.0003, 2, 128, "./checkpoint_sdd_abs2")

In [8]:
args.checkpoint = "./sdd_wo_goal"

dtype = torch.float64

torch.set_default_dtype(dtype)
device = (
    torch.device("cuda", index=args.gpu_index)
    if torch.cuda.is_available()
    else torch.device("cpu")
)

if torch.cuda.is_available():
    torch.cuda.set_device(args.gpu_index)
print(device)

cuda:0


In [9]:
def batch_bivariate_loss_ssd(V_pred, V_trgt):
    """
    V_pred, V_trgt:
        [Batch, Seq_len, Nodes, 5/2];

    """
    # mux, muy, sx, sy, corr
    # assert V_pred.shape == V_trgt.shape
    normx = V_trgt[..., 0] - V_pred[..., 0]
    normy = V_trgt[..., 1] - V_pred[..., 1]

    sx = torch.exp(V_pred[..., 2])  # sx
    sy = torch.exp(V_pred[..., 3])  # sy
    corr = torch.tanh(V_pred[..., 4])  # corr

    sxsy = sx * sy

    z = (normx / sx) ** 2 + (normy / sy) ** 2 - 2 * ((corr * normx * normy) / sxsy)
    negRho = 1 - corr ** 2

    # Numerator
    result = torch.exp(-z / (2 * negRho))
    # Normalization factor
    denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))

    # Final PDF calculation
    result = result / denom

    # Numerical stability
    epsilon = 1e-20

    result = -torch.log(torch.clamp(result, min=epsilon))

    return result.mean()

In [10]:
def graph_loss(V_pred, V_target):
    return batch_bivariate_loss_ssd(V_pred, V_target)

In [12]:
with open("./config/" + args.config_filename, "r") as file:
    try:
        hyper_params = yaml.load(file, Loader=yaml.FullLoader)
    except:
        hyper_params = yaml.load(file)
file.close()
print(hyper_params)

{'adl_reg': 1, 'data_scale': 1.86, 'dataset_type': 'image', 'dec_size': [1024, 512, 1024], 'dist_thresh': 100, 'enc_dest_size': [8, 16], 'enc_latent_size': [8, 50], 'enc_past_size': [512, 256], 'non_local_theta_size': [256, 128, 64], 'non_local_phi_size': [256, 128, 64], 'non_local_g_size': [256, 128, 64], 'non_local_dim': 128, 'fdim': 16, 'future_length': 12, 'gpu_index': 0, 'kld_reg': 1, 'learning_rate': 0.0003, 'mu': 0, 'n_values': 20, 'nonlocal_pools': 3, 'normalize_type': 'shift_origin', 'num_epochs': 650, 'num_workers': 0, 'past_length': 8, 'predictor_hidden_size': [1024, 512, 256], 'sigma': 1.3, 'test_b_size': 4096, 'time_thresh': 0, 'train_b_size': 512, 'zdim': 16}


In [14]:
train_dataset = SocialDataset(
    set_name="train",
    b_size=hyper_params["train_b_size"],
    t_tresh=hyper_params["time_thresh"],
    d_tresh=hyper_params["dist_thresh"],
    verbose=args.verbose,
)

test_dataset = SocialDataset(
    set_name="test",
    b_size=hyper_params["test_b_size"],
    t_tresh=hyper_params["time_thresh"],
    d_tresh=hyper_params["dist_thresh"],
    verbose=args.verbose,
)

model = Goal_example_model(
    input_feat=args.input_feat,
    output_feat=args.output_feat,
    config=hyper_params,
    non_local_loop=0,
).cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

./social_pool_data/train_all_512_0_100.pickle


  traj_new = np.array(traj_new)
  masks_new = np.array(masks_new)
  self.initial_pos_batches = np.array(


Initialized social dataloader...
./social_pool_data/test_all_4096_0_100.pickle
Initialized social dataloader...


In [15]:
"""Prepare some data for this batch of data"""
# shift origin and scale data
for traj in train_dataset.trajectory_batches:
    traj -= traj[:, :1, :]
    traj *= 0.2

for traj in test_dataset.trajectory_batches:
    traj -= traj[:, :1, :]
    traj *= 0.2

## Testing

In [16]:
def test(test_dataset, best_of_n=20):
    global model, optim
    model.eval()
    ade_bigls = []
    fde_bigls = []

    for i, (traj, mask, initial_pos) in enumerate(
        zip(
            test_dataset.trajectory_batches,
            test_dataset.mask_batches,
            test_dataset.initial_pos_batches,
        )
    ):

        traj_v = np.gradient(np.transpose(traj, (0, 2, 1)), 0.4, axis=-1)
        traj_a = np.gradient(traj_v, 0.4, axis=-1)
        traj_v = torch.from_numpy(traj_v).permute(0, 2, 1)
        traj_a = torch.from_numpy(traj_a).permute(0, 2, 1)

        traj, mask, initial_pos, traj_a, traj_v = (
            torch.DoubleTensor(traj).to(device),
            torch.DoubleTensor(mask).to(device),
            torch.DoubleTensor(initial_pos).to(device),
            torch.DoubleTensor(traj_a).to(device),
            torch.DoubleTensor(traj_v).to(device),
        )

        """Pre-process data into relative coords"""
        # input_traj = traj[:, : hyper_params["past_length"], :]
        dest = traj[:, -1].unsqueeze(1).repeat(1, 8, 1)
        # dest = 0.0
        # dest = torch.mean(traj, 1).unsqueeze(1).repeat(1, 8, 1)
        # input_traj = torch.cat(
        #     [
        #         traj[:, : hyper_params["past_length"]] - (dest / 3.0),
        #         traj_v[:, : hyper_params["past_length"]],
        #         traj_a[:, : hyper_params["past_length"]],
        #     ],
        #     -1,
        # )
        # input_traj = traj[:, : hyper_params["past_length"]] - (dest / 2.0)
        input_traj = traj[:, : hyper_params["past_length"]] - (dest)
        # input_traj = traj[:, : hyper_params["past_length"]] - dest
        # input_traj = torch.cat([traj[:, : hyper_params["past_length"]], dest[:, :1]], 1)

        init_traj = traj[
            :, hyper_params["past_length"] - 1 : hyper_params["past_length"], :
        ]
        V_tr = traj[:, hyper_params["past_length"] :, :]

        V_pred, _ = model(input_traj, mask)
        V_pred = V_pred.squeeze()

        log_pis = torch.ones(V_pred[..., -2:-1].shape)
        gmm2d = GMM2D(
            log_pis,
            V_pred[..., 0:2],
            V_pred[..., 2:4],
            Func.tanh(V_pred[..., -1]).unsqueeze(-1),
        )

        ade_ls = {}
        fde_ls = {}
        for n in range(traj.shape[0]):
            ade_ls[n] = []
            fde_ls[n] = []

        for k in range(best_of_n):
            V_pred = gmm2d.rsample().squeeze()

            """Evaluate rel output

            Comment out for evaluating abs output
            """
            # V_pred = torch.cumsum(V_pred, dim=1) + init_traj.repeat(1, 12, 1)

            for n in range(traj.shape[0]):
                ade_ls[n].append(torch.norm(V_pred[n] - V_tr[n], dim=-1).mean())
                fde_ls[n].append(torch.norm(V_pred[n, -1] - V_tr[n, -1]))

        # Metrics
        for n in range(traj.shape[0]):
            ade_bigls.append(min(ade_ls[n]))
            fde_bigls.append(min(fde_ls[n]))

    ade_ = sum(ade_bigls) / len(ade_bigls)
    fde_ = sum(fde_bigls) / len(fde_bigls)
    return ade_, fde_

## Training

In [17]:
def train(train_dataset, epoch):
    global model, optim
    model.train()

    for i, (traj, mask, initial_pos) in enumerate(
        zip(
            train_dataset.trajectory_batches,
            train_dataset.mask_batches,
            train_dataset.initial_pos_batches,
        )
    ):

        optimizer.zero_grad()

        traj_v = np.gradient(np.transpose(traj, (0, 2, 1)), 0.4, axis=-1)
        traj_a = np.gradient(traj_v, 0.4, axis=-1)
        traj_v = torch.from_numpy(traj_v).permute(0, 2, 1)
        traj_a = torch.from_numpy(traj_a).permute(0, 2, 1)

        traj, mask, initial_pos, traj_v, traj_a = (
            torch.DoubleTensor(traj).to(device),
            torch.DoubleTensor(mask).to(device),
            torch.DoubleTensor(initial_pos).to(device),
            torch.DoubleTensor(traj_v).to(device),
            torch.DoubleTensor(traj_a).to(device),
        )

        """Pre-process data into relative coords"""
        # rel_traj = traj[:, 1:] - traj[:, :-1]
        # V_tr = rel_traj[:, -12:]
        V_tr = traj[:, hyper_params["past_length"] :]
        dest = traj[:, -1].unsqueeze(1).repeat(1, 8, 1)
        # dest = 0.0
        # dest = torch.mean(traj, 1).unsqueeze(1).repeat(1, 8, 1)

        # input_traj = torch.cat(
        # [
        # traj[:, : hyper_params["past_length"]] - (dest / 3.0),
        # traj_a[:, : hyper_params["past_length"]],
        # traj_v[:, : hyper_params["past_length"]],
        # ],
        # -1,
        # )

        # input_traj = traj[:, : hyper_params["past_length"]] - (dest / 2.0)
        input_traj = traj[:, : hyper_params["past_length"]] - (dest)
        # input_traj = traj[:, : hyper_params["past_length"]] - dest
        # input_traj = torch.cat([traj[:, : hyper_params["past_length"]], dest[:, :1]], 1)

        V_pred, _ = model(input_traj, mask)
        V_pred = V_pred.squeeze()

        loss = graph_loss(V_pred, V_tr)
        loss.backward()

        optimizer.step()

        # Metrics
        loss_batch = loss.item()
        print("TRAIN:", "\t Epoch:", epoch, "\t Loss:", loss_batch)

In [18]:
num_epochs = 450

In [None]:
for epoch in range(num_epochs):

    train(train_dataset, epoch)

    if epoch > 20:
        ade_ = 99999
        fde_ = 99999
        ad, fd = test(test_dataset, 20)
        ade_new = min(ade_, ad)
        fde_new = min(fde_, fd)

        if ade_new < ade_ and fde_new < fde_:
            ade_ = ade_new
            fde_ = fde_new
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.checkpoint,
                    "val_best_{}_{}_{}.pth".format(
                        epoch, ade_.item() * 5.0, fde_.item() * 5.0
                    ),
                ),
            )

        print("ADE:", ade_.item() * 5.0, " FDE:", fde_.item() * 5.0)

TRAIN: 	 Epoch: 0 	 Loss: 4.976104639542424
TRAIN: 	 Epoch: 0 	 Loss: 4.119008614846586
TRAIN: 	 Epoch: 0 	 Loss: 5.17392846187933
TRAIN: 	 Epoch: 0 	 Loss: 5.268663329851114
TRAIN: 	 Epoch: 0 	 Loss: 5.317826901853049
TRAIN: 	 Epoch: 0 	 Loss: 5.231563614459357
TRAIN: 	 Epoch: 0 	 Loss: 5.702281257661526
TRAIN: 	 Epoch: 0 	 Loss: 5.652488873885025
TRAIN: 	 Epoch: 0 	 Loss: 4.244416648440456
TRAIN: 	 Epoch: 0 	 Loss: 4.272580113987758
TRAIN: 	 Epoch: 0 	 Loss: 4.55361896224454
TRAIN: 	 Epoch: 0 	 Loss: 4.416958872445815
TRAIN: 	 Epoch: 0 	 Loss: 6.488948156736695
TRAIN: 	 Epoch: 0 	 Loss: 6.518836780830007
TRAIN: 	 Epoch: 0 	 Loss: 5.6975577504131865
TRAIN: 	 Epoch: 0 	 Loss: 5.796957234079277
TRAIN: 	 Epoch: 0 	 Loss: 7.236629328871521
TRAIN: 	 Epoch: 0 	 Loss: 7.165961565620172
TRAIN: 	 Epoch: 0 	 Loss: 7.506950978945658
TRAIN: 	 Epoch: 0 	 Loss: 7.442430360161001
TRAIN: 	 Epoch: 0 	 Loss: 6.907208124759866
TRAIN: 	 Epoch: 0 	 Loss: 6.95904282533692
TRAIN: 	 Epoch: 0 	 Loss: 6.035491

TRAIN: 	 Epoch: 5 	 Loss: 4.450963650178811
TRAIN: 	 Epoch: 5 	 Loss: 4.38573975198695
TRAIN: 	 Epoch: 5 	 Loss: 3.0135117736057486
TRAIN: 	 Epoch: 5 	 Loss: 2.9888636916535978
TRAIN: 	 Epoch: 5 	 Loss: 3.62889752974633
TRAIN: 	 Epoch: 5 	 Loss: 3.568586621103368
TRAIN: 	 Epoch: 5 	 Loss: 5.815552296984096
TRAIN: 	 Epoch: 5 	 Loss: 5.77977491531818
TRAIN: 	 Epoch: 5 	 Loss: 4.855262846489159
TRAIN: 	 Epoch: 5 	 Loss: 4.991170063598151
TRAIN: 	 Epoch: 5 	 Loss: 6.451738264617149
TRAIN: 	 Epoch: 5 	 Loss: 6.602131266500959
TRAIN: 	 Epoch: 5 	 Loss: 6.758210685268232
TRAIN: 	 Epoch: 5 	 Loss: 6.8096979494662
TRAIN: 	 Epoch: 5 	 Loss: 5.8915748724918675
TRAIN: 	 Epoch: 5 	 Loss: 5.96498075931627
TRAIN: 	 Epoch: 5 	 Loss: 5.671491249351885
TRAIN: 	 Epoch: 5 	 Loss: 5.697043493474173
TRAIN: 	 Epoch: 5 	 Loss: 5.019925443073145
TRAIN: 	 Epoch: 5 	 Loss: 4.87898635830234
TRAIN: 	 Epoch: 5 	 Loss: 6.103506753551817
TRAIN: 	 Epoch: 5 	 Loss: 6.234110542311567
TRAIN: 	 Epoch: 5 	 Loss: 6.17942776

TRAIN: 	 Epoch: 10 	 Loss: 6.025657094453281
TRAIN: 	 Epoch: 10 	 Loss: 6.059471954771451
TRAIN: 	 Epoch: 10 	 Loss: 5.366661489399132
TRAIN: 	 Epoch: 10 	 Loss: 5.561970897041175
TRAIN: 	 Epoch: 10 	 Loss: 6.847239995719478
TRAIN: 	 Epoch: 10 	 Loss: 6.810365411506763
TRAIN: 	 Epoch: 10 	 Loss: 7.071890601807789
TRAIN: 	 Epoch: 10 	 Loss: 7.125538462171187
TRAIN: 	 Epoch: 10 	 Loss: 6.480687853030211
TRAIN: 	 Epoch: 10 	 Loss: 6.584172194282559
TRAIN: 	 Epoch: 10 	 Loss: 5.8737859027813055
TRAIN: 	 Epoch: 10 	 Loss: 5.83104589072857
TRAIN: 	 Epoch: 10 	 Loss: 5.527515365959557
TRAIN: 	 Epoch: 10 	 Loss: 5.402510939139146
TRAIN: 	 Epoch: 10 	 Loss: 6.203285663856521
TRAIN: 	 Epoch: 10 	 Loss: 6.285144580030246
TRAIN: 	 Epoch: 10 	 Loss: 6.341968915789712
TRAIN: 	 Epoch: 10 	 Loss: 6.204831037416037
TRAIN: 	 Epoch: 10 	 Loss: 6.264779131282818
TRAIN: 	 Epoch: 10 	 Loss: 6.292674300515696
TRAIN: 	 Epoch: 10 	 Loss: 6.087111231346379
TRAIN: 	 Epoch: 10 	 Loss: 6.123545481920674
TRAIN: 	 E

TRAIN: 	 Epoch: 15 	 Loss: 4.605060651560944
TRAIN: 	 Epoch: 15 	 Loss: 4.918653360112188
TRAIN: 	 Epoch: 15 	 Loss: 6.232705180632735
TRAIN: 	 Epoch: 15 	 Loss: 6.278124144560232
TRAIN: 	 Epoch: 15 	 Loss: 6.612066255875715
TRAIN: 	 Epoch: 15 	 Loss: 6.661280147866034
TRAIN: 	 Epoch: 15 	 Loss: 5.815610239306986
TRAIN: 	 Epoch: 15 	 Loss: 5.765322018808267
TRAIN: 	 Epoch: 15 	 Loss: 5.330027793122047
TRAIN: 	 Epoch: 15 	 Loss: 5.335722028275343
TRAIN: 	 Epoch: 15 	 Loss: 4.871195443174952
TRAIN: 	 Epoch: 15 	 Loss: 4.7385004223564025
TRAIN: 	 Epoch: 15 	 Loss: 5.764532502813357
TRAIN: 	 Epoch: 15 	 Loss: 5.7776985421521605
TRAIN: 	 Epoch: 15 	 Loss: 5.8840144194933615
TRAIN: 	 Epoch: 15 	 Loss: 5.881761738548135
TRAIN: 	 Epoch: 15 	 Loss: 5.904970460555941
TRAIN: 	 Epoch: 15 	 Loss: 5.9460857882467755
TRAIN: 	 Epoch: 15 	 Loss: 5.684225843994417
TRAIN: 	 Epoch: 15 	 Loss: 5.792804588399486
TRAIN: 	 Epoch: 15 	 Loss: 4.443758003305709
TRAIN: 	 Epoch: 15 	 Loss: 4.51939561810507
TRAIN: 

TRAIN: 	 Epoch: 20 	 Loss: 6.1497580522286714
TRAIN: 	 Epoch: 20 	 Loss: 5.989313982939369
TRAIN: 	 Epoch: 20 	 Loss: 6.333107081110839
TRAIN: 	 Epoch: 20 	 Loss: 6.365083409739265
TRAIN: 	 Epoch: 20 	 Loss: 5.501738803282141
TRAIN: 	 Epoch: 20 	 Loss: 5.525911709599141
TRAIN: 	 Epoch: 20 	 Loss: 5.182468619811947
TRAIN: 	 Epoch: 20 	 Loss: 5.232794433629017
TRAIN: 	 Epoch: 20 	 Loss: 4.5486568262847005
TRAIN: 	 Epoch: 20 	 Loss: 4.401057350290429
TRAIN: 	 Epoch: 20 	 Loss: 5.6164667336213725
TRAIN: 	 Epoch: 20 	 Loss: 5.647139734747079
TRAIN: 	 Epoch: 20 	 Loss: 5.706442154217402
TRAIN: 	 Epoch: 20 	 Loss: 5.665852471830067
TRAIN: 	 Epoch: 20 	 Loss: 5.607528989310623
TRAIN: 	 Epoch: 20 	 Loss: 5.721747272032137
TRAIN: 	 Epoch: 20 	 Loss: 5.473632019521112
TRAIN: 	 Epoch: 20 	 Loss: 5.539806783515711
TRAIN: 	 Epoch: 20 	 Loss: 4.081743011672749
TRAIN: 	 Epoch: 20 	 Loss: 4.205417625129765
TRAIN: 	 Epoch: 21 	 Loss: 3.0620625280145033
TRAIN: 	 Epoch: 21 	 Loss: 3.935606060199022
TRAIN:

TRAIN: 	 Epoch: 25 	 Loss: 4.1843130890022655
TRAIN: 	 Epoch: 25 	 Loss: 4.52283762722356
TRAIN: 	 Epoch: 25 	 Loss: 6.000041527037242
TRAIN: 	 Epoch: 25 	 Loss: 5.932856744757314
TRAIN: 	 Epoch: 25 	 Loss: 6.276429293751143
TRAIN: 	 Epoch: 25 	 Loss: 6.342422890789225
TRAIN: 	 Epoch: 25 	 Loss: 5.425722354462498
TRAIN: 	 Epoch: 25 	 Loss: 5.427130073006158
TRAIN: 	 Epoch: 25 	 Loss: 4.991492010913867
TRAIN: 	 Epoch: 25 	 Loss: 4.9804601717152455
TRAIN: 	 Epoch: 25 	 Loss: 4.3692712140211185
TRAIN: 	 Epoch: 25 	 Loss: 4.290654366558588
TRAIN: 	 Epoch: 25 	 Loss: 5.444595120269982
TRAIN: 	 Epoch: 25 	 Loss: 5.495594290908223
TRAIN: 	 Epoch: 25 	 Loss: 5.592469589038775
TRAIN: 	 Epoch: 25 	 Loss: 5.539950136892133
TRAIN: 	 Epoch: 25 	 Loss: 5.48068801319345
TRAIN: 	 Epoch: 25 	 Loss: 5.64614658766731
TRAIN: 	 Epoch: 25 	 Loss: 5.387433078366934
TRAIN: 	 Epoch: 25 	 Loss: 5.481594172364047
TRAIN: 	 Epoch: 25 	 Loss: 3.8945280049315096
TRAIN: 	 Epoch: 25 	 Loss: 4.0526741008372085
ADE: 98.

TRAIN: 	 Epoch: 30 	 Loss: 3.149829672558234
TRAIN: 	 Epoch: 30 	 Loss: 5.159700441988932
TRAIN: 	 Epoch: 30 	 Loss: 5.234983689573426
TRAIN: 	 Epoch: 30 	 Loss: 4.390758652146978
TRAIN: 	 Epoch: 30 	 Loss: 4.623517492115174
TRAIN: 	 Epoch: 30 	 Loss: 6.194670207728921
TRAIN: 	 Epoch: 30 	 Loss: 6.11073432374198
TRAIN: 	 Epoch: 30 	 Loss: 6.358436872407319
TRAIN: 	 Epoch: 30 	 Loss: 6.402118812869565
TRAIN: 	 Epoch: 30 	 Loss: 5.650915582984951
TRAIN: 	 Epoch: 30 	 Loss: 5.61074149326487
TRAIN: 	 Epoch: 30 	 Loss: 5.219423390756178
TRAIN: 	 Epoch: 30 	 Loss: 5.2090973618259016
TRAIN: 	 Epoch: 30 	 Loss: 4.630152384669593
TRAIN: 	 Epoch: 30 	 Loss: 4.538924885767205
TRAIN: 	 Epoch: 30 	 Loss: 5.511568155212281
TRAIN: 	 Epoch: 30 	 Loss: 5.544045194564953
TRAIN: 	 Epoch: 30 	 Loss: 5.710193502310157
TRAIN: 	 Epoch: 30 	 Loss: 5.64141716240268
TRAIN: 	 Epoch: 30 	 Loss: 5.669759100979808
TRAIN: 	 Epoch: 30 	 Loss: 5.690679167337635
TRAIN: 	 Epoch: 30 	 Loss: 5.392187097452464
TRAIN: 	 Epo

TRAIN: 	 Epoch: 35 	 Loss: 2.2547659722855067
TRAIN: 	 Epoch: 35 	 Loss: 2.266666799091162
TRAIN: 	 Epoch: 35 	 Loss: 3.0685397746545124
TRAIN: 	 Epoch: 35 	 Loss: 2.944005337728176
TRAIN: 	 Epoch: 35 	 Loss: 4.929768799255047
TRAIN: 	 Epoch: 35 	 Loss: 4.950717438475124
TRAIN: 	 Epoch: 35 	 Loss: 3.982562847667835
TRAIN: 	 Epoch: 35 	 Loss: 4.284342717318831
TRAIN: 	 Epoch: 35 	 Loss: 5.858436801575285
TRAIN: 	 Epoch: 35 	 Loss: 5.7418496981817855
TRAIN: 	 Epoch: 35 	 Loss: 6.128607005603792
TRAIN: 	 Epoch: 35 	 Loss: 6.1106113959237875
TRAIN: 	 Epoch: 35 	 Loss: 5.205434221733551
TRAIN: 	 Epoch: 35 	 Loss: 5.220259715883828
TRAIN: 	 Epoch: 35 	 Loss: 5.045750419478606
TRAIN: 	 Epoch: 35 	 Loss: 4.9275553830934635
TRAIN: 	 Epoch: 35 	 Loss: 4.326127414052778
TRAIN: 	 Epoch: 35 	 Loss: 6.492934787309264
TRAIN: 	 Epoch: 35 	 Loss: 6.512472061377636
TRAIN: 	 Epoch: 35 	 Loss: 6.083756234547363
TRAIN: 	 Epoch: 35 	 Loss: 6.034595032586556
TRAIN: 	 Epoch: 35 	 Loss: 6.122918989655709
TRAIN

TRAIN: 	 Epoch: 40 	 Loss: 3.5880846105421966
TRAIN: 	 Epoch: 40 	 Loss: 3.876688118075036
TRAIN: 	 Epoch: 40 	 Loss: 3.8432648954542383
TRAIN: 	 Epoch: 40 	 Loss: 2.3321821744767446
TRAIN: 	 Epoch: 40 	 Loss: 2.4857602495852573
TRAIN: 	 Epoch: 40 	 Loss: 3.0359371907713
TRAIN: 	 Epoch: 40 	 Loss: 2.9289247784507744
TRAIN: 	 Epoch: 40 	 Loss: 5.017131280159488
TRAIN: 	 Epoch: 40 	 Loss: 5.059259612891239
TRAIN: 	 Epoch: 40 	 Loss: 4.194082160754455
TRAIN: 	 Epoch: 40 	 Loss: 4.372829892809085
TRAIN: 	 Epoch: 40 	 Loss: 5.9491996977182655
TRAIN: 	 Epoch: 40 	 Loss: 5.905539367271524
TRAIN: 	 Epoch: 40 	 Loss: 6.164405803998563
TRAIN: 	 Epoch: 40 	 Loss: 6.171468502479877
TRAIN: 	 Epoch: 40 	 Loss: 5.414891684968907
TRAIN: 	 Epoch: 40 	 Loss: 5.378808000384088
TRAIN: 	 Epoch: 40 	 Loss: 4.901102139955793
TRAIN: 	 Epoch: 40 	 Loss: 4.989384718576709
TRAIN: 	 Epoch: 40 	 Loss: 4.414004449835284
TRAIN: 	 Epoch: 40 	 Loss: 4.38872110796833
TRAIN: 	 Epoch: 40 	 Loss: 5.362112644019014
TRAIN: 

TRAIN: 	 Epoch: 45 	 Loss: 3.8454729954019795
TRAIN: 	 Epoch: 45 	 Loss: 3.5568396805443703
TRAIN: 	 Epoch: 45 	 Loss: 3.64090570122735
TRAIN: 	 Epoch: 45 	 Loss: 3.7351560909508668
TRAIN: 	 Epoch: 45 	 Loss: 4.20896215854425
TRAIN: 	 Epoch: 45 	 Loss: 4.095331626728663
TRAIN: 	 Epoch: 45 	 Loss: 2.4265789434948433
TRAIN: 	 Epoch: 45 	 Loss: 2.519031475190281
TRAIN: 	 Epoch: 45 	 Loss: 3.06650943580141
TRAIN: 	 Epoch: 45 	 Loss: 3.0509658067566923
TRAIN: 	 Epoch: 45 	 Loss: 4.999733047584713
TRAIN: 	 Epoch: 45 	 Loss: 5.063486016114747
TRAIN: 	 Epoch: 45 	 Loss: 4.0881090679069185
TRAIN: 	 Epoch: 45 	 Loss: 4.263098254547446
TRAIN: 	 Epoch: 45 	 Loss: 5.7586104028452905
TRAIN: 	 Epoch: 45 	 Loss: 5.749442797937001
TRAIN: 	 Epoch: 45 	 Loss: 6.102593598456968
TRAIN: 	 Epoch: 45 	 Loss: 6.129412800195036
TRAIN: 	 Epoch: 45 	 Loss: 5.26022828648231
TRAIN: 	 Epoch: 45 	 Loss: 5.230692567584899
TRAIN: 	 Epoch: 45 	 Loss: 4.786108750286816
TRAIN: 	 Epoch: 45 	 Loss: 4.908581697145974
TRAIN: 

ADE: 91.32723095989262  FDE: 22.34939843299844
TRAIN: 	 Epoch: 50 	 Loss: 2.8676264833931695
TRAIN: 	 Epoch: 50 	 Loss: 2.940758117585187
TRAIN: 	 Epoch: 50 	 Loss: 3.6748996869560964
TRAIN: 	 Epoch: 50 	 Loss: 3.6347494457041294
TRAIN: 	 Epoch: 50 	 Loss: 3.594058706885601
TRAIN: 	 Epoch: 50 	 Loss: 3.5973687988840344
TRAIN: 	 Epoch: 50 	 Loss: 4.009687343747505
TRAIN: 	 Epoch: 50 	 Loss: 3.9506402248821817
TRAIN: 	 Epoch: 50 	 Loss: 2.3729796732738837
TRAIN: 	 Epoch: 50 	 Loss: 2.532921686434058
TRAIN: 	 Epoch: 50 	 Loss: 3.1268786980701044
TRAIN: 	 Epoch: 50 	 Loss: 3.1008322154656103
TRAIN: 	 Epoch: 50 	 Loss: 4.984071553048568
TRAIN: 	 Epoch: 50 	 Loss: 5.07535772861992
TRAIN: 	 Epoch: 50 	 Loss: 4.271831666104172
TRAIN: 	 Epoch: 50 	 Loss: 4.482431323274055
TRAIN: 	 Epoch: 50 	 Loss: 5.959573460808658
TRAIN: 	 Epoch: 50 	 Loss: 5.870960032012508
TRAIN: 	 Epoch: 50 	 Loss: 6.1502224275138175
TRAIN: 	 Epoch: 50 	 Loss: 6.235635604962427
TRAIN: 	 Epoch: 50 	 Loss: 5.531978321390598


TRAIN: 	 Epoch: 54 	 Loss: 5.222930281452051
TRAIN: 	 Epoch: 54 	 Loss: 3.679325993444594
TRAIN: 	 Epoch: 54 	 Loss: 3.808258167842034
ADE: 88.17193864618031  FDE: 21.435916791966907
TRAIN: 	 Epoch: 55 	 Loss: 2.7235570174491976
TRAIN: 	 Epoch: 55 	 Loss: 5.913149657896448
TRAIN: 	 Epoch: 55 	 Loss: 6.8867157072905645
TRAIN: 	 Epoch: 55 	 Loss: 4.530179114686998
TRAIN: 	 Epoch: 55 	 Loss: 4.932952879713207
TRAIN: 	 Epoch: 55 	 Loss: 5.238409827108862
TRAIN: 	 Epoch: 55 	 Loss: 5.548862503073306
TRAIN: 	 Epoch: 55 	 Loss: 5.545375707995516
TRAIN: 	 Epoch: 55 	 Loss: 4.789381268503249
TRAIN: 	 Epoch: 55 	 Loss: 4.8125890632850705
TRAIN: 	 Epoch: 55 	 Loss: 4.981103578139453
TRAIN: 	 Epoch: 55 	 Loss: 4.887766403152184
TRAIN: 	 Epoch: 55 	 Loss: 5.972124572763809
TRAIN: 	 Epoch: 55 	 Loss: 6.072990255392982
TRAIN: 	 Epoch: 55 	 Loss: 5.565921600306086
TRAIN: 	 Epoch: 55 	 Loss: 5.7158160714160635
TRAIN: 	 Epoch: 55 	 Loss: 6.89933718463858
TRAIN: 	 Epoch: 55 	 Loss: 6.679755597674243
TRAI

TRAIN: 	 Epoch: 59 	 Loss: 5.545190271386897
TRAIN: 	 Epoch: 59 	 Loss: 5.606094673840112
TRAIN: 	 Epoch: 59 	 Loss: 5.263646242319835
TRAIN: 	 Epoch: 59 	 Loss: 5.404852887465779
TRAIN: 	 Epoch: 59 	 Loss: 4.280925982120456
TRAIN: 	 Epoch: 59 	 Loss: 4.391033271994302
ADE: 89.63649325640642  FDE: 21.496124385395966
TRAIN: 	 Epoch: 60 	 Loss: 2.7379957959939762
TRAIN: 	 Epoch: 60 	 Loss: 2.688717536359354
TRAIN: 	 Epoch: 60 	 Loss: 3.565141440656951
TRAIN: 	 Epoch: 60 	 Loss: 3.575258398805175
TRAIN: 	 Epoch: 60 	 Loss: 3.628653072012322
TRAIN: 	 Epoch: 60 	 Loss: 3.6301766569552587
TRAIN: 	 Epoch: 60 	 Loss: 3.830735787335532
TRAIN: 	 Epoch: 60 	 Loss: 3.8414883889459666
TRAIN: 	 Epoch: 60 	 Loss: 2.3448686288453575
TRAIN: 	 Epoch: 60 	 Loss: 2.4001720639118416
TRAIN: 	 Epoch: 60 	 Loss: 2.919429062607716
TRAIN: 	 Epoch: 60 	 Loss: 2.9107631257572826
TRAIN: 	 Epoch: 60 	 Loss: 4.926645396073965
TRAIN: 	 Epoch: 60 	 Loss: 4.999108529470757
TRAIN: 	 Epoch: 60 	 Loss: 4.18156290083805
TR

TRAIN: 	 Epoch: 64 	 Loss: 5.168004570617978
TRAIN: 	 Epoch: 64 	 Loss: 5.169732603998892
TRAIN: 	 Epoch: 64 	 Loss: 5.239687139180404
TRAIN: 	 Epoch: 64 	 Loss: 5.277333937970253
TRAIN: 	 Epoch: 64 	 Loss: 5.365762448218061
TRAIN: 	 Epoch: 64 	 Loss: 5.0447420623780825
TRAIN: 	 Epoch: 64 	 Loss: 5.193002164761936
TRAIN: 	 Epoch: 64 	 Loss: 3.774990462895523
TRAIN: 	 Epoch: 64 	 Loss: 3.8605695116668683
ADE: 83.77351143089653  FDE: 20.661962197878445
TRAIN: 	 Epoch: 65 	 Loss: 2.391349212374796
TRAIN: 	 Epoch: 65 	 Loss: 2.540027743827867
TRAIN: 	 Epoch: 65 	 Loss: 3.7287087349321624
TRAIN: 	 Epoch: 65 	 Loss: 3.6044844179622326
TRAIN: 	 Epoch: 65 	 Loss: 3.3406353941170126
TRAIN: 	 Epoch: 65 	 Loss: 3.5545985058829315
TRAIN: 	 Epoch: 65 	 Loss: 3.793088440293621
TRAIN: 	 Epoch: 65 	 Loss: 3.7475711708835298
TRAIN: 	 Epoch: 65 	 Loss: 2.246655191047405
TRAIN: 	 Epoch: 65 	 Loss: 2.2900583911149752
TRAIN: 	 Epoch: 65 	 Loss: 2.940845568239396
TRAIN: 	 Epoch: 65 	 Loss: 2.921293015576603

TRAIN: 	 Epoch: 69 	 Loss: 4.751628535121606
TRAIN: 	 Epoch: 69 	 Loss: 3.993761192779168
TRAIN: 	 Epoch: 69 	 Loss: 3.986757570018599
TRAIN: 	 Epoch: 69 	 Loss: 5.015666514607831
TRAIN: 	 Epoch: 69 	 Loss: 5.10358481807462
TRAIN: 	 Epoch: 69 	 Loss: 5.11832905299771
TRAIN: 	 Epoch: 69 	 Loss: 5.157713920511767
TRAIN: 	 Epoch: 69 	 Loss: 5.10825865770805
TRAIN: 	 Epoch: 69 	 Loss: 5.290865027181759
TRAIN: 	 Epoch: 69 	 Loss: 4.9039392457797675
TRAIN: 	 Epoch: 69 	 Loss: 5.1003072601243025
TRAIN: 	 Epoch: 69 	 Loss: 3.621944048736823
TRAIN: 	 Epoch: 69 	 Loss: 3.723398645672115
ADE: 81.3831795269448  FDE: 20.59158924906979
TRAIN: 	 Epoch: 70 	 Loss: 2.23429997898866
TRAIN: 	 Epoch: 70 	 Loss: 2.331245564576839
TRAIN: 	 Epoch: 70 	 Loss: 3.1363136280504746
TRAIN: 	 Epoch: 70 	 Loss: 3.2456175409153363
TRAIN: 	 Epoch: 70 	 Loss: 3.0818314927332127
TRAIN: 	 Epoch: 70 	 Loss: 3.2350419694082313
TRAIN: 	 Epoch: 70 	 Loss: 3.416996317343004
TRAIN: 	 Epoch: 70 	 Loss: 3.4086381577435882
TRAIN:

TRAIN: 	 Epoch: 74 	 Loss: 4.924378881322987
TRAIN: 	 Epoch: 74 	 Loss: 4.975713328805202
TRAIN: 	 Epoch: 74 	 Loss: 4.560748135163775
TRAIN: 	 Epoch: 74 	 Loss: 4.6163374021646835
TRAIN: 	 Epoch: 74 	 Loss: 3.785031134713503
TRAIN: 	 Epoch: 74 	 Loss: 3.790609082204317
TRAIN: 	 Epoch: 74 	 Loss: 4.835586657412545
TRAIN: 	 Epoch: 74 	 Loss: 4.925033869774498
TRAIN: 	 Epoch: 74 	 Loss: 4.905856100066987
TRAIN: 	 Epoch: 74 	 Loss: 4.956499830769064
TRAIN: 	 Epoch: 74 	 Loss: 5.005944397205821
TRAIN: 	 Epoch: 74 	 Loss: 5.144340167169082
TRAIN: 	 Epoch: 74 	 Loss: 4.720022451334191
TRAIN: 	 Epoch: 74 	 Loss: 4.993606115656205
TRAIN: 	 Epoch: 74 	 Loss: 3.4375035447360127
TRAIN: 	 Epoch: 74 	 Loss: 3.5494220634763636
ADE: 78.31220175912551  FDE: 19.178861694051765
TRAIN: 	 Epoch: 75 	 Loss: 2.198455671307765
TRAIN: 	 Epoch: 75 	 Loss: 2.3342582231875744
TRAIN: 	 Epoch: 75 	 Loss: 3.243642843791543
TRAIN: 	 Epoch: 75 	 Loss: 3.225868301623131
TRAIN: 	 Epoch: 75 	 Loss: 3.0522906949247024
TR