In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[1])
os.getcwd()

'/home/l727n/Projects/Applied Projects/ml_perovskite'

In [46]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"


checkpoint_dir_pce = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

path_to_checkpoint_pce = join(
    checkpoint_dir_pce, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
    )

checkpoint_dir_mTh = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

path_to_checkpoint_mTh = join(
    checkpoint_dir_mTh, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt"
    )

#### 1D Model PCE

hypparams_pce = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True,
    "target": "PCE_mean"
}

model_pce = ResNet.load_from_checkpoint(
    path_to_checkpoint_pce,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams_pce,
)

print("Loaded")
model_pce.eval()

train_set_pce = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_pce.train_mean, model_pce.train_std),
    scaler=model_pce.scaler,
    no_border=False,
    return_unscaled=True,
    label="PCE_mean",
)

test_set_pce = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_pce.train_mean, model_pce.train_std),
    scaler=model_pce.scaler,
    no_border=False,
    return_unscaled=True ,
    label="PCE_mean",
    fold=None, 
    split='test',
    val=False 
)

loader_pce = DataLoader(
    train_set_pce,
    batch_size=len(train_set_pce),
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

# mTH

hypparams_mTh = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target":  False,
    "target": "meanThickness"
}

model_mTh = ResNet.load_from_checkpoint(
    path_to_checkpoint_mTh,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams_pce,
)

print("Loaded")
model_mTh.eval()

train_set_mTh = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_mTh.train_mean, model_mTh.train_std),
    scaler=model_mTh.scaler,
    no_border=False,
    return_unscaled= True,
    label= "meanThickness",
)

test_set_mth = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_mTh.train_mean, model_mTh.train_std),
    scaler=model_mTh.train_mean,
    no_border=False,
    return_unscaled=True,
    label="meanThickness",
    fold=None, 
    split='test',
    val=False 
)

loader_mTh = DataLoader(
    train_set_mTh,
    batch_size=len(train_set_mTh),
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

Lightning automatically upgraded your loaded checkpoint from v1.6.3 to v2.0.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints/1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt`


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


Lightning automatically upgraded your loaded checkpoint from v1.6.3 to v2.0.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints/mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt`


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [54]:
test_set_mth.labels.shape

(349,)

In [61]:
import plotly.figure_factory as ff
import numpy as np

hist_data = [train_set_pce.labels, test_set_pce.labels]

group_labels = ['Train', 'Test',]
colors = ["#FF8777", "#0059A0","#E1462C"]

# Create distplot with curve_type set to 'normal'
fig = ff.create_distplot(hist_data, group_labels, colors=colors,
                         bin_size=0.5, show_rug=False)

fig.update_yaxes(title_text=" ", showticklabels = True, zeroline = False,linewidth=3, showline = False, showgrid=True, range = [0,0.25],tickvals = [0,0.1,0.2,0.3], tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)")
)
fig.update_xaxes(title_text=" ", showline = True, showgrid=True, linewidth=3,tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)"))

fig.update_layout(
    template="plotly_white",
    height=400,
    width=650,
    showlegend=False,
    
)

fig.write_image("xai/images/pce/1D/distr_pce.png", scale=2)
fig.show()

In [209]:
# Select observation
n = 0

x_batch_pce, y_batch_pce = next(iter(loader_pce))
x = x_batch_pce[n]
y_pce = np.round(y_batch_pce.detach().numpy(), 2)

x_batch_mTh, y_batch_mTh = next(iter(loader_mTh))
y_mTh = np.round(y_batch_mTh.detach().numpy(), 2)

In [20]:
### Small TS Graphic

import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()

color = ["#0059A0","#5F3893","#FF8777","#E1462C"]

for i in range(4):
    fig.add_traces(go.Scatter(y=x[i],marker_color=color[i],line=dict(width=4)))

fig.update_layout(
    template="plotly_white",
    height=300,
    width=300,
    showlegend=False,
    
)

fig.update_yaxes(title=None, showticklabels=False)
fig.update_xaxes(title=None, showticklabels=False, zeroline = False)
fig.show()

In [216]:
### Concepts

import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import math
from scipy import stats
import plotly.figure_factory as ff

np.random.gumbel(loc=0.0, scale=1.0, size=(1,719))

r = stats.gumbel_r(loc = 2.5, scale = 1).rvs(size=1000000)

fig = go.Figure()

color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]

# fig.add_traces(go.Scatter(y = np.histogram(r, bins=500)[0], x=np.histogram(r, bins=500)[1],marker_color=color[5],line=dict(width=5)))
# fig.add_shape(type='line',
#                 x0=np.min(np.histogram(r, bins=100)[1]),
#                 y0=-1800,
#                 x1=np.max(np.histogram(r, bins=100)[1]),
#                 y1=-1800,
#                 line=dict(color='grey',width=7),
#                 xref='x',
#                 yref='y'
# )

fig.add_traces(go.Scatter(y=x[1],marker_color="lightgrey",line=dict(width=4)))

fig.add_shape(type='line',
                x0=0,
                y0=-2,
                x1=719,
                y1=-2,
                line=dict(color='grey',width=4),
                xref='x',
                yref='y'
)

fig.update_layout(
    template="plotly_white",
    height=300,
    width=300,
    showlegend=False,
    
)

fig.update_yaxes(title=None, showticklabels=False, showgrid = False, zeroline= False)
fig.update_xaxes(title=None, showticklabels=False, showgrid = False, zeroline= False)

fig.write_image("xai/images/pce/1D/concepts.png", scale=2)

fig.show()

In [4]:
### Pressure curve

import plotly.graph_objects as go
import pandas as pd
import numpy as np

df = np.array(pd.read_csv(r'xai/results/examplary_pressure_curve.csv')).squeeze()

fig = go.Figure()

color = ["#0059A0","#5F3893","#FF8777","#E1462C"]

fig.add_trace(go.Scatter(y=df, name="ND",marker_color=color[2], line=dict(width=3), line_shape='spline'))

fig.update_yaxes(title=" ", showticklabels = False)
fig.update_yaxes(zeroline=False,title_text=" ", showticklabels = True, showgrid=True,type="log", tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)")
)
fig.update_xaxes(zeroline=False, title="Timesteps", showgrid=False,tickvals=[0,100,200,300,400,500,600,700], tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"))

fig.update_layout(
    showlegend=False,
    bargap=0,
    bargroupgap = 0,
    legend_title="Wavelength",
    template="plotly_white",
    height=300,
    width=700,
)

fig.write_image("xai/images/pce/1D/pressure_curve.png", scale=2)
fig.show()

In [55]:
### Avg. abs. Attribution

import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()

color = [ "#E1462C", "#FF8777", "#0059A0","#5F3893"]
#v = [159.49,  297.79, 455.65, 248.11]
v = [2.03,  3.91, 3.79, 3.66]
filter = ["ND", "LP725", "LP780", "SP775"]

v.reverse()
filter.reverse()

fig.add_traces(go.Bar(y = filter, x=v, text = v, textposition="outside", textfont=dict(size=16, family="Helvetica", color="rgb(0,0,0)"), marker_color=color, orientation="h"))

fig.update_layout(
    template="plotly_white",
    height=300,
    width=400,
    showlegend=False,
    
)

fig.update_yaxes(title=None, tickfont= dict(size=16, family="Helvetica", color="rgb(0,0,0)"), ticksuffix = "  ")
fig.update_xaxes(title=None, showticklabels=False, zeroline = True,  range = [0,8], showgrid = False)
fig.write_image("xai/images/pce/1D/attr_per_filter.png", scale=4)
fig.show()

In [19]:
### mTh vs PCE Plot
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
import plotly.graph_objects as go
import pandas as pd

model = Pipeline([('poly', PolynomialFeatures(degree=2)), ('linear', LinearRegression(fit_intercept=False))])
df = pd.DataFrame({'X': y_mTh, 'Y':y_pce})


reg = model.fit(np.vstack(df['X']), y_pce)
reg_line = reg.predict(np.sort(y_mTh).reshape(-1, 1))

color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]
fig = go.Figure()

fig.add_trace(go.Scatter(x = y_mTh, y= y_pce, mode='markers', marker_color = color[1]))
fig.add_trace(go.Scatter(x=np.sort(y_mTh), y=reg_line, mode='lines',line=dict(width=3), marker_color = color[3]))

fig.update_yaxes(title_text=" ", showticklabels = True, zeroline = False,linewidth=3, showline = True, showgrid=True, tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)")
)
fig.update_xaxes(title_text=" ", showline = True, showgrid=True, linewidth=3,range = [400,1700],tickvals = [400,800,1200,1600], tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)"))

fig.update_layout(
    showlegend=False,
    legend_title=" ",
    template="plotly_white",
    height=500,
    width=500,
)

fig.write_image("xai/images/pce/1D/pcevsmth.png", scale=4)

fig.show()