In [16]:
!pip install obonet
!pip install owlready2
!pip install rdflib
!pip install torch_geometric
!pip install PyYAML



In [12]:
import os

import torch
import yaml
from src.train import train_model
from src.utils import load_data
from torch import nn

In [2]:
# load config function
def load_config(config_path="../config.yaml"):
    with open(config_path) as f:
        return yaml.safe_load(f)

In [10]:
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# load config
config = load_config()

data_specs = config["data"]
data_directory = data_specs["data_directory"]

# get data
train_loader, test_loader, edge_index, ontology_node_list = load_data(
    data_directory,
    device=device,
    batch_size=data_specs["batch_size"],
    n_samples=data_specs["n_samples"],
)

# Train the model
print("Starting training...")
model, test_losses, modularities = train_model(
    train_loader,
    test_loader,
    device=device,
    node_list=ontology_node_list,
    config=config,
    print_stats=True,
    plot=True,  # Set to False if you don't want plots saved
)
print("Training finished.")

Using device: cuda
Loaded ./data/titanic/titanic_with_cabins.ttl with rdflib (Turtle format)
DiGraph with 322 nodes and 853 edges
Loaded ./data/titanic/titanic_with_cabins.ttl with rdflib (Turtle format)
DiGraph with 322 nodes and 853 edges
loaded titanic dataset
Starting training...
Epoch [10/300], Train Loss: 0.2816, Test Loss: 0.5810 Modularity:  0.35824913
Epoch [20/300], Train Loss: 0.2750, Test Loss: 0.5742 Modularity:  0.36040625
Epoch [30/300], Train Loss: 0.2706, Test Loss: 0.5721 Modularity:  0.3617345
Epoch [40/300], Train Loss: 0.2653, Test Loss: 0.5708 Modularity:  0.36282474
Epoch [50/300], Train Loss: 0.2601, Test Loss: 0.5696 Modularity:  0.3638028
Epoch [60/300], Train Loss: 0.2548, Test Loss: 0.5682 Modularity:  0.36471194
Epoch [70/300], Train Loss: 0.2490, Test Loss: 0.5679 Modularity:  0.36557376
Epoch [80/300], Train Loss: 0.2435, Test Loss: 0.5663 Modularity:  0.366396
Epoch [90/300], Train Loss: 0.2380, Test Loss: 0.5656 Modularity:  0.36720616
Epoch [100/300], 

In [4]:
# Evaluate the model
print("\nEvaluating model performance...")
model.evaluate_model(train_loader, test_loader, device=device)


Evaluating model performance...
Train Accuracy: 0.7500
Test Accuracy: 0.7273


In [7]:
# Analyze communities
print("\nAnalyzing detected communities...")
# Ensure criterion is defined for importance calculation
criterion = nn.CrossEntropyLoss().to(
    device,
)  # Define criterion again if not global or passed

print("\nCalculating community importance...")
comm_importance = model.compute_community_importance(
    test_loader,
    criterion,
    device=device,
    num_communities=model.num_communities,
    print_stats=False,  # Set to True for detailed output per sample
)
print(f"Community Importance (per community): {comm_importance.tolist()}")
most_imp_comm_idx = torch.argmax(comm_importance).item()
print(f"Most important community index: {most_imp_comm_idx}")


Analyzing detected communities...

Calculating community importance...
Community Importance (per community): [-0.0002778438210953027, 0.000256607832852751, 0.0001695385144557804]
Most important community index: 1


In [7]:
print("\nIdentifying important nodes in the most important community...")
important_nodes_freq = model.important_community_nodes(
    test_loader,
    device=device,
    criterion=criterion,
    print_stats=False,  # Set to True for detailed output per sample
)

# Sort and print top nodes
sorted_nodes = sorted(
    important_nodes_freq.items(), key=lambda item: item[1], reverse=True,
)
print("Top nodes in the most important community:")
for node_name, freq in sorted_nodes[:20]:  # Print top 20
    print(f"  {node_name}: appears in {freq} samples")

Top nodes in the most important community:
  A_Deck: appears in 55 samples
  A_cabin: appears in 55 samples
  AllDisjointClasses: appears in 55 samples
  B_Deck: appears in 55 samples
  B_cabin: appears in 55 samples
  BoatLocation: appears in 55 samples
  Boat_Deck: appears in 55 samples
  C_Deck: appears in 55 samples
  C_cabin: appears in 55 samples
  Cabin: appears in 55 samples
  Cherbourg: appears in 55 samples
  Class: appears in 55 samples
  D_Deck: appears in 55 samples
  D_cabin: appears in 55 samples
  Deck: appears in 55 samples
  Distance: appears in 55 samples
  E_Deck: appears in 55 samples
  E_cabin: appears in 55 samples
  Entity: appears in 55 samples
  F_Deck: appears in 55 samples


In [8]:
print("\nIdentifying important edges within the most important community...")
# Assuming edge_index is available from load_data
imp_edges = model.important_community_edges(
    test_loader,
    edge_index=edge_index,  # Pass the edge_index
    device=device,
    criterion=criterion,
    print_stats=False,  # Set to True for detailed output per sample
)

# Sort and print top edges
# Convert edge tuple of indices back to names for printing
sorted_edges = sorted(imp_edges.items(), key=lambda item: item[1], reverse=True)
print("Top edges within the most important community:")
for (i, j), count in sorted_edges[:20]:  # Print top 20
    name_i = ontology_node_list[i]
    name_j = ontology_node_list[j]
    print(f"  ({name_i}, {name_j}): appears in {count} samples")


Identifying important edges within the most important community...
Top edges within the most important community:


In [9]:
# Optional: Print gradients (useful for debugging training)
print("\nPrinting model gradients (last state)...")
model.print_gradients()  # Ensure this function is defined in your model or utils


Printing model gradients (last state)...
CommunityDetection.conv1.bias: 0.03441854566335678
CommunityDetection.conv1.lin.weight: 0.0010838990565389395
CommunityDetection.norm.weight: 0.010859488509595394
CommunityDetection.norm.bias: 0.011725078336894512
OntologyEncoder.conv1.att_src: 0.017874499782919884
OntologyEncoder.conv1.att_dst: 2.8556962350378967e-11
OntologyEncoder.conv1.bias: 0.06008051708340645
OntologyEncoder.conv1.lin.weight: 0.007949931547045708
fc2.weight: 0.00857439637184143
fc2.bias: 0.027162153273820877


In [18]:
# Optional: Save the trained model
save = config["training"]["save_model"]

if save:
    save_dir = config["experiment"]["save_dir"]
    # Create the save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    save_path = os.path.join(
        config["experiment"]["save_dir"], config["experiment"]["name"] + ".pt",
    )
    # Need to import os
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

Model saved to checkpoints/exp1_community_gnn.pt
