In [2]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from mutagDataset import MUTAG

In [3]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

data = MUTAG()

num_examples = len(data)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    data, sampler=train_sampler)
test_dataloader = GraphDataLoader(
    data, sampler=test_sampler)

In [4]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv3 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
        
        
    def forward(self, graph, inputs):
        # inputs is features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.normalize(v) for k, v in h.items()}
        h = {k: F.relu(v) for k, v in h.items()}
        
        h = self.conv2(graph, h)
        h = {k: F.relu(v) for k, v in h.items()}
        
        h = self.conv3(graph, h)
        h = {k: F.dropout(v, p=0.3) for k, v in h.items()}
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()
        self.in_dim = in_dim
        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.linear = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, feat=None, eweight=None):
        if not feat:
            feat = {ntype:torch.zeros((g.num_nodes(ntype), self.in_dim)) for ntype in g.ntypes}
    
        h = feat    
        h = self.rgcn(g, h)

        with g.local_scope():
            g.ndata['h'] = h
    
            if eweight is None:
                g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                g.edata['w'] = eweight
                g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
                
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes: 
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.linear(hg)


In [5]:
rel_names = set()

for graph in data:
    for edge in graph[0].etypes:
        rel_names.add(str(edge))

rel_names = list(rel_names)

In [6]:
model = HeteroClassifier(10, 40, 2, rel_names)

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=0.015)

training_loss, test_loss = [], []
for epoch in range(50):
    num_correct, num_tests, batchLoss = 0, 0, 0
    for batched_graph, labels in train_dataloader:
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        batchLoss+= loss
        num_correct += (logits.argmax(1) == labels).sum().item()
        num_tests += len(labels)

    training_loss.append(batchLoss/num_tests)
    
    num_correct_test, num_tests_t = 0, 0 
    batchTestLoss = 0
    for batched_graph, labels in test_dataloader:
        pred = model(batched_graph)
        loss = F.cross_entropy(pred, labels)
        num_correct_test += (pred.argmax(1) == labels).sum().item()
        num_tests_t += len(labels)        
        batchTestLoss += loss

    test_loss.append(batchTestLoss/num_tests)
    
    if epoch % 5 == 0:
        print(f'Epochs {epoch}, Train accuracy: {num_correct / num_tests * 100:.2f}%, Test accuracy: {num_correct_test / num_tests_t * 100:.2f}%')
    
    



Epochs 0, Train accuracy: 62.67%, Test accuracy: 68.42%
Epochs 5, Train accuracy: 66.00%, Test accuracy: 68.42%
Epochs 10, Train accuracy: 66.00%, Test accuracy: 68.42%
Epochs 15, Train accuracy: 68.00%, Test accuracy: 68.42%
Epochs 20, Train accuracy: 67.33%, Test accuracy: 68.42%
Epochs 25, Train accuracy: 67.33%, Test accuracy: 68.42%
Epochs 30, Train accuracy: 65.33%, Test accuracy: 68.42%
Epochs 35, Train accuracy: 67.33%, Test accuracy: 68.42%
Epochs 40, Train accuracy: 68.00%, Test accuracy: 68.42%


In [None]:
import matplotlib.pyplot as plt

plt.plot(training_loss, label='train_loss')
plt.plot(test_loss, label='test_loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

In [116]:
# model = torch.save(model, './MUTAG_Trained_Model.pt')
# model = torch.load('./MUTAG_Trained_Model.pt')

In [129]:
from gnnexplainer import HeteroGNNExplainer

g = data[110][0]
feat = {ntype:torch.zeros((g.num_nodes(ntype), 10)) for ntype in g.ntypes}
explainer = HeteroGNNExplainer(model, num_hops=3, log=False)
feat_mask, edge_mask = explainer.explain_graph(g, feat)

In [1]:
model(g)

NameError: name 'model' is not defined