# Fidelity

In [1]:
import os
import hydra
import torch
import shutil
import warnings
from tqdm import tqdm
import numpy as np
from torch.optim import Adam
from omegaconf import OmegaConf
from utils import check_dir
from gnnNets import *
from dataset import get_dataset, get_dataloader
from plot_functions import concept_gradient_importance
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch_geometric.nn.models.explainer import clear_masks, set_masks
import logging
import torch_geometric
import numpy as np
import captum.attr._utils.common

import networkx as nx
import torch_geometric
import matplotlib.pyplot as plt

from dig.xgraph.method import GNNExplainer, DeepLIFT
from dig.xgraph.evaluation import XCollector
from dig.xgraph.method.subgraphx import PlotUtils

from torch_geometric.nn.models.explainer import (
    Explainer,
    clear_masks,
    set_masks,
)

from rdkit import Chem
from torch_geometric.data import Data
from visualize import visualize_attrs
from dig.xgraph.dataset.mol_dataset import *

from time import sleep

from PIL import Image
import pylab as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from hydra import compose, initialize
from omegaconf import OmegaConf
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="config", job_name="test_app")
cfg = compose(config_name="config", overrides=[])
print(OmegaConf.to_yaml(cfg))




models:
  gnn_saving_dir: ''
  gnn_name: gat
  n_heads: 3
  param:
    hiv:
      learning_rate: 0.001
      weight_decay: 0.0005
      milestones: None
      gamma: None
      batch_size: 64
      num_epochs: 200
      num_early_stop: 20
      gnn_latent_dim:
      - 128
      - 128
      - 128
      gnn_dropout: 0.0
      add_self_loop: true
      gcn_adj_normalization: true
      gnn_emb_normalization: false
      graph_classification: true
      node_classification: false
      gnn_nonlinear: relu
      readout: sum
      fc_latent_dim:
      - 128
      fc_dropout: 0.0
      fc_nonlinear: relu
    sider:
      learning_rate: 0.0001
      weight_decay: 0.0005
      milestones: None
      gamma: None
      batch_size: 64
      num_epochs: 200
      num_early_stop: 30
      gnn_latent_dim:
      - 128
      - 128
      - 128
      gnn_dropout: 0.0
      add_self_loop: true
      gcn_adj_normalization: false
      gnn_emb_normalization: false
      graph_classification: true
      nod

In [3]:
config = cfg

In [4]:
config.models.gnn_saving_dir = 'gnn_checkpoints'
config.models.param = config.models.param[config.datasets.dataset_name]

if torch.cuda.is_available():
    device = torch.device('cuda', index=config.device_id)
else:
    device = torch.device('cpu')

dataset = get_dataset(dataset_root=config.datasets.dataset_root,
                      dataset_name=config.datasets.dataset_name)
dataset.data.x = dataset.data.x.float()
dataset.data.y = dataset.data.y.squeeze().long()
if config.models.param.graph_classification:
    dataloader_params = {'batch_size': 1,
                         'stratified': config.stratified,
                         'random_split_flag': config.datasets.random_split_flag,
                         'data_split_ratio': config.datasets.data_split_ratio,
                         'seed': config.datasets.seed}

In [5]:
dataset = get_dataset(dataset_root=config.datasets.dataset_root,
                          dataset_name=config.datasets.dataset_name)

In [6]:
dataloader = get_dataloader(dataset, **dataloader_params)
test_indices = dataloader['test'].dataset.indices

In [7]:
model = get_gnnNets(dataset.num_node_features, dataset.num_classes, config.models, config.concept_whitening, concept_acts=False)
if config.concept_whitening:
    model.replace_norm_layers()

In [8]:
saved = torch.load(f"/home/michelaproietti/thesis_last/trained_models/bbbp/{config.models.gnn_name}_cw_max.pth")
state_dict = saved['net']
model.load_state_dict(state_dict)
model.to(device)
model.eval()

GCNNet(
  (readout_layer): GNNPool()
  (convs): ModuleList(
    (0): GATConv(9, 128, heads=3)
    (1): GATConv(384, 128, heads=3)
    (2): GATConv(384, 128, heads=3)
  )
  (norm_layers): ModuleList(
    (0): IterNormRotation(
      384, num_channels=384, T=5, eps=1e-05, momentum=0.2, affine=False
      (topkpool): TopKPooling(384, ratio=0.5, multiplier=1.0)
    )
    (1): IterNormRotation(
      384, num_channels=384, T=5, eps=1e-05, momentum=0.2, affine=False
      (topkpool): TopKPooling(384, ratio=0.5, multiplier=1.0)
    )
    (2): IterNormRotation(
      384, num_channels=384, T=5, eps=1e-05, momentum=0.2, affine=False
      (topkpool): TopKPooling(384, ratio=0.5, multiplier=1.0)
    )
  )
  (mlps): ModuleList(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [9]:
def data_from_smiles(smiles, y):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    xs = []
    for atom in mol.GetAtoms():
        x = []
        x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        x.append(x_map['degree'].index(atom.GetTotalDegree()))
        x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        x.append(x_map['num_radical_electrons'].index(
            atom.GetNumRadicalElectrons()))
        x.append(x_map['hybridization'].index(
            str(atom.GetHybridization())))
        x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        x.append(x_map['is_in_ring'].index(atom.IsInRing()))
        xs.append(x)

    x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        e.append(e_map['bond_type'].index(str(bond.GetBondType())))
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

    # Sort indices.
    if edge_index.numel() > 0:
        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
                smiles=smiles, mol=mol)
    
    return data

In [10]:
x_collector = XCollector()

In [11]:
dst = '/home/michelaproietti/thesis_last/fidelity/'
dst_explanations = dst + 'explanations/'
if not os.path.exists(dst):
    os.mkdir(dst)
if not os.path.exists(dst_explanations):
    os.mkdir(dst_explanations)
    
explainer_type = 'gnnexplainer' # Other explainer types can be implemented
CW = config.concept_whitening
plus_fidelity_scores, minus_fidelity_scores = [], []

for index,data in enumerate(tqdm(dataset[test_indices])):
    data.x = data.x.float()
    
    data_new = data_from_smiles(data.smiles, data.y)
    data_new.x = data_new.x.float()
    data_new.to(device)

    clear_masks(model)
    pred = model(data_new)
    pred = torch.softmax(pred, dim=-1)[0,0].item()
    
    if explainer_type == 'gnnexplainer':
        gnn_explainer = GNNExplainer(model, epochs=100, lr=0.01, explain_graph=True)
        gnn_explainer.device = device
        gnn_explainer._to_log_prob = lambda x: x[0]

        clear_masks(model)

        edge_mask, hard_edge_mask, related_preds = \
                        gnn_explainer(data_new.x, data_new.edge_index,
                                      sparsity=0.8,
                                      num_classes=dataset.num_classes)

    minus_edge_mask = edge_mask[0]
    plus_edge_mask = torch.ones(minus_edge_mask.shape).to(device) - minus_edge_mask

    set_masks(model,  torch.nn.Parameter(plus_edge_mask.to(device)), data_new.edge_index, apply_sigmoid=False)
    pred_plus = model(data_new)
    pred_plus = torch.softmax(pred_plus, dim=-1)[0,0].item()
    plus_fidelity_scores.append(pred - pred_plus)
    
    clear_masks(model)
    
    set_masks(model,  torch.nn.Parameter(minus_edge_mask.to(device)), data_new.edge_index, apply_sigmoid=False)
    pred_minus = model(data_new)
    pred_minus = torch.softmax(pred_minus, dim=-1)[0,0].item()
    minus_fidelity_scores.append(pred - pred_minus)

	baddbmm(Number beta, Tensor input, Number alpha, Tensor batch1, Tensor batch2, *, Tensor out)
Consider using one of the following signatures instead:
	baddbmm(Tensor input, Tensor batch1, Tensor batch2, *, Number beta, Number alpha, Tensor out) (Triggered internally at  /opt/conda/conda-bld/pytorch_1646756402876/work/torch/csrc/utils/python_arg_parser.cpp:1055.)
  Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2)) # In the paper: 1/n *(Z-mu*1^T)(Z-mu*1^T)^T
100%|██████████| 205/205 [18:28<00:00,  5.41s/it]


In [12]:
print(f'Fidelity+: {np.mean(plus_fidelity_scores)} pm {np.std(plus_fidelity_scores)}')
print(f'Fidelity-: {np.mean(minus_fidelity_scores)} pm {np.std(minus_fidelity_scores)}')

Fidelity+: -0.2630177048926458 pm 0.5601574488888476
Fidelity-: 0.23757027873328637 pm 0.3820786015093287
