In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as DL
import os
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [55]:
## Loading node embedding data
node_embeddings = torch.load('./Output/node_embeddings_initial.pt')
node_embeddings = node_embeddings.to(device)

## Loading adjacency matrix
adj = torch.load('./Output/sub_adjacency_matrix.pt')
adj = adj.to(device)

  node_embeddings = torch.load('./Output/node_embeddings_initial.pt')
  adj = torch.load('./Output/sub_adjacency_matrix.pt')


In [56]:
# Define the model
class F3GNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_classes, num_gcn_layers=2):
        """
        Initialize the GNN Classifier (F3).
        
        Args:
            input_dim (int): Dimension of the input graph representations (h).
            hidden_dim (int): Dimension of the hidden layers in GCN.
            output_dim (int): Dimension of the final representation before classification.
            num_classes (int): Number of output classes.
            num_gcn_layers (int): Number of graph convolutional layers.
        """
        super(F3GNNClassifier, self).__init__()
        
        # Graph Convolution Layers
        self.gcn_layers = nn.ModuleList()
        self.gcn_layers.append(GCNConv(input_dim, hidden_dim))  # First GCN layer
        for _ in range(num_gcn_layers - 1):
            self.gcn_layers.append(GCNConv(hidden_dim, hidden_dim))  # Hidden GCN layers
        
        # Linear Layers and ReLU Activation
        self.linear1 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(output_dim, num_classes)
        
        # Softmax for Output Probabilities
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, h, Ap):
        """
        Forward pass of F3.

        Args:
            h (torch.Tensor): Node embeddings from F1 (shape: [num_nodes, input_dim]).
            Ap (torch.Tensor): Population graph adjacency matrix (dense format, shape: [num_nodes, num_nodes]).

        Returns:
            torch.Tensor: Class probability vectors for each input graph (shape: [num_nodes, num_classes]).
        """
        # Convert dense adjacency matrix to sparse edge index format
        edge_index, edge_weight = dense_to_sparse(Ap)
        
        # Apply GCN layers
        for gcn_layer in self.gcn_layers:
            h = gcn_layer(h, edge_index, edge_weight)
            h = self.relu(h)
        
        # Linear Layers
        h = self.linear1(h)
        h = self.relu(h)
        logits = self.linear2(h)

        # Softmax for class probabilities
        probabilities = self.softmax(logits)
        return probabilities


# Main script
if __name__ == "__main__":
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load data
    node_embeddings = torch.load('./Output/node_embeddings_initial.pt').to(device)  # h
    adj = torch.load('./Output/sub_adjacency_matrix.pt').to(device)  # Ap

    # Example labeled nodes (indices and labels)
    # Replace this with actual indices and labels from your dataset
    labeled_node_indices = torch.tensor([0, 1, 2, 3, 4, 5]).to(device)  # Indices of labeled nodes
    labels = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.long).to(device)  # Corresponding labels

    # Model parameters
    input_dim = node_embeddings.size(1)  # Dimension of node embeddings (h)
    hidden_dim = 32  # Hidden dimension for GCN
    output_dim = 16  # Output dimension for the first linear layer
    num_classes = 3  # Number of classes for classification
    num_gcn_layers = 2  # Number of GCN layers

    # Initialize the F3 model
    model = F3GNNClassifier(input_dim, hidden_dim, output_dim, num_classes, num_gcn_layers).to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Training loop
    num_epochs = 100
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        output_probabilities = model(node_embeddings, adj)  # Forward pass

        # Filter predictions for labeled nodes
        filtered_predictions = output_probabilities[labeled_node_indices]
        filtered_labels = labels  # Ensure labels correspond to the same indices

        # Compute loss
        loss = criterion(filtered_predictions, filtered_labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        # Print loss every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # Evaluate the model
    model.eval()
    with torch.no_grad():
        output_probabilities = model(node_embeddings, adj)
        filtered_predictions = output_probabilities[labeled_node_indices]
        predicted_classes = torch.argmax(filtered_predictions, dim=1)
        print("\nPredicted Classes:", predicted_classes)
        print("Class Probabilities:\n", filtered_predictions)


  node_embeddings = torch.load('./Output/node_embeddings_initial.pt').to(device)  # h
  adj = torch.load('./Output/sub_adjacency_matrix.pt').to(device)  # Ap


Epoch [10/100], Loss: 0.5559
Epoch [20/100], Loss: 0.5514
Epoch [30/100], Loss: 0.5514
Epoch [40/100], Loss: 0.5514
Epoch [50/100], Loss: 0.5514
Epoch [60/100], Loss: 0.5514
Epoch [70/100], Loss: 0.5514
Epoch [80/100], Loss: 0.5514
Epoch [90/100], Loss: 0.5514
Epoch [100/100], Loss: 0.5514

Predicted Classes: tensor([0, 1, 2, 0, 1, 2], device='cuda:0')
Class Probabilities:
 tensor([[1.0000e+00, 1.0597e-22, 4.9235e-25],
        [6.8061e-23, 1.0000e+00, 2.8070e-28],
        [7.4626e-15, 4.8197e-11, 1.0000e+00],
        [1.0000e+00, 1.3722e-18, 3.8647e-29],
        [1.5025e-20, 1.0000e+00, 3.8885e-23],
        [8.4359e-17, 3.7674e-10, 1.0000e+00]], device='cuda:0')
