In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# For reproducibility
# torch.manual_seed(42)
# np.random.seed(42)


In [None]:
import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


In [None]:
from torch_geometric.utils import to_networkx
import matplotlib.colors as colors
import matplotlib.cm as cmx
import networkx as nx
import matplotlib.pyplot as plt
import networkx as nx
def plotmutag(data):
  cmap = colors.ListedColormap(['blue', 'black','red','yellow','orange','green','purple'])
  ColorLegend = {'Carbon': 0,'Nitrogen': 1,'Oxygen': 2,'Fluorine': 3,'Iodine':4,'Chlorine':5,'Bromine':6}
  cNorm  = colors.Normalize(vmin=0, vmax=6)
  scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
  #print(cmap.colors)

  exampledata=data
  exfeatures=exampledata.x
  #exlabel=exampledata.y
  examplelabels=torch.argmax(exfeatures,dim=1)
  #print(exlabel)
  examplegraph=to_networkx(exampledata,to_undirected=True)
  f = plt.figure(2,figsize=(8,8))
  ax = f.add_subplot(1,1,1)
  for label in ColorLegend:
      ax.plot([0],[0],color=scalarMap.to_rgba(ColorLegend[label]),label=label)
  nx.draw_networkx(examplegraph, node_size=150,node_color=examplelabels,cmap=cmap,vmin=0,vmax=6,with_labels=False,ax=ax)
  plt.legend(fontsize=12,loc='best')
  plt.show()



In [None]:
# Split dataset into train and test (e.g., 80% train, 20% test)
num_train = int(0.8 * len(dataset))
num_test = len(dataset) - num_train
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [num_train, num_test])

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCNEncoder(torch.nn.Module):
    def __init__(self, inputdim, hidden_channels):
        super(GCNEncoder, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(inputdim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.bn = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(0.5)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x, edge_index, batch, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_weight=edge_weight)

        node_embeddings = x  # Save node-level embeddings
        graph_embedding = global_mean_pool(node_embeddings, batch)
        graph_embedding = self.bn(graph_embedding)
        graph_embedding = F.dropout(graph_embedding, p=0.5, training=self.training)
        return graph_embedding, node_embeddings

class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

class CombinedModel(torch.nn.Module):
    def __init__(self, inputdim, hidden_channels, num_classes):
        super(CombinedModel, self).__init__()
        self.encoder = GCNEncoder(inputdim, hidden_channels)
        self.classifier = LinearClassifier(input_dim=hidden_channels, num_classes=num_classes)

    def forward(self, x, edge_index,batch=None,edge_weight=None):
        graph_embedding, node_embeddings = self.encoder(x, edge_index, batch, edge_weight)
        logits = self.classifier(graph_embedding)
        return logits, node_embeddings  # <-- graph_embedding used internally, not returned


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool  # Changed from mean to add

class GCNEncoder(torch.nn.Module):
    def __init__(self, inputdim, hidden_channels):
        super(GCNEncoder, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(inputdim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.leaky_relu = nn.LeakyReLU(0.1)

    def forward(self, x, edge_index, batch, edge_weight=None):
        x1 = self.leaky_relu(self.conv1(x, edge_index, edge_weight))
        x2 = self.leaky_relu(self.conv2(x1, edge_index, edge_weight))
        x3 = self.conv3(x2, edge_index, edge_weight)

        node_embeddings = x3
        graph_embedding = global_add_pool(node_embeddings, batch)  # Changed from mean to add

        return graph_embedding, node_embeddings

class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim)
        self.fc2 = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.3, training=self.training)
        return self.fc2(x)

class CombinedModel(torch.nn.Module):
    def __init__(self, inputdim, hidden_channels, num_classes):
        super(CombinedModel, self).__init__()
        self.encoder = GCNEncoder(inputdim, hidden_channels)
        self.classifier = LinearClassifier(input_dim=hidden_channels, num_classes=num_classes)

    def forward(self, x, edge_index, batch=None, edge_weight=None):
        graph_embedding, node_embeddings = self.encoder(x, edge_index, batch, edge_weight)
        logits = self.classifier(graph_embedding)
        return logits, node_embeddings


In [None]:

num_features=7
inputdim=num_features
model=CombinedModel(inputdim, hidden_channels=64,num_classes=2)


In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Add a learning rate scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# def train():
#     model.train()

#     for epoch in range(num_epochs):
#         for data in train_loader:  # Iterate in batches over the training dataset.
#             embedding,  out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
#             #print(out)
#             loss = criterion(out, data.y)  # Compute the loss.
#             loss.backward()  # Derive gradients.
#             optimizer.step()  # Update parameters based on gradients.
#             optimizer.zero_grad()  # Clear gradients.

#         # Update the learning rate scheduler
#         scheduler.step()

#         # Print the current learning rate every epoch (optional)
#         print(f"Epoch {epoch + 1}/{num_epochs}, Learning Rate: {scheduler.get_last_lr()[0]}",loss)
#         # train_acc = test(train_loader)
#         # test_acc = test(test_loader)
#         # print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
def train():
    model.train()

    for epoch in range(num_epochs):
        for data in train_loader:  # Iterate in batches over the training dataset.
            # Forward pass
            x, edge_index, batch = data.x, data.edge_index, data.batch
            out,embedding = model(x, edge_index, batch)

            # Compute the loss
            loss = criterion(out, data.y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Update the learning rate scheduler
        scheduler.step()

        # Print the current learning rate and loss every epoch
        print(f"Epoch {epoch + 1}/{num_epochs}, Learning Rate: {scheduler.get_last_lr()[0]}, Loss: {loss.item()}")

# Set the number of epochs
num_epochs = 700

# Call the training loop
train()


# # Set the number of epochs
# num_epochs = 800

# # Call the training loop
# train()



In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(model, dataset, class_dict):
    """
    Evaluate the model on the provided dataset, compute the confusion matrix,
    and plot it with class names.

    Parameters:
    - model: Trained GNN model
    - dataset: List of data objects
    - class_dict: Dictionary mapping class labels to class names, e.g., {0: 'Class A', 1: 'Class B'}
    """

    # Step 1: Evaluate the model and get predictions and true labels
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in dataset:
            out,_ = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            all_preds.append(pred.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Step 2: Compute the confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # Step 3: Plot the confusion matrix
    class_names = [class_dict[i] for i in range(len(class_dict))]

    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

# Example usage:
# Assuming the class labels are {0: 'Mutagenic', 1: 'Non-Mutagenic'}
#class_dict = {0: 'Mutagenic', 1: 'Non-Mutagenic'}

# Example dataset (assuming it's a list of data objects)
# dataset = [...]

# Call the function with the model, dataset (as a list), and class dictionary
#plot_confusion_matrix(model, dataset, class_dict)


In [None]:
class_dict={0:'Non-Mutagenic',1:'Mutagenic'}
plot_confusion_matrix(model,dataset,class_dict)

In [None]:
torch.save(model.state_dict(), 'mutag_classifier.pt')
print("Model weights saved.")


In [None]:
# Reconstruct the model architecture manually
# from your_module import GCNClassifierWithEmbeddings

#model = GINClassifierWithEmbeddings(in_channels=1, hidden_channels=32, num_classes=2)
model.load_state_dict(torch.load('mutag_classifier.pt'))
#model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()

print("Model weights loaded and ready for inference.")


In [None]:


import torch
import torch.nn as nn
from src.lap_solvers.hungarian import hungarian
from src.lap_solvers.sinkhorn import Sinkhorn
from itertools import product
from src.spectral_clustering import spectral_clustering
from src.utils.pad_tensor import pad_tensor

import time



class Timer:
    def __init__(self):
        self.start_time = 0
    def tic(self):
        self.start_time = time.time()
    def toc(self, str=""):
        print_helper('{:.5f}sec {}'.format(time.time()-self.start_time, str))

DEBUG=False

def print_helper(*args):
    if DEBUG:
        print(*args)


class GA_GM(nn.Module):
    """
    Graduated Assignment solver for
     Graph Matching, Multi-Graph Matching and Multi-Graph Matching with a Mixture of Modes.

    This operation does not support batched input, and all input tensors should not have the first batch dimension.

    Parameter: maximum iteration mgm_iter
               sinkhorn iteration sk_iter
               initial sinkhorn regularization sk_tau0
               sinkhorn regularization decaying factor sk_gamma
               minimum tau value min_tau
               convergence tolerance conv_tal
    Input: multi-graph similarity matrix W
           initial multi-matching matrix U0
           number of nodes in each graph ms
           size of universe n_univ
           (optional) projector to doubly-stochastic matrix (sinkhorn) or permutation matrix (hungarian)
    Output: multi-matching matrix U
    """
    def __init__(self, mgm_iter=(200,), cluster_iter=10, sk_iter=20, sk_tau0=(0.5,), sk_gamma=0.5, cluster_beta=(1., 0.), converge_tol=1e-5, min_tau=(1e-2,), projector0=('sinkhorn',)):
        super(GA_GM, self).__init__()
        self.mgm_iter = mgm_iter
        self.cluster_iter = cluster_iter
        self.sk_iter = sk_iter
        self.sk_tau0 = sk_tau0
        self.sk_gamma = sk_gamma
        self.cluster_beta = cluster_beta
        self.converge_tol = converge_tol
        self.min_tau = min_tau
        self.projector0 = projector0

    def forward(self, A, W, U0, ms, n_univ, quad_weight=1., cluster_quad_weight=1., num_clusters=2):
        # gradient is not required for MGM module
        W = W.detach()

        num_graphs = ms.shape[0]
        U = U0
        m_indices = torch.cumsum(ms, dim=0)

        Us = []
        clusters = []

        # initialize U with no clusters
        cluster_M = torch.ones(num_graphs, num_graphs, device=A.device)
        cluster_M01 = cluster_M

        U = self.gagm(A, W, U, ms, n_univ, cluster_M, self.sk_tau0[0], self.min_tau[0], self.mgm_iter[0], self.projector0[0],
                      quad_weight=quad_weight, hung_iter=(num_clusters == 1))
        Us.append(U)

        # MGM problem
        if num_clusters == 1:
            return U, torch.zeros(num_graphs, dtype=torch.int)

        for beta, sk_tau0, min_tau, max_iter, projector0 in \
                zip(self.cluster_beta, self.sk_tau0, self.min_tau, self.mgm_iter, self.projector0):
            for i in range(self.cluster_iter):
                lastU = U

                # clustering step
                def get_alpha(scale=1., qw=1.):
                    Alpha = torch.zeros(num_graphs, num_graphs, device=A.device)
                    for idx1, idx2 in product(range(num_graphs), repeat=2):
                        if idx1 == idx2:
                            continue
                        start_x = m_indices[idx1 - 1] if idx1 != 0 else 0
                        end_x = m_indices[idx1]
                        start_y = m_indices[idx2 - 1] if idx2 != 0 else 0
                        end_y = m_indices[idx2]
                        A_i = A[start_x:end_x, start_x:end_x]
                        A_j = A[start_y:end_y, start_y:end_y]
                        W_ij = W[start_x:end_x, start_y:end_y]
                        U_i = U[start_x:end_x, :]
                        U_j = U[start_y:end_y, :]
                        X_ij = torch.mm(U_i, U_j.t())
                        Alpha_ij = torch.sum(W_ij * X_ij) \
                                   + torch.exp(-torch.norm(torch.chain_matmul(X_ij.t(), A_i, X_ij) - A_j) / scale) * qw
                        Alpha[idx1, idx2] = Alpha_ij
                    return Alpha
                Alpha = get_alpha(qw=cluster_quad_weight)

                last_cluster_M01 = cluster_M01
                cluster_v = spectral_clustering(Alpha, num_clusters, normalized=True)
                cluster_M01 = (cluster_v.unsqueeze(0) == cluster_v.unsqueeze(1)).to(dtype=Alpha.dtype)
                cluster_M = (1 - beta) * cluster_M01 + beta

                if beta == self.cluster_beta[0] and i == 0:
                    clusters.append(cluster_v)

                # matching step
                U = self.gagm(A, W, U, ms, n_univ, cluster_M, sk_tau0, min_tau, max_iter,
                              projector='hungarian' if i != 0 else projector0, quad_weight=quad_weight,
                              hung_iter=(beta == self.cluster_beta[-1]))

                print_helper('beta = {:.2f}, delta U = {:.4f}, delta M = {:.4f}'.format(beta, torch.norm(lastU - U), torch.norm(last_cluster_M01 - cluster_M01)))

                Us.append(U)
                clusters.append(cluster_v)

                if beta == 1:
                    break

                if torch.norm(lastU - U) < self.converge_tol and torch.norm(last_cluster_M01 - cluster_M01) < self.converge_tol:
                    break

        #return Us, clusters
        return  U, cluster_v

    def gagm(self, A, W, U0, ms, n_univ, cluster_M, init_tau, min_tau, max_iter, projector='sinkhorn', hung_iter=True, quad_weight=1.):
        num_graphs = ms.shape[0]
        U = U0
        m_indices = torch.cumsum(ms, dim=0)

        lastU = torch.zeros_like(U)

        sinkhorn_tau = init_tau
        #beta = 0.9
        iter_flag = True

        while iter_flag:
            for i in range(max_iter):
                lastU2 = lastU
                lastU = U

                # compact matrix form update of V
                UUt = torch.mm(U, U.t())
                cluster_weight = torch.repeat_interleave(cluster_M, ms.to(dtype=torch.long), dim=0)
                cluster_weight = torch.repeat_interleave(cluster_weight, ms.to(dtype=torch.long), dim=1)
                V = torch.chain_matmul(A, UUt * cluster_weight, A, U) * quad_weight * 2 + torch.mm(W * cluster_weight, U)
                V /= num_graphs

                U_list = []
                if projector == 'hungarian':
                    m_start = 0
                    for m_end in m_indices:
                        U_list.append(hungarian(V[m_start:m_end, :n_univ]))
                        m_start = m_end
                elif projector == 'sinkhorn':
                    if torch.all(ms == ms[0]):
                        if ms[0] <= n_univ:
                            U_list.append(
                                Sinkhorn(max_iter=self.sk_iter, tau=sinkhorn_tau, batched_operation=True) \
                                    (V.reshape(num_graphs, -1, n_univ), dummy_row=True).reshape(-1, n_univ))
                        else:
                            U_list.append(
                                Sinkhorn(max_iter=self.sk_iter, tau=sinkhorn_tau, batched_operation=True) \
                                    (V.reshape(num_graphs, -1, n_univ).transpose(1, 2), dummy_row=True).transpose(1, 2).reshape(-1, n_univ))
                    else:
                        V_list = []
                        n1 = []
                        m_start = 0
                        for m_end in m_indices:
                            V_list.append(V[m_start:m_end, :n_univ])
                            n1.append(m_end - m_start)
                            m_start = m_end
                        n1 = torch.tensor(n1)
                        U = Sinkhorn(max_iter=self.sk_iter, tau=sinkhorn_tau, batched_operation=True) \
                            (torch.stack(pad_tensor(V_list), dim=0), n1, dummy_row=True)
                        m_start = 0
                        for idx, m_end in enumerate(m_indices):
                            U_list.append(U[idx, :m_end - m_start, :])
                            m_start = m_end
                else:
                    raise NameError('Unknown projecter name: {}'.format(projector))

                U = torch.cat(U_list, dim=0)
                if num_graphs == 2:
                    U[:ms[0], :] = torch.eye(ms[0], n_univ, device=U.device)

                if torch.norm(U - lastU) < self.converge_tol or torch.norm(U - lastU2) == 0:
                    break

            if i == max_iter - 1: # not converged
                if hung_iter:
                    pass
                else:
                    U_list = [hungarian(_) for _ in U_list]
                    U = torch.cat(U_list, dim=0)
                    print_helper(i, 'max iter')
                    break

            # projection control
            if projector == 'hungarian':
                print_helper(i, 'hungarian')
                break
            elif sinkhorn_tau > min_tau:
                print_helper(i, sinkhorn_tau)
                sinkhorn_tau *= self.sk_gamma
            else:
                print_helper(i, sinkhorn_tau)
                if hung_iter:
                    projector = 'hungarian'
                else:
                    U_list = [hungarian(_) for _ in U_list]
                    U = torch.cat(U_list, dim=0)
                    break

        return U


class HiPPI(nn.Module):
    """
    HiPPI solver for multiple graph matching: Higher-order Projected Power Iteration in ICCV 2019

    This operation does not support batched input, and all input tensors should not have the first batch dimension.

    Parameter: maximum iteration mgm_iter
               sinkhorn iteration sk_iter
               sinkhorn regularization sk_tau
    Input: multi-graph similarity matrix W
           initial multi-matching matrix U0
           number of nodes in each graph ms
           size of universe d
           (optional) projector to doubly-stochastic matrix (sinkhorn) or permutation matrix (hungarian)
    Output: multi-matching matrix U
    """
    def __init__(self, max_iter=50, sk_iter=20, sk_tau=1/200.):
        super(HiPPI, self).__init__()
        self.max_iter = max_iter
        self.sinkhorn = Sinkhorn(max_iter=sk_iter, tau=sk_tau)
        self.hungarian = hungarian

    def forward(self, W, U0, ms, d, projector='sinkhorn'):
        num_graphs = ms.shape[0]

        U = U0
        for i in range(self.max_iter):
            lastU = U
            WU = torch.mm(W, U) #/ num_graphs
            V = torch.chain_matmul(WU, U.t(), WU) #/ num_graphs ** 2

            #V_median = torch.median(torch.flatten(V, start_dim=-2), dim=-1).values
            #V_var, V_mean = torch.var_mean(torch.flatten(V, start_dim=-2), dim=-1)
            #V = V - V_mean
            #V = V / torch.sqrt(V_var)

            #V = V / V_median

            U = []
            m_start = 0
            m_indices = torch.cumsum(ms, dim=0)
            for m_end in m_indices:
                if projector == 'sinkhorn':
                    U.append(self.sinkhorn(V[m_start:m_end, :d], dummy_row=True))
                elif projector == 'hungarian':
                    U.append(self.hungarian(V[m_start:m_end, :d]))
                else:
                    raise NameError('Unknown projector {}.'.format(projector))
                m_start = m_end
            U = torch.cat(U, dim=0)

            #print_helper('iter={}, diff={}, var={}, vmean={}, vvar={}'.format(i, torch.norm(U-lastU), torch.var(torch.sum(U, dim=0)), V_mean, V_var))

            if torch.norm(U - lastU) < 1e-5:
                print_helper(i)
                break

        return U

In [None]:
import torch
from torch_geometric.utils import to_dense_adj

# Define target class (e.g., class 1)
target_class = 0

# Filter the dataset for graphs that are classified as the target class.
selected_data = []
for data in dataset:
    # Run the classifier on each graph; model is assumed to be on CPU.
    with torch.no_grad():
        model = model.to('cpu')
        out, _ = model(data.x,data.edge_index)
        pred = out.argmax(dim=1).item()  # For a single graph, out is shape [1, num_classes]
    if pred == target_class:
        selected_data.append(data)

print("Number of graphs classified as target class:", len(selected_data))

# --- Prepare global inputs for GA_GM on the selected graphs ---

# 1. Compute node counts for each selected graph
ms_sel = torch.tensor([data.num_nodes for data in selected_data], dtype=torch.long)

# 2. Build a list of dense (binary) adjacency matrices for each selected graph
adj_list_sel = []
for data in selected_data:
    A_dense = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes)[0]
    A_dense = (A_dense > 0).float()  # Convert to binary adjacency
    adj_list_sel.append(A_dense)

# 3. Build a global block-diagonal adjacency matrix A_sel
A_sel = torch.block_diag(*adj_list_sel)  # Shape: (total_nodes, total_nodes)

# 4. Compute node embeddings for each selected graph using the trained GNN classifier.
# Here we run the model to extract node embeddings.
all_embeddings_sel = []
for data in selected_data:
    with torch.no_grad():
        _, node_emb = model(data.x,data.edge_index,None)
    all_embeddings_sel.append(node_emb)
all_x_sel = torch.cat(all_embeddings_sel, dim=0)  # Shape: (total_nodes, hidden_channels)

# 5. Compute the global node similarity matrix W_sel using the inner product of the node embeddings.
W_sel = torch.mm(all_x_sel, all_x_sel.t())  # Shape: (total_nodes, total_nodes)

# 6. Set universe size n_univ_sel. Here, we choose the maximum number of nodes among the selected graphs.
n_univ_sel = 100#int(ms_sel.max().item())
print("n_univ_sel:", n_univ_sel)
total_nodes_sel = int(ms_sel.sum().item())
# Alternatively, you could use total_nodes_sel if you prefer a larger universe:
# n_univ_sel = total_nodes_sel

# 7. Initialize U0_sel: matching matrix of shape (total_nodes_sel, n_univ_sel)
U0_sel = (1.0 / n_univ_sel) * torch.ones(total_nodes_sel, n_univ_sel) + 1e-3 * torch.randn(total_nodes_sel, n_univ_sel)

print("Shape of all_x_sel",all_x_sel.shape)
print("Total nodes in selected graphs:", total_nodes_sel)
print("Global A_sel shape:", A_sel.shape)
print("Global W_sel shape:", W_sel.shape)
print("Initial U0_sel shape:", U0_sel.shape)
print("ms_sel:", ms_sel)

# --- Run the GA_GM solver on the selected graphs ---
# For demonstration, we set the number of clusters to 2 (MGMC).
num_clusters = 1

# Instantiate the GA_GM solver (assumed to be already imported from src.lap_solvers)
ga_gm_solver = GA_GM()  # Runs on CPU by default

# Run the forward pass
U_final_sel, clusters_sel = ga_gm_solver(
    A_sel,          # Global block-diagonal adjacency matrix
    W_sel,          # Global node similarity matrix computed from GNN embeddings
    U0_sel,         # Initial matching matrix
    ms_sel,         # Node counts per graph
    n_univ_sel,     # Universe size (columns in U0_sel)
    quad_weight=1.0,
    cluster_quad_weight=1.0,
    num_clusters=num_clusters  # >1 enables clustering
)

print("Final matching matrix U_final_sel shape:", U_final_sel.shape)
print("Final clustering vector (clusters_sel):", clusters_sel)


In [None]:
def find_topk_graphs(model, dataset, target_class, k=5):
    """
    Given a classifier and a target class, return the top-k graphs from the dataset
    that the classifier assigns the highest confidence score for that class.

    Returns:
        topk_graphs: List of PyG Data objects
        topk_indices: List of indices of top-k graphs in the original dataset
        topk_scores: List of class scores assigned by the model
    """
    model.eval()
    model = model.to('cpu')

    scores = []
    for i, data in enumerate(dataset):
        with torch.no_grad():
            out, _ = model(data.x,data.edge_index,None)
            prob = F.softmax(out, dim=1)[0, target_class].item()
        scores.append((prob, data, i))

    # Sort the list by score in descending order and pick top-k
    topk = sorted(scores, key=lambda x: x[0], reverse=True)[:k]

    # Unpack top-k results
    topk_scores = [entry[0] for entry in topk]
    topk_graphs = [entry[1] for entry in topk]
    topk_indices = [entry[2] for entry in topk]

    for rank, (idx, score) in enumerate(zip(topk_indices, topk_scores)):
        print(f"Graph Rank {rank+1}: Index = {idx}, Class Score = {score:.4f}")

    return topk_graphs, topk_indices, topk_scores


In [None]:
topk_graphs, topk_indices, topk_scores = find_topk_graphs(model, selected_data, target_class, k=10)


In [None]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import subgraph

class SharedGraphExplainer(nn.Module):
    def __init__(self, in_channels, hidden_channels=32, temp_start=5.0, temp_end=0.1, epochs=300):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)
        self.temp_start = temp_start
        self.temp_end = temp_end
        self.epochs = epochs
        self.current_epoch = 0

    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        logits = self.lin(h).squeeze(-1)  # Shape: (num_nodes,)
        return logits

    def sample_mask(self, logits):
        temp = self.get_current_temp()
        eps = 1e-20
        uniform_noise = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform_noise + eps) + eps)
        y = logits + gumbel_noise
        mask = torch.sigmoid(y / temp)
        return mask

    def get_current_temp(self):
        return self.temp_start * (self.temp_end / self.temp_start) ** (self.current_epoch / self.epochs)





In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data


def train_shared_explainer(model, topk_graphs, topk_indices, U_final_sel, ms_sel, n_univ_sel,
                           target_class, topk_univ_k=5, epochs=300, lr=0.01,lambda_cls=1.0, lambda_align=1.0,
                           lambda_sparsity=0.001, lambda_entropy=0.001,lambda_budget=1.0,budget=10):

    device = torch.device('cpu')
    model = model.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False  # Freeze classifier

    in_channels = topk_graphs[0].x.size(1)
    explainer = SharedGraphExplainer(in_channels).to(device)
    optimizer = torch.optim.Adam(explainer.parameters(), lr=lr)

    graph_offsets = torch.cumsum(torch.cat([torch.tensor([0]), ms_sel]), dim=0)

    for epoch in range(epochs):

        explainer.current_epoch=epoch
        total_loss = 0.0

        for i, data in enumerate(topk_graphs):
            data = data.to(device)
            optimizer.zero_grad()

            # === Node mask ===
            logits = explainer(data.x, data.edge_index)
            mask = explainer.sample_mask(logits)
            print(f"Mask value of graph {i} is", mask)

            # === Feature & edge masking ===
            masked_x = data.x * mask.unsqueeze(-1)
            src, dst = data.edge_index
            edge_mask = mask[src] * mask[dst]  # soft edge weights

            out, _ = model(masked_x, data.edge_index, data.batch, edge_weight=edge_mask)
            class_score = out[0, target_class]
            probs = F.softmax(out, dim=1)
            print(f"Class Score achieved is for graph {i} is", probs)
            loss_cls = -class_score

            # === Sparsity regularization ===
            loss_sparsity = mask.mean()  # Encourage fewer nodes
            # ==== Budget Regularization === #
            loss_budget= F.relu(mask.sum() - budget)

            # === Entropy regularization ===
            mask_clipped = torch.clamp(mask, min=1e-6, max=1 - 1e-6)
            loss_entropy = - (mask_clipped * torch.log(mask_clipped) + (1 - mask_clipped) * torch.log(1 - mask_clipped)).mean()

            # === Alignment loss with transferred mask ===
            start_i = graph_offsets[topk_indices[i]]
            end_i = graph_offsets[topk_indices[i] + 1]
            U_i = U_final_sel[start_i:end_i]  # shape: [n_i, n_univ]
            univ_mask = torch.matmul(U_i.T, mask)  # shape: [n_univ]

            loss_align = 0.0
            for j in range(len(topk_graphs)):
                if i == j:
                    continue
                start_j = graph_offsets[topk_indices[j]]
                end_j = graph_offsets[topk_indices[j] + 1]
                U_j = U_final_sel[start_j:end_j]  # shape: [n_j, n_univ]

                # Transferred node mask to graph j
                mask_j_transferred = torch.matmul(U_j, univ_mask)  # shape: [n_j]
                data_j = topk_graphs[j].to(device)

                masked_x_j = data_j.x * mask_j_transferred.unsqueeze(-1)
                src_j, dst_j = data_j.edge_index
                edge_mask_j = mask_j_transferred[src_j] * mask_j_transferred[dst_j]

                out_j, _ = model(masked_x_j, data_j.edge_index, data_j.batch, edge_weight=edge_mask_j)
                score_j = out_j[0, target_class]
                loss_align -= score_j

            # === Total loss ===
            loss = lambda_cls*loss_cls + lambda_align * loss_align + lambda_sparsity * loss_sparsity + lambda_entropy * loss_entropy+lambda_budget * loss_budget
            loss.backward()

            #=== Diagnostics ===
            # print(f"[Epoch {epoch:03d}] Graph {i} | Mask mean: {mask.mean().item():.4f} | "
            #       f"Class score: {class_score.item():.4f} | Loss_cls: {loss_cls.item():.4f} | "
            #       f"Loss_sparsity: {loss_sparsity.item():.4f} | Loss_entropy: {loss_entropy.item():.4f}")
            # for name, param in explainer.named_parameters():
            #     if param.grad is not None:
            #         print(f" → {name}: grad norm = {param.grad.norm().item():.6f}")
            #     else:
            #         print(f" → {name}: ❌ NO GRADIENT")

            optimizer.step()
            total_loss += loss.item()

        if epoch % 20 == 0:
            print(f"[Epoch {epoch:03d}] Total Loss: {total_loss:.4f}")

    return explainer


In [None]:
print(target_class)
print(topk_indices)

In [None]:

explainer = train_shared_explainer(
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    n_univ_sel=n_univ_sel,
    target_class=target_class,
    topk_univ_k=10,         # size of subgraph
    epochs=130,
    lr=0.001,
    lambda_cls=2.5,
    lambda_align=2.6,
    lambda_entropy=0.2,
    lambda_sparsity=4,
    lambda_budget=10.0,
    budget=8      # alignment weight
)

In [None]:
import torch
import torch.nn.functional as F
import pickle

def evaluate_explanation_generalizability_thresholded(
    explainer, model, topk_graphs, topk_indices, U_final_sel, ms_sel,
    selected_data, selected_indices, score_filter_threshold=0.8, target_class=1,
    save_path="explanation_generalization_results_filtered.pkl"):

    device = torch.device('cpu')
    model = model.to(device).eval()
    explainer = explainer.to(device).eval()

    graph_offsets = torch.cumsum(torch.cat([torch.tensor([0]), ms_sel]), dim=0)
    results = []

    for i, ref_graph in enumerate(topk_graphs):
        ref_graph = ref_graph.to(device)

        # Get explanation mask from the reference graph
        logits_ref = explainer(ref_graph.x, ref_graph.edge_index)
        mask_ref = explainer.sample_mask(logits_ref).detach()

        # Project to universal mask
        start_i = graph_offsets[topk_indices[i]]
        end_i = graph_offsets[topk_indices[i] + 1]
        U_i = U_final_sel[start_i:end_i]
        univ_mask = torch.matmul(U_i.T, mask_ref)

        high_score_count = 0
        filtered_total = 0
        deltas = []
        masked_scores = []
        original_scores = []

        for j, data_j in enumerate(selected_data):
            data_j = data_j.to(device)

            start_j = graph_offsets[selected_indices[j]]
            end_j = graph_offsets[selected_indices[j] + 1]
            U_j = U_final_sel[start_j:end_j]

            # Original class score
            with torch.no_grad():
                out_orig, _ = model(data_j.x, data_j.edge_index, data_j.batch)
                orig_score = F.softmax(out_orig, dim=1)[0, target_class].item()

            if orig_score < score_filter_threshold:
                continue  # Skip this graph if original confidence is too low

            # Transfer mask
            mask_j = torch.matmul(U_j, univ_mask)

            # Masked class score
            masked_x_j = data_j.x * mask_j.unsqueeze(-1)
            src_j, dst_j = data_j.edge_index
            edge_mask_j = mask_j[src_j] * mask_j[dst_j]

            out_masked, _ = model(masked_x_j, data_j.edge_index, data_j.batch, edge_weight=edge_mask_j)
            masked_score = F.softmax(out_masked, dim=1)[0, target_class].item()

            delta = masked_score - orig_score
            deltas.append(delta)
            masked_scores.append(masked_score)
            original_scores.append(orig_score)

            if masked_score >= score_filter_threshold:
                high_score_count += 1

            filtered_total += 1

        print(f"Explanation {i}:")
        print(f" → Filtered evaluation on {filtered_total} graphs with original score > {score_filter_threshold}")
        print(f" → High masked score count (>{score_filter_threshold}): {high_score_count}/{filtered_total}")

        results.append({
            'ref_index': i,
            'high_score_count': high_score_count,
            'high_score_ratio': high_score_count / filtered_total if filtered_total > 0 else 0.0,
            'deltas': deltas,
            'masked_scores': masked_scores,
            'original_scores': original_scores,
            'univ_mask': univ_mask.detach().cpu()
        })

    # Save all results
    with open(save_path, "wb") as f:
        pickle.dump(results, f)

    print(f"\n✅ Filtered generalization results saved to {save_path}")
    return results


In [None]:
results = evaluate_explanation_generalizability_thresholded(
    explainer=explainer,
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    selected_data=selected_data,
    selected_indices=list(range(len(selected_data))),  # assuming sequential match
    score_filter_threshold=0.8,
    target_class=target_class,
    save_path="generalization_results_filtered.pkl"
)


In [None]:
import torch

# Load the results from the saved file if not already in memory
import pickle
with open("generalization_results_filtered.pkl", "rb") as f:
    results = pickle.load(f)

# Find the result with the highest high_score_count
best_result = max(results, key=lambda x: x['high_score_count'])
best_ref_index = best_result['ref_index']

# Retrieve the corresponding data object
best_graph = topk_graphs[best_ref_index]  # This is the graph that generated the most transferable explanation

print(f"Best explanation is from graph index: {best_ref_index}")
print(f"High score count: {best_result['high_score_count']}")

# Now `best_graph` is your Data object


In [None]:
import torch
from torch_geometric.utils import subgraph
import pickle

# === Step 1: Load results if not already in memory ===
with open("generalization_results_filtered.pkl", "rb") as f:
    results = pickle.load(f)

# === Step 2: Find the graph index with best generalization ===
best_result = max(results, key=lambda x: x['high_score_count'])
best_ref_index = best_result['ref_index']
print(f"Best explanation from graph index {best_ref_index} with high score count = {best_result['high_score_count']}")

# === Step 3: Get the corresponding Data object ===
best_graph = topk_graphs[best_ref_index].to('cpu')

# === Step 4: Get the corresponding node mask ===
with torch.no_grad():
    logits = explainer(best_graph.x, best_graph.edge_index)
    node_mask = explainer.sample_mask(logits).detach().cpu()

# === Step 5: Threshold the mask to get node indices to keep ===
threshold = 0.5  # You can change this
selected_nodes = torch.where(node_mask > threshold)[0]
print(f"Number of selected nodes (mask > {threshold}): {len(selected_nodes)}")

# === Step 6: Extract the subgraph ===
sub_edge_index, _ = subgraph(
    subset=selected_nodes,
    edge_index=best_graph.edge_index,
    relabel_nodes=True
)

# === Step 7: Create new subgraph object (optional) ===
from torch_geometric.data import Data
subgraph_data = Data(
    x=best_graph.x[selected_nodes],
    edge_index=sub_edge_index
)
plotmutag(subgraph_data)

# Now subgraph_data contains your masked subgraph


In [None]:
import torch
from torch_geometric.utils import subgraph
import pickle
from torch_geometric.data import Data

# === Step 1: Load results if not already in memory ===
with open("generalization_results_filtered.pkl", "rb") as f:
    results = pickle.load(f)

# === Step 2: Get top 3 graphs with highest high_score_count ===
top_results = sorted(results, key=lambda x: x['high_score_count'], reverse=True)[:3]

threshold = 0.5  # Change if needed

for rank, result in enumerate(top_results, 1):
    ref_index = result['ref_index']
    print(f"\n[Rank {rank}] Graph index {ref_index} with high score count = {result['high_score_count']}")

    # === Get the corresponding Data object ===
    graph = topk_graphs[ref_index].to('cpu')

    # === Get the corresponding node mask ===
    with torch.no_grad():
        logits = explainer(graph.x, graph.edge_index)
        node_mask = explainer.sample_mask(logits).detach().cpu()

    # === Threshold the mask ===
    selected_nodes = torch.where(node_mask > threshold)[0]
    print(f" → Number of selected nodes (mask > {threshold}): {len(selected_nodes)}")

    # === Extract the subgraph ===
    sub_edge_index, _ = subgraph(
        subset=selected_nodes,
        edge_index=graph.edge_index,
        relabel_nodes=True
    )

    # === Construct and plot subgraph ===
    subgraph_data = Data(
        x=graph.x[selected_nodes],
        edge_index=sub_edge_index
    )

    print(f" → Plotting subgraph for Rank {rank}")
    plotmutag(subgraph_data)


In [None]:
import torch
from torch_geometric.datasets import TUDataset
from collections import Counter
import matplotlib.pyplot as plt

# Load MUTAG dataset
#ataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

# Placeholder: known halogen indices in MUTAG's one-hot encoding
# These are dataset-specific; for MUTAG, the common atom types are:
# C (carbon), N (nitrogen), O (oxygen), F (fluorine), I (iodine), Cl (chlorine), Br (bromine)
# You may need to print dataset[0].x to identify indices exactly
halogen_indices = [3, 4, 5, 6]  # Example: F, I, Cl, Br — adjust if needed!

# Counters
halogen_counts = {0: 0, 1: 0}  # class -> total halogen atoms
compound_counts = {0: 0, 1: 0}  # class -> number of compounds

# Iterate over graphs
for data in dataset:
    label = int(data.y.item())
    compound_counts[label] += 1

    # Count halogen atoms in this graph
    halogen_atom_mask = data.x[:, halogen_indices].sum(dim=1) > 0
    halogen_count = halogen_atom_mask.sum().item()
    halogen_counts[label] += halogen_count

# Compute average halogen atoms per compound
avg_halogen_per_compound = {
    cls: halogen_counts[cls] / compound_counts[cls] for cls in halogen_counts
}

# Display results
print("Total compounds:", compound_counts)
print("Total halogen atoms:", halogen_counts)
print("Average halogen atoms per compound:")
for cls, avg in avg_halogen_per_compound.items():
    label = "Non-Mutagenic" if cls == 0 else "Mutagenic"
    print(f"  {label}: {avg:.2f}")

# Optional: plot
plt.bar(["Non-Mutagenic", "Mutagenic"],
        [avg_halogen_per_compound[0], avg_halogen_per_compound[1]],
        color=["green", "red"])
plt.ylabel("Avg. Halogen Atoms per Compound")
plt.title("Halogen Frequency in MUTAG Dataset")
plt.show()


In [None]:
explainer = train_shared_explainer(
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    n_univ_sel=n_univ_sel,
    target_class=target_class,
    topk_univ_k=10,         # size of subgraph
    epochs=130,
    lr=0.001,
    lambda_cls=1.0,
    lambda_align=2.6,
    lambda_entropy=0.2,
    lambda_sparsity=4,
    lambda_budget=8.0,
    budget=8      # alignment weight
)

In [None]:
import torch
import torch.nn.functional as F
import pickle
import matplotlib.pyplot as plt
import numpy as np
import os

# === TRAINING AND EVALUATION ===



# === Shared Training Args ===
shared_kwargs = dict(
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    n_univ_sel=n_univ_sel,
    target_class=target_class,
    topk_univ_k=10,
    epochs=130,
    lr=0.001,
    lambda_cls=1.0,
    
    lambda_entropy=0.2,
    lambda_sparsity=4,
    lambda_budget=10.0,
    budget=8,
)

# === Train WITH generalization loss ===
print("🔁 Training WITH generalization loss...")
explainer_with = train_shared_explainer(
    **shared_kwargs,
    lambda_align=2.6
)

results_with_path = "explanation_generalization_results_with_gen.pkl"
results_with = evaluate_explanation_generalizability_thresholded(
    explainer=explainer_with,
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    selected_data=selected_data,
    selected_indices=list(range(len(selected_data))),
    target_class=target_class,
    score_filter_threshold=0.8,
    save_path=results_with_path
)

# === Train WITHOUT generalization loss ===
print("\n🔁 Training WITHOUT generalization loss...")
explainer_without = train_shared_explainer(
    **shared_kwargs,
    lambda_align=0.0
)

results_without_path = "explanation_generalization_results_without_gen.pkl"
results_without = evaluate_explanation_generalizability_thresholded(
    explainer=explainer_without,
    model=model,
    topk_graphs=topk_graphs,
    topk_indices=topk_indices,
    U_final_sel=U_final_sel,
    ms_sel=ms_sel,
    selected_data=selected_data,
    selected_indices=list(range(len(selected_data))),
    target_class=target_class,
    score_filter_threshold=0.8,
    save_path=results_without_path
)

# === PLOTTING ===

def extract_generalization_scores(results):
    return [r['high_score_ratio'] for r in results]

with_scores = extract_generalization_scores(results_with)
without_scores = extract_generalization_scores(results_without)
indices = list(range(len(with_scores)))

plt.figure(figsize=(10, 6))
plt.plot(indices, with_scores, marker='o', label='With Generalization Loss')
plt.plot(indices, without_scores, marker='x', label='Without Generalization Loss')
plt.xlabel("Explanation Index")
plt.ylabel("Generalization Score")
plt.title("Comparison of Generalization Scores")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("generalization_score_comparison.png", dpi=300)
plt.show()

# === SUMMARY ===
print("\n--- Summary ---")
print(f"Mean Generalization Score (With):     {np.mean(with_scores):.4f}")
print(f"Mean Generalization Score (Without):  {np.mean(without_scores):.4f}")
print("📊 Plot saved as: generalization_score_comparison.png")


In [None]:
# === PLOTTING: Grouped Bar Plot ===

def extract_generalization_scores(results):
    return [r['high_score_ratio'] for r in results]

with_scores = extract_generalization_scores(results_with)
without_scores = extract_generalization_scores(results_without)
indices = list(range(len(with_scores)))

bar_width = 0.35
x = np.arange(len(indices))  # Explanation indices

plt.figure(figsize=(10, 6))
plt.bar(x - bar_width/2, with_scores, bar_width, label='With Generalization Loss')
plt.bar(x + bar_width/2, without_scores, bar_width, label='Without Generalization Loss')
plt.xlabel("Explanation Index")
plt.ylabel("Generalization Score")
plt.title("Generalization Score per Explanation")
plt.xticks(x, [f"{i}" for i in indices])
plt.legend()
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig("generalization_score_comparison_barplot.png", dpi=300)
plt.show()

# === SUMMARY ===
print("\n--- Summary ---")
print(f"Mean Generalization Score (With):     {np.mean(with_scores):.4f}")
print(f"Mean Generalization Score (Without):  {np.mean(without_scores):.4f}")
print("📊 Bar plot saved as: generalization_score_comparison_barplot.png")


In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np

# === Replace these with your actual result file paths ===
with_gen_path = "explanation_generalization_results_with_gen.pkl"
without_gen_path = "explanation_generalization_results_without_gen.pkl"

# === Load generalization scores from saved result files ===
def load_scores(path):
    with open(path, "rb") as f:
        results = pickle.load(f)
    scores = [r["high_score_ratio"] for r in results]
    return scores

# Load both sets of scores
with_scores = load_scores(with_gen_path)
without_scores = load_scores(without_gen_path)

# === Plot Violin ===
plt.figure(figsize=(8, 6))
data = [with_scores, without_scores]

parts = plt.violinplot(data, showmeans=True, showextrema=True, showmedians=False)

# Customize violin appearance
colors = ['#1f77b4', '#ff7f0e']
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i])
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Mean markers
means = [np.mean(with_scores), np.mean(without_scores)]
plt.scatter([1, 2], means, color='black', marker='o', label='Mean')

# Axis labels and styling
plt.xticks([1, 2], ['With Gen Loss', 'Without Gen Loss'])
plt.ylabel("Generalization Score")
plt.title("Distribution of Generalization Scores Across Explanations for Mutagenic Class")
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.legend()
plt.tight_layout()

# Save and show
plt.savefig("violin_generalization_score.png", dpi=300)
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# === Configurable grid of lambda values ===
sparsity_range = [0, 2, 4, 6, 8]
entropy_range = [0.0, 0.1, 0.2, 0.4, 0.6]

# === Store mean class scores ===
scores_grid = np.zeros((len(sparsity_range), len(entropy_range)))

# === Run experiments ===
for i, lambda_sparsity in enumerate(sparsity_range):
    for j, lambda_entropy in enumerate(entropy_range):
        print(f"Training with λ_sparsity={lambda_sparsity}, λ_entropy={lambda_entropy}")

        explainer = train_shared_explainer(
            model=model,
            topk_graphs=topk_graphs,
            topk_indices=topk_indices,
            U_final_sel=U_final_sel,
            ms_sel=ms_sel,
            n_univ_sel=n_univ_sel,
            target_class=target_class,
            topk_univ_k=10,
            epochs=130,
            lr=0.001,
            lambda_cls=2.0,
            lambda_align=2.6,
            lambda_entropy=lambda_entropy,
            lambda_sparsity=lambda_sparsity,
            lambda_budget=8.0,
            budget=8
        )

        # Evaluate: average target class score on the reference topk graphs
        explainer.eval()
        total_score = 0.0
        for g in topk_graphs:
            g = g.to('cpu')
            logits = explainer(g.x, g.edge_index)
            mask = explainer.sample_mask(logits)
            masked_x = g.x * mask.unsqueeze(-1)
            src, dst = g.edge_index
            edge_weight = mask[src] * mask[dst]

            out, _ = model(masked_x, g.edge_index, g.batch, edge_weight=edge_weight)
            score = out.softmax(dim=1)[0, target_class].item()
            total_score += score

        avg_score = total_score / len(topk_graphs)
        scores_grid[i, j] = avg_score

# === 3D Surface Plot ===
X, Y = np.meshgrid(entropy_range, sparsity_range)
Z = scores_grid

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='k')

ax.set_xlabel("λ_entropy")
ax.set_ylabel("λ_sparsity")
ax.set_zlabel("Avg Target Class Score")
ax.set_title("Effect of Sparsity and Entropy on Target Class Score on the Mutagenic Class")
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

plt.tight_layout()
plt.savefig("3d_ablation_entropy_sparsity_target_score.png", dpi=300)
plt.show()


In [None]:
# === 3D Surface Plot ===
X, Y = np.meshgrid(entropy_range, sparsity_range)
Z = scores_grid

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='k')

ax.set_xlabel("λ_entropy")
ax.set_ylabel("λ_sparsity")
ax.set_zlabel("Avg Target Class Score")
ax.set_title("Effect of Sparsity and Entropy on Target Class Score on the Non-Mutagenic Class")

# Set Z axis limits (optional, adjust to your score range)
ax.set_zlim(0, 1.0)  # This flattens the surface vertically

# === Set background panes to white ===
ax.xaxis.pane.set_facecolor((1.0, 1.0, 1.0, 1.0))  # X pane white
ax.yaxis.pane.set_facecolor((1.0, 1.0, 1.0, 1.0))  # Y pane white
ax.zaxis.pane.set_facecolor((1.0, 1.0, 1.0, 1.0))  # Z pane white
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

plt.tight_layout()
plt.savefig("3d_ablation_entropy_sparsity_target_score.png", dpi=300)
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
sns.heatmap(scores_grid, annot=True, fmt=".3f", cmap="rocket_r", 
            xticklabels=entropy_range, yticklabels=sparsity_range, cbar_kws={"label": "Target Class Score"})
plt.xlabel("λ_entropy")
plt.ylabel("λ_sparsity")
plt.title("Target Class Score (Mutagenic Class)")
plt.tight_layout()
plt.savefig("heatmap_target_score_trendy.png", dpi=300)
plt.show()


In [None]:
import pandas as pd

# Convert your score grid into a long-form DataFrame
rows = []
for i, lam_s in enumerate(sparsity_range):
    for j, lam_e in enumerate(entropy_range):
        rows.append({
            'λ_sparsity': lam_s,
            'λ_entropy': lam_e,
            'Target Score': scores_grid[i, j]
        })

df = pd.DataFrame(rows)


In [None]:
from joypy import joyplot
import matplotlib.pyplot as plt

# Optional: Sort values for smoother ridgeline
df_sorted = df.sort_values(by="λ_sparsity")

# Joypy expects "λ_sparsity" as the category (by=...), and "Target Score" as the value
plt.figure(figsize=(10, 6))
joyplot(data=df_sorted, by="λ_sparsity", column="Target Score", 
        colormap=plt.cm.viridis, fade=True, linewidth=1)

plt.title("Ridgeline Plot of Target Class Score Across λ_entropy (Grouped by λ_sparsity)")
plt.xlabel("Target Class Score")
plt.tight_layout()
plt.savefig("ridgeline_target_score.png", dpi=300)
plt.show()
