In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

'/home/l727n/Projects/Applied Projects/ml_perovskite'

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"
checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

path_to_checkpoint = join(
    checkpoint_dir, "3D-epoch=999-val_MAE=0.000-train_MAE=0.360.ckpt"
)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
#### 3D Model

hypparams = {
    "dataset": "Perov_3d",
    "dims": 3,
    "bottleneck": False,
    "name": "SlowFast",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
}

model = SlowFast.load_from_checkpoint(
    path_to_checkpoint, num_classes=1, hypparams=hypparams
)

print("Loaded")
model.eval()

train_mean, train_std = PerovskiteDataset3d(
    data_dir=data_dir, transform=None, fold=None, split="train", label="PCE_mean"
).get_stats()

trainset_full = PerovskiteDataset3d(
    data_dir=data_dir,
    transform=normalize_3d(train_mean, train_std),
    fold=None,
    split="train",
    label="PCE_mean",
    val=False,
)

trainloader_full = DataLoader(
    trainset_full, batch_size=len(trainset_full), shuffle=False
)

images, label = next(iter(trainloader_full))

scaler = PerovskiteDataset3d(
    data_dir=data_dir,
    transform=normalize_3d(train_mean, train_std),
    fold=None,
    split="train",
    label="PCE_mean",
    val=False,
).get_fitted_scaler()

tensor([0.2687, 0.0188, 0.0056, 0.0215]) tensor([0.1645, 0.0109, 0.0030, 0.0147])
Loaded


In [6]:
import pandas as pd

data = images.numpy().reshape(-1, 4 * 36 * 65 * 56)
data = pd.DataFrame(data)

label = pd.DataFrame(label.numpy())
data["target"] = label

feature_names = data.columns[0:-1].to_list()

In [7]:
class Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        x = torch.Tensor(image.values)
        x = image.reshape(-1, 4, 36, 65, 56)
        y_pred = self.model(x)
        
        return y_pred.detach().numpy()

    def predict(self, image):
        x = torch.Tensor(image.values)
        x = x.reshape(-1, 4, 36, 65, 56)
        y_pred = self.model(x)
        y_pred= y_pred.detach().squeeze(1).numpy()
        
        return y_pred

In [8]:
scaler = PerovskiteDataset3d(
    data_dir=data_dir,
    transform=normalize_3d(train_mean, train_std),
    fold=None,
    split="train",
    label="PCE_mean",
    val=False,
).get_fitted_scaler()

In [9]:
import dice_ml
from dice_ml import Dice
from dice_ml.utils import helpers

wModel = Wrapper(model)

data_dice = dice_ml.Data(
    dataframe=data, continuous_features=feature_names, outcome_name="target"
)
model_dice = dice_ml.Model(model=wModel, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)

In [None]:
n = 0
query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[0.85, 100]
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 36, 65, 56)
x = np.array(data.iloc[n][0:-1]).reshape(4, 36, 65, 56)
y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
y_cf_high = np.round(
    scaler.inverse_transform(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
    ),
    2,
)

In [None]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.8]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 36, 65, 56)
y_cf_low = np.round(
    scaler.inverse_transform(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])
    ),
    2,
)

In [None]:
import numpy as np
import plotly.graph_objs as go
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from PIL import Image
import os
import glob

def format_title(title, subtitle=None, font_size=16, subtitle_font_size=14):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"

N = images[0][0].shape[0]

for i in range(0, N):
    fig1 = make_subplots(
        rows=3,
        cols=4,
        vertical_spacing=0.1,
    subplot_titles=(
        format_title("Original", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("High"),
        None,
        None,
        None,
        format_title("Low"),
        None,
        None,
        None,
    ),
    )

    fig1.add_trace(
        go.Heatmap(z=x[0][0][i], colorscale="gray", showscale=False), row=1, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=x[0][1][i], colorscale="gray", showscale=False), row=1, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=x[0][2][i], colorscale="gray", showscale=False), row=1, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=x[0][3][i], colorscale="gray", showscale=False), row=1, col=4
    )


    fig1.add_trace(
        go.Heatmap(z=cf_high[0][0][i], colorscale="gray", showscale=False), row=2, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[0][1][i], colorscale="gray", showscale=False), row=2, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[0][2][i], colorscale="gray", showscale=False), row=2, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=cf_high[0][3][i], colorscale="gray", showscale=False), row=2, col=4
    )


    fig1.add_trace(
        go.Heatmap(z=cf_low[0][0][i], colorscale="gray", showscale=False), row=3, col=1
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[0][1][i], colorscale="gray", showscale=False), row=3, col=2
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[0][2][i], colorscale="gray", showscale=False), row=3, col=3
    )
    fig1.add_trace(
        go.Heatmap(z=cf_low[0][3][i], colorscale="gray", showscale=False), row=3, col=4
    )


    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=1,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=2,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=3,
    )

    fig1.update_yaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )
    fig1.update_xaxes(
        showticklabels=False,
        showline=True,
        linewidth=0.5,
        linecolor="grey",
        mirror=True,
        col=4,
    )

    fig1.update_layout(
        title=format_title(
            "Perovskite 2D Time Model",
            "Counterfactual Explanation ("
            + str(cf_methode)
            + ") / True PCE: "
            + str(*y[0])
            + " / CF PCE High: "
            + str(*y_cf_high[0])
            + " / CF PCE Low: "
            + str(*y_cf_low[0])
            + f" / Frame {i}",
        ),
        template="plotly_white",
        title_y=0.97,
        title_x=0.095,
        height=700,
        width=800,
    )

    fig1.write_image("xai/images/3D/frame_0" + str(i) + "_.png", scale=2)


imgs = (
    Image.open(f)
    for f in sorted(glob.glob("xai/images/3D/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="xai/images/3D/cf_3D.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=400,
    loop=0,
)

for i in range(0, N):
    os.remove("xai/images/3D/frame_0" + str(i) + "_.png")
