In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "2D-epoch=999-val_MAE=0.000-train_MAE=0.289.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_2D_RN18_full-epoch=999-val_MAE=0.000-train_MAE=25.299.ckpt"
    )


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#### 2D Model

hypparams = {
    "dataset": "Perov_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

trainset_full = PerovskiteDataset2d(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

trainloader_full = DataLoader(
    trainset_full, batch_size=len(trainset_full), shuffle=False
)

images, label = next(iter(trainloader_full))


[0.23169845 0.00265788 0.00174048 0.00421168] [3.4151509e-02 3.0193795e-04 9.2120092e-05 9.2122407e-04]
Loaded


In [4]:
import pandas as pd

data = images.numpy().reshape(-1, 4 * 65 * 56)
data = pd.DataFrame(data)

label = pd.DataFrame(label.numpy())
data["target"] = label

feature_names = data.columns[0:-1].to_list()


In [5]:
class Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        x = torch.Tensor(image.values)
        x = image.reshape(-1, 4, 65, 56)
        y_pred = self.model(x)

        return y_pred.detach().numpy()

    def predict(self, image):
        x = torch.Tensor(image.values)
        x = x.reshape(-1, 4, 65, 56)
        y_pred = self.model(x)
        y_pred = y_pred.detach().squeeze(1).numpy()

        return y_pred


In [6]:
import dice_ml
from dice_ml import Dice
from dice_ml.utils import helpers

wModel = Wrapper(model)

data_dice = dice_ml.Data(
    dataframe=data, continuous_features=feature_names, outcome_name="target"
)
model_dice = dice_ml.Model(model=wModel, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)


In [7]:
n = 0

query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances,
    total_CFs=1,
    desired_range=[0.85, 100] if target == "pce" else [1100, 5000],
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 65, 56)
x = np.array(data.iloc[n][0:-1]).reshape(4, 65, 56)

if target == "pce":
    scaler = PerovskiteDataset2d(
        data_dir=data_dir,
        transform=normalize_2d(model.train_mean, model.train_std),
        fold=None,
        split="train",
        no_border=False,
        return_unscaled= False if target == "pce" else True,
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=False,
        ).get_fitted_scaler()

    y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
    y_cf_high = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y = np.round(data.iloc[n]["target"].reshape([-1, 1]), 2)
    y_cf_high = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]), 2,
    )



  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [01:47<00:00, 107.34s/it]100%|██████████| 1/1 [01:47<00:00, 107.34s/it]


In [8]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.8] if target == "pce" else [0, 600]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 65, 56)

if target == "pce":
    y_cf_low = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y_cf_low = np.round(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )



  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [01:43<00:00, 103.38s/it]100%|██████████| 1/1 [01:43<00:00, 103.38s/it]


In [10]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


fig = make_subplots(
    rows=3,
    cols=4,
    vertical_spacing=0.1,
    subplot_titles=(
        format_title("Original", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("High" if target == "pce" else "Thicker"),
        None,
        None,
        None,
        format_title("Low" if target == "pce" else "Thinner"),
        None,
        None,
        None,
    ),
)

fig.add_trace(go.Heatmap(z=x[0], colorscale="gray", showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z=x[1], colorscale="gray", showscale=False), row=1, col=2)
fig.add_trace(go.Heatmap(z=x[2], colorscale="gray", showscale=False), row=1, col=3)
fig.add_trace(go.Heatmap(z=x[3], colorscale="gray", showscale=False), row=1, col=4)


fig.add_trace(
    go.Heatmap(z=cf_high[0], colorscale="gray", showscale=False), row=2, col=1
)
fig.add_trace(
    go.Heatmap(z=cf_high[1], colorscale="gray", showscale=False), row=2, col=2
)
fig.add_trace(
    go.Heatmap(z=cf_high[2], colorscale="gray", showscale=False), row=2, col=3
)
fig.add_trace(
    go.Heatmap(z=cf_high[3], colorscale="gray", showscale=False), row=2, col=4
)


fig.add_trace(go.Heatmap(z=cf_low[0], colorscale="gray", showscale=False), row=3, col=1)
fig.add_trace(go.Heatmap(z=cf_low[1], colorscale="gray", showscale=False), row=3, col=2)
fig.add_trace(go.Heatmap(z=cf_low[2], colorscale="gray", showscale=False), row=3, col=3)
fig.add_trace(go.Heatmap(z=cf_low[3], colorscale="gray", showscale=False), row=3, col=4)

fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)


if target == "pce":
    title = format_title(
        "Perovskite 2D Image Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True PCE: "
        + str(*y[0])
        + " / CF PCE High: "
        + str(*y_cf_high[0])
        + " / CF PCE Low: "
        + str(*y_cf_low[0]),
    )
else:
    title = format_title(
        "Perovskite 2D Image Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True mTH: "
        + str(*y[0]) + "[nm]"
        + " / CF mTH Thick: "
        + str(*y_cf_high[0])
        + " / CF mTH Thin: "
        + str(*y_cf_low[0]),
    )

fig.update_layout(
    title=title,
    title_y=0.97,
    title_x=0.02,
    height=700,
    width=800,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

#fig.write_image("xai/images/"+ target + "/2D_image/2D_image_cf.png", scale=2)

fig.show()
