In [7]:
import torch
import numpy as np

# Dummy data
n_nodes = 4
n_clusters = 2

# Cluster assignment matrix C: each node belongs fully to one cluster (one-hot)
C = torch.tensor([
    [1, 0],  # Node 0 in cluster 0
    [1, 0],  # Node 1 in cluster 0
    [0, 1],  # Node 2 in cluster 1
    [0, 1],  # Node 3 in cluster 1
], dtype=torch.float32)

# Node features matrix X: simple 2D features
X = torch.tensor([
    [1, 0],  # Node 0 feature
    [1, 0],  # Node 1 feature
    [0, 1],  # Node 2 feature
    [0, 1],  # Node 3 feature
], dtype=torch.float32)

# Adjacency matrix for a small graph (4 nodes)
adj = np.array([
    [0, 1, 0, 0],  # Node 0 connected to Node 1
    [1, 0, 0, 0],  # Node 1 connected to Node 0
    [0, 0, 0, 1],  # Node 2 connected to Node 3
    [0, 0, 1, 0],  # Node 3 connected to Node 2
], dtype=np.float32)

edge_index = torch.tensor([
    [0, 1, 2, 3],  # source nodes
    [1, 0, 3, 2]   # target nodes
], dtype=torch.long)

# The expected:
# - Nodes are split evenly into 2 clusters: first 2 nodes in cluster 0, last 2 in cluster 1
# - Adjacency consists of two disconnected components (node 0-1 and node 2-3)
# - Features match the cluster assignment perfectly (cluster 0 nodes have feature [1,0], cluster 1 have [0,1])
# This should give us a high modularity score and low loss.

print("C:", C)
print("X:", X)
print("Adjacency matrix:", adj)


C: tensor([[1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.]])
X: tensor([[1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.]])
Adjacency matrix: [[0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]]


In [8]:
import torch
import torch.nn as nn

class ModularityLoss(nn.Module):
    def __init__(self, n_clusters, initial_alpha=0.5):
        super(ModularityLoss, self).__init__()
        # Make alpha a learnable parameter
        self.n_clusters = n_clusters
        self.alpha = nn.Parameter(torch.tensor(initial_alpha, dtype=torch.float32))

    def forward(self, C, X, adj):
        device = C.device
        n = adj.shape[0]
        k = self.n_clusters
        # Convert adjacency matrix to torch tensor on same device
        if not torch.is_tensor(adj):
            adj_tensor = torch.tensor(adj.toarray(), dtype=torch.float32, device=device)
        else:
            adj_tensor = adj.to(device)
        deg = torch.sum(adj_tensor, dim=1)
        m = torch.sum(deg) / 2
        d_outer = torch.outer(deg, deg) / (2 * m)
        B = adj_tensor - d_outer
        modularity_term = torch.trace(C.t() @ B @ C)
        mod_loss = -(1 / (2 * m)) * modularity_term
        x_norm = X / (X.norm(dim=1, keepdim=True) + 1e-8)
        W = torch.mm(x_norm, x_norm.t()).clamp(min=0)
        s = torch.sum(W, dim=1)
        w = torch.sum(s) / 2
        s_outer = torch.outer(s, s) / (2 * w + 1e-8)
        B_attr = W - s_outer
        attr_mod_term = torch.trace(C.t() @ B_attr @ C)
        attr_loss = -(1 / (2 * w + 1e-8)) * attr_mod_term
        ones = torch.ones((n, 1), device=device)
        cluster_distribution = C.t() @ ones
        collapse_term = torch.norm(cluster_distribution - 1, p=1)
        collapse_reg = (torch.sqrt(torch.tensor(k, device=device, dtype=torch.float32)) / n) * collapse_term
        # Use self.alpha (learnable)
        loss = self.alpha * mod_loss + (1 - self.alpha) * attr_loss + collapse_reg
        return loss


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCNModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_clusters, dropout=0.5, leaky_relu_negative_slope=0.2):
        super(GCNModel, self).__init__()
        self.gcn1 = GCNConv(in_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, n_clusters)
        self.dropout = nn.Dropout(dropout)
        self.leaky_relu = nn.LeakyReLU(leaky_relu_negative_slope)

    def forward(self, x, edge_index):
        x = self.gcn1(x, edge_index)
        x = self.leaky_relu(x)
        x = self.dropout(x)
        x = self.gcn2(x, edge_index)
        C = F.softmax(x, dim=1)
        return C

In [10]:
in_dim = X.shape[1]
hidden_dim = 64  # or desired size
dropout = 0.5
num_epochs = 100

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gcn_model = GCNModel(in_dim, hidden_dim, n_clusters, dropout).to(device)
loss_fn = ModularityLoss(n_clusters, initial_alpha=0.5).to(device)

In [12]:
optimizer = torch.optim.Adam(
    list(gcn_model.parameters()) + list(loss_fn.parameters()),
    lr=1e-3
)

In [34]:
adj_tensor = torch.tensor(adj, dtype=torch.float32)

# Ensure norm_adj is tensor and on device
if not torch.is_tensor(adj_tensor):
    adj_norm = torch.tensor(adj_tensor.toarray(), dtype=torch.float32, device=device)
else:
    adj_norm = adj_tensor.to(device)

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move tensors to the device
X = X.to(device)
C = C.to(device)
adj_tensor = torch.tensor(adj, dtype=torch.float32).to(device)
edge_index = edge_index.to(device)

# Instantiate model and loss on device
gcn_model = gcn_model.to(device)
loss_fn = loss_fn.to(device)

In [37]:
for epoch in range(num_epochs):
    gcn_model.train()
    optimizer.zero_grad()
    C = gcn_model(X, edge_index)
    loss = loss_fn(C, X, adj_tensor)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch}: Loss={loss.item()}, Alpha={loss_fn.alpha.item()}')

Epoch 0: Loss=0.705234706401825, Alpha=0.4997046887874603
Epoch 1: Loss=0.7053388357162476, Alpha=0.49955153465270996
Epoch 2: Loss=0.7024539113044739, Alpha=0.4993745982646942
Epoch 3: Loss=0.705615758895874, Alpha=0.49914586544036865
Epoch 4: Loss=0.6980979442596436, Alpha=0.4989511966705322
Epoch 5: Loss=0.7070960402488708, Alpha=0.49878063797950745
Epoch 6: Loss=0.7053557634353638, Alpha=0.49864742159843445
Epoch 7: Loss=0.7070155739784241, Alpha=0.4985295832157135
Epoch 8: Loss=0.7041813731193542, Alpha=0.49841776490211487
Epoch 9: Loss=0.7025058269500732, Alpha=0.498278945684433
Epoch 10: Loss=0.7071056962013245, Alpha=0.4982041120529175
Epoch 11: Loss=0.6802989840507507, Alpha=0.4981391131877899
Epoch 12: Loss=0.7047457695007324, Alpha=0.49803948402404785
Epoch 13: Loss=0.707088828086853, Alpha=0.497982382774353
Epoch 14: Loss=0.7025809288024902, Alpha=0.49796923995018005
Epoch 15: Loss=0.7066896557807922, Alpha=0.4980018138885498
Epoch 16: Loss=0.7064625024795532, Alpha=0.49800

## Testing batching mechanism

In [33]:
import numpy as np

def mini_batch_sampler(adj, batch_size, num_batches, 
                                walk_length=4, num_walks=20):
    """
    Efficient mini-batch sampling of nodes using random walk neighborhood expansion.
    - adj: normalized adjacency matrix (scipy sparse csr matrix)
    - batch_size: number of nodes per batch
    - num_batches: number of batches to sample
    - walk_length: number of steps per walk (default 4)
    - num_walks: number of walks per node (default 20)
    Returns a list of numpy arrays for each batch, each array containing the batch nodes + expanded neighbors.
    """
    nnodes = adj.shape[0]
    all_nodes = np.arange(nnodes)
    batches = []
    for _ in range(num_batches):
        # 1. Sample batch nodes
        batch_nodes = np.random.choice(all_nodes, size=batch_size, replace=False)
        # 2. Expand neighborhood by random walks
        visited = set(batch_nodes)
        for node in batch_nodes:
            for _ in range(num_walks):
                curr = node
                for _ in range(walk_length):
                    neighbors = adj[curr].nonzero()[1]
                    # print("neighbors", neighbors, neighbors[1])
                    if len(neighbors) == 0:
                        break
                    next_node = np.random.choice(neighbors)
                    visited.add(next_node)
                    curr = next_node
        subgraph_nodes = np.array(list(visited))
        print("Nodes in current batch: ", subgraph_nodes)
        batches.append(subgraph_nodes)
    return batches

# Example usage:
# norm_adj is your normalized adjacency matrix from collect_data
# batches = mini_batch_sampler(norm_adj, batch_size=32, num_batches=10)


In [27]:
from src.utils import collect_data

norm_adj, data, nclasses = collect_data('PubMed')
features = data.x.to(device)
edge_index = data.edge_index.to(device)
labels = data.y.tolist()

Dataset downloaded successfully!


In [34]:
batches = mini_batch_sampler(norm_adj, 64, 64)

Nodes in current batch:  [    0 16396  8205 ... 16379  8188  8189]
Nodes in current batch:  [    6  8199    12 ... 16373 16375 16376]
Nodes in current batch:  [ 8193 16386  8196 ... 16354 16355 16357]
Nodes in current batch:  [16384  8193 16389 ... 16365 16370  8186]
Nodes in current batch:  [16387     7  8201 ...  8165 16373 16379]
Nodes in current batch:  [ 8193     2  8197 ...  8183 16380 16382]
Nodes in current batch:  [   17    22    27 ... 16373 16378 16379]
Nodes in current batch:  [    6     7  8203 ...  8183 16380 16381]
Nodes in current batch:  [16384  8192 16387 ... 16380 16381  8191]
Nodes in current batch:  [ 8192     1 16387 ...  8173  8175 16380]
Nodes in current batch:  [    0  8192  8198 ... 16376 16378  8188]
Nodes in current batch:  [16389     6  8198 ... 16342  8153 16355]
Nodes in current batch:  [    0     1 16387 ... 16370 16380  8191]
Nodes in current batch:  [ 8192 16387     6 ... 16362  8176 16372]
Nodes in current batch:  [16384  8196     4 ...  8172 16370  8

In [35]:
batches

[array([    0, 16396,  8205, ..., 16379,  8188,  8189]),
 array([    6,  8199,    12, ..., 16373, 16375, 16376]),
 array([ 8193, 16386,  8196, ..., 16354, 16355, 16357]),
 array([16384,  8193, 16389, ..., 16365, 16370,  8186]),
 array([16387,     7,  8201, ...,  8165, 16373, 16379]),
 array([ 8193,     2,  8197, ...,  8183, 16380, 16382]),
 array([   17,    22,    27, ..., 16373, 16378, 16379]),
 array([    6,     7,  8203, ...,  8183, 16380, 16381]),
 array([16384,  8192, 16387, ..., 16380, 16381,  8191]),
 array([ 8192,     1, 16387, ...,  8173,  8175, 16380]),
 array([    0,  8192,  8198, ..., 16376, 16378,  8188]),
 array([16389,     6,  8198, ..., 16342,  8153, 16355]),
 array([    0,     1, 16387, ..., 16370, 16380,  8191]),
 array([ 8192, 16387,     6, ..., 16362,  8176, 16372]),
 array([16384,  8196,     4, ...,  8172, 16370,  8186]),
 array([ 8192, 16384,     2, ..., 16375, 16378, 16383]),
 array([16384,     0,  8193, ...,  8182, 16378,  8191]),
 array([16387, 16389,  8197, ..

In [17]:
nnodes = norm_adj.shape[0]
nnodes

19717

In [18]:
all_nodes = np.arange(nnodes)

In [19]:
all_nodes

array([    0,     1,     2, ..., 19714, 19715, 19716])