# MUTAG Model

- Reference: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=mHSP6-RBOqCE

In [1]:
import sys
sys.path.append('..')

import numpy as np 

import torch 
from torch.nn import Linear
import torch.nn.functional as F

import torch_geometric as pyg
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

from egg_models.egg_generic_losses import activation_hook

In [2]:
# Reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True 
np.random.seed(SEED)

In [3]:
# Download data. 
data_raw = TUDataset(root='data/TUDataset', name='MUTAG')

# Shuffle
data_raw = data_raw.shuffle()

# Split 
train_data = data_raw[:150]
test_data = data_raw[150:]

# Create data lists
train_data_list = []
train_data_list_0 = []
train_data_list_1 = []
test_data_list = []
test_data_list_0 = []
test_data_list_1 = []

for graph in train_data:
    train_data_list.append(graph)

    if graph.y.item() == 0: 
        train_data_list_0.append(graph)

    elif graph.y.item() == 1: 
        train_data_list_1.append(graph)

for graph in test_data:
    test_data_list.append(graph)

    if graph.y.item() == 0: 
        test_data_list_0.append(graph)

    elif graph.y.item() == 1: 
        test_data_list_1.append(graph)

In [4]:
# Define Explainee Model 
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [5]:
# Train explainee model 
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train(data_loader): 
    model.train()

    for batch in data_loader: 
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

def model_accuracy(data_loader):
    model.eval()

    correct = 0
    for batch in data_loader: 
        out = model(batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())

    return correct / len(data_loader.dataset)

train_data_loader = pyg.loader.DataLoader(train_data_list, 
                                          batch_size=16, shuffle=True)

test_data_loader = pyg.loader.DataLoader(test_data_list, 
                                          batch_size=16, shuffle=False)

for epoch in range(201): 
    train(train_data_loader)
    train_accuracy = model_accuracy(train_data_loader)
    test_accuracy = model_accuracy(test_data_loader)

    print(f"Epoch: {epoch} Train Accuracy: {train_accuracy} " + 
          f"Test Accuracy: {test_accuracy}")


Epoch: 0 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 1 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 2 Train Accuracy: 0.6733333333333333 Test Accuracy: 0.7368421052631579
Epoch: 3 Train Accuracy: 0.74 Test Accuracy: 0.8157894736842105
Epoch: 4 Train Accuracy: 0.7066666666666667 Test Accuracy: 0.8157894736842105
Epoch: 5 Train Accuracy: 0.7066666666666667 Test Accuracy: 0.8157894736842105
Epoch: 6 Train Accuracy: 0.7466666666666667 Test Accuracy: 0.7631578947368421
Epoch: 7 Train Accuracy: 0.7466666666666667 Test Accuracy: 0.7631578947368421
Epoch: 8 Train Accuracy: 0.74 Test Accuracy: 0.7894736842105263
Epoch: 9 Train Accuracy: 0.7466666666666667 Test Accuracy: 0.7894736842105263
Epoch: 10 Train Accuracy: 0.7533333333333333 Test Accuracy: 0.7631578947368421
Epoch: 11 Train Accuracy: 0.74 Test Accuracy: 0.7894736842105263
Epoch: 12 Train Accuracy: 0.72 Test Accuracy: 0.8157894736842105
Epoch: 13 Train Accuracy: 0.74 Test Acc

In [6]:
# Extract average embeddings.
activation_names = ["conv2", "conv3"]
train_avg_embedding_0 = {"conv2": [], "conv3": []}
train_avg_embedding_1 = {"conv2": [], "conv3": []}
pool_func = pyg.nn.global_mean_pool

train_data_loader_1 = pyg.loader.DataLoader(train_data_list)
for batch in train_data_loader_1:
    model.eval()

    acts, remove_hooks = activation_hook(model, activation_names)
    out = model(batch)
    remove_hooks()

    if batch.y.item() == 0: 
        for name in activation_names: 
            embed = pool_func(acts[name], batch.batch)
            train_avg_embedding_0[name].append(
                embed.detach().to('cpu')
            )

    if batch.y.item() == 1: 
        for name in activation_names: 
            embed = pool_func(acts[name], batch.batch)
            train_avg_embedding_1[name].append(
                embed.detach().to('cpu')
            )
            
for name in activation_names:
    train_avg_embedding_0[name] = torch.stack(
        train_avg_embedding_0[name]
    ).mean(dim=0)

    train_avg_embedding_1[name] = torch.stack(
        train_avg_embedding_1[name]
    ).mean(dim=0)

print(train_avg_embedding_0)
print(train_avg_embedding_1)

{'conv2': tensor([[-0.2915, -1.8095, -0.2189, -0.1180, -0.2793, -0.2710, -2.0111, -0.3490,
         -0.9480, -1.5090, -0.4723, -0.1447, -0.9340, -0.9859, -0.1834, -0.3930,
         -1.3659, -0.3043, -0.7467, -0.2003, -0.1981, -0.9850, -0.1815, -0.4189,
         -0.9779, -0.3608, -0.2256, -0.2999, -0.1910, -0.2666, -1.5029, -1.8962,
         -0.5124, -1.0498, -0.0932, -0.4114, -0.1367, -0.4498, -0.1757, -0.1389,
         -0.3753, -0.1467, -2.1132, -0.1634, -0.7426, -0.2079, -0.3401, -0.3812,
         -0.2202, -1.8615, -0.8407, -0.3386, -0.2537, -1.7592, -0.3878, -0.5448,
         -0.2084, -0.4950, -2.0607, -0.3693, -1.4028, -1.0697, -0.4403, -0.4440]]), 'conv3': tensor([[ 0.1202,  0.0301, -0.1539,  0.0836, -0.3402, -0.1375, -0.1699,  0.0640,
         -0.1472, -0.0222, -0.1640,  0.3020, -0.2556,  0.0745,  0.1479, -0.1212,
          0.0636,  0.0140, -0.0713,  0.1265, -0.1106, -0.0566,  0.0307,  0.0153,
          0.0102,  0.1602, -0.0802,  0.0850,  0.1408,  0.1291,  0.0837, -0.0738,
      

In [7]:
# Save the data. 
import dill as pickle 
base_path = "../data/explainees/MUTAG/"

# Save model weights. 
torch.save(model.state_dict(), base_path + "gcn_200.pt")

with open(base_path + "MUTAG_train_data_list.pkl", 'wb') as f:
    pickle.dump(train_data_list, f)
with open(base_path + "MUTAG_train_data_list_0.pkl", 'wb') as f:
    pickle.dump(train_data_list_0, f)
with open(base_path + "MUTAG_train_data_list.pkl_1", 'wb') as f:
    pickle.dump(train_data_list_1, f)
    
with open(base_path + "MUTAG_test_data_list.pkl", 'wb') as f:
    pickle.dump(test_data_list, f)
with open(base_path + "MUTAG_test_data_list_0.pkl", 'wb') as f:
    pickle.dump(test_data_list_0, f)
with open(base_path + "MUTAG_test_data_list_1.pkl", 'wb') as f:
    pickle.dump(test_data_list_1, f)

with open(base_path + "MUTAG_train_avg_embedding_dict_0.pkl", 'wb') as f:
    pickle.dump(train_avg_embedding_0, f)

with open(base_path + "MUTAG_train_avg_embedding_dict_1.pkl", 'wb') as f:
    pickle.dump(train_avg_embedding_1, f)

## Debugging

In [9]:
import sys 
sys.path.append("..")

from utils import mutag_helper
from egg_models.egg_generic import EggGeneric

2024-04-19 14:00:10.182825: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [15]:
import importlib
importlib.reload(mutag_helper)

<module 'utils.mutag_helper' from '/endosome/work/DPDS/s224833/Dissertation/gnn-egg/notebooks/../utils/mutag_helper.py'>

In [16]:
generator = EggGeneric(5, None, (7,), None, (4, ), batch_size=16)
generator.AdjacencyMatrix.probs = torch.nn.Parameter(
    torch.zeros_like(generator.AdjacencyMatrix.probs)
)

In [20]:
gen_dict = generator()
gen_ex = mutag_helper.egg_to_ex(gen_dict)
model(gen_ex)

tensor([[ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986],
        [ 0.1681, -0.3986]], grad_fn=<AddmmBackward0>)

In [1]:
import torch 

test = torch.tensor([2., 2., 1.])
test2 = torch.ones_like(test) * test.mode().values
test2

tensor([2., 2., 2.])