<a href="https://colab.research.google.com/github/ajkohl/HyperbolicGCN-Optimization/blob/main/Copy_of_HyperBolicGCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio
!pip install geoopt
!pip install networkx


Collecting geoopt
  Downloading geoopt-0.5.0-py3-none-any.whl.metadata (6.7 kB)
Downloading geoopt-0.5.0-py3-none-any.whl (90 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: geoopt
Successfully installed geoopt-0.5.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
from torch import optim
import geoopt
from geoopt import ManifoldParameter

##############################################
# Data Loading (Example)
##############################################

def load_fb15k_data(train_path, valid_path, test_path):
    def read_triples(path, ent2id, rel2id):
        triples = []
        with open(path, 'r') as f:
            for line in f:
                h, r, t = line.strip().split('\t')
                if h not in ent2id:
                    ent2id[h] = len(ent2id)
                if t not in ent2id:
                    ent2id[t] = len(ent2id)
                if r not in rel2id:
                    rel2id[r] = len(rel2id)
                triples.append((ent2id[h], rel2id[r], ent2id[t]))
        return triples

    ent2id = {}
    rel2id = {}
    train_triples = read_triples(train_path, ent2id, rel2id)
    valid_triples = read_triples(valid_path, ent2id, rel2id)
    test_triples = read_triples(test_path, ent2id, rel2id)

    G = nx.Graph()
    G.add_nodes_from(range(len(ent2id)))
    for h, r, t in train_triples:
        G.add_edge(h, t)

    return ent2id, rel2id, G, train_triples, valid_triples, test_triples

# Example paths (You must provide actual paths)
train_path = 'train.txt'
valid_path = 'valid.txt'
test_path = 'test.txt'

# For demonstration, create dummy data if files are not available
import os
if not os.path.exists(train_path):
    with open(train_path, 'w') as f:
        for i in range(100):
            f.write(f"entity_{i}\trelation_{i%10}\tentity_{(i+1)%100}\n")
    with open(valid_path, 'w') as f:
        for i in range(100, 120):
            f.write(f"entity_{i%100}\trelation_{i%10}\tentity_{(i+1)%100}\n")
    with open(test_path, 'w') as f:
        for i in range(120, 140):
            f.write(f"entity_{i%100}\trelation_{i%10}\tentity_{(i+1)%100}\n")

ent2id, rel2id, G, train_triples, valid_triples, test_triples = load_fb15k_data(train_path, valid_path, test_path)

num_nodes = len(ent2id)
num_relations = len(rel2id)

adj = nx.to_scipy_sparse_array(G, nodelist=range(num_nodes))
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)  # symmetrize
adj = torch.tensor(adj.toarray(), dtype=torch.float32) + 1e-5*torch.eye(num_nodes)
degrees = adj.sum(axis=1)
inv_sqrt_deg = 1. / torch.sqrt(degrees)
inv_sqrt_deg[torch.isinf(inv_sqrt_deg)] = 0
D_inv_sqrt = torch.diag(inv_sqrt_deg)
norm_adj = D_inv_sqrt @ adj @ D_inv_sqrt
norm_adj = norm_adj.to(torch.float32)

# Node features: using identity for simplicity
features = torch.eye(num_nodes, device='cpu')  # Initialize on CPU first

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
features, norm_adj = features.to(device), norm_adj.to(device)

##############################################
# Hyperbolic GCN Layers
##############################################

class HyperbolicLinear(nn.Module):
    """
    Hyperbolic linear layer using the Poincaré Ball model.
    """
    def __init__(self, manifold, in_features, out_features, bias=True):
        super(HyperbolicLinear, self).__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        # ManifoldParameter for weight in Euclidean space
        self.weight = ManifoldParameter(torch.randn(out_features, in_features)*0.01, manifold=geoopt.Euclidean())
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # x is on manifold. Map to tangent space at 0, apply linear, map back
        x_tan = self.manifold.logmap0(x, dim=-1)
        out = F.linear(x_tan, self.weight, self.bias)
        out = self.manifold.expmap0(out, dim=-1)
        return out

class HyperbolicGCNLayer(nn.Module):
    def __init__(self, manifold, in_features, out_features):
        super(HyperbolicGCNLayer, self).__init__()
        self.manifold = manifold
        self.lin = HyperbolicLinear(manifold, in_features, out_features)

    def forward(self, x, adj):
        # x on manifold
        x_tan = self.manifold.logmap0(x, dim=-1)
        x_agg_tan = adj @ x_tan  # aggregate in tangent space
        x_agg = self.manifold.expmap0(x_agg_tan, dim=-1)
        return self.lin(x_agg)

class HyperbolicGCN(nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes, c=1.0):
        super(HyperbolicGCN, self).__init__()
        # Initialize the manifold inside the model to ensure its parameters are included
        self.manifold = geoopt.PoincareBall(c=c, learnable=True)
        self.layer1 = HyperbolicGCNLayer(self.manifold, num_features, hidden_dim)
        self.layer2 = HyperbolicGCNLayer(self.manifold, hidden_dim, num_classes)

    def forward(self, x, adj):
        x = self.manifold.expmap0(x, dim=-1)  # ensure x is on manifold
        x = self.layer1(x, adj)
        # Hyperbolic activation: apply tanh in tangent space
        x_tan = self.manifold.logmap0(x, dim=-1)
        x_tan = torch.tanh(x_tan)
        x = self.manifold.expmap0(x_tan, dim=-1)
        x = self.layer2(x, adj)
        return x

##############################################
# Link Prediction Utilities
##############################################

def get_positive_negative_samples(train_triples, num_nodes, num_neg_samples=5):
    positive = set(train_triples)
    negatives = []
    while len(negatives) < len(train_triples) * num_neg_samples:
        h = torch.randint(0, num_nodes, (1,)).item()
        t = torch.randint(0, num_nodes, (1,)).item()
        r = torch.randint(0, num_relations, (1,)).item()
        if (h, r, t) not in positive:
            negatives.append((h, r, t))
    return train_triples, negatives

def train_epoch(model, optimizer, criterion, pos_triples, neg_triples):
    model.train()
    optimizer.zero_grad()

    # Combine positive and negative samples
    pos_labels = torch.ones(pos_triples.size(0), device=device)
    neg_labels = torch.zeros(neg_triples.size(0), device=device)
    triples = torch.cat([pos_triples, neg_triples], dim=0)
    labels = torch.cat([pos_labels, neg_labels], dim=0).unsqueeze(1)  # Shape: [2N, 1]

    # Get embeddings
    heads = triples[:,0]
    relations = triples[:,1]
    tails = triples[:,2]

    # Forward pass through the model
    logits = model(features, norm_adj)  # Shape: [num_nodes, num_relations]

    # For link prediction, define a scoring function
    # Here, we'll use a simple dot product between head and tail embeddings
    # A proper implementation should incorporate relation embeddings
    # For demonstration, we'll compute scores for each triple
    head_emb = logits[heads]  # Shape: [2N, num_classes]
    tail_emb = logits[tails]  # Shape: [2N, num_classes]

    # Simple dot product scoring
    scores = (head_emb * tail_emb).sum(dim=1, keepdim=True)  # Shape: [2N, 1]

    loss = criterion(scores, labels)
    loss.backward()
    optimizer.step()

    return loss.item()

##############################################
# Model and Training
##############################################

model = HyperbolicGCN(num_features=num_nodes, hidden_dim=64, num_classes=num_relations, c=0.06).to(device)

# Initialize input: ensure features are on the manifold
# Here, features are identity; map them to the manifold
x = model.manifold.expmap0(features, dim=-1)

# Define optimizer using Geoopt's RiemannianAdam
optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=1e-2)

# Define loss function for link prediction: Binary Cross Entropy with logits
criterion = nn.BCEWithLogitsLoss()

# Prepare training samples
train_pos, train_neg = get_positive_negative_samples(train_triples, num_nodes)

# Convert to tensors
train_pos = torch.tensor(train_pos, dtype=torch.long, device=device)
train_neg = torch.tensor(train_neg, dtype=torch.long, device=device)

# Enable anomaly detection to help debug further errors
torch.autograd.set_detect_anomaly(True)

num_epochs = 200
for epoch in range(num_epochs):
    loss = train_epoch(model, optimizer, criterion, train_pos, train_neg)
    print(f"Epoch {epoch}, Loss: {loss:.4f}, Curvature: {model.manifold.c.item():.6f}")


Epoch 0, Loss: 0.6931, Curvature: 0.060000
Epoch 1, Loss: 0.7010, Curvature: 0.060434
Epoch 2, Loss: 0.6932, Curvature: 0.060789
Epoch 3, Loss: 0.7013, Curvature: 0.060452
Epoch 4, Loss: 0.6936, Curvature: 0.060166
Epoch 5, Loss: 0.6960, Curvature: 0.059859
Epoch 6, Loss: 0.6949, Curvature: 0.059521
Epoch 7, Loss: 0.6927, Curvature: 0.059216
Epoch 8, Loss: 0.6943, Curvature: 0.058950
Epoch 9, Loss: 0.6950, Curvature: 0.058716
Epoch 10, Loss: 0.6934, Curvature: 0.058506
Epoch 11, Loss: 0.6918, Curvature: 0.058247
Epoch 12, Loss: 0.6928, Curvature: 0.057889
Epoch 13, Loss: 0.6916, Curvature: 0.057471
Epoch 14, Loss: 0.6900, Curvature: 0.057036
Epoch 15, Loss: 0.6897, Curvature: 0.056608
Epoch 16, Loss: 0.6868, Curvature: 0.056151
Epoch 17, Loss: 0.6839, Curvature: 0.055664
Epoch 18, Loss: 0.6818, Curvature: 0.055149
Epoch 19, Loss: 0.6768, Curvature: 0.054625
Epoch 20, Loss: 0.6744, Curvature: 0.054129
Epoch 21, Loss: 0.6712, Curvature: 0.053841
Epoch 22, Loss: 0.6736, Curvature: 0.05387

In [1]:
!pip install geoopt

Collecting geoopt
  Downloading geoopt-0.5.0-py3-none-any.whl.metadata (6.7 kB)
Downloading geoopt-0.5.0-py3-none-any.whl (90 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/90.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: geoopt
Successfully installed geoopt-0.5.0
