In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from utils.io_func import load_from_pkl, load_from_pth
from utils.helper import LSTMHelper

In [None]:
helper =LSTMHelper()

In [None]:
BASE_SITE = "Site_1"
TRAIN_YEARS = [str(year) for year in [2015, 2016, 2017]]
TEST_YEARS = [str(year) for year in [2018]]
X_PATH_TEMPLATE = "../preprocessing/out/{site}/x-corn_soybean-{year}.npy"
Y_PATH_TEMPLATE = "../preprocessing/out/{site}/y-corn_soybean-{year}.npy"
SCALER_PATH_TEMPLATE = (
    "../experiments/out/end_of_the_season/"
    "AtLSTM-corn_soybean/{site}/scaler.pkl"
)
MODEL_PATH_TEMPLATE = (
    "../experiments/out/end_of_the_season/"
    "AtLSTM-corn_soybean/{site}/atlstm.pth"
)
DEVICE = torch.device("cuda:0")

# Input

In [None]:
def get_paths(path_template, site, years):
    paths = []
    for year in years:
        paths.append(path_template.format(site=site, year=year))
    return paths


x_train = helper.input_x(get_paths(X_PATH_TEMPLATE, BASE_SITE, TRAIN_YEARS))
y_train = helper.input_y(get_paths(Y_PATH_TEMPLATE, BASE_SITE, TRAIN_YEARS))
x_test = helper.input_x(get_paths(X_PATH_TEMPLATE, BASE_SITE, TEST_YEARS))
y_test = helper.input_y(get_paths(Y_PATH_TEMPLATE, BASE_SITE, TEST_YEARS))

# Normalization

In [None]:
scaler = load_from_pkl(SCALER_PATH_TEMPLATE.format(site=BASE_SITE))
x_train = helper.normalize_with_scaler(scaler, x_train)
x_test = helper.normalize_with_scaler(scaler, x_test)

# Sampling

In [None]:
# select 3000 random samples
sample_index = np.random.choice(x_test.shape[0], 3000, replace=False)
sample_x = x_test[sample_index]
sample_y = y_test[sample_index]
sample_dataloader = helper.make_data_loader(sample_x, sample_y, shuffle=False)

# Input feature importance analysis

In [None]:
net = helper.build_model()
net.load_state_dict(load_from_pth(MODEL_PATH_TEMPLATE.format(site=BASE_SITE)))
net.to(DEVICE);  # semicolon is used for preventing extra output
net = nn.DataParallel(net, device_ids=[0,1,2,3])

for p in net.parameters():
    p.requires_grad = False

In [None]:
net.train()  # cudnn RNN backward can only be called in training mode
sample_grad_list = []
for i, batch in enumerate(sample_dataloader):
    xt_batch = batch["x"].to(DEVICE)
    xt_batch.requires_grad = True
    outputs, _ = net(xt_batch)
    outputs[np.arange(0, outputs.shape[0]), batch["y"]].sum().backward()
    sample_grad_list.append(xt_batch.grad.detach().cpu().numpy())
sample_grad_arr = np.concatenate(sample_grad_list, axis=0)

In [None]:
band_names = [
    "Blue", "Green", "Red", "Near-infrared",
    "Shortwave infrared 1", "Shortwave infrared 2",
]
class_dict = {
    "Corn": [0, "blue"],
    "Soybean": [1, "red"],
}

for class_name in ["Corn", "Soybean"]:
    class_value, class_color = class_dict[class_name]
    grad4class = sample_grad_arr[sample_y==class_value]
    mean4class = grad4class.mean(axis=0)
    std4class = grad4class.std(axis=0)
    fig, axs = plt.subplots(
        figsize=(18, 10.5), nrows=2, ncols=3,
        gridspec_kw={"wspace": 0.5, "hspace": 0.4}
    )
    vmin = (mean4class - std4class).min() - 0.1
    vmax = (mean4class + std4class).max() + 0.1
    for i, band_name in enumerate(band_names):
        means = mean4class[:, i]
        stds = std4class[:, i]
        ax = axs.reshape(-1)[i]
        plt.sca(ax)
        xrange = range(1, sample_grad_arr.shape[1] + 1)
        plt.plot(xrange, [0]*sample_grad_arr.shape[1], "--", color="orange")
        plt.plot(xrange, means, color=class_color)
        plt.fill_between(
            xrange, means - stds, means + stds,
            facecolor=class_color, alpha=0.1
        )
        plt.plot(
            xrange, means - stds,
            linewidth=1, color=class_color, alpha=0.25
        )
        plt.plot(
            xrange, means + stds,
            linewidth=1, color=class_color, alpha=0.25
        )
        plt.ylim([vmin, vmax])
        xticks = np.arange(1, sample_grad_arr.shape[1] + 1, 3)
        plt.xticks(xticks, xticks)
        plt.xlabel("Weeks after April 15")
        plt.ylabel("Derivative")
        plt.title(band_name)
    fig.suptitle(class_name)