### Imports & Paths

In [None]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[2])
os.getcwd()

target = "mth"  # mth, pce


In [None]:
import torch
import torch.nn as nn
import numpy as np
import scipy
import plotly.graph_objects as go
import dice_ml
import pandas as pd

from torch.utils.data import DataLoader
from torch import Tensor
from tqdm import tqdm
from os.path import join
from plotly.subplots import make_subplots
from dice_ml import Dice

from data.perovskite_dataset import PerovskiteDataset2d_time
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from data.augmentations.perov_2d import normalize as normalize_2d
from base_model import seed_worker

data_dir = os.getcwd() + "/preprocessed"

if target == "pce":
    checkpoint_dir = (
        "/add/path/to/model/checkpoints/"
    )

    path_to_checkpoint = join(checkpoint_dir, "2D_time-epoch=999-val_MAE=0.000-train_MAE=0.725.ckpt")
else:
    checkpoint_dir = "/add/path/to/model/checkpoints/"

    path_to_checkpoint = join(checkpoint_dir, "mT_2Dtime_RN18_full3-epoch=999-val_MAE=0.000-train_MAE=36.879.ckpt")


### Model Init

In [None]:
#### 2D Model

hypparams = {
    "dataset": "Perov_time_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

trainset_full = PerovskiteDataset2d_time(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
)

trainloader_full = DataLoader(trainset_full, batch_size=len(trainset_full), shuffle=False)

images, label = next(iter(trainloader_full))

data = images.numpy().reshape(-1, 4 * 65 * 72)
data = pd.DataFrame(data)

label = pd.DataFrame(label.numpy())
data["target"] = label

feature_names = data.columns[0:-1].to_list()


class Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, image):
        x = torch.Tensor(image.values)
        x = image.reshape(-1, 4, 65, 72)
        y_pred = self.model(x)

        return y_pred.detach().numpy()

    def predict(self, image):
        x = torch.Tensor(image.values)
        x = x.reshape(-1, 4, 65, 72)
        y_pred = self.model(x)
        y_pred = y_pred.detach().squeeze(1).numpy()

        return y_pred


### CF Computation

In [None]:
wModel = Wrapper(model)

data_dice = dice_ml.Data(dataframe=data, continuous_features=feature_names, outcome_name="target")
model_dice = dice_ml.Model(model=wModel, backend="sklearn", model_type="regressor")

cf_methode = "genetic"
methode = Dice(data_dice, model_dice, method=cf_methode)


#### CF Higher

In [None]:
n = 0

query_instances = pd.DataFrame(data.iloc[n][0:-1]).T
genetic = methode.generate_counterfactuals(
    query_instances,
    total_CFs=1,
    desired_range=[0.9, 100] if target == "pce" else [1300, 5000],
)

cf_high = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 65, 72)
x = np.array(data.iloc[n][0:-1]).reshape(4, 65, 72)

if target == "pce":
    scaler = PerovskiteDataset2d_time(
        data_dir=data_dir,
        transform=normalize_2d(model.train_mean, model.train_std),
        fold=None,
        split="train",
        no_border=False,
        return_unscaled=False if target == "pce" else True,
        label="PCE_mean" if target == "pce" else "meanThickness",
        val=False,
    ).get_fitted_scaler()

    y = np.round(scaler.inverse_transform(data.iloc[n]["target"].reshape([-1, 1])), 2)
    y_cf_high = np.round(
        scaler.inverse_transform(np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])),
        2,
    )
else:
    y = np.round(data.iloc[n]["target"].reshape([-1, 1]), 2)
    y_cf_high = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )


#### CF Lower

In [None]:
genetic = methode.generate_counterfactuals(
    query_instances, total_CFs=1, desired_range=[-100, -0.9] if target == "pce" else [0, 700]
)

cf_low = np.array(genetic.cf_examples_list[0].final_cfs_df)[:, 0:-1].reshape(4, 65, 72)

if target == "pce":
    y_cf_low = np.round(
        scaler.inverse_transform(np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1])),
        2,
    )
else:
    y_cf_low = np.round(
        np.array(genetic.cf_examples_list[0].final_cfs_df)[:, -1].reshape([-1, 1]),
        2,
    )


## CF Visualization

In [None]:
def format_title(title, subtitle=None, font_size=16, subtitle_font_size=12):
    title = f'<span style="font-size: {font_size}px;"><b>{title}</b></span>'
    if not subtitle:
        return title
    subtitle = f'<span style="font-size: {subtitle_font_size}px;">{subtitle}</span>'
    return f"{title}<br>{subtitle}"


fig = make_subplots(
    rows=3,
    cols=4,
    vertical_spacing=0.1,
    subplot_titles=(
        format_title("Original", "ND"),
        format_title("", "LP725"),
        format_title("", "LP780"),
        format_title("", "SP775"),
        format_title("High" if target == "pce" else "Thicker"),
        None,
        None,
        None,
        format_title("Low" if target == "pce" else "Thinner"),
        None,
        None,
        None,
    ),
)

for i in range(4):
    fig.add_trace(go.Heatmap(z=x[i], colorscale="gray", showscale=False), row=1, col=i + 1)
    fig.add_trace(go.Heatmap(z=cf_high[i], colorscale="gray", showscale=False), row=2, col=i + 1)
    fig.add_trace(go.Heatmap(z=cf_low[i], colorscale="gray", showscale=False), row=3, col=i + 1)

fig.update_yaxes(showticklabels=False)
fig.update_xaxes(showticklabels=False)


if target == "pce":
    title = format_title(
        "Perovskite 2D Time Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True PCE: "
        + str(*y[0])
        + " / CF PCE High: "
        + str(*y_cf_high[0])
        + " / CF PCE Low: "
        + str(*y_cf_low[0]),
    )
else:
    title = format_title(
        "Perovskite 2D Time Model",
        "Counterfactual Explanation ("
        + str(cf_methode)
        + ") / True mTH: "
        + str(*y[0])
        + "[nm]"
        + " / CF mTH Thick: "
        + str(*y_cf_high[0])
        + " / CF mTH Thin: "
        + str(*y_cf_low[0]),
    )

fig.update_layout(
    title=title,
    title_y=0.97,
    title_x=0.05,
    height=700,
    width=800,
)

fig.update_xaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)
fig.update_yaxes(showline=True, linewidth=0.5, linecolor="grey", mirror=True)

fig.write_image("xai/images/" + target + "/2D_time/2D_time_cf.png", scale=2)

fig.show()
