In [1]:
import numpy as np
import pandas as pd
import pickle
import sys, os
import torch
from sklearn.metrics import cohen_kappa_score
import matplotlib.pyplot as plt


module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)
print(module_path)
from datasets import dataset_utils
from datasets import sequence_aggregator


from explainability_analysis.visualization_functions import * 
from explainability_analysis.transformer_analysis import *
from explainability_analysis.crop_spectral_signature_analysis import * 

num_classes = 12
shuffle_setting = "original_sequences"
shuffle_sequences_fn = lambda x: x == "shuffled_sequences"

model_root_path = "C:/Users/results/{}_classes/{}/right_padding/obs_aq_date/layers=1,heads=1,emb_dim=128/".format(num_classes, shuffle_setting)
predictions_path = os.path.join(model_root_path, "predictions")
#load the test set predictions
predictions = np.loadtxt(os.path.join(predictions_path, "predicted_vs_true.csv"), skiprows = 1, delimiter = ",", dtype = np.uint)
attn_weights_path = os.path.join(predictions_path, "attn_weights", "postprocessed")

dataset_folder = "C:/Users/datasets/BavarianCrops/"
class_mapping = os.path.join(dataset_folder, "classmapping{}.csv".format(num_classes))
train_set, valid_set, test_set = dataset_utils.get_partitioned_dataset(
    dataset_folder,
    class_mapping,
    sequence_aggregator.SequencePadder(),
    None,
    shuffle_sequences=shuffle_sequences_fn(shuffle_setting))

#load the data required for the explainability analysis
keys_and_queries_data_path = os.path.join(attn_weights_path, "keys_and_queries.pickle")
with open(keys_and_queries_data_path, 'rb') as handle:
    parcels_keys_queries= pickle.load(handle)

spectral_indices = test_set.calculate_spectral_indices()
validation_spectral_indices = valid_set.calculate_spectral_indices()
total_temporal_attention = summarize_attention_weights_as_feature_embeddings(attn_weights_path, "layer_0", 0, summary_fn="sum")

C:\Users\Ivica Obadic\PycharmProjects\EOExplainability
Initializing BavarianCropsDataset train partition in holl with sequence shuffling = False
read 12 classes
precached dataset files found at C:/Users/datasets/BavarianCrops/npy\12\fallow,grassland,winter wheat,corn,summer wheat,winter spelt,winter rye,winter barley,summer barley,summer oat,winter triticale,rapeseed\blocks\holl\train
Dataset C:/Users/datasets/BavarianCrops/. region holl. partition train. Sequence shuffling = False X:20858x(71, 16), y:(20858,) with 12 classes
Initializing BavarianCropsDataset valid partition in holl with sequence shuffling = False
read 12 classes
precached dataset files found at C:/Users/datasets/BavarianCrops/npy\12\fallow,grassland,winter wheat,corn,summer wheat,winter spelt,winter rye,winter barley,summer barley,summer oat,winter triticale,rapeseed\blocks\holl\valid
Dataset C:/Users/datasets/BavarianCrops/. region holl. partition valid. Sequence shuffling = False X:3909x(144, 16), y:(3909,) with 12 

In [None]:
def get_data_for_parcel_id(parcel_id):
    
    with open(os.path.join(attn_weights_path,'{}_attn_weights_df.pickle'.format(parcel_id)), 'rb') as handle:
        parcel_attn_weights = pickle.load(handle)
        
    parcel_spectral_index = spectral_indices[parcel_id].copy()
    parcel_total_temporal_attention = total_temporal_attention[str(parcel_id)].to_numpy().flatten()
    parcel_spectral_index["TOTAL TEMPORAL ATTENTION"] = parcel_total_temporal_attention
    parcel_spectral_index = parcel_spectral_index.drop(["YEAR", "MONTH", "DATE", "CLASS", "NDVI"], axis=1)
    parcel_key_query_data = parcels_keys_queries[parcel_id]
    parcel_key_query_data = parcel_key_query_data.rename(columns={"TYPE":"Embedding", "emb_dim_1": "t-SNE dim. 1", "emb_dim_2": "t-SNE dim. 2"})
    #TODO: Remove the -1 label while creating the key and query embeddings
    parcel_key_query_data.replace(-1, "", inplace=True)
    crop_type_class = predictions[predictions[:, 0] == parcel_id][0, 1]
    crop_type_label = crop_types[crop_type_class]
    return ("{} parcel {}".format(crop_type_label, parcel_id),
             parcel_attn_weights["layer_0"],
             parcel_spectral_index,
             parcel_key_query_data)

def get_parcels_data(target_crop_types):
    
    parcels_data = []
    
    for crop_type_label in target_crop_types:
        crop_type_id = crop_types.index(crop_type_label)
        
        relevant_crop_parcels = predictions[predictions[:, 1] == crop_type_id]
        random_crop_parcel_idx = np.random.choice(relevant_crop_parcels.shape[0], size=1, replace=False)
        random_crop_parcel_id = str(relevant_crop_parcels[random_crop_parcel_idx, 0][0])
        
        parcel_data = get_data_for_parcel_id(
            int(random_crop_parcel_id))
        
        
        parcels_data.append(parcel_data)
    
    return parcels_data

crop_types = test_set.classname.flatten().tolist()
parcels_data = get_parcels_data(["corn", "grassland", "grassland"])

In [2]:
spectral_indices_per_week_and_parcel = calc_spectral_signature_per_time_frame(spectral_indices, agg_variable="PARCEL_ID")

Aggregating the spectral indices based on the time frame method


In [6]:
# spectral_indices_per_week_and_parcel["AVERAGE_NDVI"] = spectral_indices_per_week_and_parcel["NDVI"]["mean"]
# spectral_indices_per_week_and_parcel_ndvi = spectral_indices_per_week_and_parcel[["PARCEL_ID", "WEEK", "AVERAGE_NDVI"]]
spectral_indices_per_week_and_parcel[["PARCEL_ID", "WEEK", "NDVI"]].to_csv("average_ndvi_per_week_and_parcel.csv", index=False)

In [None]:
visualize_attention_weights(parcels_data, "Attention Weights vs Spectral Bands")

In [None]:
grassland_parcel_id = 72822622
grassland_parcel_data = get_data_for_parcel_id(grassland_parcel_id)
grassland_parcel_attn_weights = grassland_parcel_data[1]
grassland_parcel_spectral_index = grassland_parcel_data[2]
grassland_parcel_key_and_query_data = grassland_parcel_data[3]

In [None]:
plt.rcParams.update({
    "font.family": "serif",  # use serif/main font for text elements
    "text.usetex": True,     # use inline math for ticks
    "pgf.rcfonts": False,
    "font.size": 16,
    "font.family": "Computer Modern Roman"
    })
figure_results_path = "C:/Users/Ivica Obadic/paper_plots/attention_signature"
fig, ax_attn_weights = plt.subplots(figsize=(6,6))
ax_attn_weights = attn_weights_heatmap(grassland_parcel_attn_weights, ax_attn_weights, "Oranges", "grassland", 11)
fig.tight_layout()
plt.savefig(os.path.join(figure_results_path, 'attn_weights_parcel_distribution.eps'), dpi=400)

In [None]:
fig, ax_key_query = plt.subplots(figsize=(6,6))
grassland_parcel_temporal_key_attention = grassland_parcel_spectral_index["TOTAL TEMPORAL ATTENTION"].values

key_grassland_parcel_data = grassland_parcel_key_and_query_data.loc[grassland_parcel_key_and_query_data["Embedding"] == "KEY"]
key_grassland_parcel_data["TEMPORAL_ATTN"] = grassland_parcel_temporal_key_attention
cluster_avg_temp_attn = key_grassland_parcel_data.groupby("CLUSTER").agg(np.mean)
cluster_avg_temp_attn.sort_values(by="TEMPORAL_ATTN", inplace=True)
cluster_avg_temp_attn["Cluster Label"] = ["Non-attention cluster", "Attention cluster"]
print(cluster_avg_temp_attn)
key_grassland_parcel_data["AVG_CLUSTER_ATTN"] = key_grassland_parcel_data["CLUSTER"].map(lambda x: cluster_avg_temp_attn.loc[x]["TEMPORAL_ATTN"])
BASE_MARKER_SIZE = 120
query_markers_sizes = [BASE_MARKER_SIZE] * len(grassland_parcel_temporal_key_attention)
key_markers_sizes = key_grassland_parcel_data["AVG_CLUSTER_ATTN"].values * BASE_MARKER_SIZE
query_markers_sizes.extend(key_markers_sizes)

color_mapping={"":"b", 0: "tab:orange", 1: "m"}

grassland_parcel_key_and_query_data.rename(columns=({ 'CLUSTER': 'Cluster'}), inplace=True)

grassland_parcel_key_and_query_data.replace({"QUERY": "Query", "KEY": "Key"}, inplace=True)

ax_key_query = sns.scatterplot(data=grassland_parcel_key_and_query_data,
                               x="t-SNE dim. 1",
                               y="t-SNE dim. 2",
                               style="Embedding",
                               hue="Cluster",
                               palette=color_mapping,
                               s=query_markers_sizes,
                               ax=ax_key_query)
fig.tight_layout()
plt.savefig(os.path.join(figure_results_path, 'key_query_embeddings.eps'), dpi=400)

In [None]:
healthy_vegetation_curve = get_average_spectral_reflectance_curve(validation_spectral_indices)
fig, axs = plt.subplots(figsize=(6, 6))

key_grassland_parcel_data = key_grassland_parcel_data.set_index(grassland_parcel_spectral_index.index)
grassland_parcel_spectral_index[["CLUSTER"]] = key_grassland_parcel_data[["CLUSTER"]]
grassland_parcel_spectral_index_with_cluster_label = pd.merge(grassland_parcel_spectral_index, cluster_avg_temp_attn, on="CLUSTER").drop(["CLUSTER"], axis=1)
print(grassland_parcel_spectral_index_with_cluster_label.head())

spectral_bands_plot_data = pd.melt(grassland_parcel_spectral_index_with_cluster_label, id_vars="Cluster Label", var_name="BAND",value_name="SPECTRAL REFLECTANCE")
spectral_bands_plot_data = pd.merge(spectral_bands_plot_data, healthy_vegetation_curve, on="BAND")

color_mapping={"Attention cluster": "tab:orange", "Non-attention cluster": "m"}

axs = sns.pointplot(
    data=spectral_bands_plot_data,
    x="Wavelength (nm)",
    y="SPECTRAL REFLECTANCE_x",
    hue="Cluster Label",
    palette=color_mapping,
    ci="sd",
    ax = axs)

healthy_vegetation_curve.rename(columns=({'SPECTRAL REFLECTANCE': 'Spectral reflectance', "SPECTRAL SIGNATURE": "Spectral signature"}), inplace=True)
axs = sns.pointplot(
    data=healthy_vegetation_curve,
    x="Wavelength (nm)",
    y='Spectral reflectance',
    hue="Spectral signature",
    palette="Greens_r",
    linestyles='--',
    ci="sd",
    ax = axs)
axs.set(ylim=(0, 1))
n = 5  # Keeps every 5th label
[l.set_visible(False) for (i,l) in enumerate(axs.xaxis.get_ticklabels()) if i % n != 0]
fig.tight_layout()
plt.savefig(os.path.join(figure_results_path, 'attention_spectral_reflectance_vs_healthy_vegetation.eps'), dpi=400)

In [None]:
def calc_cluster_class_signature(spectral_indices, parcels_key_queries, total_temporal_attention):
    parcel_data = None
    for idx, parcel_id in enumerate(spectral_indices.keys()):
        if idx%2000 == 0:
            print("Calculation for parcel idx {}".format(idx))
        parcel_spectral_index = spectral_indices[parcel_id]
        parcel_class = parcel_spectral_index["CLASS"].values[0]
        parcel_spectral_index = parcel_spectral_index.drop(["YEAR", "MONTH", "DATE", "CLASS", "NDVI"], axis=1)
        parcel_total_temporal_attention = total_temporal_attention[str(parcel_id)].to_numpy().flatten()
        parcel_spectral_index["TOTAL TEMPORAL ATTENTION"] = parcel_total_temporal_attention
        parcel_key_query_data = parcels_key_queries[parcel_id]
        parcel_key_data = parcel_key_query_data.loc[parcel_key_query_data['TYPE'] == "KEY"]
        parcel_key_data = parcel_key_data.set_index(parcel_spectral_index.index)
        parcel_spectral_index[["KEY_CLUSTER"]] = parcel_key_data[["CLUSTER"]]
        spectral_cluster_diff = parcel_spectral_index[["KEY_CLUSTER","TOTAL TEMPORAL ATTENTION"]].groupby(["KEY_CLUSTER"]).aggregate(np.mean).sort_values(by="TOTAL TEMPORAL ATTENTION")
        spectral_cluster_diff["CLUSTER_LABEL"] = ["NO_ATTN", "HIGH_ATTN"]

        parcel_spectral_index["CLUSTER_LABEL"] = parcel_spectral_index["KEY_CLUSTER"].map(lambda x: spectral_cluster_diff.loc[x]["CLUSTER_LABEL"])
        parcel_spectral_index["CLASS"] = parcel_class

        if parcel_data is None:
            parcel_data = parcel_spectral_index
        else:
            parcel_data = pd.concat([parcel_data, parcel_spectral_index])

    #parcel_data = parcel_data.loc[parcel_data["CLUSTER_LABEL"] == "HIGH_ATTN"]
    return parcel_data


cluster_class_signature = calc_cluster_class_signature(spectral_indices, parcels_keys_queries, total_temporal_attention)
cluster_class_signature.head()

In [None]:
fig, axs = plt.subplots(figsize=(13, 8))
sns.barplot(data=cluster_class_signature, x="CLASS", y="TOTAL TEMPORAL ATTENTION", hue="CLUSTER_LABEL", ci="sd", ax = axs)
fig.suptitle("Crop-Type Difference in Attention Weights", fontsize=16)
fig.tight_layout()

In [None]:

cluster_class_signature_corn_grassland = cluster_class_signature.loc[cluster_class_signature["CLASS"].isin(["corn", "grassland", "summer barley", "winter wheat"])]
cluster_class_signature_corn_grassland = cluster_class_signature_corn_grassland[cluster_class_signature_corn_grassland["CLUSTER_LABEL"] == "HIGH_ATTN"]
cluster_class_signature_corn_grassland = cluster_class_signature_corn_grassland.drop(["TOTAL TEMPORAL ATTENTION", "KEY_CLUSTER"], axis=1)

cluster_class_signature_corn_grassland = pd.melt(cluster_class_signature_corn_grassland, id_vars=["CLUSTER_LABEL", "CLASS"], var_name="BAND",value_name="SPECTRAL REFLECTANCE")
cluster_class_signature_corn_grassland = pd.merge(cluster_class_signature_corn_grassland, healthy_vegetation_curve, on="BAND")
print(cluster_class_signature_corn_grassland)

fig, axs = plt.subplots(figsize=(6, 6))
axs = sns.pointplot(data=cluster_class_signature_corn_grassland,
                    x="Wavelength (nm)",
                    palette=crop_type_color_mapping,
                    y="SPECTRAL REFLECTANCE",
                    hue="CLASS",
                    ci="sd",
                    ax = axs)
axs = sns.pointplot(data=healthy_vegetation_curve,
                    x="Wavelength (nm)",
                    y='Spectral reflectance',
                    hue = "Spectral signature",
                    palette="Greens_r",
                    linestyles='--',
                    ci="sd",
                    ax = axs)
#axs.set_title("Crop-Type Attention Footprints", fontsize=12)
axs.set(ylim=(0, 1))
n = 5  # Keeps every 5th label
[l.set_visible(False) for (i,l) in enumerate(axs.xaxis.get_ticklabels()) if i % n != 0]
fig.tight_layout()
plt.savefig(os.path.join(figure_results_path, 'average_crop_type_footprint.eps'), dpi=400)

In [None]:
cluster_class_signature_corn_grassland = cluster_class_signature.loc[cluster_class_signature["CLASS"].isin(["corn", "grassland", "summer barley", "winter wheat", "winter barley"])]
cluster_class_signature_corn_grassland = cluster_class_signature_corn_grassland[cluster_class_signature_corn_grassland["CLUSTER_LABEL"] == "NO_ATTN"]
cluster_class_signature_corn_grassland = cluster_class_signature_corn_grassland.drop(["TOTAL TEMPORAL ATTENTION", "KEY_CLUSTER"], axis=1)

cluster_class_signature_corn_grassland = pd.melt(cluster_class_signature_corn_grassland, id_vars=["CLUSTER_LABEL", "CLASS"], var_name="BAND",value_name="SPECTRAL REFLECTANCE")
cluster_class_signature_corn_grassland = pd.merge(cluster_class_signature_corn_grassland, healthy_vegetation_curve, on="BAND")
print(cluster_class_signature_corn_grassland)

fig, axs = plt.subplots(figsize=(10, 6))
axs = sns.pointplot(
    data=cluster_class_signature_corn_grassland,
    x="WAVELENGTH(nm)",
    palette="plasma",
    y="SPECTRAL REFLECTANCE_x",
    hue="CLASS",
    ci="sd",
    xticklabels=50,
    ax = axs)
axs = sns.pointplot(
        data=healthy_vegetation_curve,
        x="WAVELENGTH(nm)",
        y='SPECTRAL REFLECTANCE',
        hue = "SPECTRAL SIGNATURE",
        palette="Greens_r",
        linestyles='--',
        ci="sd",
        xticklabels=50,
        ax = axs)
axs.set_title("Crop-Type Non-Attention Footprints", fontsize=12)
n = 5  # Keeps every 5th label
[l.set_visible(False) for (i,l) in enumerate(axs.xaxis.get_ticklabels()) if i % n != 0]
fig.tight_layout()