In [None]:
import os

from sa_convlstm import SAConvLSTM
from convlstm import ConvLSTM
from utils import *

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import sys
import pickle
from tqdm import tqdm
import numpy as np
import math
import argparse
import json

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
save_dir = './save_models/saclstm_pcrps_epoch_50_pinchu_pandamonium/'

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")

In [None]:
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

In [None]:
testFolder = wb_dataset(root=args.data, dataset_type="test", frames_input=args.input_length,
                              frames_output=args.output_length, prob = args.prob_crps)

testLoader = torch.utils.data.DataLoader(testFolder,
                                          batch_size=args.batch_size,
                                          shuffle=False)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

if args.convlstm and not args.prob_crps:
    network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

optimizer = torch.optim.Adam(network.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)    
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=0, verbose=True, min_lr=0.0001)

In [None]:
max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk"                                         

In [None]:
chk = torch.load(save_dir + chkpnt)
network.load_state_dict(chk['net'])

In [None]:
big_iou = []

for i in range(len(testFolder)):
    if i%50 == 0: print(i)
    item = testFolder.__getitem__(i)
    output = network(torch.from_numpy(item[None, :7, ...]).float().to(device), train=False).detach().cpu().numpy()
    item = item[:, :5, ...]*testFolder.long_std+testFolder.long_mean
    var = output[0][:, 5:, ...]*testFolder.long_std
    output = output[0][:, :5, ...]*testFolder.long_std+testFolder.long_mean
    var = var**2
    iou_day= []
    for d in range(5):
        channel = 0
        day = d
        var_max_half_day = np.quantile(var[day, channel], 0.9)
        var_set = var[day, channel]
        var_set = var_set>var_max_half_day
        del_day = np.abs(item[day+7][channel] - output[day+0][channel])
        del_max_half_day = np.quantile(del_day, 0.9)
        del_set = del_day > del_max_half_day
        iou_day.append(np.sum(np.logical_and(var_set, del_set))/np.sum(np.logical_or(var_set, del_set)))
    big_iou.append(iou_day)

In [None]:
arr = np.array(big_iou)

In [None]:
arr.mean()