In [None]:
import os

from atmap_sa_convlstm import SAConvLSTM
from convlstm import ConvLSTM
from utils import *

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import sys
import pickle
from tqdm import tqdm
import numpy as np
import math
import argparse
import json

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
save_dir = './save_models/saclstm_epoch_50_pinchu_pandamonium/'

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")

In [None]:
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

In [None]:
testFolder = wb_dataset(root=args.data, dataset_type="test", frames_input=args.input_length,
                              frames_output=args.output_length, prob = args.prob_crps)

testLoader = torch.utils.data.DataLoader(testFolder,
                                          batch_size=args.batch_size,
                                          shuffle=False)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

if args.convlstm and not args.prob_crps:
    network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

optimizer = torch.optim.Adam(network.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)    
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=0, verbose=True, min_lr=0.0001)

In [None]:
max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk"                                         

In [None]:
chk = torch.load(save_dir + chkpnt)
network.load_state_dict(chk['net'])

In [None]:
item = testFolder.__getitem__(367)

In [None]:
output, att_h_maps, att_Z_maps = network(torch.from_numpy(item[None, :7, ...]).float().to(device), train=False)
output = output.detach().cpu().numpy()

In [None]:
torch.cat(att_Z_maps, dim=0).shape

In [None]:
day_h_maps = torch.cat(att_h_maps, dim=0).reshape(11, 4, 128, 32, 64).mean(dim=(0))
day_Z_maps = torch.cat(att_Z_maps, dim=0).reshape(11, 4, 32, 32, 64).mean(dim=(0))

In [None]:
day_Z_maps.shape

In [None]:
import numpy as np
from matplotlib.colors import Normalize
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point

In [None]:
longs = np.arange(0, 360, 5.625)
lats = np.linspace(-90, 90, 32)

In [None]:
channel=0

In [None]:
for i in range(0, 32):
    labels = day_Z_maps[0][i].detach().cpu().numpy()
    labels = 2*(labels-np.min(labels))/(np.max(labels)-np.min(labels)) - 1
    wrap_data, wrap_lon = add_cyclic_point(labels, coord=longs, axis=1)
    plt.figure(figsize=(8,4))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    im = ax.contourf(wrap_lon, lats, wrap_data, 100, transform=ccrs.PlateCarree(), cmap='twilight_shifted')
    ax.set_global()
    # cbar = ax.figure.colorbar(im, aspect=10, fraction=0.2, orientation='vertical')
    # cbar.ax.tick_params(labelsize=20)
    plt.savefig("photos/367_atmap_saclstm_pcrps/" + str(i) + ".png")
#     plt.show()