In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from rdflib import Graph

# 1. Load the data
def load_nt_file(file_path):
    g = Graph()
    g.parse(file_path, format="turtle")
    return g

# Convert RDF graph to NetworkX graph
def rdf_to_networkx(rdf_graph):
    G = nx.Graph()
    for s, p, o in rdf_graph:
        G.add_edge(str(s), str(o))
    return G

# 2. Create a PyG graph object
def create_pyg_graph(nx_graph):
    return from_networkx(nx_graph)

# 3. Define the model
class SimpleGraphEmbedding(nn.Module):
    def __init__(self, num_nodes, embedding_dim):
        super(SimpleGraphEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_nodes, embedding_dim)
        self.linear = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, edge_index):
        # Get embeddings for both source and target nodes
        src_embeds = self.embedding(edge_index[0])
        dst_embeds = self.embedding(edge_index[1])
        
        # Combine embeddings (you can experiment with different combination methods)
        combined = src_embeds + dst_embeds
        
        # Pass through a linear layer
        return self.linear(combined)

# 4. Train the embeddings
def train_embeddings(model, data, num_epochs=100, lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        out = model(data.edge_index)
        
        # Simple loss: try to make connected nodes have similar embeddings
        loss = F.mse_loss(out, torch.zeros_like(out))
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')

    return model


In [4]:
file_path = "14_graph.nt"
rdf_graph = load_nt_file(file_path)

In [5]:
# Main execution
nx_graph = rdf_to_networkx(rdf_graph)
pyg_graph = create_pyg_graph(nx_graph)

num_nodes = pyg_graph.num_nodes
embedding_dim = 128
model = SimpleGraphEmbedding(num_nodes, embedding_dim)

trained_model = train_embeddings(model, pyg_graph)

# Get embeddings
embeddings = trained_model.embedding.weight.detach().numpy()

Epoch 10/100, Loss: 0.0334


KeyboardInterrupt: 