In [None]:
import torch
from torch import nn
from torch.optim import lr_scheduler
import pandas as pd
import numpy as np

import os, sys
sys.path.insert(0, os.path.realpath(os.path.pardir))

from utils.config import cfg
from utils.blocks.forecaster import Forecaster
from utils.blocks.encoder import Encoder
from utils.tools.ordered_easydict import OrderedDict
from utils.blocks.module import EF
from utils.loss import WeightedCrossEntropyLoss
from utils.blocks.trajGRU import TrajGRU
from utils.blocks.module import Predictor
from utils.train_and_test import train_and_test
from net_params import encoder_params, forecaster_params
from utils.utils import *
from utils.blocks.probToPixel import ProbToPixel

In [2]:
batch_size = 2
max_iterations = 5000
test_iteration_interval = 100
test_and_save_checkpoint_iterations = 100

LR = 1e-5

encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster = Forecaster(forecaster_params[0], forecaster_params[1]).to(cfg.GLOBAL.DEVICE)


encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

thresholds = cfg.EVALUATION.THRESHOLDS
# encoder_forecaster.forecaster.stage1.conv3_3 = nn.Conv2d(8, len(thresholds)+1, kernel_size=(1, 1), stride=(1, 1)).to(cfg.GLOBAL.DEVICE)


thresholds = dBZ_to_rainfall(thresholds)
weights = np.ones_like(thresholds)
balancing_weights = cfg.EVALUATION.BALANCING_WEIGHTS
for i, threshold in enumerate(cfg.EVALUATION.THRESHOLDS):
    weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (thresholds >= threshold)
weights = weights + 1
weights = np.array([1] + weights.tolist())
weights = torch.from_numpy(weights).to(cfg.GLOBAL.DEVICE).float()
criterion = WeightedCrossEntropyLoss(thresholds, weights).to(cfg.GLOBAL.DEVICE)

ts = rainfall_to_dBZ(thresholds).tolist()
middle_value_dbz = np.array([-10.0] + [(x+y)/2 for x, y in zip(ts, ts[1:]+[60.0])])
middle_value = dBZ_to_pixel(middle_value_dbz).astype(np.float32)
probToPixel = ProbToPixel(middle_value, requires_grad=False)

optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.7)

folder_name = 'trajGRU_CE'

In [3]:
torch.cuda.is_available()

False

In [4]:
data = torch.randn(5, 4, 1, 480, 480)
output = encoder_forecaster(data.cuda())
print(output.size())

AssertionError: Torch not compiled with CUDA enabled

In [4]:
encoder_forecaster

EF(
  (encoder): Encoder(
    (rnn1): TrajGRU(
      (i2h): Conv2d(8, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(8, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (h2f_conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (ret): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
    )
    (stage1): Sequential(
      (conv1_leaky_1): Conv2d(1, 8, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
      (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn2): TrajGRU(
      (i2h): Conv2d(192, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(192, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (h2f_conv1): Conv2d(192, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 

In [None]:
train_and_test(encoder_forecaster, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name, probToPixel)

In [None]:
# from utils.tools.evaluation import *
# from utils.tools.dataloader import BKKIterator
# import copy
# import time
# import pickle

# IN_LEN = cfg.BENCHMARK.IN_LEN
# OUT_LEN = cfg.BENCHMARK.OUT_LEN
# name = 'conv2d'

In [None]:
# with torch.no_grad():
#     is_deeplearning_model = (torch.nn.Module in model.__class__.__bases__)
#     if is_deeplearning_model:
#         model.eval()
#     evaluator = Evaluation(seq_len=OUT_LEN, use_central=False)
#     bkk_iter = BKKIterator(pd_path=cfg.ONM_PD.RAINY_TEST,
#                                     sample_mode="sequent",
#                                     seq_len=IN_LEN + OUT_LEN,
#                                     stride=cfg.BENCHMARK.STRIDE)
#     model_run_avarage_time = dict()
#     model_run_avarage_time[name] = 0.0
#     valid_time = 0
#     while not bkk_iter.use_up:
#         valid_batch, valid_mask, sample_datetimes, _ = bkk_iter.sample(batch_size=1)
#         if valid_batch.shape[1] == 0:
#             break
#         if not cfg.EVALUATION.VALID_DATA_USE_UP and valid_time > cfg.EVALUATION.VALID_TIME:
#             break

#         valid_batch = valid_batch.astype(np.float32) / 255.0
#         valid_data = valid_batch[:IN_LEN, ...]
#         valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...]
#         mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int)

#         if is_deeplearning_model:
#             valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE)

#         start = time.time()
#         output = model(valid_data)
#         model_run_avarage_time[name] += time.time() - start

#         if is_deeplearning_model:
#             output = output.cpu().numpy()

#         output = np.clip(output, 0.0, 1.0)

#         evaluator.update(valid_label, output, mask)

#         valid_time += 1
#     model_run_avarage_time[name] /= valid_time
#     evaluator.save_pkl(os.path.join(cfg.BENCHMARK.STAT_PATH, name + '.pkl'))

In [None]:
# with open(os.path.join(cfg.BENCHMARK.STAT_PATH, 'model_run_avarage_time.pkl'), 'wb') as f:
#     pickle.dump(model_run_avarage_time, f)

In [None]:
# for p in os.listdir(os.path.abspath(cfg.BENCHMARK.STAT_PATH))[:1]:
#     e = pickle.load(open(os.path.join(cfg.BENCHMARK.STAT_PATH, p), 'rb'))
#     _, _, csi, hss, _, mse, mae, balanced_mse, balanced_mae, _ = e.calculate_stat()
#     print(p.split('.')[0])
#     for i, thresh in enumerate(cfg.EVALUATION.THRESHOLDS):
#         print('thresh %.1f csi: average %.4f, last frame %.4f; hss: average %.4f, last frame %.4f;'
#               % (thresh, csi[:, i].mean(), csi[-1, i], hss[:, i].mean(), hss[-1, i]))

#     print(('mse: average %.2f, last frame %.2f\n' +
#         'mae: average %.2f, last frame %.2f\n'+
#         'bmse: average %.2f, last frame %.2f\n' +
#         'bmae: average %.2f, last frame %.2f\n') % (mse.mean(), mse[-1], mae.mean(), mae[-1],
#               balanced_mse.mean(), balanced_mse[-1], balanced_mae.mean(), balanced_mae[-1]))