In [4]:
import os
import time
import pandas as pd
pd.set_option('precision', 4)

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, utils

from tools.args_tools import args
from tools.trajGRU import model
from tools.datasetGRU import TyDataset, ToTensor, Normalize

In [23]:
print(args)

{'disable_cuda': False, 'gpu': 0, 'max_epochs': 50, 'lr': 0.0001, 'lr_scheduler': False, 'weight_decay': 0.1, 'clip': False, 'clip_max_norm': 100, 'batch_norm': False, 'normalize_target': False, 'input_frames': 6, 'output_frames': 18, 'input_with_grid': True, 'channel_factor': 2, 'I_lat_l': 23.9125, 'I_lat_h': 26.15, 'I_lon_l': 120.4, 'I_lon_h': 122.6375, 'F_lat_l': 24.6625, 'F_lat_h': 25.4, 'F_lon_l': 121.15, 'F_lon_h': 121.8875, 'res_degree': 0.0125, 'origin_lat_l': 20, 'origin_lat_h': 27, 'origin_lon_l': 118, 'origin_lon_h': 123.5, 'working_folder': '/home/jack/Onedrive/01_IIS/04_TY_research', 'root_dir': '/home/jack/Onedrive/01_IIS/04_TY_research/01_Radar_data/02_numpy_files', 'ty_list_file': '/home/jack/Onedrive/01_IIS/04_TY_research/ty_list.xlsx', 'result_dir': '/home/jack/Onedrive/01_IIS/04_TY_research/04_results', 'params_dir': '/home/jack/Onedrive/01_IIS/04_TY_research/05_params', 'I_x_left': 193, 'I_x_right': 372, 'I_y_low': 314, 'I_y_high': 493, 'F_x_left': 253, 'F_x_right':

In [40]:
args.input_with_grid = True
args.lr_scheduler = True
args.clip = True
args.batch_norm = True
args.value_dtype = torch.float

In [41]:
## Create dataset
# Normalize data
mean = [7.044] * args.input_frames
std = [12.180] * args.input_frames
if args.normalize_target:
    mean += [1.122] * args.output_frames
    std += [3.858] * args.output_frames
    
transfrom = transforms.Compose([ToTensor(), Normalize(mean=mean, std=std)])

traindataset = TyDataset(ty_list_file = args.ty_list_file,
                         root_dir = args.root_dir,
                         input_frames = args.input_frames,
                         output_frames = args.output_frames,
                         train=True,
                         with_grid = args.input_with_grid,
                         transform = transfrom)

testdataset = TyDataset(ty_list_file = args.ty_list_file,
                        root_dir = args.root_dir,
                        input_frames = args.input_frames,
                        output_frames = args.output_frames,
                        train = False,
                        with_grid = args.input_with_grid,
                        transform = transfrom)

In [42]:
## return the shape of data
print(traindataset[0]['RAD'].shape)
print(testdataset[0]['QPE'].shape)

torch.Size([6, 3, 180, 180])
torch.Size([18, 60, 60])


In [43]:
inputs_channels = traindataset[0]['RAD'].shape[1]
# set the factor of cnn channels
c = args.channel_factor

## construct Traj GRU
# initialize the parameters of the encoders and forecasters

rnn_link_size = [13,13,9]

encoder_input_channel = inputs_channels
encoder_downsample_channels = [2*c,32*c,96*c]
encoder_rnn_channels = [32*c,96*c,96*c]

decoder_input_channel = 0
decoder_upsample_channels = [96*c,96*c,4*c]
decoder_rnn_channels = [96*c,96*c,32*c]

if int(args.I_shape[0]/3) == args.F_shape[0]:
    encoder_downsample_k = [5,4,4]
    encoder_downsample_s = [3,2,2]
    encoder_downsample_p = [1,1,1]
elif args.I_shape[0] == args.F_shape[0]:
    encoder_downsample_k = [3,4,4]
    encoder_downsample_s = [1,2,2]
    encoder_downsample_p = [1,1,1]

encoder_rnn_k = [3,3,3]
encoder_rnn_s = [1,1,1]
encoder_rnn_p = [1,1,1]
encoder_n_layers = 6

decoder_upsample_k = [4,4,3]
decoder_upsample_s = [2,2,1]
decoder_upsample_p = [1,1,1]

decoder_rnn_k = [3,3,3]
decoder_rnn_s = [1,1,1]
decoder_rnn_p = [1,1,1]
decoder_n_layers = 6

decoder_output = 1
decoder_output_k = 3
decoder_output_s = 1
decoder_output_p = 1
decoder_output_layers = 1

In [44]:
Net = model(n_encoders=args.input_frames, n_decoders=args.output_frames, rnn_link_size=rnn_link_size, 
            encoder_input_channel=encoder_input_channel, encoder_downsample_channels=encoder_downsample_channels,
            encoder_rnn_channels=encoder_rnn_channels, encoder_downsample_k=encoder_downsample_k,
            encoder_downsample_s=encoder_downsample_s, encoder_downsample_p=encoder_downsample_p, 
            encoder_rnn_k=encoder_rnn_k,encoder_rnn_s=encoder_rnn_s, encoder_rnn_p=encoder_rnn_p, 
            encoder_n_layers=encoder_n_layers, decoder_input_channel=decoder_input_channel, 
            decoder_upsample_channels=decoder_upsample_channels, decoder_rnn_channels=decoder_rnn_channels,
            decoder_upsample_k=decoder_upsample_k, decoder_upsample_s=decoder_upsample_s, 
            decoder_upsample_p=decoder_upsample_p, decoder_rnn_k=decoder_rnn_k, decoder_rnn_s=decoder_rnn_s,
            decoder_rnn_p=decoder_rnn_p, decoder_n_layers=decoder_n_layers, decoder_output=decoder_output, 
            decoder_output_k=decoder_output_k, decoder_output_s=decoder_output_s, 
            decoder_output_p=decoder_output_p, decoder_output_layers=decoder_output_layers, 
            batch_norm=args.batch_norm).to(args.device, dtype=args.value_dtype)
print(Net)

model(
  (Encoder_01): Encoder(
    (Downsample_00): CNN2D_cell(
      (layer): Sequential(
        (0): Conv2d(3, 4, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
        (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (trajGRUCell_00): trajGRUCell(
      (subnetwork): subCNN(
        (layer): Sequential(
          (0): Conv2d(68, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (1): ReLU()
          (2): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (3): ReLU()
        )
      )
      (reset_gate_input): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (update_gate_input): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (out_gate_input): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (reset_gate_warp): warp_CNN(
        (warpnet): Conv2d(832, 64, kernel_size=(1, 1), stride=(1, 1))
      )
    