In [None]:
import pandas as pd 
import numpy as np
import torch
import os.path as osp
import math
import copy

from XASNet.data import QM9_XAS
from XASNet.utils import cam_gnn, cam_graphnet
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet

from XASNet.utils import GraphDataProducer
from XASNet.utils import (
    GroundTruthGenerator,
    OrcaAnlyser,
    Contributions
)
from XASNet.utils import auc, plot_roc_curve
from XASNet.utils import plot_graph

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid', palette='muted', font_scale=1.5)

from pylab import rc, rcParams
rc('text', usetex=False)
rc('axes', linewidth=2)
rc('font', weight='normal')

params = {'legend.fontsize': 17,
          'figure.figsize': (8, 6),
         'axes.labelsize': 25,
         'axes.titlesize': 25,
         'xtick.labelsize':25,
         'ytick.labelsize':25,
         'figure.dpi' : 200}
rcParams.update(params)
from matplotlib.ticker import FormatStrFormatter

device = 'cpu'

# Load GNN model 

In [None]:
xasnet_gnn = XASNet_GNN(
    gnn_name='gatv2',
    in_channels=[11, 128, 256, 512],
    out_channels=[128, 256, 512, 600],
    num_targets=100,
    num_layers=4,
    heads=3
).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    spectragnn.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

# Loading test data 

In [None]:
root = 'path-to-qm9xas-dataset'
test_qm9xas = QM9_XAS(root=root,
             raw_dir='./raw/',
             spectra=[])

In [None]:
# picking a graph from test dataset
mol_idx = 9088

graph_picker = GraphDataProducer(
    model=xasnet_gnn,
    gnn_type="gatv2",
    test_data=test_qm9xas,
    idx_to_pick=mol_idx
)

graph = graph_picker.picked_graph

In [None]:
# make prediction with the loaded model 
y_true = graph.spectrum
x_pred, y_pred = graph_picker.predictions()

In [None]:
# make an atom labels mapping in a dict
atomic_num = graph.z
label_map = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
atom_labels = []
for i, z in enumerate(atomic_num):
    atom_labels.append(f"{label_map[z.item()]} {i}")

# Visualisation and peak finder

In [None]:
from utils.visualisation import plot_prediction

In [None]:
plot_prediction(x_pred, 
                y_pred, 
                y_true, 
                normalise=True, 
                add_peaks=False,
                save=False)

# Calculating CAM data based on the model and input graph

In [None]:
cam_gatv2 = cam_gnn(
    graph, 
    xasnet_gnn
)

In [None]:
# a data frame with xas spectrum data and cam attributions of atoms
all_cam_data = pd.DataFrame(
    np.c_[x_pred, y_pred, cam_gatv2.T], 
    columns=['energies', 'osc', *atom_labels]
)

# Obtaining the ground truth 

In [None]:
path_orca_output = 'path-to-orca-raw-output-file'
path_orca_spectrum = 'path-to-xas-spectrum-output-file'

orca_analyzer = OrcaAnlyser(path_orca_output,
                            path_orca_spectrum)

In [None]:
excitations = orca_analyzer.give_excitations()

In [None]:
contributions = Contributions(
    excitations, 
    all_cam_data, 
    281, 
    atom_labels
)

In [None]:
# obtaining core/virtual contribution of atoms to the peak, in this case 281 eV
acc, don = contributions.don_acc_contrs()
# obtaining corresponding cam contributions 
cam_contr = contributions.cam_contrs()

In [None]:
# prepare atom labels to visualise on Graph figures
labels = dict(zip(
    np.arange(len(atom_labels)),
    atom_labels
))

In [None]:
# cam graph
nx_g_cam = plot_graph(graph, labels, 
           cam_contr['weights'], 
           save_fig=False,
           acceptor_orb=False)

In [None]:
# ground truth core orbitals
nx_g_don = plot_graph(graph, labels, 
           don['weights'], 
           save_fig=False,
           don_orb=True)

In [None]:
# ground truth virtual orbitals 
nx_g_acc = plot_graph(graph, labels, 
           acc['weights'], 
           save_fig=False,
           acceptor_orb=True)

# AUC-ROC of XAS prediction 

In [None]:
from utils.auc_roc import auc, plot_roc_curve

In [None]:
scores = np.asarray(cam_contr['weights'])
contributions = np.asarray(don['weights'])

In [None]:
auc_score, fpr, tpr = auc(scores, contributions)

In [None]:
plot_roc_curve(fpr, tpr)

In [None]:
auc_score