In [1]:
import torch
from torch.optim import lr_scheduler

In [2]:
import os, sys
sys.path.insert(0, os.path.realpath(os.path.pardir))

from utils.config import cfg
from utils.blocks.forecaster import Forecaster
from utils.blocks.encoder import Encoder
from collections import OrderedDict
from utils.blocks.module import EF
from torch.optim import lr_scheduler
from utils.loss import Weighted_mse_mae
from utils.blocks.trajGRU import TrajGRU
from utils.train_and_test import train_and_test
from net_params import encoder_params, forecaster_params

## Train-Valid-Test Split

In [3]:
from utils.tools.train_test_split import *

train_test_split(cfg.ONM_PD.FOLDER_ALL, ratio=(0.8,0.05,0.15))

## Experiment trajGRU BMSE BMAE

In [4]:
batch_size = cfg.GLOBAL.BATCH_SIZE
max_iterations = 5000
test_iteration_interval = 2500
test_and_save_checkpoint_iterations = 2500

LR = 1e-4

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)

encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(forecaster_params[0], forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

encoder_forecaster.load_state_dict(torch.load(os.path.join(cfg.GLOBAL.MODEL_SAVE_DIR, 'encoder_forecaster_45000.pth')))

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
mult_step_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[2000, 4000], gamma=0.1)

folder_name = "trajGRU_BMSE_BMAE"

### Verifying model network

In [6]:
encoder_forecaster

EF(
  (encoder): Encoder(
    (rnn1): TrajGRU(
      (i2h): Conv2d(8, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(8, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (h2f_conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (ret): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
    )
    (stage1): Sequential(
      (conv1_leaky_1): Conv2d(1, 8, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
      (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn2): TrajGRU(
      (i2h): Conv2d(192, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(192, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (h2f_conv1): Conv2d(192, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 

### Verifying input/output shape

In [5]:
data = torch.randn(5, 4, 1, 480, 480)
output = encoder_forecaster(data)
print(output.size())

torch.Size([20, 4, 1, 480, 480])


### Train and Test

In [7]:
train_and_test(encoder_forecaster, optimizer, criterion, mult_step_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name)

  0%|          | 0/100000 [00:00<?, ?it/s]