In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "pce" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
import xgboost as xgb
from torchmetrics import MeanAbsoluteError, MeanSquaredError
from data.augmentations.perov_1d import normalize as normalize_1d
from argparse import ArgumentParser
from os.path import join

import dice_ml
from dice_ml import Dice

from alibi.explainers import CounterFactual

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt"
    )


  from .autonotebook import tqdm as notebook_tqdm


#### Import of model and computation of counterfactuals (higher and lower than response confidence intervall)

In [3]:
from data.augmentations.perov_1d import Normalize1D


val_mse_folds = []
val_mae_folds = []

for fold_nb in range(5):
    print("Fold: ", fold_nb)

    train_mean, train_std = PerovskiteDataset1d(
    data_dir,
    transform=None,
    fold=fold_nb,
    scaler=None,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    ).get_stats()

    trainset = PerovskiteDataset1d(
    data_dir,
    transform=Normalize1D(train_mean, train_std),
    scaler=None,
    fold=fold_nb,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    split="train",
    val=False,
    )

    trainloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False)

    valset = PerovskiteDataset1d(
        data_dir=data_dir,
        transform=normalize_1d(train_mean, train_std),
        scaler=None,
        no_border=False,
        return_unscaled= False if target == "pce" else True,
        fold=fold_nb,
        split="train",
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=True,
    )
    
    valloader = DataLoader(valset, batch_size=len(valset), shuffle=False)

    for timeseries, label in trainloader:

        time = timeseries.numpy().reshape(-1, 4 * 719)  # [:,3,:]

        xgbr = xgb.XGBRegressor(
            n_estimators=100,
            tree_method="gpu_hist",
            n_jobs=20,
            max_depth=6,
            learning_rate=0.3 if target == "pce" else 0.15,
            booster="gbtree",
            num_parallel_tree=1,
            objective="reg:squarederror",
        )
        xgbr.fit(time, label.numpy())

    for timeseries, label in valloader:

        time = timeseries.numpy().reshape(-1, 4 * 719)  # [:,3,:]
        scores = xgbr.predict(time)

        mse = MeanSquaredError()
        mae = MeanAbsoluteError()

        val_mse = mse(torch.from_numpy(scores), label)
        print("MSE: ", val_mse)
        val_mse_folds.append(val_mse)

        val_mae = mae(torch.from_numpy(scores), label)
        print("MAE:", val_mae)
        val_mae_folds.append(val_mae)


print("Val Mean MSE: ", np.mean(val_mse_folds))
print("Val Mean MAE: ", np.mean(val_mae_folds))



Fold:  0
MSE:  tensor(1.0629)
MAE: tensor(0.7347)
Fold:  1
MSE:  tensor(0.4804)
MAE: tensor(0.4906)
Fold:  2
MSE:  tensor(0.6694)
MAE: tensor(0.5768)
Fold:  3
MSE:  tensor(0.6397)
MAE: tensor(0.5380)
Fold:  4
MSE:  tensor(0.6528)
MAE: tensor(0.5759)
Val Mean MSE:  0.7010284
Val Mean MAE:  0.5831931


In [4]:
trainset_full = PerovskiteDataset1d(
    data_dir,
    transform=Normalize1D(train_mean, train_std),
    scaler=None,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

trainloader_full = DataLoader(
    trainset_full, batch_size=len(trainset_full), shuffle=False
)

xgbr = xgb.XGBRegressor(
    n_estimators=100,
    tree_method="gpu_hist",
    n_jobs=20,
    max_depth=6,
    learning_rate=0.3 if target == "pce" else 0.15,
    booster="gbtree",
    num_parallel_tree=1,
    objective="reg:squarederror",
)

timeseries, label = next(iter(trainloader_full))

time = timeseries.numpy().reshape(-1, 4 * 719)  # [:,3,:]

xgbr.fit(time, label.numpy())

for timeseries, label in valloader:

    time = timeseries.numpy().reshape(-1, 4 * 719)  # [:,3,:]
    scores = xgbr.predict(time)

    mse = MeanSquaredError()
    mae = MeanAbsoluteError()

    val_mse = mse(torch.from_numpy(scores), label)
    print("MSE: ", val_mse)
    val_mse_folds.append(val_mse)

    val_mae = mae(torch.from_numpy(scores), label)
    print("MAE:", val_mae)
    val_mae_folds.append(val_mae)


MSE:  tensor(0.0056)
MAE: tensor(0.0508)


In [5]:
import pandas as pd

data = timeseries.numpy().reshape(-1, 4 * 719)
data = pd.DataFrame(data)

label = pd.DataFrame(label.numpy())
data["target"] = label

feature_names = data.columns[0:-1].to_list()


data_dice = dice_ml.Data(
    dataframe=data, continuous_features=feature_names, outcome_name="target"
)


In [6]:
import dice_ml
from dice_ml.utils import helpers

model_dice = dice_ml.Model(model=xgbr, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)


In [7]:
n = 0

query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances,
    total_CFs=1,
    desired_range=[0.9, 100] if target == "pce" else [1100, 5000],
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 719)
x = np.array(data.iloc[n][0:-1]).reshape(4, 719)

if target == "pce":
    scaler = PerovskiteDataset1d(
        data_dir=data_dir,
        transform=normalize_1d(train_mean, train_std),
        fold=None,
        split="train",
        no_border=False,
        return_unscaled= False if target == "pce" else True,
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=False,
        ).get_fitted_scaler()

    y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
    y_cf_high = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y = np.round(data.iloc[n]["target"].reshape([-1, 1]), 2)
    y_cf_high = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]), 2,
    )



  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:07<00:00,  7.40s/it]100%|██████████| 1/1 [00:07<00:00,  7.40s/it]


In [8]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.9] if target == "pce" else [0, 600]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 719)

if target == "pce":
    y_cf_low = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y_cf_low = np.round(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )


  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:07<00:00,  7.14s/it]100%|██████████| 1/1 [00:07<00:00,  7.14s/it]


#### Plot

In [9]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    subplot_titles=(
        format_title("High" if target == "pce" else "Thicker", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("Low" if target == "pce" else "Thinner"),
        None,
        None,
        None,
    ),
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(y=cf_high[0], name="cf ND", marker_color="#042940", showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(y=cf_high[1], name="cf LP725", marker_color="#005C53", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(y=cf_high[2], name="cf LP780", marker_color="#9FC131", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=4,
)
fig.add_trace(
    go.Scatter(
        y=cf_high[3], name="cf NSP775D", marker_color="#DBF227", showlegend=False
    ),
    row=1,
    col=4,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(y=cf_low[0], name="cf ND", marker_color="#042940", showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(y=cf_low[1], name="cf LP725", marker_color="#005C53", showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(y=cf_low[2], name="cf LP780", marker_color="#9FC131", showlegend=False),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Scatter(
        y=cf_low[3], name="cf NSP775D", marker_color="#DBF227", showlegend=False
    ),
    row=2,
    col=4,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_xaxes(title="Timesteps", row=2, col=1)

if target == "pce":
    title = format_title(
        "Perovskite 1D Boosted Regression Tree Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True PCE: "
        + str(*y[0])
        + " / CF PCE High: "
        + str(*y_cf_high[0])
        + " / CF PCE Low: "
        + str(*y_cf_low[0]),
    )
else:
    title = format_title(
        "Perovskite 1D Boosted Regression Tree Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True mean Thickness (mTH): "
        + str(*y[0]) + "[nm]"
        + " / CF mTH Thick: "
        + str(*y_cf_high[0])
        + " / CF mTH Thin: "
        + str(*y_cf_low[0]),
    )

fig.update_layout(
    title=title,
    legend_title=None,
    title_y=0.95,
    title_x=0.035,
    template="plotly_white",
    height=400,
    width=1200,
)

#fig.write_image("xai/images/"+ target + "/1D/1D_cf.png", scale=2)

fig.show()


## PyTorch 1D Model

In [10]:
#### 1D Model

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "stochastic_depth": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

model.eval()
print("Loaded")


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [11]:
class Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        x = torch.Tensor(image.values)
        x = image.reshape(-1, 4, 719)
        y_pred = self.model(x)

        return y_pred.detach().numpy()

    def predict(self, image):
        x = torch.Tensor(image.values)
        x = x.reshape(-1, 4, 719)
        y_pred = self.model(x)
        y_pred = y_pred.detach().squeeze(1).numpy()

        return y_pred


In [12]:
import dice_ml
from dice_ml.utils import helpers

wModel = Wrapper(model)

model_dice = dice_ml.Model(model=wModel, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)


In [13]:
n = 0

query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances,
    total_CFs=1,
    desired_range=[0.9, 100] if target == "pce" else [1100, 5000],
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 719)
x = np.array(data.iloc[n][0:-1]).reshape(4, 719)

if target == "pce":
    scaler = PerovskiteDataset1d(
        data_dir=data_dir,
        transform=normalize_1d(train_mean, train_std),
        fold=None,
        split="train",
        no_border=False,
        return_unscaled= False if target == "pce" else True,
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=False,
        ).get_fitted_scaler()

    y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
    y_cf_high = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y = np.round(data.iloc[n]["target"].reshape([-1, 1]), 2)
    y_cf_high = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]), 2,
    )


  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:11<00:00, 11.60s/it]100%|██████████| 1/1 [00:11<00:00, 11.60s/it]


In [14]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.9] if target == "pce" else [0, 600]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 719)

if target == "pce":
    y_cf_low = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y_cf_low = np.round(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )

  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [00:11<00:00, 11.51s/it]100%|██████████| 1/1 [00:11<00:00, 11.51s/it]


In [16]:
import plotly.graph_objects as go


def format_title(title, subtitle=None, subtitle_font_size=14):
    title = f"<b>{title}</b>"
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=4,
    subplot_titles=(
        format_title("High" if target == "pce" else "Thicker", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("Low" if target == "pce" else "Thinner"),
        None,
        None,
        None,
    ),
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(y=cf_high[0], name="cf ND", marker_color="#042940", showlegend=False),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(y=cf_high[1], name="cf LP725", marker_color="#005C53", showlegend=False),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(y=cf_high[2], name="cf LP780", marker_color="#9FC131", showlegend=False),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=1,
    col=4,
)
fig.add_trace(
    go.Scatter(
        y=cf_high[3], name="cf NSP775D", marker_color="#DBF227", showlegend=False
    ),
    row=1,
    col=4,
)

fig.add_trace(
    go.Scatter(y=x[0], name="ND", marker_color="grey", opacity=0.3, showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(y=cf_low[0], name="cf ND", marker_color="#042940", showlegend=False),
    row=2,
    col=1,
)
fig.add_trace(
    go.Scatter(
        y=x[1], name="LP725", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(y=cf_low[1], name="cf LP725", marker_color="#005C53", showlegend=False),
    row=2,
    col=2,
)
fig.add_trace(
    go.Scatter(
        y=x[2], name="LP780", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(y=cf_low[2], name="cf LP780", marker_color="#9FC131", showlegend=False),
    row=2,
    col=3,
)
fig.add_trace(
    go.Scatter(
        y=x[3], name="SP775", marker_color="grey", opacity=0.3, showlegend=False
    ),
    row=2,
    col=4,
)
fig.add_trace(
    go.Scatter(
        y=cf_low[3], name="cf NSP775D", marker_color="#DBF227", showlegend=False
    ),
    row=2,
    col=4,
)

fig.update_yaxes(title=None)
fig.update_yaxes(title_text=None)
fig.update_xaxes(title=None)

fig.update_yaxes(title="Intensity", row=1, col=1)
fig.update_yaxes(title="Intensity", row=2, col=1)
fig.update_xaxes(title="Timesteps", row=2, col=1)

if target == "pce":
    title = format_title(
        "Perovskite 1D Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True PCE: "
        + str(*y[0])
        + " / CF PCE High: "
        + str(*y_cf_high[0])
        + " / CF PCE Low: "
        + str(*y_cf_low[0]),
    )
else:
    title = format_title(
        "Perovskite 1D Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True mean Thickness (mTH): "
        + str(*y[0]) + "[nm]"
        + " / CF mTH Thick: "
        + str(*y_cf_high[0])
        + " / CF mTH Thin: "
        + str(*y_cf_low[0]),
    )

fig.update_layout(
    title=title,
    legend_title=None,
    title_y=0.95,
    title_x=0.035,
    template="plotly_white",
    height=400,
    width=1200,
)

#fig.write_image("xai/images/"+ target + "/1D/1D_cf_nn.png", scale=2)

fig.show()
