In [1]:
import os 
import argparse, time
import numpy as np
import torch
from datetime import datetime
from model.build_model import build_model
from model.model_misc import train_model 
from model.misc import io_utils
from model.misc.torch_utils import seed_everything, count_params
from data.data_utils import load_data

In [2]:
ar = argparse.Namespace(
        task='ais',
        noise=None,
        Nobj=1,
        num_workers=0,
        data_root='AIS/',
        shuffle=True,
        
        # setting for ODE model
        # set baseline model to latent node
        model='node',
        ode_latent_dim=6,
        # Num of hidden layers in MLP diff func
        de_L=2,
        # Num of hidden neurons in MLP diff func
        de_H=128,

        # modulator
        # Invariant function
        inv_fnc='MLP',
        # dim of dynamic modulator
        modulator_dim=0,
        # dim of static modulator
        content_dim=0,
        # Time frames to select for RNN based Encoder for Invariance
        T_inv=0,
        # Nfilt invariant encoder cnn
        cnn_filt_inv = 16,

        # ode solver setting
        order = 1,
        solver = 'euler',
        dt = 0.1,
        use_adjoint = 'no_adjoint',

        # VAE (encoder decoder part)
        T_in = 50,
        cnn_filt_enc = 16,
        cnn_filt_de = 16,
        rnn_hidden = 10,
        dec_H = 128,
        dec_L = 2,
        dec_act = 'relu',
        enc_H = 50,

        # training setting
        Nepoch = 2000,
        # Number of sequential increments of the sequence length
        Nincr = 3,
        batch_size = 16,
        lr = 0.002,
        seed = 13,
        continue_training = False,
        plot_every = 250,
        plotL = 1,
        # num of forecast steps for plotting
        forecast_tr = 2,
        forecast_vl = 2,
        exp_id = 100,

        # save
        save = 'results/',
    )

In [3]:
def perform(args):
    ######### setup output directory and logger ###########
    args.save = os.path.join(os.getcwd(), \
        args.save+args.task+'/'+args.model+'/'+str(args.exp_id), '')

    ############################
    io_utils.makedirs(args.save)
    io_utils.makedirs(os.path.join(args.save, 'plots'))
    io_utils.makedirs(os.path.join(args.save, 'plots', 'fit'))
    io_utils.makedirs(os.path.join(args.save, 'plots', 'latents'))
    logger = io_utils.get_logger(logpath=os.path.join(args.save, 'logs.txt'))
    logger.info('Results stored in {}'.format(args.save))

    ########## set global random seed ###########
    if args.seed==-1:
        args.seed = int(time.time()*np.random.random()/1000)
    seed_everything(args.seed)

    ########## dtype #########
    dtype = torch.float64
    logger.info('********** Float type is {} ********** '.format(dtype))

    ########## plotter #######
    from model.misc.plot_utils import Plotter
    save_path = os.path.join(args.save, 'plots')
    plotter   = Plotter(save_path, args.task)

    ########### device #######
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info('********** Running model on {} ********** '.format(device))

    ########### data ############
    trainset, validset, testset, params = load_data(args, device, dtype)
    logger.info('********** {} dataset with loaded ********** '.format(args.task))
    logger.info('data params: {}'.format(params[args.task]))

    ########### model ###########
    model = build_model(args, device, dtype)
    model.to(device)
    model.to(dtype)

    logger.info('********** Built {} model with dynamics modulator dim {} and  content variable dim {}**********'.format(args.model, args.modulator_dim, args.content_dim))
    logger.info('********** Number of parameters: {} **********'.format(count_params(model)))
    logger.info('********** Augmented Dynamics: {} **********'.format(model.aug))
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    logger.info(model)

    if args.continue_training:
        fname = os.path.join(os.path.abspath(os.path.dirname(__file__)), args.save, 'model.pth')
        model.load_state_dict(torch.load(fname,map_location=torch.device(device)))
        logger.info('********** Resume training for model {} ********** '.format(fname))

    train_model(args, model, plotter, trainset, validset, testset, logger, params[args.task])

In [4]:
def buildArgs(t_in, t_inv, ode_latent_dim, dynamic_dim, static_dim, exp_id):
    args = argparse.Namespace(
        task='ais',
        noise=None,
        Nobj=1,
        num_workers=0,
        data_root='AIS/',
        shuffle=True,
        
        # setting for ODE model
        # set baseline model to latent node
        model='node',
        ode_latent_dim=ode_latent_dim,
        # Num of hidden layers in MLP diff func
        de_L=2,
        # Num of hidden neurons in MLP diff func
        de_H=128,

        # modulator
        # Invariant function
        inv_fnc='MLP',
        # dim of dynamic modulator
        modulator_dim=dynamic_dim,
        # dim of static modulator
        content_dim=static_dim,
        # Time frames to select for RNN based Encoder for Invariance
        T_inv=t_inv,
        # Nfilt invariant encoder cnn
        cnn_filt_inv = 16,

        # ode solver setting
        order = 1,
        solver = 'euler',
        dt = 0.1,
        use_adjoint = 'no_adjoint',

        # VAE (encoder decoder part)
        T_in = t_in,
        cnn_filt_enc = 16,
        cnn_filt_de = 16,
        rnn_hidden = 10,
        dec_H = 128,
        dec_L = 2,
        dec_act = 'relu',
        enc_H = 50,

        # training setting
        Nepoch = 2000,
        # Number of sequential increments of the sequence length
        Nincr = 3,
        batch_size = 16,
        lr = 0.002,
        seed = 13,
        continue_training = False,
        plot_every = 250,
        plotL = 1,
        # num of forecast steps for plotting
        forecast_tr = 2,
        forecast_vl = 2,
        exp_id = exp_id,

        # save
        save = 'results/',
    )

    perform(args)

In [5]:
def match(setting):
    buildArgs(setting[0], setting[1], setting[2], setting[3], setting[4], setting[5])

In [None]:
# experiment model explanation: exp_id (can be found from results/ais/node)
settings = [
    # NODE with dim 6: 0
    [50, 0, 6, 0, 0, 0],
    # NODE with dim 12: 1
    [50, 0, 12, 0, 0, 1],
    # NODE with dim 24: 2
    [50, 0, 24, 0, 0, 2],
    # MoNODE with static modulator and dim 3: 10
    [50, 50, 3, 0, 3, 10],
    # MoNODE with static modulator and dim 6: 11
    [50, 50, 6, 0, 6, 11],
    # MoNODE with static modulator and dim 12: 12
    [50, 50, 12, 0, 12, 12],
    # MoNODE with dynamic modulator and dim 3: 20
    [50, 50, 3, 3, 0, 20],
    # MoNODE with dynamic modulator and dim 6: 21
    [50, 50, 3, 3, 0, 21],
    # MoNODE with dynamic modulator and dim 12: 22
    [50, 50, 3, 3, 0, 22],
    # MoNODE with both modulator and dim 2: 30
    [50, 50, 2, 2, 2, 30],
    # MoNODE with both modulator and dim 4: 30
    [50, 50, 4, 4, 4, 31],
    # MoNODE with both modulator and dim 8: 30
    [50, 50, 8, 8, 8, 32],
]

In [7]:
from multiprocessing import *

pool = Pool(processes=12)
pool.map(func=match, iterable=settings)

Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/2/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/10/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/0/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/21/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/22/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/20/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/1/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/11/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/12/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/30/
********** Float type is torch.float64 ********** 
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/32/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/31/
********** Float type is torch.float64 ********** 
********** Float type is 

ais

********** Running model on cuda ********** 
********** Running model on cuda ********** 


aisais

TrueTrue



  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xvl = torch.load(data_path_vl)
********** Running model on cuda ********** 
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']ais
ais

********** Running model on cuda ********** 
********** Running model on cuda ********** 
********** Running model on cuda ********** 
********** Running model on cuda ********** 


ais
aisaisaisais


True
aisTrueTrue
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

********** Running model on cuda ********** 


True




  Xtr = torch.load(data_path_tr)




True

ais
TrueTrue


  Xtr = torch.load(data_path_tr)


True

  Xvl = torch.load(data_path_vl)





  Xvl = torch.load(data_path_vl)





  Xtr = torch.load(data_path_tr)


True



  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)
  Xte = torch.load(data_path_te)


True

  Xte = torch.load(data_path_te)
  Xvl = torch.load(data_path_vl)





  Xtr = torch.load(data_path_tr)



['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xvl = torch.load(data_path_vl)
  Xtr = torch.load(data_path_tr)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']


  Xvl = torch.load(data_path_vl)
  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)





  Xvl = torch.load(data_path_vl)
  Xvl = torch.load(data_path_vl)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']


  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']





['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xtr = torch.load(data_path_tr)






  Xvl = torch.load(data_path_vl)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xtr = torch.load(data_path_tr)



['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xvl = torch.load(data_path_vl)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xte = torch.load(data_path_te)



['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']


  Xtr = torch.load(data_path_tr)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xvl = torch.load(data_path_vl)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)





  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xtr = torch.load(data_path_tr)



['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']




  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xvl = torch.load(data_path_vl)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xvl = torch.load(data_path_vl)
  Xvl = torch.load(data_path_vl)






['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xtr = torch.load(data_path_tr)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)


['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']

  Xte = torch.load(data_path_te)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)






  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
  Xte = torch.load(data_path_te)


Train data:  torch.Size([10, 100, 2])
Train data:  torch.Size([10, 100, 2])Val   data: 
 torch.Size([5, 100, 2])Val   data: 
 torch.Size([5, 100, 2])
Test  data: Test  data:   torch.Size([8, 100, 2])torch.Size([8, 100, 2])



********** ais dataset with loaded ********** 


Train data: 

data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}
********** ais dataset with loaded ********** 


Train data:  

data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}


 torch.Size([10, 100, 2])torch.Size([10, 100, 2])
Train data: 
Val   data:  torch.Size([5, 100, 2])Val   data: 
 torch.Size([5, 100, 2])Test  data:  
Test  data: torch.Size([8, 100, 2])
 

********** ais dataset with loaded ********** 


torch.Size([8, 100, 2])

data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}


 Train data: 
torch.Size([10, 100, 2])

********** ais dataset with loaded ********** 


 


data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}


Train data: Train data: torch.Size([10, 100, 2])Train data: Val   data:   Train data: 
   torch.Size([10, 100, 2])Val   data: torch.Size([10, 100, 2])torch.Size([10, 100, 2])torch.Size([10, 100, 2])torch.Size([5, 100, 2])
 


Val   data: Val   data: 
torch.Size([5, 100, 2])Test  data: Val   data:   Val   data: 
 torch.Size([5, 100, 2])torch.Size([5, 100, 2])Test  data:    torch.Size([8, 100, 2])
torch.Size([5, 100, 2])
torch.Size([8, 100, 2])
Test  data: Test  data: torch.Size([5, 100, 2])
 Test  data: 
 

********** ais dataset with loaded ********** 


torch.Size([8, 100, 2])
 

********** ais dataset with loaded ********** 


torch.Size([8, 100, 2])


********** Built node model with dynamics modulator dim 3 and  content variable dim 0**********


torch.Size([8, 100, 2])

data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}






data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}


Test  data: 

********** Number of parameters: 37478 **********
********** ais dataset with loaded ********** 


 

********** ais dataset with loaded ********** 
********** Augmented Dynamics: True **********
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}


torch.Size([8, 100, 2])

Argument Nepoch: 2000
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}





Argument Nincr: 3
Argument Nobj: 1
********** ais dataset with loaded ********** 
********** ais dataset with loaded ********** 
********** Built node model with dynamics modulator dim 0 and  content variable dim 12**********
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}
Argument T_in: 50
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}
********** Built node model with dynamics modulator dim 8 and  content variable dim 8**********
Argument T_inv: 50
********** Number of parameters: 43472 **********
********** Built node model with dynamics modulator dim 0 and  content variable dim 0**********
Argument batch_size: 16
********** Augmented Dynamics: False **********
********** Number of parameters: 42240 **********
Argument cnn_filt_de: 16
Argument Nepoch: 2000
********** Augmented Dynamics: True **********
Argument cnn_filt_enc: 16
Argument Nincr: 3
********** Numb

EOFError: Ran out of input

manually performed remain ones, setting[-3], setting[-5]

In [17]:
key = True
key2 = True
while key or key2:
    try:
        if key:
            match(settings[-3])
        key = False
    except:
        pass
    try:
        if key2:
            match(settings[-5])
        key2 = False
    except:
        pass


Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/30/
Results stored in /media/usr/SSD/yongmin/NDE/MONODE/results/ais/node/30/
********** Float type is torch.float64 ********** 
********** Float type is torch.float64 ********** 
********** Running model on cuda ********** 
********** Running model on cuda ********** 
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
  Xtr = torch.load(data_path_tr)
  Xvl = torch.load(data_path_vl)
  Xte = torch.load(data_path_te)
********** ais dataset with loaded ********** 
********** ais dataset with loaded ********** 
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}
data params: {'train': {'N': 33, 'T': 50}, 'valid': {'N': 5, 'T': 50}, 'test': {'N': 8, 'T': 50}, 'dt': 0.05}
********** Built node model with dynamics modulator dim 2 and  content variable dim 2**********
********** Built node model with dynamics modulator di

ais
True
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
Train data:  torch.Size([10, 100, 2])
Val   data:  torch.Size([5, 100, 2])
Test  data:  torch.Size([8, 100, 2])


Epoch:   0/2000 | tr_loss:8971320.03(8971320.03) | valid_mse T=50 :10206.092 | valid_mse T=100 :9376.916 
Epoch:   0/2000 | tr_loss:8971320.03(8971320.03) | valid_mse T=50 :10206.092 | valid_mse T=100 :9376.916 
********** Current Best Model based on validation error ***********
********** Current Best Model based on validation error ***********
Epoch:   0/2000
Epoch:   0/2000
T=50 test_mse 8242.764(0.000)
T=50 test_mse 8242.764(0.000)
T=100 test_mse 8538.203(0.000)
T=100 test_mse 8538.203(0.000)
Epoch:   1/2000 | tr_loss:9186860.35(8977786.24) | valid_mse T=50 :10158.069 | valid_mse T=100 :9265.820 
Epoch:   1/2000 | tr_loss:9186860.35(8977786.24) | valid_mse T=50 :10158.069 | valid_mse T=100 :9265.820 
********** Current Best Model based on validation error ***********
********** Current Best Model based on validation error ***********
Epoch:   1/2000
Epoch:   1/2000
T=50 test_mse 8094.010(0.000)
T=50 test_mse 8094.010(0.000)
T=100 test_mse 8173.765(0.000)
T=100 test_mse 8173.765(0.0

ais
True
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
['ais', 'ais-vl-data.pkl', 'ais-tr-data.pkl', 'ais-te-data.pkl', 'data.pickle']
Train data:  torch.Size([10, 100, 2])
Val   data:  torch.Size([5, 100, 2])
Test  data:  torch.Size([8, 100, 2])


Epoch:   0/2000 | tr_loss:9271869.14(9271869.14) | valid_mse T=50 :10220.052 | valid_mse T=100 :9411.558 
Epoch:   0/2000 | tr_loss:9271869.14(9271869.14) | valid_mse T=50 :10220.052 | valid_mse T=100 :9411.558 
Epoch:   0/2000 | tr_loss:9271869.14(9271869.14) | valid_mse T=50 :10220.052 | valid_mse T=100 :9411.558 
********** Current Best Model based on validation error ***********
********** Current Best Model based on validation error ***********
********** Current Best Model based on validation error ***********
Epoch:   0/2000
Epoch:   0/2000
Epoch:   0/2000
T=50 test_mse 8225.771(0.000)
T=50 test_mse 8225.771(0.000)
T=50 test_mse 8225.771(0.000)
T=100 test_mse 8554.326(0.000)
T=100 test_mse 8554.326(0.000)
T=100 test_mse 8554.326(0.000)
Epoch:   1/2000 | tr_loss:8772059.75(9256874.86) | valid_mse T=50 :10198.253 | valid_mse T=100 :9345.072 
Epoch:   1/2000 | tr_loss:8772059.75(9256874.86) | valid_mse T=50 :10198.253 | valid_mse T=100 :9345.072 
Epoch:   1/2000 | tr_loss:8772059.7

Check the results by access the log files in the result folder.