In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth" # mth, pce

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"

if target == "pce":
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "3D-epoch=999-val_MAE=0.000-train_MAE=0.360.ckpt"
    )
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(
        checkpoint_dir, "mT_3D_SF_full-epoch=999-val_MAE=0.000-train_MAE=20.877.ckpt" 
    )

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#### 3D Model

hypparams = {
    "dataset": "Perov_3d",
    "dims": 3,
    "bottleneck": False,
    "name": "SlowFast",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness"
}

model = SlowFast.load_from_checkpoint(
    path_to_checkpoint, num_classes=1, hypparams=hypparams
)

print("Loaded")
model.eval()

trainset_full = PerovskiteDataset3d(
    data_dir,
    transform=normalize_3d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled= False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

trainloader_full = DataLoader(
    trainset_full, batch_size=len(trainset_full), shuffle=False
)

images, label = next(iter(trainloader_full))


tensor([0.2687, 0.0188, 0.0056, 0.0215]) tensor([0.1645, 0.0109, 0.0030, 0.0147])
Loaded


In [18]:
import pandas as pd

data = images.numpy().reshape(-1, 4 * 36 * 65 * 56)
data = pd.DataFrame(data)

label = pd.DataFrame(label.numpy())
data["target"] = label

feature_names = data.columns[0:-1].to_list()


In [19]:
class Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        x = torch.Tensor(image.values)
        x = image.reshape(-1, 4, 36, 65, 56)
        y_pred = self.model(x)

        return y_pred.detach().numpy()

    def predict(self, image):
        x = torch.Tensor(image.values)
        x = x.reshape(-1, 4, 36, 65, 56)
        y_pred = self.model(x)
        y_pred = y_pred.detach().squeeze(1).numpy()

        return y_pred


In [20]:
import dice_ml
from dice_ml import Dice
from dice_ml.utils import helpers

wModel = Wrapper(model)

data_dice = dice_ml.Data(
    dataframe=data, continuous_features=feature_names, outcome_name="target"
)
model_dice = dice_ml.Model(model=wModel, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)


In [None]:
n = 0

query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances,
    total_CFs=1,
    desired_range=[0.9, 100] if target == "pce" else [1300, 5000],
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(
    4, 36, 65, 56
)
x = np.array(data.iloc[n][0:-1]).reshape(4, 36, 65, 56)

if target == "pce":
    scaler = PerovskiteDataset3d(
        data_dir=data_dir,
        transform=normalize_3d(model.train_mean, model.train_std),
        fold=None,
        split="train",
        no_border=False,
        return_unscaled= False if target == "pce" else True,
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=False,
        ).get_fitted_scaler()

    y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
    y_cf_high = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y = np.round(data.iloc[n]["target"].reshape([-1, 1]), 2)
    y_cf_high = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]), 2,
    )


  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [29:10:27<00:00, 105027.38s/it]100%|██████████| 1/1 [29:10:27<00:00, 105027.38s/it]


In [None]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.9] if target == "pce" else [0, 700]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 36, 65, 56)

if target == "pce":
    y_cf_low = np.round(
        scaler.inverse_transform(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
        ),
        2,
    )
else:
    y_cf_low = np.round(
            np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )




  0%|          | 0/1 [00:00<?, ?it/s]100%|██████████| 1/1 [29:25:57<00:00, 105957.45s/it]100%|██████████| 1/1 [29:25:57<00:00, 105957.46s/it]


In [None]:
dist_high = np.sqrt(np.square(x - cf_high).sum((0, 2, 3)))
dist_low = np.sqrt(np.square(x - cf_low).sum((0, 2, 3)))


In [10]:
np.savez(
    "./xai/results/" + target + "_cf_3D_results.npz",
    x,
    y,
    cf_high,
    y_cf_high,
    cf_low,
    y_cf_low,
    dist_high,
    dist_low,
)


In [4]:
data = np.load("./xai/results/" + target + "_cf_3D_results.npz")

x = data['arr_0']
y = data['arr_1']
cf_high = data['arr_2']
y_cf_high = data['arr_3']
cf_low = data['arr_4']
y_cf_low = data['arr_5']
dist_high = data['arr_6']
dist_low = data['arr_7']

In [17]:
y_cf_high

array([[1262.09]])

In [14]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def format_title(title, subtitle=None, font_size=16, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

cd = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]
frames = [0,6,12,30]
wl = 3

fig1 = make_subplots(
    rows=3,
    cols=6,
    vertical_spacing=0.1, 
    specs=[[{}, {}, {}, {}, {}, {}],
        [{},{},{},{},{"colspan": 2}, None],
        [{},{},{},{},{"colspan": 2}, None]],
)

bars = [cd[5]] * 36

bars[frames[0]] = "#E3AF5F"
bars[frames[1]] = "#E3AF5F"
bars[frames[2]] = "#E3AF5F"
bars[frames[3]] = "#E3AF5F"


for i in range(4):
    fig1.add_trace(
        go.Heatmap(z=x[wl][frames[i]], colorscale="gray", showscale=False), row=1, col=i+1
    )

for i in range(4):
    fig1.add_trace(
        go.Heatmap(z=cf_high[wl][frames[i]], colorscale="gray", showscale=False), row=2, col=i+1
    )

for i in range(4):
    fig1.add_trace(
        go.Heatmap(z=cf_low[wl][frames[i]], colorscale="gray", showscale=False), row=3, col=i+1
    )


fig1.add_trace(
    go.Bar(
        y=np.sqrt(np.square(x - cf_high).sum(( 2, 3)))[wl],
        marker_color=bars,
        opacity=1,
        showlegend=False,
        marker_line_width=0,
    ),
    row=2,
    col=5,
)

fig1.add_trace(
    go.Bar(
        y=np.sqrt(np.square(x - cf_low).sum(( 2, 3)))[wl],
        marker_color=bars,
        opacity=1,
        showlegend=False,
        marker_line_width=0,
    ),
    row=3,
    col=5,
)

for i in range(4):
    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=i+1,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=i+1,
    )

fig1.update_yaxes(tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"),range = [0,202], col= 5)
fig1.update_xaxes(
   ticktext = [0,0,120,240,600,719], tickvals=[0,frames[0],frames[1],frames[2],frames[3], 35], tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"), col=5
)

fig1.update_layout(
    template="plotly_white",
    title_y=0.97,
    title_x=0.085,
    height=700,
    width=1120,
)

fig1.write_image("xai/images/"+ target + "/3D/3D_cf_paper.png", scale=2)
fig1.show()


In [15]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from PIL import Image
import os
import glob


def format_title(title, subtitle=None, font_size=16, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


N = images[0][0].shape[0]

for i in range(0, N):
    fig1 = make_subplots(
        rows=3,
        cols=5,
        vertical_spacing=0.1,
        subplot_titles=(
            format_title("Original", "ND"),
            format_title("", "LP725"),
            format_title("", "LP780"),
            format_title("", "SP775"),
            format_title("", "\u2225" + "\u03B4" + "\u2225" + "\u2082"),
            format_title("High" if target == "pce" else "Thicker"),
            None,
            None,
            None,
            None,
            format_title("Low" if target == "pce" else "Thinner"),
            None,
            None,
            None,
            None,
        ),
    )

    fig1.add_trace(
        go.Heatmap(z=x[0][i], colorscale="gray", showscale=False), row=1, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=x[1][i], colorscale="gray", showscale=False), row=1, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=x[2][i], colorscale="gray", showscale=False), row=1, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=x[3][i], colorscale="gray", showscale=False), row=1, col=4
    )

    fig1.add_trace(
        go.Heatmap(z=cf_high[0][i], colorscale="gray", showscale=False), row=2, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[1][i], colorscale="gray", showscale=False), row=2, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[2][i], colorscale="gray", showscale=False), row=2, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[3][i], colorscale="gray", showscale=False), row=2, col=4
    )

    fig1.add_trace(
        go.Heatmap(z=cf_low[0][i], colorscale="gray", showscale=False), row=3, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[1][i], colorscale="gray", showscale=False), row=3, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[2][i], colorscale="gray", showscale=False), row=3, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[3][i], colorscale="gray", showscale=False), row=3, col=4
    )

    fig1.add_trace(
        go.Bar(
            y=dist_high,
            marker_color=np.where(dist_high == dist_high[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=2,
        col=5,
    )

    fig1.add_trace(
        go.Bar(
            y=dist_low,
            marker_color=np.where(dist_low == dist_low[i], "red", "#042940"),
            opacity=0.5,
            showlegend=False,
            marker_line_width=0,
        ),
        row=3,
        col=5,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )

    fig1.update_xaxes(
        ticktext=["0", "35"], tickvals=[0, 35], tickfont=dict(size=10), col=5
    )
    fig1.update_yaxes(
        ticktext=[np.round(dist_low.max())],
        tickvals=[dist_low.max()],
        tickfont=dict(size=10),
        col=5,
        row=3,
    )
    fig1.update_yaxes(
        ticktext=[np.round(dist_high.max())],
        tickvals=[dist_high.max()],
        tickfont=dict(size=10),
        col=5,
        row=2,
    )

    if target == "pce":
        title = format_title(
            "Perovskite 3D Model",
            "Counterfactual Explanation ("
            + str(cf_methode)
            + ") / True PCE: "
            + str(*y[0])
            + " / CF PCE High: "
            + str(*y_cf_high[0])
            + " / CF PCE Low: "
            + str(*y_cf_low[0]),
        )
    else:
        title = format_title(
            "Perovskite 3D Model",
            "Counterfactual Explanation ("
            + str(cf_methode)
            + ") / True mTH: "
            + str(*y[0]) + "[nm]"
            + " / CF mTH Thick: "
            + str(*y_cf_high[0])
            + " / CF mTH Thin: "
            + str(*y_cf_low[0]),
        )


    fig1.update_layout(
        title=title,
        template="plotly_white",
        title_y=0.97,
        title_x=0.085,
        height=650,
        width=900,
    )

    fig1.write_image("xai/images/" + target + "/3D/frame_0" + str(i) + "_.png", scale=2)


imgs = (
    Image.open(f)
    for f in sorted(glob.glob("xai/images/" + target + "/3D/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="xai/images/" + target + "/3D/3D_cf.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=400,
    loop=0,
)

for i in range(0, N):
    os.remove("xai/images/" + target + "/3D/frame_0" + str(i) + "_.png")
