In [None]:
import os

from sa_convlstm import SAConvLSTM
from convlstm import ConvLSTM
from utils import *

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import sys
import pickle
from tqdm import tqdm
import numpy as np
import math
import argparse
import json

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
load_dir = './save_models/saclstm_orth_2_epoch_50_pinchu_pandamonium/'
save_dir = load_dir[:-1] + "_transfer/"

In [None]:
transfer_epochs = 1
transfer_data = "./WeatherBenchData/wthrbnch_air_pv_PUV_sh_u_v_5.625deg_24.npy"

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")

In [None]:
with open(load_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

In [None]:
args.num_epochs = transfer_epochs
args.data = transfer_data

In [None]:
random_seed = 1234
np.random.seed(random_seed)
torch.manual_seed(random_seed)

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
with open(save_dir+'args.txt', 'w') as f:
    json.dump(args.__dict__, f, indent=2)

In [None]:
trainFolder = wb_dataset(root=args.data, dataset_type="train", frames_input=args.input_length,
                              frames_output=args.output_length, prob = args.prob_crps)

validFolder = wb_dataset(root=args.data, dataset_type="eval", frames_input=args.input_length,
                              frames_output=args.output_length, prob = args.prob_crps)

trainLoader = torch.utils.data.DataLoader(trainFolder,
                                          batch_size=args.batch_size,
                                          shuffle=True)

validLoader = torch.utils.data.DataLoader(validFolder,
                                          batch_size=args.batch_size,
                                          shuffle=False)

In [None]:
if args.convlstm and not args.prob_crps:
    network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:
    network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

In [None]:
optimizer = torch.optim.Adam(network.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)    
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=0, verbose=True, min_lr=0.0001)

In [None]:
max_ep = 0
for f in os.listdir(load_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk"  

In [None]:
chk = torch.load(load_dir + chkpnt)
network.load_state_dict(chk['net'])

In [None]:
for param in network.parameters():
    param.requires_grad = False
for param in network.layers[3].parameters():
    param.requires_grad = True
for param in network.conv_output.parameters():
    param.requires_grad = True

In [None]:
def transfer_train():
    cur_epoch = 0
    epoch_eval_loss = []
    epoch_eval_wb_loss = []
    count = 0
    best = math.inf
    ssr_ratio = 1
    for i in range(cur_epoch, args.num_epochs):
        print('\nepoch: {0}'.format(i))
        network.train()
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for j, mc in enumerate(t):
            if ssr_ratio > 0:
                ssr_ratio = max(ssr_ratio - args.ssr_decay_rate, 0)
                
            mc_pred = network(mc.float(), teacher_forcing=True, scheduled_sampling_ratio=ssr_ratio, train=True)
            optimizer.zero_grad()
            
            if args.prob_crps:
                loss = loss_prob(mc_pred, mc[:, 1:].to(device), args.output_dim) 
            else:
                loss = loss_mc(mc_pred, mc[:, 1:].to(device))
            
            loss.backward()                                                                                             
            if args.gradient_clipping:
                nn.utils.clip_grad_norm_(network.parameters(), args.clipping_threshold)
            optimizer.step()

            if j % 2500 == 0:
                print('batch training loss: {:.5f}, ssr ratio: {:.4f}'.format(loss, ssr_ratio))

        # evaluation
        loss_mc_eval = infer(validLoader, args.input_length, network, args.output_dim, args.prob_crps)
        epoch_eval_loss.append(loss_mc_eval)
        loss_wb_eval = infer_WB(validFolder, validLoader, args.input_length, network, args.output_dim, args.prob_crps)
        epoch_eval_wb_loss.append(loss_wb_eval.detach().cpu().numpy())
        print('epoch eval loss:\nmc loss: {:.5f}'.format(loss_mc_eval))
        lr_scheduler.step(loss_mc_eval)
        if loss_mc_eval >= best:
            count += 1
            print('eval loss is not improved for {} epoch'.format(count))
        else:
            count = 0
            print('eval loss is improved from {:.5f} to {:.5f}, saving model'.format(best, loss_mc_eval))
            save_model(save_dir + str(i) + "_checkpoint.chk")
            best = loss_mc_eval

        if count == args.patience:
            print('early stopping reached, best loss is {:5f}'.format(best))
            break
    np.save(save_dir + "eval_loss.npy", np.array(epoch_eval_loss))
    np.save(save_dir + "eval_wb_loss.npy", np.array(epoch_eval_wb_loss))

def save_model(path):
    torch.save({'net': network.state_dict(),
                'optimizer': optimizer.state_dict()}, path)

In [None]:
transfer_train()

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")

In [None]:
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

In [None]:
del trainFolder
del validFolder
del trainLoader
del validLoader

In [None]:
testFolder = wb_dataset(root=args.data, dataset_type="test", frames_input=args.input_length,
                              frames_output=args.output_length, prob = args.prob_crps)

testLoader = torch.utils.data.DataLoader(testFolder,
                                          batch_size=args.batch_size,
                                          shuffle=False)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
if args.convlstm and not args.prob_crps:
    network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

optimizer = torch.optim.Adam(network.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)    
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=0, verbose=True, min_lr=0.0001)

In [None]:
max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk"                                         

In [None]:
chk = torch.load(save_dir + chkpnt)
network.load_state_dict(chk['net'])

In [None]:
loss_mc_test = infer(testLoader, args.input_length, network, args.output_dim, args.prob_crps)
print("test rmse: ", loss_mc_test)

In [None]:
loss_WB_test = infer_WB(testFolder, testLoader, args.input_length, network, args.output_dim, args.prob_crps)
print("test lat rmse: ")
print(loss_WB_test)

In [None]:
item = testFolder.__getitem__(367)

In [None]:
output = network(torch.from_numpy(item[None, :7, ...]).float().to(device), train=False).detach().cpu().numpy()

In [None]:
import numpy as np
from matplotlib.colors import Normalize
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point

In [None]:
longs = np.arange(0, 360, 5.625)
lats = np.linspace(-90, 90, 32)

In [None]:
item = item*testFolder.long_std+testFolder.long_mean

In [None]:
output = output[0]*testFolder.long_std+testFolder.long_mean

In [None]:
day_delta=0
channel=0

In [None]:
labels = item[7][channel]
wrap_data, wrap_lon = add_cyclic_point(labels, coord=longs, axis=1)
plt.figure(figsize=(20,9))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
ax.contourf(wrap_lon, lats, wrap_data, 100, transform=ccrs.PlateCarree(), norm = Normalize(vmin=240, vmax=305))
ax.set_global()
plt.show()

In [None]:
labels = item[11][channel]
wrap_data, wrap_lon = add_cyclic_point(labels, coord=longs, axis=1)
plt.figure(figsize=(20,9))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
ax.contourf(wrap_lon, lats, wrap_data, 100, transform=ccrs.PlateCarree(), norm = Normalize(vmin=240, vmax=305))
ax.set_global()
plt.show()

In [None]:
# ---------------------

In [None]:
labels = output[0][channel]
wrap_data, wrap_lon = add_cyclic_point(labels, coord=longs, axis=1)
plt.figure(figsize=(20,9))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
ax.contourf(wrap_lon, lats, wrap_data, 100, transform=ccrs.PlateCarree(), norm = Normalize(vmin=240, vmax=305))
ax.set_global()
plt.show()

In [None]:
labels = output[4][channel]
wrap_data, wrap_lon = add_cyclic_point(labels, coord=longs, axis=1)
plt.figure(figsize=(20,9))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
ax.contourf(wrap_lon, lats, wrap_data, 100, transform=ccrs.PlateCarree(), norm = Normalize(vmin=240, vmax=305))
ax.set_global()
plt.show()