In [1]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[1])
os.getcwd()

'/home/l727n/Projects/Applied Projects/ml_perovskite'

In [2]:
import torch
import torch.nn as nn
import numpy as np
import kaleido
from torch.utils.data import DataLoader
from data.perovskite_dataset import (
    PerovskiteDataset1d,
    PerovskiteDataset2d,
    PerovskiteDataset3d,
    PerovskiteDataset2d_time,
)
from models.resnet import ResNet152, ResNet, BasicBlock, Bottleneck
from models.slowfast import SlowFast
from data.augmentations.perov_1d import normalize
from data.augmentations.perov_2d import normalize as normalize_2d
from data.augmentations.perov_3d import normalize as normalize_3d
from base_model import seed_worker
from argparse import ArgumentParser
from os.path import join

data_dir = "/home/l727n/Projects/Applied Projects/ml_perovskite/preprocessed"


checkpoint_dir_pce = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/checkpoints"

path_to_checkpoint_pce = join(
    checkpoint_dir_pce, "1D-epoch=999-val_MAE=0.000-train_MAE=0.490.ckpt"
    )

checkpoint_dir_mTh = "/home/l727n/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/KIT-FZJ_2021_Perovskite/data_Jan_2022/mT_checkpoints/checkpoints"

path_to_checkpoint_mTh = join(
    checkpoint_dir_mTh, "mT_1D_RN152_full-epoch=999-val_MAE=0.000-train_MAE=40.332.ckpt"
    )

#### 1D Model PCE

hypparams_pce = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target": True,
    "target": "PCE_mean"
}

model_pce = ResNet.load_from_checkpoint(
    path_to_checkpoint_pce,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams_pce,
)

print("Loaded")
model_pce.eval()

train_set_pce = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_pce.train_mean, model_pce.train_std),
    scaler=model_pce.scaler,
    no_border=False,
    return_unscaled=True,
    label="PCE_mean",
)

loader_pce = DataLoader(
    train_set_pce,
    batch_size=len(train_set_pce),
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

# mTH

hypparams_mTh = {
    "dataset": "Perov_1d",
    "dims": 1,
    "bottleneck": False,
    "name": "ResNet152",
    "data_dir": data_dir,
    "no_border": False,
    "resnet_dropout": 0.0,
    "norm_target":  False,
    "target": "meanThickness"
}

model_mTh = ResNet.load_from_checkpoint(
    path_to_checkpoint_mTh,
    block=BasicBlock,
    num_blocks=[4, 13, 55, 4],
    num_classes=1,
    hypparams=hypparams_pce,
)

print("Loaded")
model_mTh.eval()

train_set_mTh = PerovskiteDataset1d(
    data_dir,
    transform=normalize(model_mTh.train_mean, model_mTh.train_std),
    scaler=model_mTh.scaler,
    no_border=False,
    return_unscaled= True,
    label= "meanThickness",
)


loader_mTh = DataLoader(
    train_set_mTh,
    batch_size=len(train_set_mTh),
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    worker_init_fn=seed_worker,
    persistent_workers=True,
)

  from .autonotebook import tqdm as notebook_tqdm


tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded
tensor([0.2697, 0.0191, 0.0057, 0.0216]) tensor([0.1589, 0.0106, 0.0030, 0.0145])
Loaded


In [3]:
# Select observation
n = 0

x_batch_pce, y_batch_pce = next(iter(loader_pce))
x = x_batch_pce[n]
y_pce = np.round(y_batch_pce.detach().numpy(), 2)

x_batch_mTh, y_batch_mTh = next(iter(loader_mTh))
y_mTh = np.round(y_batch_mTh.detach().numpy(), 2)

In [20]:
### Small TS Graphic

import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()

color = ["#0059A0","#5F3893","#FF8777","#E1462C"]

for i in range(4):
    fig.add_traces(go.Scatter(y=x[i],marker_color=color[i],line=dict(width=4)))

fig.update_layout(
    template="plotly_white",
    height=300,
    width=300,
    showlegend=False,
    
)

fig.update_yaxes(title=None, showticklabels=False)
fig.update_xaxes(title=None, showticklabels=False, zeroline = False)
fig.show()

In [7]:
### Pressure curve

import plotly.graph_objects as go
import pandas as pd

df = np.array(pd.read_csv(r'xai/results/examplary_pressure_curve.csv')).squeeze()

fig = go.Figure()

color = ["#0059A0","#5F3893","#FF8777","#E1462C"]

fig.add_trace(go.Scatter(y=df, name="ND",marker_color=color[2], line=dict(width=3)))

fig.update_yaxes(title=" ", showticklabels = False)
fig.update_yaxes(zeroline=False,title_text=" ", showticklabels = True, showgrid=True,type="log", tickfont= dict(size=14, family="Helvetica", color="rgb(0,0,0)")
)
fig.update_xaxes(zeroline=False, title="Timesteps", showgrid=False,tickvals=[0,100,200,300,400,500,600,700], tickfont= dict(size=14, family="Helvetica", color="rgb(0,0,0)"))

fig.update_layout(
    showlegend=False,
    bargap=0,
    bargroupgap = 0,
    legend_title="Wavelength",
    template="plotly_white",
    height=300,
    width=700,
)

fig.write_image("xai/images/"+ target + "/1D/pressure_curve.png", scale=2)
fig.show()

In [9]:
reg = model.fit(y_mTh.reshape(-1, 1), y_pce)
reg_line = reg.predict(y_mTh.reshape(-1, 1))
reg_line.shape

(780,)

In [12]:
sort = np.sort(y_mTh)
sort

array([ 446.69,  462.17,  499.51,  515.35,  528.34,  532.77,  535.14,
        535.28,  545.53,  547.95,  548.11,  548.97,  549.59,  550.42,
        552.54,  553.  ,  553.82,  554.4 ,  554.71,  554.75,  555.93,
        556.69,  560.93,  564.58,  566.28,  567.93,  569.51,  570.14,
        571.29,  572.35,  573.6 ,  575.62,  577.41,  579.53,  581.12,
        583.05,  583.16,  584.26,  584.75,  585.23,  585.32,  587.43,
        588.18,  588.19,  590.11,  592.02,  593.42,  593.74,  593.93,
        594.03,  594.92,  595.37,  596.28,  597.9 ,  598.13,  598.62,
        598.63,  599.41,  600.04,  600.07,  600.58,  601.12,  601.72,
        603.76,  604.35,  604.57,  607.14,  607.36,  607.88,  608.42,
        608.5 ,  608.62,  611.68,  611.71,  611.77,  611.93,  613.04,
        613.82,  614.14,  614.28,  614.79,  614.92,  615.55,  615.84,
        616.45,  617.25,  617.32,  617.84,  618.89,  619.74,  621.24,
        622.74,  622.76,  622.76,  623.35,  623.49,  623.86,  624.16,
        624.64,  626

In [19]:
### mTh vs PCE Plot
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
import plotly.graph_objects as go
import pandas as pd

model = Pipeline([('poly', PolynomialFeatures(degree=2)), ('linear', LinearRegression(fit_intercept=False))])
df = pd.DataFrame({'X': y_mTh, 'Y':y_pce})


reg = model.fit(np.vstack(df['X']), y_pce)
reg_line = reg.predict(np.sort(y_mTh).reshape(-1, 1))

color = ["#E1462C", "#0059A0", "#5F3893", "#FF8777","#0A2C6E", "#CEDEEB"]
fig = go.Figure()

fig.add_trace(go.Scatter(x = y_mTh, y= y_pce, mode='markers', marker_color = color[1]))
fig.add_trace(go.Scatter(x=np.sort(y_mTh), y=reg_line, mode='lines',line=dict(width=3), marker_color = color[3]))

fig.update_yaxes(title_text=" ", showticklabels = True, zeroline = False,linewidth=3, showline = True, showgrid=True, tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)")
)
fig.update_xaxes(title_text=" ", showline = True, showgrid=True, linewidth=3,range = [400,1700],tickvals = [400,800,1200,1600], tickfont= dict(size=12, family="Helvetica", color="rgb(0,0,0)"))

fig.update_layout(
    showlegend=False,
    legend_title=" ",
    template="plotly_white",
    height=500,
    width=500,
)

fig.write_image("xai/images/pce/1D/pcevsmth.png", scale=4)

fig.show()