In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import networkx as nx
import torch_geometric
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import matplotlib.pyplot as plt


In [2]:
def generate_cSBM(n = 5000, f = 2000, d = 5, l = 2.06, mu = 2.0, censor_fraction=0.1):
    """
    Generate a cSBM dataset.

    Parameters:
    - n: Number of nodes in the graph.
    - f: Number of features for each node.
    - d: Average degree of each node.
    - l: Signal-to-noise ratio.
    - mu: Signal strength.
    - censor_fraction: Fraction of edges to censor.

    Returns:
    - PyTorch Geometric Data object representing the censored graph.
    """
    
    # Generate SBM graph
    nodes_per_community = n // 2
    sizes = [nodes_per_community] * 2
    ksi = n / f
    phi = 2 / np.pi * np.arctan(l * np.sqrt(ksi) / mu)
    p_in = (d + l * np.sqrt(d)) / n
    p_out = (d - l * np.sqrt(d)) / n
    p_matrix = np.ones((2, 2)) * p_out
    np.fill_diagonal(p_matrix, p_in)
    G = nx.stochastic_block_model(sizes, p_matrix)
    
    # Censor edges
    all_edges = list(G.edges())
    np.random.shuffle(all_edges)
    edges_to_remove = all_edges[:int(censor_fraction * len(all_edges))]
    G.remove_edges_from(edges_to_remove)
    
    # Add node features
    for node in G.nodes():
        u = np.random.normal(0, 1 / f, f)
        y = G.nodes[node]['block']
        Z = np.random.normal(0, 1, f)
        G.nodes[node]['x'] = np.sqrt(mu / n) * y * u + Z / np.sqrt(f)
    
    # Convert to PyTorch Geometric format
    data = from_networkx(G)
    
    # Constructing the feature tensor
    node_features = [G.nodes[node]['x'] for node in G.nodes()] # Ensure that the feature vectors are correctly shaped and then stack them
    data.x = torch.tensor(node_features, dtype=torch.float)
    
    return data, G

# Generate cSBM data and the networkx graph
cSBM_data, G = generate_cSBM()
print(cSBM_data)

# Visualization
def visualize_graph(G):
    # Get the community of each node for coloring
    community_map = {node: G.nodes[node]['block'] for node in G.nodes()}
    colors = [community_map[node] for node in G.nodes()]
    
    plt.figure(figsize=(10, 8))
    nx.draw(G, node_color=colors, with_labels=False, node_size=50, cmap=plt.cm.jet)
    plt.show()

visualize_graph(G)

In [3]:
print(cSBM_data.x.shape)

In [4]:
from models.discriminator import Discriminator
from models.PolyGCL_model import PolyGCL
from loss import contrastive_loss

In [5]:
torch.cuda.is_available()

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [27]:
cSBM_data = cSBM_data.to(device)
edge_index = cSBM_data.edge_index
x = cSBM_data.x
model = PolyGCL(in_size = 2000, hidden_size = 2000, out_size = 2000, K = 10).to(device)
discriminator = Discriminator(2000).to(device)
print(next(model.parameters()).device)
print(edge_index.device)
print(x.device)

In [30]:
# training
optimizer = optim.Adam(model.parameters(), lr=0.1)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    x_tilde = PolyGCL.get_negative_example(x)
    pos_Z_H, pos_Z_L = model(x, edge_index)
    neg_Z_H, neg_Z_L = model(x_tilde, edge_index)
    g = model.get_global_summary(pos_Z_H, pos_Z_L)
    loss = contrastive_loss(pos_Z_H, neg_Z_H, pos_Z_L, neg_Z_L, g, discriminator)
    loss.backward()
    optimizer.step()
    print('Epoch:', epoch, 'Loss:', loss.item())