In [6]:
import cv2
import shutil
import numpy as np
import pandas as pd
import torch
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from PIL import Image

import os, sys
sys.path.insert(0, '/content/Data-Science-Project-2-2021-2-Nowcasting/')

from utils.tools.dataloader import BKKIterator
from utils.config import cfg
from utils.config import cfg
from utils.blocks.forecaster import Forecaster
from utils.blocks.encoder import Encoder
from collections import OrderedDict
from utils.blocks.module import EF, Predictor
from utils.loss import Weighted_mse_mae
from utils.blocks.trajGRU import TrajGRU
from utils.train_and_test import train_and_test
from utils.tools.evaluation import *
from experiment.net_params import encoder_params, forecaster_params, conv2d_params
from utils.tools import image, mask
from utils.tools.evaluation import Evaluation

## Train-Valid-Test Split

In [7]:
from utils.tools.train_test_split import *
from utils.utils import *

rebuild_bkk_pkl()
train_test_split(cfg.ONM_PD.FOLDER_ALL, ratio=(0.8,0.05,0.15))

## Create Model and Load weights from .pth

Remove map_location=torch.device('cpu') if you're able to use cuda

In [None]:
IN_LEN = cfg.BENCHMARK.IN_LEN
OUT_LEN = cfg.BENCHMARK.OUT_LEN

# encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
# forecaster = Forecaster(forecaster_params[0], forecaster_params[1])
# model = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)
# model.load_state_dict(torch.load(os.path.join(cfg.GLOBAL.MODEL_SAVE_DIR, 'trajGRU_BMSE_BMAE', 'models', 'encoder_forecaster_1000.pth'), map_location=torch.device('cpu')))

model = Predictor(conv2d_params).to(cfg.GLOBAL.DEVICE)
model.load_state_dict(torch.load(os.path.join(cfg.GLOBAL.MODEL_SAVE_DIR, 'conv2d', 'models', 'encoder_forecaster_100.pth'), map_location=torch.device('cpu')))

## Predict 20 (OUT_LEN) outputs from 5 (IN_LEN) inputs

In [None]:
bkk_iter = BKKIterator(pd_path=cfg.ONM_PD.RAINY_TEST,
                       sample_mode="sequent",
                       seq_len=IN_LEN + OUT_LEN,
                       stride=cfg.BENCHMARK.STRIDE)

valid_batch, valid_mask, sample_datetimes, _ = bkk_iter.sample(batch_size=1)

valid_batch = valid_batch.astype(np.float32) / 255.0
valid_data = valid_batch[:IN_LEN, ...]
valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...]
mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int)
torch_valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE)

with torch.no_grad():
    output = model(torch_valid_data)

output = np.clip(output.cpu().numpy(), 0.0, 1.0)

base_dir = '.'
# S*B*1*H*W
label = valid_label[:, 0, 0, :, :]
output = output[:, 0, 0, :, :]
mask = mask[:, 0, 0, :, :].astype(np.uint8)

In [None]:
i = 0  # 0 ~ 19

In [None]:
plt.imshow(label[i][0][0])
plt.show()

In [None]:
plt.imshow(output[i][0][0])
plt.show()