In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
import torch
import torch.nn as nn
from utils.io_func import load_from_pkl, load_from_pth
from utils.helper import TransformerHelper

In [None]:
helper = TransformerHelper()

In [None]:
BASE_SITE = "Site_1"
TRAIN_YEARS = [str(year) for year in [2015, 2016, 2017]]
TEST_YEARS = [str(year) for year in [2018]]
X_PATH_TEMPLATE = "../preprocessing/out/{site}/x-corn_soybean-{year}.npy"
Y_PATH_TEMPLATE = "../preprocessing/out/{site}/y-corn_soybean-{year}.npy"
SCALER_PATH_TEMPLATE = (
    "../experiments/out/end_of_the_season/"
    "Transformer-corn_soybean/{site}/scaler.pkl"
)
MODEL_PATH_TEMPLATE = (
    "../experiments/out/end_of_the_season/"
    "Transformer-corn_soybean/{site}/transformer.pth"
)
DEVICE = torch.device("cuda:0")

# Input

In [None]:
def get_paths(path_template, site, years):
    paths = []
    for year in years:
        paths.append(path_template.format(site=site, year=year))
    return paths


x_train = helper.input_x(get_paths(X_PATH_TEMPLATE, BASE_SITE, TRAIN_YEARS))
y_train = helper.input_y(get_paths(Y_PATH_TEMPLATE, BASE_SITE, TRAIN_YEARS))
x_test = helper.input_x(get_paths(X_PATH_TEMPLATE, BASE_SITE, TEST_YEARS))
y_test = helper.input_y(get_paths(Y_PATH_TEMPLATE, BASE_SITE, TEST_YEARS))

# Normalization

In [None]:
scaler = load_from_pkl(SCALER_PATH_TEMPLATE.format(site=BASE_SITE))
x_train = helper.normalize_with_scaler(scaler, x_train)
x_test = helper.normalize_with_scaler(scaler, x_test)

# Sampling

In [None]:
# select 3000 random samples
sample_index = np.random.choice(x_test.shape[0], 3000, replace=False)
sample_x = x_test[sample_index]
sample_y = y_test[sample_index]
sample_dataloader = helper.make_data_loader(sample_x, sample_y, shuffle=False)

# Self-attention weight analysis

In [None]:
net = helper.build_model()
net.load_state_dict(load_from_pth(MODEL_PATH_TEMPLATE.format(site=BASE_SITE)))
net.to(DEVICE);  # semicolon is used for preventing extra output

In [None]:
attn_weights_list = []


def record_attn_weights(module, attn_in, attn_out):
    attn_weights_list.append(attn_out[1].detach().cpu().numpy())


handles = []
for i, encoder_layer in enumerate(net.encoder.layers):
    handles.append(
        encoder_layer.self_attn.register_forward_hook(record_attn_weights)
    )

In [None]:
sample_y_soft_pred, sample_y_hard_pred = helper.predict(
    net, sample_dataloader, DEVICE
)
for handle in handles:
    handle.remove()

In [None]:
class_dict = {"Corn": 0, "Soybean": 1}
attn_wegihts_arr = np.array(attn_weights_list)
for class_name in ["Corn", "Soybean"]:
    class_value = class_dict[class_name]
    attn_wegihts4class = attn_wegihts_arr[:, sample_y==class_value, :, :]
    fig, axs = plt.subplots(
        figsize=(13, 5), nrows=1, ncols=2,
        gridspec_kw={"wspace": 0.4, "hspace": 0.5}
    )
    for i, attn_weights in enumerate(attn_wegihts4class):
        attn_weights_mean = attn_weights.mean(axis=0)
        ax = axs.reshape(-1)[i]
        plt.sca(ax)
        vmin = np.floor(attn_weights_mean.min()*1000)/1000
        vmax = np.ceil(attn_weights_mean.max()*1000)/1000
        sns.heatmap(
            attn_weights_mean, vmin=vmin, vmax=vmax,
            cbar=False, cmap=sns.cubehelix_palette(light=0.95, as_cmap=True)
        )
        ticks = np.arange(4, attn_weights_mean.shape[0], 4)
        plt.xticks(ticks - 0.5, ticks)
        plt.yticks(ticks - 0.5, ticks)
        plt.ylabel("Weeks after April 15\n(output high-level features)")
        plt.xlabel("Weeks after April 15\n(input low-level features)")
        plt.title("Layer {:d}".format(i + 1))
        cbar = plt.colorbar(
            ax.get_children()[0],
            ticks=np.linspace(vmin, vmax, num=5),
            format=ticker.StrMethodFormatter("{x:.4f}"),
            extend="both"
        )
        cbar.ax.tick_params(labelsize=plt.rcParams["font.size"] - 4)
        cbar.outline.set_visible(False)
    # remove all tick markers on the axis of x, y and colorbar
    for ax in fig.axes:
        ax.tick_params(length=0)
    fig.suptitle(class_name)
    fig.subplots_adjust(top=0.8)

# Input feature importance analysis

In [None]:
net = helper.build_model()
net.load_state_dict(load_from_pth(MODEL_PATH_TEMPLATE.format(site=BASE_SITE)))
net.to(DEVICE);  # semicolon is used for preventing extra output
net = nn.DataParallel(net, device_ids=[0,1,2,3])

net.eval()
for p in net.parameters():
    p.requires_grad = False

In [None]:
sample_grad_list = []
for i, batch in enumerate(sample_dataloader):
    xt_batch = batch["x"].to(DEVICE)
    xt_batch.requires_grad = True
    outputs = net(xt_batch)
    outputs[np.arange(0, outputs.shape[0]), batch["y"]].sum().backward()
    sample_grad_list.append(xt_batch.grad.detach().cpu().numpy())
sample_grad_arr = np.concatenate(sample_grad_list, axis=0)

In [None]:
band_names = [
    "Blue", "Green", "Red", "Near-infrared",
    "Shortwave infrared 1", "Shortwave infrared 2",
]
class_dict = {
    "Corn": [0, "blue"],
    "Soybean": [1, "red"],
}

for class_name in ["Corn", "Soybean"]:
    class_value, class_color = class_dict[class_name]
    grad4class = sample_grad_arr[sample_y==class_value]
    mean4class = grad4class.mean(axis=0)
    std4class = grad4class.std(axis=0)
    fig, axs = plt.subplots(
        figsize=(18, 10.5), nrows=2, ncols=3,
        gridspec_kw={"wspace": 0.5, "hspace": 0.4}
    )
    vmin = (mean4class - std4class).min() - 0.1
    vmax = (mean4class + std4class).max() + 0.1
    for i, band_name in enumerate(band_names):
        means = mean4class[:, i]
        stds = std4class[:, i]
        ax = axs.reshape(-1)[i]
        plt.sca(ax)
        xrange = range(1, sample_grad_arr.shape[1] + 1)
        plt.plot(xrange, [0]*sample_grad_arr.shape[1], "--", color="orange")
        plt.plot(xrange, means, color=class_color)
        plt.fill_between(
            xrange, means - stds, means + stds,
            facecolor=class_color, alpha=0.1
        )
        plt.plot(
            xrange, means - stds,
            linewidth=1, color=class_color, alpha=0.25
        )
        plt.plot(
            xrange, means + stds,
            linewidth=1, color=class_color, alpha=0.25
        )
        plt.ylim([vmin, vmax])
        xticks = np.arange(1, sample_grad_arr.shape[1] + 1, 3)
        plt.xticks(xticks, xticks)
        plt.xlabel("Weeks after April 15")
        plt.ylabel("Derivative")
        plt.title(band_name)
    fig.suptitle(class_name)