In [1]:
import torch
import shap
from torch import nn, triangular_solve
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import lightning.pytorch as pl
import os
from os import path
from tqdm import tqdm
import math
from pathlib import Path
from collections import defaultdict
from root import ROOT_DIR
from utils import data_loader_precip, dataset_precip, data_loader_precip, dataset_hybrid
from utils.model_classes import get_model_class
from models import unet_precip_regression_lightning as unet_regr
import glob
NORM = 47.83


In [2]:
class ModelWrapper(nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.split_sizes = [
        12 * 22 * 8,  # Size of tensor1 after flattening
        12 * 64 * 64,  # Size of tensor2 after flattening
        64 * 64,              # Size of tensor3 after flattening
    ]

    def forward(self, input):
        input = torch.tensor(input).to('cuda')
        print(input.shape)
        b = input.shape[0]
        flat_nodes, flat_images, flat_target = torch.split(input, self.split_sizes, dim=1)
        print(flat_nodes.size)
        print(flat_images.size)
        print(flat_target.size)

        input_nodes = flat_nodes.reshape((b, 12, 22, 8))
        input_images = flat_images.reshape((b, 12, 64, 64))
        y_true = flat_target.reshape((b, 64, 64))
        # Forward pass through the model to get predictions
        y_pred = self.model(input_images, input_nodes)

        # Compute the loss: MSE loss with per-sample losses
        loss = nn.functional.mse_loss(y_pred.squeeze(), y_true.squeeze(), reduction="none")
        mean_loss = loss.mean(dim=1)
        return mean_loss.cpu().detach().numpy()

In [3]:
dataset = dataset_hybrid.precipitation_maps_h5_nodes(
    in_file= ROOT_DIR / "data" / "precipitation" / "Node_2016-2019.h5",
    num_input_images=12,
    num_output_images=6, 
    train=False,
    return_timestamp = False)

indices = np.random.choice(len(dataset), 100, replace=False)
subset = Subset(dataset, indices)

input_images = []
input_nodes = []
target_images = []

for i in subset:
    input_img, input_2, target_img, target_2 = i
    input_images.append(torch.tensor(input_img, dtype=torch.float32).to('cuda'))
    input_nodes.append(torch.tensor(input_2).to('cuda'))
    target_images.append(torch.tensor(target_img, dtype=torch.float32).to('cuda'))

input_images = torch.stack(input_images)
input_nodes = torch.stack(input_nodes)
target_images = torch.stack(target_images)
print(input_images.shape)

torch.Size([100, 12, 64, 64])


In [4]:
model_name = "Bridge"
model_folder = ROOT_DIR / "comparison" / model_name
model, model_name = get_model_class(model_name)
models = [m for m in os.listdir(model_folder) if "small" in m]
model_file = models[0]
model = model.load_from_checkpoint(f"{model_folder}/{model_file}")

model = ModelWrapper(model)
model.eval()
#FLATTEN INPUT
flat_nodes = input_nodes.view(input_nodes.size(0), -1)  # Flatten from the 2nd dimension onward
flat_images = input_images.view(input_images.size(0), -1)
flat_target = target_images.view(target_images.size(0), -1)

flattened = torch.cat((flat_nodes, flat_images, flat_target), axis=1)[:5]

# Create the masker for input_nodes
explainer = shap.KernelExplainer(model, flattened[:2].cpu().numpy())

torch.Size([2, 55360])
<built-in method size of Tensor object at 0x7f473232ff60>
<built-in method size of Tensor object at 0x7f473232fc40>
<built-in method size of Tensor object at 0x7f473232fce0>


In [None]:
print(input_nodes.shape)
shap_values = explainer.shap_values(flattened.cpu().numpy()[3:4])
print("SHAP values shape:", shap_values.shape)
node_shap_values = shap_values[:,:2112]
shap.summary_plot(node_shap_values)

torch.Size([100, 12, 22, 8])


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 55360])
<built-in method size of Tensor object at 0x7f47322a6c50>
<built-in method size of Tensor object at 0x7f47336bbb50>
<built-in method size of Tensor object at 0x7f4733983b50>


In [None]:
print(input_nodes.shape)
print(input_images.shape)
print(target_images.shape)