### Imports & Paths

In [None]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[1])
os.getcwd()

target = "pce"  # mth, pce


In [None]:
import torch
import torch.nn as nn
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = os.getcwd() + "/preprocessed"


## Init Models and Dataloader

#### Init 1D Model

In [None]:
#### 1D Model (no border)

if target == "pce":
    checkpoint_dir = (
        "/add/path/to/model/checkpoints/"
    )

    path_to_checkpoint = join(checkpoint_dir, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt")
else:
    checkpoint_dir = "/add/path/to/model/checkpoints/"

    path_to_checkpoint = join(checkpoint_dir, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt")

hypparams = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "stochastic_depth": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

test_set = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    fold=None,
    split="test",
    val=False,
)


#### Init 2D Model

In [None]:
if target == "pce":
    checkpoint_dir = (
        "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"
    )

    path_to_checkpoint = join(checkpoint_dir, "2D-epoch=999-val_MAE=0.000-train_MAE=0.289.ckpt")
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(checkpoint_dir, "mT_2D_RN18_full-epoch=999-val_MAE=0.000-train_MAE=25.299.ckpt")

#### 2D Model

hypparams = {
    "dataset": "Perov_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "stochastic_depth": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

test_set = PerovskiteDataset2d(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    fold=None,
    split="test",
    val=False,
)


#### Init 2D_time Model

In [None]:
if target == "pce":
    checkpoint_dir = (
        "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"
    )

    path_to_checkpoint = join(checkpoint_dir, "2D_time-epoch=999-val_MAE=0.000-train_MAE=0.725.ckpt")
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(checkpoint_dir, "mT_2Dtime_RN18_full3-epoch=999-val_MAE=0.000-train_MAE=36.879.ckpt")

#### 2D Model

hypparams = {
    "dataset": "Perov_time_2d",
    "dims": 2,
    "bottleneck": False,
    "name": "ResNet18",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "stochastic_depth": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = ResNet.load_from_checkpoint(
    path_to_checkpoint,
    block=BasicBlock,
    num_blocks=[2, 2, 2, 2],
    num_classes=1,
    hypparams=hypparams,
)

print("Loaded")
model.eval()

test_set = PerovskiteDataset2d_time(
    data_dir,
    transform=normalize_2d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    fold=None,
    split="test",
    val=False,
)


#### Init 3D Model

In [None]:
if target == "pce":
    checkpoint_dir = (
        "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"
    )

    path_to_checkpoint = join(checkpoint_dir, "3D-epoch=999-val_MAE=0.000-train_MAE=0.360.ckpt")
else:
    checkpoint_dir = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

    path_to_checkpoint = join(checkpoint_dir, "mT_3D_SF_full-epoch=999-val_MAE=0.000-train_MAE=20.877.ckpt")

#### 3D Model

hypparams = {
    "dataset": "Perov_3d",
    "dims": 3,
    "bottleneck": False,
    "name": "SlowFast",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True if target == "pce" else False,
    "target": "PCE_mean" if target == "pce" else "meanThickness",
}

model = SlowFast.load_from_checkpoint(path_to_checkpoint, num_classes=1, hypparams=hypparams)

print("Loaded")
model.eval()

test_set = PerovskiteDataset3d(
    data_dir,
    transform=normalize_3d(model.train_mean, model.train_std),
    scaler=model.scaler,
    no_border=False,
    return_unscaled=False if target == "pce" else True,
    label="PCE_mean" if target == "pce" else "meanThickness",
    fold=None,
    split="test",
    val=False,
)


In [None]:
batch_size = 32

loader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)


def pytorch_predict(model, test_loader, device):
    """
    Make prediction from a pytorch model
    """
    # set model to evaluate model
    model.eval()

    y_true = torch.tensor([], dtype=torch.long, device=device)
    all_outputs = torch.tensor([], device=device)

    # deactivate autograd engine and reduce memory usage and speed up computations
    with torch.no_grad():
        for data in test_loader:
            inputs = [i.to(device) for i in data[:-1]]
            labels = data[-1].to(device)

            outputs = model(*inputs)
            y_true = torch.cat((y_true, labels), 0)
            all_outputs = torch.cat((all_outputs, outputs), 0)

    y_true = y_true.cpu().numpy()
    y_pred = all_outputs.flatten().cpu().numpy()

    return y_true, y_pred


y_true, y_pred = pytorch_predict(model, loader, "cpu")


### Parity Plot

In [None]:
color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777", "#0A2C6E", "#CEDEEB"]

if target == "pce":
    y_true = test_set.scaler.inverse_transform(y_true.reshape(1, -1))[0]
    y_pred = test_set.scaler.inverse_transform(y_pred.reshape(1, -1))[0]

fig = px.scatter(
    x=y_true, y=y_pred, trendline="lowess", trendline_color_override=color[3], trendline_options=dict(frac=0.9)
)
fig.update_traces(marker=dict(color=color[1]), line=dict(width=3))

fig.add_shape(
    type="line", line=dict(dash="dash", width=3), x0=y_true.min(), y0=y_true.min(), x1=y_true.max(), y1=y_true.max()
)

fig.update_yaxes(
    title_text="Predicted",
    showticklabels=True,
    zeroline=False,
    linewidth=3,
    showline=True,
    showgrid=True,  # range = [-4,2.1],
    tickfont=dict(size=12, family="Helvetica", color="rgb(0,0,0)"),
)
fig.update_xaxes(
    title_text="Ground Truth",
    showline=True,
    showgrid=True,
    linewidth=3,  # range = [-4,4],
    tickfont=dict(size=12, family="Helvetica", color="rgb(0,0,0)"),
)

fig.update_layout(
    showlegend=False,
    template="plotly_white",
    height=400,
    width=400,
    font=dict(family="Helvetica", color="#000000", size=14),
)

fig.write_image("xai/images/" + target + "/1D/parity_plot.png", scale=4)
fig.show()


### Residual Plot

In [None]:
color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777", "#0A2C6E", "#CEDEEB"]

if target == "pce":
    y_true = test_set.scaler.inverse_transform(y_true.reshape(1, -1))[0]
    y_pred = test_set.scaler.inverse_transform(y_pred.reshape(1, -1))[0]

fig = px.scatter(
    x=y_pred,
    y=y_pred - y_true,
)
fig.update_traces(marker=dict(color=color[1]), line=dict(width=3))

fig.add_hline(y=0, line=dict(dash="dash", width=3))

fig.update_yaxes(
    title_text="Residual",
    showticklabels=True,
    range=[-15, 15] if target == "pce" else [-1000, 1000],
    zeroline=False,
    linewidth=3,
    showline=True,
    showgrid=True,  # range = [-4,2.1],
    tickfont=dict(size=12, family="Helvetica", color="rgb(0,0,0)"),
)
fig.update_xaxes(
    title_text="Predicted",
    showline=True,
    showgrid=True,
    zeroline=False,
    linewidth=3,  # range = [-4,4],
    tickfont=dict(size=12, family="Helvetica", color="rgb(0,0,0)"),
)

fig.update_layout(
    showlegend=False,
    template="plotly_white",
    height=400,
    width=400,
    font=dict(family="Helvetica", color="#000000", size=14),
)

fig.write_image("xai/images/" + target + "/3D/residual_plot.png", scale=4)
fig.show()
