In [21]:
import logging

import networkx as nx
import numpy as np
import torch

import geom.hyperboloid as hyperboloid
import geom.poincare as poincare
from learning.frechet import Frechet
from learning.pca import TangentPCA, EucPCA, PGA, HoroPCA, BSA
from utils.data import load_graph, load_embeddings
from utils.metrics import avg_distortion_measures, compute_metrics, format_metrics, aggregate_metrics
from utils.sarkar import sarkar, pick_root
import matplotlib.pyplot as plt

In [22]:
logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S"
    )

torch.set_default_dtype(torch.float64)

In [23]:
def load_graph_data(dataset="smalltree"):
    graph = load_graph(dataset)
    n_nodes = graph.number_of_nodes()
    nodelist = np.arange(n_nodes)
    graph_dist = torch.from_numpy(nx.floyd_warshall_numpy(graph, nodelist=nodelist))

    return graph, graph_dist

In [24]:
def get_hyperbolic_embeddings(graph,dataset="smalltree",dim=10 ,use_sarkar=False, sarkar_scale=3.5):
    if use_sarkar:
    # embed with Sarkar
        logging.info("Using sarkar embeddings")
        root = pick_root(graph)
        z = sarkar(graph, tau=sarkar_scale, root=root, dim=dim)
        z = torch.from_numpy(z)
        z_dist = poincare.pairwise_distance(z) / sarkar_scale
        return z, z_dist
    else:
        # load pre-trained embeddings
        logging.info("Using optimization-based embeddings")
        assert dim in [2, 10, 50], "pretrained embeddings are only for 2, 10 and 50 dimensions"
        z = load_embeddings(dataset, dim=dim)
        z = torch.from_numpy(z)
        z_dist = poincare.pairwise_distance(z)
        return z, z_dist

In [27]:
def run(dataset="smalltree", 
        model_type="horopca",
        metrics_final=["distortion", "frechet_var"],
        dim=10,
        n_components=2, 
        n_runs=5,
        use_sarkar=False,
        sarkar_scale=3.5, lr=5e-2):

    config = {
        "metrics_final": metrics_final,
        "dim": dim,
        "n_components": n_components,
        "n_runs": n_runs,
        "use_sarkar": use_sarkar,
        "sarkar_scale": sarkar_scale,
        "lr": lr
        }

    metrics = []
    embeddings = {}
    logging.info(f"Running experiments for {dataset} dataset.")

    graph, graph_dist = load_graph_data(dataset)
    z, z_dist = get_hyperbolic_embeddings(graph, dataset, dim, use_sarkar, sarkar_scale)
    
    # compute embeddings' distortion
    distortion = avg_distortion_measures(graph_dist, z_dist)[0]
    logging.info("Embedding distortion in {} dimensions: {:.4f}".format(dim, distortion))

    # Compute the mean and center the data
    logging.info("Computing the Frechet mean to center the embeddings")
    frechet = Frechet(lr=1e-2, eps=1e-5, max_steps=5000)
    mu_ref, has_converged = frechet.mean(z, return_converged=True)
    logging.info(f"Mean computation has converged: {has_converged}")
    x = poincare.reflect_at_zero(z, mu_ref)

    pca_models = {
        'pca': {'class': EucPCA, 'optim': False, 'iterative': False, "n_runs": 1},
        'tpca': {'class': TangentPCA, 'optim': False, 'iterative': False, "n_runs": 1},
        'pga': {'class': PGA, 'optim': True, 'iterative': True, "n_runs": n_runs},
        'bsa': {'class': BSA, 'optim': True, 'iterative': False, "n_runs": n_runs},
        'horopca': {'class': HoroPCA, 'optim': True, 'iterative': False, "n_runs": n_runs},
    }

    if model_type in pca_models.keys():
        model_params = pca_models[model_type]
        for _ in range(model_params["n_runs"]):
            model = model_params['class'](dim=dim, n_components=n_components, lr=lr, max_steps=500)
            # if torch.cuda.is_available():
            #     model.cuda()
            model.fit(x, iterative=model_params['iterative'], optim=model_params['optim'])
            metrics.append(model.compute_metrics(x))
            embeddings = model.map_to_ball(x).detach().cpu().numpy()
        metrics = aggregate_metrics(metrics)
    else:
        # run hMDS baseline
        logging.info(f"Running hMDS")
        x_hyperboloid = hyperboloid.from_poincare(x)
        distances = hyperboloid.distance(x.unsqueeze(-2), x.unsqueeze(-3))
        D_p = poincare.pairwise_distance(x)
        x_h = hyperboloid.mds(D_p, d=n_components)
        x_proj = hyperboloid.to_poincare(x_h)
        embeddings["hMDS"] = x_proj.numpy()
        metrics = compute_metrics(x, x_proj)

    logging.info(f"Experiments for {dataset} dataset completed.")
    logging.info("Computing evaluation metrics")
    results = format_metrics(metrics, metrics_final)
    for line in results:
        logging.info(line)
    return metrics, config


In [28]:
list_dataset = ["smalltree", "phylo-tree", "ca-CSphd", "bio-diseasome"]
list_model = ["pca", "horopca", "pga", "bsa", "tpca", "hmds"]

results = []

for dataset in list_dataset:
    print("Dataset:", dataset)
    metrics, config = run(dataset=dataset, model_type="horopca")
    results.append({"dataset" : dataset, "metrics": metrics, "config": config})

2024-01-11 10:49:22 INFO     Running experiments for smalltree dataset.
2024-01-11 10:49:22 INFO     Using optimization-based embeddings
2024-01-11 10:49:22 INFO     Embedding distortion in 10 dimensions: 0.0222
2024-01-11 10:49:22 INFO     Computing the Frechet mean to center the embeddings
2024-01-11 10:49:22 INFO     Mean computation has converged: True


Dataset: smalltree


KeyboardInterrupt: 