In [None]:
import torch
from src.models.gnn import GCNRegressor, GATRegressor, GCNClassifier
import os
from tqdm import tqdm
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from generate_data import *
from src.datasets.dataset import ManifoldGraphDataset

architectures = {
    'gcn': GCNRegressor,
    'gcn_clf' : GCNClassifier,
    'gat': GATRegressor
}

In [None]:
# Load the model
hidden_channels = 100
features_max_k = 21
num_layers = 5

# data parameters
subgraph_k = 5
degree_features = False
scale_features = False
data_dir = 'data/poincare_est_geodesic_distances/train'
data_files = os.listdir(data_dir)
print("Data files: ", data_files)

model_name = "poincare_est_geodesic_features"
arch = 'gcn'
task = 'regression'
if task == 'classification':
    assert arch == 'gcn_clf'

model_path = f'outputs/{model_name}/nn/best_val.pt'
save_path = f'outputs/{model_name}/plots'
os.makedirs(save_path, exist_ok=True)

model = architectures[arch](
    features_max_k, 
    hidden_channels,
    num_layers,
    dropout=0.0
)
save_dict = torch.load(model_path)
state_dict = save_dict['model']
model.load_state_dict(state_dict)
model.eval()

In [None]:
def get_data(data_dir, file, scale_features, subgraph_k, degree_features):
    data = torch.load(os.path.join(data_dir, file))
    X = data['coords']
    data = {file: data}
    dataset = ManifoldGraphDataset(
        task,
        data,
        subgraph_k=subgraph_k,
        degree_features=degree_features,
        subsample_pctg=1.0,
        avoid_boundary=False,
        scale_features=scale_features
    )
    return X, dataset

def predict_curvature(model, subgraph_dataset, scale_features=False):
    # for each node in graph, create subgraph and pass through GCN
    # return tuple of (features, curvature, predicted curvature)
    subgraph_list = subgraph_dataset.data
    features = []
    curvatures = []
    predicted_curvatures = []
    losses = []

    for subgraph in tqdm(subgraph_list, desc="Predicting curvatures"):
        with torch.no_grad():
            x, edge_index, edge_attrs, batch = subgraph.x.float(), subgraph.edge_index, subgraph.edge_attr.float(), subgraph.batch

            pred = model(x, edge_index, edge_attrs, batch)
            if scale_features:
                scale = subgraph.scale
                pred = pred * (scale ** 2)
                subgraph.y = subgraph.y * (scale ** 2)
            
            # calculate loss            
            loss = torch.nn.functional.mse_loss(pred, subgraph.y)

            losses.append(loss)
            features.append(subgraph.x)
            curvatures.append(subgraph.y)
            predicted_curvatures.append(pred)
    return features, curvatures, predicted_curvatures, losses

def predict_class(model, subgraph_dataset, scale_features=False):
    # for each node in graph, create subgraph and pass through GCN
    # return tuple of (features, curvature, predicted curvature)
    subgraph_list = subgraph_dataset.data
    features = []
    curvatures = []
    predictions = []
    accuracies = []

    for subgraph in tqdm(subgraph_list, desc="Predicting curvatures"):
        with torch.no_grad():
            x, edge_index, edge_attrs, batch = subgraph.x.float(), subgraph.edge_index, subgraph.edge_attr.float(), subgraph.batch

            pred_vec = model(x, edge_index, edge_attrs, batch)
            
            # calculate accuracy            
            _, preds = torch.max(pred_vec, 1)
            correct = (preds == subgraph.y).sum().item()
            accuracy = correct / subgraph.y.shape[0]

            accuracies.append(accuracy)
            features.append(subgraph.x)
            curvatures.append(subgraph.y)
            predictions.append(preds)
    return features, curvatures, predictions, accuracies


def visualize_subgraph(subgraph_data, X, dim=3, n=3, path=None):
    fn = visualize_subgraph_3d if dim == 3 else visualize_subgraph_2d
    indices = np.random.choice(X.shape[0], n)
    fn(subgraph_data, X, indices, path)

def visualize_subgraph_3d(subgraph_data, X, indices, path):
    for index in indices:
        subgraph_indices = subgraph_data.subgraph_node_indices
        subgraph_X = X[subgraph_indices[index]]

        _, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
        axs[0].scatter(subgraph_X[:, 0], subgraph_X[:, 1], subgraph_X[:, 2], c='black', s=20)
        axs[0].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', s=1)


        g = to_networkx(subgraph_data[index], to_undirected=True)
        pos = X[subgraph_indices[index]]
        g.pos = pos

        node_xyz = pos
        edge_xyz = np.array([(pos[u], pos[v]) for u, v in g.edges()])

        axs[1].scatter(*node_xyz.T, s=60, ec="w")

        # Plot the edges
        for vizedge in edge_xyz:
            axs[1].plot(*vizedge.T, color="tab:gray")

        if path is not None:
            plt.savefig(path+f'_subgraph_{index}.png')
            
def visualize_subgraph_2d(subgraph_data, X, indices, path):
    for index in indices:
        subgraph_indices = subgraph_data.subgraph_node_indices
        subgraph_X = X[subgraph_indices[index]]

        _, axs = plt.subplots(1, 2, figsize=(12, 6))
        axs[0].scatter(subgraph_X[:, 0], subgraph_X[:, 1], c='black', s=20)
        axs[0].scatter(X[:, 0], X[:, 1], s=1)


        g = to_networkx(subgraph_data[index], to_undirected=True)
        pos = X[subgraph_indices[index]]
        g.pos = pos

        node_xyz = pos
        edge_xyz = np.array([(pos[u], pos[v]) for u, v in g.edges()])

        axs[1].scatter(*node_xyz.T, s=60, ec="w")

        # Plot the edges
        for vizedge in edge_xyz:
            axs[1].plot(*vizedge.T, color="tab:gray")

        if path is not None:
            plt.savefig(path+f'_subgraph_{index}.png')


# task
predict_fn = predict_class if task == 'classification' else predict_curvature

In [None]:
# torus
torus_file = 'torus_inrad_1_outrad_2_nodes_10000_k_10.pt'
# assert torus_file in data_files
torus_X, torus_subgraph_data = get_data(data_dir.replace('train', ''), torus_file, scale_features=scale_features, subgraph_k=5, degree_features=degree_features)

features, curvatures, predicted_curvatures, metric = predict_fn(model, torus_subgraph_data, scale_features=scale_features)
if task == 'classification':
    print(f"Accuracy: {torch.tensor(metric).mean()}")
else:
    print(f"Mean loss: {torch.mean(torch.stack(metric))}")
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# # plot results
# # create two adjacent 3d subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
axs[0].set_title('Ground Truth')
axs[1].set_title('Predicted')
axs[0].set_zlim(-3, 3)
axs[1].set_zlim(-3, 3)
p1 = axs[0].scatter(torus_X[:, 0], torus_X[:, 1], torus_X[:, 2], 'o', c=curvatures.tolist())
p2 = axs[1].scatter(torus_X[:, 0], torus_X[:, 1], torus_X[:, 2], 'o', c=predicted_curvatures.tolist())
fig.colorbar(p1, ax=axs.ravel().tolist())
fig.colorbar(p2, ax=axs.ravel().tolist())
p1.set_clim(-2, 1)
p2.set_clim(-2, 1)
os.makedirs(f'{save_path}/torus/pred', exist_ok=True)
os.makedirs(f'{save_path}/torus/subgraph', exist_ok=True)
plt.savefig(f'{save_path}/torus/pred/{torus_file.replace(".pt", "")}.png')
visualize_subgraph(torus_subgraph_data, torus_X, dim=3, path=f'{save_path}/torus/subgraph/{torus_file.replace(".pt", "")}', n=5)

In [None]:
# plot distribution of ground truth labels vs predicted labels
plt.figure()
plt.hist(curvatures, bins=50, alpha=0.5, label='Ground Truth')
plt.hist(predicted_curvatures, bins=50, alpha=0.5, label='Predicted')
plt.legend()
plt.savefig(f'{save_path}/torus/pred/{torus_file.replace(".pt", "")}_hist.png')

In [None]:
losses = []
for r in [2.82, 2, 1.633, 1.414, 1.265, 1.15, 1.069, 1]:
    sphere_file = f'sphere_dim_2_rad_{r}_nodes_10000_k_10.pt'
    sphere_X, sphere_subgraph_data = get_data(data_dir, sphere_file, scale_features=scale_features, subgraph_k=subgraph_k, degree_features=degree_features)
    features, curvatures, predicted_curvatures, losses = predict_curvature(model, sphere_subgraph_data)
    average_loss = torch.mean(torch.stack(losses))
    losses.append(average_loss)
    print(f"Loss for sphere with radius {r}: {average_loss.item()}")
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    # plot results
    # create two adjacent 3d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    axs[0].set_zlim(-3, 3)
    axs[1].set_zlim(-3, 3)
    p1 = axs[0].scatter(sphere_X[:, 0], sphere_X[:, 1], sphere_X[:, 2], 'o', c=curvatures.tolist())
    p2 = axs[1].scatter(sphere_X[:, 0], sphere_X[:, 1], sphere_X[:, 2], 'o', c=predicted_curvatures.tolist())
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    p1.set_clim(-3, 3)
    p2.set_clim(-3, 3)
    os.makedirs(f'{save_path}/sphere/pred', exist_ok=True)
    plt.savefig(f'{save_path}/sphere/pred/{sphere_file.replace(".pt", "")}.png')

    # histogram
    plt.figure()
    plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
    plt.axvline(curvatures[0], color='r', linestyle='dashed', label='Ground Truth')
    plt.legend(loc='upper right')
    os.makedirs(f'{save_path}/sphere/hist', exist_ok=True)
    plt.savefig(f'{save_path}/sphere/hist/{sphere_file.replace(".pt", "")}.png')
    os.makedirs(f'{save_path}/sphere/subgraph', exist_ok=True)
    visualize_subgraph(sphere_subgraph_data, sphere_X, dim=3, path=f'{save_path}/sphere/subgraph/{sphere_file.replace(".pt", "")}')

total_loss = torch.mean(torch.stack(losses))
print(f"Total loss: {total_loss.item()}")

In [None]:
for r in [2.31, 1.032]:
    sphere_file = f'sphere_dim_2_rad_{r}_nodes_10000_k_10.pt'
    sphere_X, sphere_subgraph_data = get_data(data_dir.replace('train','val'), sphere_file, scale_features=scale_features, subgraph_k=subgraph_k, degree_features=degree_features)
    features, curvatures, predicted_curvatures, metric = predict_fn(model, sphere_subgraph_data)
    if task == 'classification':
        print(f"Accuracy: {torch.tensor(metric).mean()}")
    else:
        print(f"Mean loss: {torch.mean(torch.stack(metric))}")
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    # plot results
    # create two adjacent 3d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    axs[0].set_zlim(-3, 3)
    axs[1].set_zlim(-3, 3)
    p1 = axs[0].scatter(sphere_X[:, 0], sphere_X[:, 1], sphere_X[:, 2], 'o', c=curvatures.tolist())
    p2 = axs[1].scatter(sphere_X[:, 0], sphere_X[:, 1], sphere_X[:, 2], 'o', c=predicted_curvatures.tolist())
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    p1.set_clim(-3, 3)
    p2.set_clim(-3, 3)
    os.makedirs(f'{save_path}/sphere/pred', exist_ok=True)
    plt.savefig(f'{save_path}/sphere/pred/{sphere_file.replace(".pt", "")}.png')

    # histogram
    plt.figure()
    plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
    plt.axvline(curvatures[0], color='r', linestyle='dashed', label='Ground Truth')
    plt.legend(loc='upper right')
    os.makedirs(f'{save_path}/sphere/hist', exist_ok=True)
    os.makedirs(f'{save_path}/sphere/subgraph', exist_ok=True)
    plt.savefig(f'{save_path}/sphere/hist/{sphere_file.replace(".pt", "")}.png')
    visualize_subgraph(sphere_subgraph_data, sphere_X, dim=3, path=f'{save_path}/sphere/subgraph/{sphere_file.replace(".pt", "")}')

In [None]:
Rh = 2
for k in [-0.25, -0.5, -0.75, -1.0, -1.25, -1.5, -1.75, -2.0]:
    poincare_file = f'poincare_K_{k}_nodes_10000_Rh_{Rh}_k_10.pt'
    poincare_X, poincare_subgraph_data = get_data(os.path.join(data_dir, 'train'), poincare_file, scale_features=scale_features)
    features, curvatures, predicted_curvatures, losses = predict_curvature(model, poincare_subgraph_data)
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    # # plot results
    # # create two adjacent 2d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    p1 = axs[0].scatter(poincare_X[:, 0], poincare_X[:, 1], c=curvatures.tolist())
    p2 = axs[1].scatter(poincare_X[:, 0], poincare_X[:, 1], c=predicted_curvatures.tolist())
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    p1.set_clim(-3, 3)
    p2.set_clim(-3, 3)
    os.makedirs(f'{save_path}/poincare/pred', exist_ok=True)
    plt.savefig(f'{save_path}/poincare/pred/{poincare_file.replace(".pt", "")}.png')

    # histogram
    plt.figure()
    plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
    plt.axvline(curvatures[0], color='r', linestyle='dashed', label='Ground Truth')
    plt.legend(loc='upper right')
    os.makedirs(f'{save_path}/poincare/hist', exist_ok=True)
    os.makedirs(f'{save_path}/poincare/subgraph', exist_ok=True)
    plt.savefig(f'{save_path}/poincare/hist/{poincare_file.replace(".pt", "")}.png')
    visualize_subgraph(poincare_subgraph_data, poincare_X, dim=2, path=f'{save_path}/poincare/subgraph/{poincare_file.replace(".pt", "")}')


In [None]:
Rh = 2
for k in [-0.375, -1.875]:
    poincare_file = f'poincare_K_{k}_nodes_10000_Rh_{Rh}_k_10.pt'
    poincare_X, poincare_subgraph_data = get_data(data_dir.replace('train','val'), poincare_file, scale_features=scale_features, subgraph_k=subgraph_k, degree_features=degree_features)
    features, curvatures, predicted_curvatures, metric = predict_fn(model, poincare_subgraph_data)
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    if task == 'classification':
        print(f"Accuracy: {torch.tensor(metric).mean()}")
    else:
        print(f"Mean loss: {torch.mean(torch.stack(metric))}")
        
    # # plot results
    # # create two adjacent 2d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    p1 = axs[0].scatter(poincare_X[:, 0], poincare_X[:, 1], c=curvatures.tolist())
    p2 = axs[1].scatter(poincare_X[:, 0], poincare_X[:, 1], c=predicted_curvatures.tolist())
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    p1.set_clim(-3, 3)
    p2.set_clim(-3, 3)
    os.makedirs(f'{save_path}/poincare/pred', exist_ok=True)
    plt.savefig(f'{save_path}/poincare/pred/{poincare_file.replace(".pt", "")}.png')

    # histogram
    plt.figure()
    plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
    plt.axvline(curvatures[0], color='r', linestyle='dashed', label='Ground Truth')
    plt.legend(loc='upper right')
    os.makedirs(f'{save_path}/poincare/hist', exist_ok=True)
    os.makedirs(f'{save_path}/poincare/subgraph', exist_ok=True)
    plt.savefig(f'{save_path}/poincare/hist/{poincare_file.replace(".pt", "")}.png')
    visualize_subgraph(poincare_subgraph_data, poincare_X, dim=2, path=f'{save_path}/poincare/subgraph/{poincare_file.replace(".pt", "")}')



In [None]:
euclidean_file = 'euclidean_dim_2_rad_1_nodes_10000_k_10.pt'
euclidean_X, euclidean_subgraph_data = get_data(os.path.join(data_dir, 'train'), euclidean_file, scale_features=scale_features)
features, curvatures, predicted_curvatures, losses = predict_curvature(model, euclidean_subgraph_data)
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# # plot results
# # create two adjacent 2d subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].set_title('Ground Truth')
axs[1].set_title('Predicted')
p1 = axs[0].scatter(euclidean_X[:, 0], euclidean_X[:, 1], c=curvatures.tolist())
p2 = axs[1].scatter(euclidean_X[:, 0], euclidean_X[:, 1], c=predicted_curvatures.tolist())
fig.colorbar(p1, ax=axs.ravel().tolist())
fig.colorbar(p2, ax=axs.ravel().tolist())
p1.set_clim(-3, 3)
p2.set_clim(-3, 3)
os.makedirs(f'{save_path}/euclidean/pred', exist_ok=True)
plt.savefig(f'{save_path}/euclidean/pred/{euclidean_file.replace(".pt", "")}.png')

# histogram
plt.figure()
plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
plt.axvline(curvatures[0], color='r', linestyle='dashed', label='Ground Truth')
plt.legend(loc='upper right')
os.makedirs(f'{save_path}/euclidean/hist', exist_ok=True)
plt.savefig(f'{save_path}/euclidean/hist/{euclidean_file.replace(".pt", "")}.png')
os.makedirs(f'{save_path}/euclidean/subgraph', exist_ok=True)
visualize_subgraph(euclidean_subgraph_data, euclidean_X, dim=2, path=f'{save_path}/euclidean/subgraph/{euclidean_file.replace(".pt", "")}')



In [None]:
# generate new data
hyperboloid_file = 'hyperbolic_nodes_10000_k_10.pt'
# assert torus_file in data_files
hyperboloid_X, hyperboloid_subgraph_data = get_data(data_dir, hyperboloid_file, scale_features=scale_features)

features, curvatures, predicted_curvatures, losses = predict_curvature(model, hyperboloid_subgraph_data)
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# plot results
# create two adjacent 3d subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
axs[0].set_title('Ground Truth')
axs[1].set_title('Predicted')
axs[0].set_zlim(-3, 3)
axs[1].set_zlim(-3, 3)
p1 = axs[0].scatter(hyperboloid_X[:, 0], hyperboloid_X[:, 1], hyperboloid_X[:, 2], 'o', c=curvatures.tolist())
p2 = axs[1].scatter(hyperboloid_X[:, 0], hyperboloid_X[:, 1], hyperboloid_X[:, 2], 'o', c=predicted_curvatures.tolist())
fig.colorbar(p1, ax=axs.ravel().tolist())
fig.colorbar(p2, ax=axs.ravel().tolist())
p1.set_clim(-3, 0.5)
p2.set_clim(-3, 0.5)
os.makedirs(f'{save_path}/hyperboloid/pred', exist_ok=True)
plt.savefig(f'{save_path}/hyperboloid/pred/{hyperboloid_file.replace(".pt", "")}.png')


# histogram
plt.figure()
plt.scatter(hyperboloid_X[:,2], predicted_curvatures, label='Predicted')
plt.scatter(hyperboloid_X[:,2], curvatures, label='Ground Truth')
plt.legend(loc='upper right')
os.makedirs(f'{save_path}/hyperboloid/hist', exist_ok=True)
plt.savefig(f'{save_path}/hyperboloid/hist/{hyperboloid_file}.png')
os.makedirs(f'{save_path}/hyperboloid/subgraph', exist_ok=True)
visualize_subgraph(hyperboloid_subgraph_data, hyperboloid_X, dim=3, path=f'{save_path}/hyperboloid/subgraph/{hyperboloid_file.replace(".pt", "")}')


In [None]:
# generate new data
paraboloid_file = 'paraboloid_nodes_a_1_b_-1_10000_k_10.pt'
# assert torus_file in data_files
paraboloid_X, paraboloid_subgraph_data = get_data(data_dir.replace('train', ''), paraboloid_file, scale_features=scale_features, subgraph_k=subgraph_k, degree_features=degree_features)

features, curvatures, predicted_curvatures, losses = predict_curvature(model, paraboloid_subgraph_data, scale_features=scale_features)
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# plot results
# create two adjacent 3d subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
axs[0].set_title('Ground Truth')
axs[1].set_title('Predicted')
axs[0].set_zlim(-3, 3)
axs[1].set_zlim(-3, 3)
p1 = axs[0].scatter(paraboloid_X[:, 0], paraboloid_X[:, 1], paraboloid_X[:, 2], 'o', c=curvatures.tolist())
p2 = axs[1].scatter(paraboloid_X[:, 0], paraboloid_X[:, 1], paraboloid_X[:, 2], 'o', c=predicted_curvatures.tolist())
fig.colorbar(p1, ax=axs.ravel().tolist())
fig.colorbar(p2, ax=axs.ravel().tolist())
p1.set_clim(-5, 3)
p2.set_clim(-5, 3)
os.makedirs(f'{save_path}/paraboloid/pred', exist_ok=True)
plt.savefig(f'{save_path}/paraboloid/pred/{paraboloid_file.replace(".pt", "")}.png')

os.makedirs(f'{save_path}/paraboloid/subgraph', exist_ok=True)
visualize_subgraph(paraboloid_subgraph_data, paraboloid_X, dim=3, path=f'{save_path}/paraboloid/subgraph/{paraboloid_file.replace(".pt", "")}')


In [None]:
plt.hist(predicted_curvatures, bins=100, alpha=0.5, label='Predicted')
plt.hist(curvatures, bins=100, alpha=0.5, label='Ground Truth')