# Analyzing the MDA trajectory

Particularly, compare the contributions with the coulomb potential.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import schnetpack as sp
import torch.nn as nn
import numpy as np

from copy import copy, deepcopy
import networkx as nx
from scipy.spatial.distance import cdist

import torch, numpy
import matplotlib.pyplot as plt

from symbxai.lrp.symbolic_xai import SchNetSymbXAI

from tqdm import tqdm

from numpy import genfromtxt


In [None]:
# A few global variables
atom_names_dict = {1: "H", 6: "C", 7: "N", 8: "O", 9: "F"}
models = {}
datasets = {}
target_props = {}
cutoff = {}
kcal2eV_scal=23.060541945329334

# Load MDA Model

In [None]:
mdamodel_file = '../saved_models/mda_schnorb_model_v2/best_model'
mdamodel = torch.load(mdamodel_file, map_location=torch.device('cpu'))
cutoff['mda'] = mdamodel.representation.cutoff_fn.cutoff.item()
# qm9model.do_postprocessing = False
models['mda'] = mdamodel

target_props['mda'] = 'energy'
model= mdamodel
# models['qm9'](copy(datasets['mda'][0]))


# Load the Data - MDA Trajectory

In [None]:
from ase.io import read

from schnetpack.interfaces.ase_interface import AtomsConverter

ats = read('data/mda_extracted_88300_88800.xyz', index=":")

converter = AtomsConverter(neighbor_list=sp.transform.ASENeighborList(cutoff=cutoff['mda']),
        device="cpu",
        dtype=torch.float32
    )

mdatraj = [converter(at) for at in ats]

datasets['mdatraj'] = mdatraj

# Load MDA in equilibrium state

In [None]:
from ase.io import read

from schnetpack.interfaces.ase_interface import AtomsConverter

ats = read('data/equi_mda.extxyz', index=":")

converter = AtomsConverter(neighbor_list=sp.transform.ASENeighborList(cutoff=cutoff['mda']),
        device="cpu",
        dtype=torch.float32
    )

emda = [converter(at) for at in ats]

datasets['equi_mda'] = emda

## Visualize one MDA molecule

In [None]:
from symbxai.visualization.qc_utils import vis_mol_2d
fig, ax = plt.subplots(figsize=(7,7))
sample = datasets['mdatraj'][0]
anum, pos = sample['_atomic_numbers'].data.numpy(), sample['_positions'].data.numpy()

vis_mol_2d(ax,
            anum,
            pos,
           projdim=0,
          with_atom_id=True)
plt.tight_layout()
plt.show()

In [None]:
def find_const_approx(vals):
    osplit_id, ovleft, ovright = 0, float('inf'), float('inf')
    for split_id in range(1,len(vals)-1):
        left, right = vals[:split_id], vals[split_id:]
        vleft, vright = len(left)*numpy.var(left), len(right)*numpy.var(right)

        if ovleft+ovright > vleft+vright:
            # found new optimum
            osplit_id, ovleft, ovright = split_id, vleft, vright
            
    return osplit_id

## Total energy

In [None]:
all_outs = [model(copy(sample))['energy'].detach().numpy() for sample in mdatraj ]
plt.figure(figsize=(4,2))
plt.plot(all_outs, lw=12, color='black')
plt.ylim(sum(all_outs)/len(all_outs) -1, sum(all_outs)/len(all_outs) +1)
plt.xticks([])
plt.yticks([])
# plt.savefig('pics/qc_prediction_change_fig1.svg', transparent=True)
plt.show()

In [None]:
from symbxai.visualization.utils import vis_barh_query
osplit_id = find_const_approx(all_outs)
split_dist = abs(np.mean(all_outs[:osplit_id]) - np.mean(all_outs[osplit_id:]))
vis_barh_query({'energy':split_dist}, xlim=(0,1),filename=None
              )

# Experiment 1 - visualize all first order contributions (classic XAI)

# node contributions

In [None]:

stop_top = 3

model_mode = 'mda'
model = models[model_mode]
gamma = .0

ida2atnnum_str = lambda ida: atom_names_dict[sample['_atomic_numbers'][ida].item()]


all_lrp_contr = []
for sample in mdatraj:
    explainer = SchNetSymbXAI(copy(sample),
                                      models[model_mode], 
                                      target_props[model_mode], 
                                      gamma = gamma)
    all_lrp_contr.append(explainer.node_relevance())

all_lrp_contr = torch.stack(all_lrp_contr)

node_cp_dist ={}
for i in range(all_lrp_contr.shape[1]):
    rels = all_lrp_contr[:,i].numpy()
    osplit_id = find_const_approx(rels)
    split_dist = abs(np.mean(rels[:osplit_id]) - np.mean(rels[osplit_id:]))
    node_cp_dist[ida2atnnum_str(i) + f'$_{i}$'] = split_dist 
    
print('Distance of the constant approx')
node_cp_dist = dict(sorted(node_cp_dist.items(), key=lambda item: item[1], reverse=True))
vis_barh_query({key:val for i, (key,val) in enumerate(node_cp_dist.items()) if i < stop_top}, xlim=(0,1),
              filename=None)


In [None]:

topkeys = [key for num, (key,val) in enumerate(node_cp_dist.items()) if num<stop_top]
for i in range(all_lrp_contr.shape[1]):
    
    if ida2atnnum_str(i) + f'$_{i}$' in topkeys:
        fig, ax = plt.subplots(1,1, figsize=(4,2))
        print(ida2atnnum_str(i) + f'$_{i}$')
        plt.plot(all_lrp_contr[:,i], lw=12, color='black' ) #, label= ida2atnnum_str(i) + f'$_{i}$')
        margin = (2 - (max(all_lrp_contr[:,i]) - min(all_lrp_contr[:,i])))/2
        plt.ylim([min(all_lrp_contr[:,i])-margin, max(all_lrp_contr[:,i])+margin])
        plt.xticks([])
        plt.yticks([])
        # plt.savefig(f'pics/qc_node_change_fig1_{i}.svg', transparent=True)
        plt.show()

# Experiment 2 - Find the reaction variable using SymbXAI

## Step 1: Compute all Harsanyi Dividends

In [None]:
from symbxai.utils import powerset

max_order = 1
all_hars_div = []
all_sets = powerset(range(9), K=max_order)

for sample in tqdm(mdatraj):
    explainer = SchNetSymbXAI(copy(sample),
                              models[model_mode], 
                              target_props[model_mode], 
                              gamma = gamma)
    
    hars_div = []
    for S in all_sets:
        hars_div.append(explainer.harsanyi_div(S))
    all_hars_div.append(torch.tensor(hars_div))
all_hars_div = torch.stack(all_hars_div)

In [None]:
import pickle
# pickle.dump(all_hars_div, open(f'intermediate_results/query_search_algo/hars_mda_traj_max_order{max_order}.pkl', 'wb'))

In [None]:
from symbxai.query_search.utils import setup_queries

max_setsize = 1
max_and_order = 1
max_indexdist = float('inf') # ist aber egal mit max_setsize = 1
query_mode = 'conj. disj. reasonably mixed' #'conj. disj. (neg. disj.) reasonably mixed'
tokens = [ida2atnnum_str(i) + f'{i}' for i in explainer.node_domain]

all_queries = setup_queries(explainer.node_domain, 
                                    tokens,
                                    max_and_order, 
                                    max_setsize=max_setsize, 
                                    max_indexdist=max_indexdist, 
                                    mode=query_mode,
                                    repres_style='Latex')

all_attributions = []
for query in all_queries:
    all_attr_per_query = []
    for hars_div in all_hars_div:
        all_attr_per_query.append( sum([hars_div[i] for i, S in enumerate(all_sets) if query(S)]))
    all_attributions.append(torch.tensor(all_attr_per_query))
        
all_attributions = torch.stack(all_attributions)

In [None]:
stop_top =3
query_cp_dist = {}
for i in range(all_attributions.shape[0]):
    query = all_queries[i]
    rels = all_attributions[i].numpy()
    osplit_id = find_const_approx(rels)
    split_dist = abs(np.mean(rels[:osplit_id]) - np.mean(rels[osplit_id:]))
    query_cp_dist[query.str_rep] = split_dist

query_cp_dist = dict(sorted(query_cp_dist.items(), key=lambda item: abs(item[1]), reverse=True))
top_query_dists = {key:val for i,(key,val) in enumerate(query_cp_dist.items()) if i < stop_top}
vis_barh_query(top_query_dists, xlim=(0,1),
              filename=None)

In [None]:

for i in range(all_attributions.shape[0]):
    query = all_queries[i]
    if query.str_rep in top_query_dists.keys():
        fig, ax = plt.subplots(1,1, figsize=(4,2))
        print(query.str_rep)
        plt.plot(all_attributions[i], lw=12, color='black' ) #, label= ida2atnnum_str(i) + f'$_{i}$')
        margin = (2 - (max(all_attributions[i]) - min(all_attributions[i])))/2
        plt.ylim([min(all_attributions[i])-margin, max(all_attributions[i])+margin])
        plt.xticks([])
        plt.yticks([])
        # plt.legend()
        # plt.savefig(f'pics/qc_query_change_fig1_{query.str_rep}.svg', transparent=True)
        plt.show()
        
        # plt.plot(all_attributions[i], label=query.str_rep)
    
# plt.legend()
# plt.show()

In [None]:
all_dists = []
atom_pairs = []
for i in range(sample['_n_atoms']):
    for j in range(sample['_n_atoms']):
        if i<j: atom_pairs.append((i,j))

for sample in mdatraj:
    dists = torch.cdist(sample['_positions'],sample['_positions'])
    pairwise_dists = torch.tensor([dists[i,j] for i,j in atom_pairs])
    all_dists.append(pairwise_dists)

all_dists = torch.stack(all_dists)


In [None]:
show_top =10
apair_dist_cp_dist = {}
for i in range(all_dists.shape[1]):
    rels = all_dists[:,i].numpy()
    a1, a2 = atom_pairs[i]
    osplit_id = find_const_approx(rels)
    split_dist = abs(np.mean(rels[:osplit_id]) - np.mean(rels[osplit_id:]))
    str_rep = ida2atnnum_str(a1) + f'$_{a1}$ - '+ ida2atnnum_str(a2)+f'$_{a2}$' 
    apair_dist_cp_dist[ str_rep ] = split_dist

apair_dist_cp_dist = dict(sorted(apair_dist_cp_dist.items(), key=lambda item: item[1], reverse=True))
top_dist_dists = {key:val for i,(key,val) in enumerate(apair_dist_cp_dist.items()) if i < show_top}
vis_barh_query(top_dist_dists)