In [2]:
import torch
from torch.optim import lr_scheduler
import pandas as pd

In [3]:
import os, sys
sys.path.insert(0, os.path.realpath(os.path.pardir))

from utils.config import cfg
from utils.blocks.forecaster import Forecaster
from utils.blocks.encoder import Encoder
from utils.tools.ordered_easydict import OrderedDict
from utils.blocks.module import EF
from utils.loss import Weighted_mse_mae
from utils.train_and_test import train_and_test
from net_params import conv2d_params
from utils.blocks.module import Predictor

In [4]:
batch_size = cfg.GLOBAL.BATCH_SIZE
max_iterations = 80000
test_iteration_interval = 10000
test_and_save_checkpoint_iterations = 10000

LR = 1e-4

criterion = Weighted_mse_mae().to(cfg.GLOBAL.DEVICE)

model = Predictor(conv2d_params).to(cfg.GLOBAL.DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-6)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=20000, gamma=0.7)
folder_name = "conv2d"

In [5]:
data = torch.randn(5, 4, 1, 480, 480)
output = model(data)
print(output.size())

torch.Size([20, 4, 1, 480, 480])


In [6]:
model

Predictor(
  (model): Sequential(
    (conv1_relu_1): Conv2d(5, 64, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
    (relu_conv1_relu_1): ReLU(inplace=True)
    (conv2_relu_1): Conv2d(64, 192, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (relu_conv2_relu_1): ReLU(inplace=True)
    (conv3_relu_1): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (relu_conv3_relu_1): ReLU(inplace=True)
    (deconv1_relu_1): ConvTranspose2d(192, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (relu_deconv1_relu_1): ReLU(inplace=True)
    (deconv2_relu_1): ConvTranspose2d(192, 64, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (relu_deconv2_relu_1): ReLU(inplace=True)
    (deconv3_relu_1): ConvTranspose2d(64, 64, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
    (relu_deconv3_relu_1): ReLU(inplace=True)
    (conv3_relu_2): Conv2d(64, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_conv3_relu_2): ReLU(inplace=True)
    (conv3_3

In [None]:
train_and_test(model, optimizer, criterion, exp_lr_scheduler, batch_size, max_iterations, test_iteration_interval, test_and_save_checkpoint_iterations, folder_name)

In [7]:
from utils.tools.evaluation import *
from utils.tools.dataloader import BKKIterator
import copy
import time
import pickle

IN_LEN = cfg.BENCHMARK.IN_LEN
OUT_LEN = cfg.BENCHMARK.OUT_LEN
name = 'conv2d'

In [None]:
with torch.no_grad():
    is_deeplearning_model = (torch.nn.Module in model.__class__.__bases__)
    if is_deeplearning_model:
        model.eval()
    evaluator = Evaluation(seq_len=OUT_LEN, use_central=False)
    bkk_iter = BKKIterator(pd_path=cfg.ONM_PD.RAINY_TEST,
                                    sample_mode="sequent",
                                    seq_len=IN_LEN + OUT_LEN,
                                    stride=cfg.BENCHMARK.STRIDE)
    model_run_avarage_time = dict()
    model_run_avarage_time[name] = 0.0
    valid_time = 0
    while not bkk_iter.use_up:
        valid_batch, valid_mask, sample_datetimes, _ = bkk_iter.sample(batch_size=1)
        if valid_batch.shape[1] == 0:
            break
        if not cfg.EVALUATION.VALID_DATA_USE_UP and valid_time > cfg.EVALUATION.VALID_TIME:
            break

        valid_batch = valid_batch.astype(np.float32) / 255.0
        valid_data = valid_batch[:IN_LEN, ...]
        valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...]
        mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int)

        if is_deeplearning_model:
            valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE)

        start = time.time()
        output = model(valid_data)
        model_run_avarage_time[name] += time.time() - start

        if is_deeplearning_model:
            output = output.cpu().numpy()

        output = np.clip(output, 0.0, 1.0)

        evaluator.update(valid_label, output, mask)

        valid_time += 1
    model_run_avarage_time[name] /= valid_time
    evaluator.save_pkl(os.path.join(cfg.BENCHMARK.STAT_PATH, name + '.pkl'))

In [None]:
with open(os.path.join(cfg.BENCHMARK.STAT_PATH, 'model_run_avarage_time.pkl'), 'wb') as f:
    pickle.dump(model_run_avarage_time, f)

In [8]:
for p in os.listdir(os.path.abspath(cfg.BENCHMARK.STAT_PATH))[:1]:
    e = pickle.load(open(os.path.join(cfg.BENCHMARK.STAT_PATH, p), 'rb'))
    _, _, csi, hss, _, mse, mae, balanced_mse, balanced_mae, _ = e.calculate_stat()
    print(p.split('.')[0])
    for i, thresh in enumerate(cfg.EVALUATION.THRESHOLDS):
        print('thresh %.1f csi: average %.4f, last frame %.4f; hss: average %.4f, last frame %.4f;'
              % (thresh, csi[:, i].mean(), csi[-1, i], hss[:, i].mean(), hss[-1, i]))

    print(('mse: average %.2f, last frame %.2f\n' +
        'mae: average %.2f, last frame %.2f\n'+
        'bmse: average %.2f, last frame %.2f\n' +
        'bmae: average %.2f, last frame %.2f\n') % (mse.mean(), mse[-1], mae.mean(), mae[-1],
              balanced_mse.mean(), balanced_mse[-1], balanced_mae.mean(), balanced_mae[-1]))

conv2d
thresh 0.5 csi: average 1.0000, last frame 1.0000; hss: average 0.5000, last frame 0.5000;
thresh 5.5 csi: average 0.2113, last frame 0.2100; hss: average 0.0000, last frame -0.0000;
thresh 10.0 csi: average 0.2022, last frame 0.2010; hss: average 0.0000, last frame -0.0000;
thresh 15.0 csi: average 0.1996, last frame 0.1930; hss: average 0.0125, last frame -0.0000;
thresh 20.0 csi: average 0.1607, last frame 0.1226; hss: average 0.0739, last frame 0.0000;
thresh 25.0 csi: average 0.2099, last frame 0.1122; hss: average 0.1869, last frame 0.0000;
thresh 30.0 csi: average 0.2377, last frame 0.2374; hss: average 0.2505, last frame 0.2626;
thresh 35.0 csi: average 0.2604, last frame 0.2579; hss: average 0.3170, last frame 0.3139;
thresh 40.0 csi: average 0.2378, last frame 0.2519; hss: average 0.3033, last frame 0.3230;
thresh 45.0 csi: average 0.2083, last frame 0.2170; hss: average 0.2720, last frame 0.2844;
thresh 50.0 csi: average 0.1793, last frame 0.1931; hss: average 0.2403,