In [None]:
# Do this step first, if your work-directory isn't the main folder (nuclide-identification)
import os
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
os.chdir(parent_dir)
print(f"Directory changed to: {os.getcwd()}")


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import torch
import src.measurements.api as mpi
import src.peaks.api as ppi
from src.cnn.training import Training
import json
import mlflow
import os
from config.loader import load_config

plt.rcParams['text.usetex'] = True

In [None]:
os.environ["AWS_ACCESS_KEY_ID"] = load_config()["minio"]["AWS_ACCESS_KEY_ID"]
os.environ["AWS_SECRET_ACCESS_KEY"] = load_config()["minio"]["AWS_SECRET_ACCESS_KEY"]
os.environ["MLFLOW_S3_ENDPOINT_URL"] = load_config()["minio"]["MLFLOW_S3_ENDPOINT_URL"]
model_uri = load_config()["mlflow"]["uri"]
model_name = "CNN_CPU"
model_version = "latest"
mlflow.set_tracking_uri(uri=model_uri)
model = mlflow.pytorch.load_model(f"models:/{model_name}/{model_version}").to("cpu")

client = mlflow.tracking.MlflowClient(
    tracking_uri=load_config()["mlflow"]["uri"]
)
run_id = client.get_latest_versions("CNN_CPU")[0].run_id
run_id = "e78098da07bb482aa6b451bd7c6fc310"
run = client.get_run(run_id)
client.download_artifacts(run_id=run_id, path="artifacts.json", dst_path="tmp")
mlb_classes = run.data.params["mlb_classes"].split(",")

In [None]:
training_macro_loss = client.get_metric_history(run_id=run_id, key="training_macro_loss")
training_micro_loss = client.get_metric_history(run_id=run_id, key="training_micro_loss")
training_mac_loss = []
training_mic_loss = []

validation_macro_loss = client.get_metric_history(run_id=run_id, key="validation_macro_loss")
validation_mac_loss = []
validation_mic_loss = []
for i in range(len(training_macro_loss)):
    macro_loss = training_macro_loss[i].value
    training_mac_loss.append(macro_loss)
    micro_loss = training_micro_loss[i].value
    training_mic_loss.append(micro_loss)

    val_macro_loss = validation_macro_loss[i].value
    validation_mac_loss.append(val_macro_loss)

data_mic_mac_loss = pd.DataFrame([training_mac_loss, training_mic_loss, validation_mac_loss])
data_mic_mac_loss = data_mic_mac_loss.T.rename(columns={0: "training_macro_loss", 1: "training_micro_loss",
                                                        2: "validation_macro_loss"})
data_mic_mac_loss["epoch"] = data_mic_mac_loss.index
data_mic_mac_loss

In [None]:
with open("tmp/artifacts.json") as f:
    artifacts = json.load(f)

data_tpr_fpr = pd.DataFrame()

for idx in range(len(artifacts["validation_tpr"])):
    for nuclide in artifacts["validation_tpr"][idx].keys():
        nuclide_df = pd.DataFrame(artifacts["validation_tpr"][idx][nuclide], columns=["validation_tpr"])
        nuclide_df["validation_fpr"] = artifacts["validation_fpr"][idx][nuclide]
        nuclide_df["nuclide"] = nuclide
        nuclide_df["epoch"] = idx
        data_tpr_fpr = pd.concat([data_tpr_fpr, nuclide_df], axis=0)
data_tpr_fpr = data_tpr_fpr.reset_index(drop=True)

data_auc = pd.DataFrame()

for idx in range(len(artifacts["training_auc"])):
    for nuclide in artifacts["training_auc"][idx].keys():
        nuclide_df = pd.DataFrame([artifacts["training_auc"][idx][nuclide]], columns=["training_auc"])
        nuclide_df["validation_auc"] = artifacts["validation_auc"][idx][nuclide]
        nuclide_df["nuclide"] = nuclide
        nuclide_df["epoch"] = idx
        data_auc = pd.concat([data_auc, nuclide_df], axis=0)
data_auc = data_auc.reset_index(drop=True)

In [None]:
sns.relplot(data=data_auc, x="epoch", y="validation_auc", hue="nuclide", kind="line", palette=sns.color_palette(),
            alpha=0.2)
plt.plot(data_mic_mac_loss["epoch"], data_mic_mac_loss["training_macro_loss"], color="black",
         label="training_macro_loss")

plt.plot(data_mic_mac_loss["epoch"], data_mic_mac_loss["validation_macro_loss"], color="red",
         label="validation_macro_loss")
plt.legend()

In [None]:
re_keys = mpi.API().re_splitted_keys()
re_keys = re_keys.loc[re_keys["type"] == "cnn_validation"]
re_keys = re_keys["datetime"].tolist()
validation_set = ppi.API().re_measurement(re_keys)

In [None]:
def format_isotope(isotope):
    match = re.match(r"([a-zA-Z]+)(\d+)", isotope)
    if match:
        element, mass = match.groups()
        return f"$^{{{mass}}}{element.capitalize()}$"
    else:
        return isotope


def __identify_background(data, wndw: int = 5, scale: float = 1.5) -> np.ndarray:
    data = data.sort_index()
    data["count"] = data["count"].astype(float)
    counts = data["count"].values
    slopes = np.abs(np.diff(counts))
    moving_avg = np.convolve(slopes, np.ones(wndw) / wndw, mode="same")
    threshold = np.mean(moving_avg) * scale
    background_mask = moving_avg < threshold
    background_mask = np.append(
        background_mask, True
    )
    background = np.interp(
        np.arange(len(counts)),
        np.arange(len(counts))[background_mask],
        counts[background_mask],
    )
    return background


plt.rcParams['text.usetex'] = True


def make_subplot(ax, one_meas_processed_measurement, limitation, idx, title, scaler):
    meas_500 = one_meas_processed_measurement.loc[one_meas_processed_measurement["energy"] > 0]
    meas_500["background"] = __identify_background(meas_500)
    meas_500["counts_cleaned"] = meas_500["count"] - meas_500["background"]
    ax.plot(meas_500["energy"], meas_500["count"], label="Gemessene $\gamma$-Spektren", color=sns.color_palette()[0],
            alpha=1,
            linewidth=0.5)

    colors = sns.color_palette("dark")
    color_idx = 0
    colors_dict = {}
    last_x = 0
    already_annotated = False
    for _, row in meas_500[meas_500["peak"] == True].iterrows():
        x = row["energy"]
        label = row["identified_isotope"]
        if label in colors_dict.keys():
            color = colors_dict[label]
            already_annotated = True
        else:
            color = colors[color_idx]
            colors_dict[label] = color
            color_idx += 1
            already_annotated = False
        if x - last_x < 100 and last_x > 0:
            space = x * scaler
        else:
            space = 0

        last_x = x
        ax.vlines(
            x=x,
            ymin=0,
            ymax=1_000_000,
            color=color,
            alpha=0.5,
            linewidth=0.3,
            linestyle=(0, (5, 5))
        )
        if already_annotated is True:
            pass
        else:
            ax.annotate(
                text=label,
                xy=(x, limitation),
                xytext=(x + space, limitation * 1.15),
                textcoords='data',
                fontsize=11,
                color=color,
                rotation=90,
                zorder=0,
                ha='left',
                va='bottom',
                arrowprops=dict(
                    arrowstyle='-',
                    connectionstyle='arc,angleA=90,angleB=0,rad=0,armA=10,armB=0',
                    color=color,
                    linewidth=0.3,
                    zorder=100,
                    linestyle=(0, (5, 5)),
                    alpha=0.5
                )
            )

    ax.text(1.02, 0.6, title,
            transform=ax.transAxes,
            rotation=270,
            va='center',
            zorder=2000,
            ha='left',
            fontsize=12)
    ax.set_ylim(0, limitation)
    ax.set_xlim(0, 2000)
    ax.tick_params(axis='x', labelsize=12, bottom=True)
    ax.tick_params(axis='y', labelsize=12, left=True)
    ax.grid(False)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xticks([0, 500, 1000, 1500, 2000])


fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(10, 5), sharex=True)
plt.subplots_adjust(hspace=0.5, wspace=0.5)
good_examples = [0, 20, 40, 58, 75, 109]
axs = axs.flatten()

fig.text(0.5, 0.0, "Energie [keV]", size=14, ha='center')
fig.text(0.05, 0.5, "Zählwert", size=14, va='center', rotation='vertical')
limitations = [50, 300, 500, 1000, 60, 2000]
scaler = [10, 10, 10, 0.0, 0.2, 10]
titles = ["(A) 1 Nuklid", "(B) 2 Nuklide", "(C) 3 Nuklide",
          "(D) 4 Nuklide", "(E) 5 Nuklide", "(F) 6 Nuklide"]
for idx, sample in enumerate(good_examples):
    date = re_keys[sample]
    limitation = limitations[idx]
    one_meas_processed_measurement = validation_set.loc[validation_set["datetime"] == date]
    one_meas_processed_measurement["identified_isotope"] = one_meas_processed_measurement["identified_isotope"].apply(
        format_isotope)
    make_subplot(axs[idx], one_meas_processed_measurement, limitation, idx, titles[idx], scaler[idx])

handles_labels = [axs[i].get_legend_handles_labels() for i in range(len(axs))]
handles, labels = [], []
for h, l in handles_labels:
    for handle, label in zip(h, l):
        if label not in labels:
            handles.append(handle)
            labels.append(label)
leg = fig.legend(
    handles, labels,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.05),
    ncol=2,
    fontsize=12,
    frameon=False
)
for line in leg.get_lines():
    line.set_linewidth(3)
plt.savefig("plots/background_meas_example555.pdf", bbox_inches='tight')

In [None]:

import plotly.graph_objects as go

fig = go.Figure()

for key in re_keys[8:9]:
    filtered_df = validation_set.loc[validation_set["datetime"] == key].reset_index(drop=True)
    fig.add_trace(go.Scatter(
        x=filtered_df["energy"],
        y=filtered_df["count"],
        mode='lines',
        name=str(key),
        line=dict(color='black')
    ))
    identified_peaks = filtered_df.loc[filtered_df["peak"] == True].reset_index(drop=True)
    unique_isotopes = identified_peaks["identified_isotope"].unique()
    for iso in unique_isotopes:
        filtered_iso = identified_peaks.loc[identified_peaks["identified_isotope"] == iso].reset_index(drop=True)
        fig.add_trace(go.Scatter(x=filtered_iso["energy"], y=filtered_iso["count"],
                                 mode='markers', name=iso,
                                 marker=dict(color='red', size=8, symbol='circle')))

fig.update_layout(height=1000, title="Signal mit Peaks")
fig.update_xaxes(title="Energie [keV]")
fig.update_yaxes(title="Zählwert")
fig.show()


In [None]:
plt.rcParams['text.usetex'] = True

training_min_loss_x = data_mic_mac_loss["training_macro_loss"].max()
training_min_loss_y = \
data_mic_mac_loss.loc[data_mic_mac_loss["training_macro_loss"] == training_min_loss_x]["epoch"].values[0]

validation_min_loss_x = data_mic_mac_loss["validation_macro_loss"].max()
validation_min_loss_y = \
data_mic_mac_loss.loc[data_mic_mac_loss["validation_macro_loss"] == validation_min_loss_x]["epoch"].values[0]

plt.figure(figsize=(10, 5))
plt.plot(data_mic_mac_loss["epoch"], data_mic_mac_loss["training_macro_loss"], color="black",
         label="training_macro_loss")

plt.plot(data_mic_mac_loss["epoch"], data_mic_mac_loss["validation_macro_loss"], color="red",
         label="validation_macro_loss")
plt.annotate(f'MAX(Macro-AUC) Training = {round(training_min_loss_x, 2)}',
             ha='center', va='bottom',
             size='large',
             xytext=(training_min_loss_y, 1), xy=(training_min_loss_y, training_min_loss_x),
             arrowprops={'facecolor': 'darkgrey'}, alpha=0.5)

plt.annotate(f'MAX(Macro-AUC) Validierung = {round(validation_min_loss_x, 2)}',
             ha='center', va='bottom',
             size='large',
             xytext=(validation_min_loss_y, 1), xy=(validation_min_loss_y, validation_min_loss_x),
             arrowprops={'facecolor': 'darkgrey'}, alpha=0.5)
plt.ylim(0, 1)

In [None]:
aucs = data_auc.loc[data_auc["epoch"] == validation_min_loss_y]


def format_isotope(isotope):
    match = re.match(r"([a-zA-Z]+)(\d+)", isotope)
    if match:
        element, mass = match.groups()
        validation_auc = aucs.loc[aucs["nuclide"] == isotope].reset_index(drop=True)["validation_auc"][0]
        training_uac = aucs.loc[aucs["nuclide"] == isotope].reset_index(drop=True)["training_auc"][0]
        return f"$^{{{mass}}}{element.capitalize()}$\nAUC-Training={round(training_uac, 2)}\nAUC-Validation={round(validation_auc, 2)}"
    else:
        return isotope


fig = plt.figure(figsize=(10, 5))
filtered_data = data_tpr_fpr[data_tpr_fpr["epoch"] == validation_min_loss_y].reset_index(drop=True)
filtered_data["nuclide"] = filtered_data["nuclide"].apply(format_isotope)
rel = sns.relplot(filtered_data, x="validation_fpr", y="validation_tpr", hue="nuclide", kind="line",
                  drawstyle="steps-pre", palette=sns.color_palette())
ax = rel.ax
ax.plot([0, 1], [0, 1], color="lightgrey", linestyle="--", linewidth=2, label="Basislinie")
ax.set_xlabel("False Positive Rate (FPR)", size=14)
ax.set_ylabel("True Positive Rate (TPR)", size=14)
ax.tick_params(axis='x', labelsize=12, bottom=True)
ax.tick_params(axis='y', labelsize=12, left=True)
ax.grid(True, alpha=0.2)
sns.move_legend(rel,
                loc="lower center",
                bbox_to_anchor=(0.5, 1.02),
                borderaxespad=0,
                title="",
                ncol=3,
                frameon=False)
handles, labels = ax.get_legend_handles_labels()
rel._legend.remove()
ax.legend(handles, labels, loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False)

In [None]:
import src.measurements.api as mpi

splitted_keys = mpi.API().re_splitted_keys()
validation_keys = splitted_keys.loc[splitted_keys["type"] == "cnn_validation"].reset_index(drop=True)[
    "datetime"].tolist()
validation_measurements = ppi.API().re_measurement(validation_keys)

In [None]:
validation_cnn_pm = Training(use_processed_synthetics=bool(
    load_config()["cnn"]["use_processed_synthetics"]
),
    use_processed_measuremnets=bool(
        load_config()["cnn"]["use_processed_measurements"],
    ),
    use_re_processed_data=True).validation_cnn_pm

In [None]:
fitted_mlb = validation_cnn_pm.fitted_mlb
fitted_mlb.classes_

In [None]:
isos_ind = pd.DataFrame([validation_cnn_pm.labels_by_datetime]).T.reset_index()
import re


def count_ones(val):
    arr = val
    return sum(arr[0])


isos_ind["ones_count"] = isos_ind[0].apply(count_ones)
isos_ind.groupby("ones_count")["index"].count()

In [None]:
from torch.utils.data import DataLoader

validation_cnn_pm_loader = DataLoader(
    validation_cnn_pm, batch_size=1, shuffle=True)

In [None]:
item = validation_cnn_pm.__getitem__(10)
test = item[0].float().to("cpu").unsqueeze(0).unsqueeze(0)
item

In [None]:
# mlb_classes
model.eval()
for i in range(30):
    item = validation_cnn_pm.__getitem__(i)
    test = item[0].float().to("cpu").unsqueeze(0).unsqueeze(0)
    output = model(test)
    print(torch.sigmoid(output))
    print(item[0])
    print(item[2])
    print(item[1])



In [None]:
x = np.arange(0, 8160 * 0.34507313512321336, 0.34507313512321336)
plt.plot(x, test[0].T)

plt.ylim(0, 0.1)
plt.xlim(0, 500)

In [None]:
import torch
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, savgol_filter

target = torch.tensor([item[2].T])

test.requires_grad_()
output = model(test)

positive_classes = (target[0] == 1).nonzero(as_tuple=True)[0]

for class_idx in positive_classes:
    model.zero_grad()
    print(class_idx)

    class_score = output[0, class_idx]
    class_score.backward(retain_graph=True)

    saliency = test.grad.data.abs()
    saliency, _ = saliency.max(dim=1)
    saliency_map = saliency.squeeze().cpu().numpy()

    saliency_map -= saliency_map.min()
    saliency_map /= saliency_map.max() + 1e-8

    plt.plot(x, savgol_filter(saliency_map, window_length=100, polyorder=1))
    plt.title(f"Saliency Map for Class {class_idx.item()}")
    plt.show()

    test.grad.zero_()


In [None]:
import shap

model.eval()
background = []
for i in range(30):
    item = validation_cnn_pm.__getitem__(i)
    test = item[0].float().to("cpu").unsqueeze(0).unsqueeze(0)
    background.append(test)
    output = model(test)
    print(torch.sigmoid(output))
    print(item[0])
    print(item[2])
    print(item[1])
background = torch.cat(background, dim=0)
explainer = shap.DeepExplainer(model, background)

In [None]:
import torch
from torch.utils.data import DataLoader
import shap
import numpy as np


class ModelWrapper:
    def __init__(self, model):
        self.model = model.eval().to("cuda")

    def __call__(self, x_numpy):
        x_tensor = torch.from_numpy(x_numpy).float().unsqueeze(1).to("cuda")
        with torch.no_grad():
            return torch.sigmoid(self.model(x_tensor)).cpu().numpy()


validation_cnn_pm_loader = DataLoader(validation_cnn_pm, batch_size=50, shuffle=True)
batch = next(iter(validation_cnn_pm_loader))[0]

background = batch[:20].numpy()
test_sample = batch[21].unsqueeze(0).numpy()
wrapped_model = ModelWrapper(model)
explainer = shap.Explainer(wrapped_model, background, algorithm="partition")
shap_values = explainer(test_sample)

In [None]:
import matplotlib.pyplot as plt

shap_vals_class0 = shap_values.values[0, :, 0]

plt.figure(figsize=(14, 4))
x = np.arange(0, 8160 * 0.34507313512321336, 0.34507313512321336)
plt.plot(x, shap_vals_class0)
plt.title("SHAP values for class 1")
plt.xlabel("Feature index")
plt.ylabel("SHAP value")
plt.grid(True)
plt.tight_layout()
plt.show()
