In [10]:
import torch
from torch_geometric.nn.models import GNNExplainer
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

In [36]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
from tqdm import tqdm
from torch_geometric.data import Data, DataListLoader, Batch,DataLoader
from torch.utils.data import random_split,ConcatDataset


class EdgeBlock(torch.nn.Module):
    def __init__(self):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = Seq(Lin(48*2, 128), 
                            BatchNorm1d(128),
                            ReLU(),
                            Lin(128, 128))

    def forward(self, src, dest, edge_attr, u, batch=128):
        out = torch.cat([src, dest], 1)
        return self.edge_mlp(out)

class NodeBlock(torch.nn.Module):
    def __init__(self):
        super(NodeBlock, self).__init__()
        self.node_mlp_1 = Seq(Lin(48+128, 128), 
                              BatchNorm1d(128),
                              ReLU(), 
                              Lin(128, 128))
        self.node_mlp_2 = Seq(Lin(48+128, 128), 
                              BatchNorm1d(128),
                              ReLU(), 
                              Lin(128, 128))

    def forward(self, x, edge_index, edge_attr, u, batch=128):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        return self.node_mlp_2(out)

    
class GlobalBlock(torch.nn.Module):
    def __init__(self):
        super(GlobalBlock, self).__init__()
        self.global_mlp = Seq(Lin(128, 128),                               
                              BatchNorm1d(128),
                              ReLU(), 
                              Lin(128, 2))

    def forward(self, x, edge_index, edge_attr, u, batch=128):
        out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)


class InteractionNetwork(torch.nn.Module):
    def __init__(self):
        super(InteractionNetwork, self).__init__()
        self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
        self.bn = BatchNorm1d(48)
        
    def forward(self, x, edge_index, batch=128):
        
        x = self.bn(x)
        x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
        return u

In [37]:
model=InteractionNetwork()
model.load_state_dict(torch.load("../data/model/IN_best_dec10.pth",map_location=torch.device('cpu')))

<All keys matched successfully>

In [38]:
explainer=GNNExplainer(model,epochs=10)

In [100]:
def _collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)

data=[torch.load(f"/teams/DSC180A_FA20_A00/b06particlephysics/personal-alcx/data/test/processed/file_ntuple_merged_0/jet_{i}.pth")[0][0] for i in range(1)]
data=DataListLoader(data)
data.collate_fn=_collate

In [101]:
data.dataset

[Data(edge_index=[2, 306], u=[1, 2], x=[18, 48], y=[1, 2])]

In [81]:
x=data.dataset[0].x
y=data.dataset[0].y
edge_index=data.dataset[0].edge_index

data.dataset

[Data(edge_index=[2, 306], u=[1, 2], x=[18, 48], y=[1, 2])]

In [107]:
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=y)
plt.show()

TypeError: 'int' object is not subscriptable