<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Complete_GCN_Code_with_Normalized_Adjacency.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

# --- GCN Layer ---
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adj):
        support = self.linear(x)
        output = torch.matmul(adj, support)  # Aggregation step
        return output

# --- GCN Model ---
class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(in_features, hidden_features)
        self.gcn2 = GCNLayer(hidden_features, out_features)

    def forward(self, x, adj):
        x = F.relu(self.gcn1(x, adj))
        x = self.gcn2(x, adj)
        return F.log_softmax(x, dim=1)

# --- Adjacency Normalization Function ---
def normalize_adjacency(adj):
    """
    Applies A_hat = D^{-1/2} (A + I) D^{-1/2}
    """
    I = torch.eye(adj.size(0))
    adj = adj + I  # Add self-loops
    degree = adj.sum(dim=1)
    D_inv_sqrt = torch.diag(torch.pow(degree, -0.5))
    return D_inv_sqrt @ adj @ D_inv_sqrt  # Symmetric normalization

# --- Example Graph ---
# Node features: 6 nodes with 10-dimensional features
node_features = torch.rand(6, 10)

# Example adjacency matrix (6 nodes)
adjacency_matrix = torch.tensor([
    [1, 1, 0, 0, 0, 0],
    [1, 1, 1, 0, 0, 0],
    [0, 1, 1, 1, 0, 0],
    [0, 0, 1, 1, 1, 0],
    [0, 0, 0, 1, 1, 1],
    [0, 0, 0, 0, 1, 1]
], dtype=torch.float)

# Normalize the adjacency matrix
adj_norm = normalize_adjacency(adjacency_matrix)

# Instantiate and run the model
model = GCN(in_features=10, hidden_features=16, out_features=3)
output = model(node_features, adj_norm)

# Output
print("Graph Output Shape:", output.shape)
print("Node class scores:\n", output)