In [None]:
import torch
from src.models.gnn import GCNRegressor, GATRegressor
import os
from tqdm import tqdm
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
from generate_data import *

architectures = {
    'gcn': GCNRegressor,
    'gat': GATRegressor
}

In [None]:
# Load the model
hidden_channels = 256
features_max_k = 2500

# data parameters
N = 5000
k = 10
subgraph_k = 1 # hops in subgraph

model_name = "gcn_2_22_no_edge"
arch = 'gat'

model_path = f'outputs/{model_name}/nn/best_val.pt'
save_path = f'outputs/{model_name}/plots'
os.makedirs(save_path, exist_ok=True)

model = architectures[arch](features_max_k, hidden_channels)
model.load_state_dict(torch.load(model_path))
print(model)

In [None]:
def predict_curvature(model, data, subgraph_k):
    # for each node in graph, create subgraph and pass through GCN
    # return tuple of (features, curvature, predicted curvature)
    subgraph_list = create_subgraphs(data, subgraph_k=subgraph_k)
    features = []
    curvatures = []
    predicted_curvatures = []

    for subgraph in subgraph_list:
        with torch.no_grad():
            pred = model(subgraph)
            features.append(subgraph.x)
            curvatures.append(subgraph.y)
            predicted_curvatures.append(pred)
    return features, curvatures, predicted_curvatures
    
    
def create_subgraphs(full_graph, subgraph_k=1):
    subgraph_list = []
    for i in tqdm(range(full_graph.x.shape[0]), desc=f'Processing Graph'):
        subgraph_nodes, subgraph_edges, mapping, _ = k_hop_subgraph(i, subgraph_k, full_graph.edge_index.transpose(0,1).long(), relabel_nodes=True, directed=False)
        subgraph = Data(x=full_graph.x[subgraph_nodes], edge_index=subgraph_edges, edge_attr=full_graph.edge_attr[subgraph_edges[0]], y=full_graph.y[i])
        subgraph_list.append(subgraph)
    return subgraph_list

In [None]:
# # Plot the results for the torus
# # generate new data
rads = [(1, 2)]
for inner_radius, outer_radius in rads:
    torus_data, X = create_torus_dataset(inner_radius, outer_radius, N, k, features_max_k=features_max_k)

    # predict curvature
    features, curvatures, predicted_curvatures = predict_curvature(model, torus_data, subgraph_k)
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    # # plot results
    # # create two adjacent 3d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    axs[0].set_zlim(-3, 3)
    axs[1].set_zlim(-3, 3)
    p1 = axs[0].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=curvatures.tolist())
    p2 = axs[1].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=predicted_curvatures.tolist())
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    plt.savefig(f'{save_path}/{model_name}_model_torus_inrad_{inner_radius}_outrad_{outer_radius}_pred.png')

In [None]:
# Plot the results for the sphere
# generate new data
Rs = [1, 2]
for R in Rs:
    sphere_data, X = create_sphere_dataset(R, N, features_max_k=features_max_k)
    # predict curvature
    features, curvatures, predicted_curvatures = predict_curvature(model, sphere_data, subgraph_k)
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

    # plot results
    # create two adjacent 3d subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Predicted')
    axs[0].set_zlim(-3, 3)
    axs[1].set_zlim(-3, 3)

    p1 = axs[0].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=curvatures.tolist(), vmin=-3, vmax=3)
    p2 = axs[1].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=predicted_curvatures.tolist(), vmin=-3, vmax=3)
    fig.colorbar(p1, ax=axs.ravel().tolist())
    fig.colorbar(p2, ax=axs.ravel().tolist())
    plt.savefig(f'{save_path}/{model_name}_model_sphere_R_{R}_pred.png')
    # plot histogram of predicted curvatures for sphere
    plt.figure()
    plt.hist(predicted_curvatures, bins=100)
    plt.title(f'Predicted Curvatures for Sphere, R = {R}')
    plt.savefig(f'{save_path}/{model_name}_model_sphere_R_{R}_hist.png')

In [None]:
# generate data for poincare disk
k = 10
Rh = 1
Ks = [-2, -1]
for K in Ks:
    poincare_data, X = create_poincare_dataset(N, K, k, Rh, features_max_k=features_max_k)
    # predict curvature
    features, curvatures, predicted_curvatures = predict_curvature(model, poincare_data, subgraph_k=subgraph_k)
    curvatures = torch.cat(curvatures, dim=0).numpy()
    predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()
    plt.figure()
    # plot results
    plt.hist(predicted_curvatures, bins=100, range=(-10, 10))
    plt.title(f'Predicted Curvatures for Poincare Disc, K = {K}')
    plt.savefig(f'{save_path}/{model_name}_model_poincare_K_{K}_hist.png')

In [None]:
# generate euclidean data
d = 2
rad = 1
k = 10
euclidean_data, X = create_euclidean_dataset(N, d, rad, k, features_max_k=features_max_k)
# predict curvature
features, curvatures, predicted_curvatures = predict_curvature(model, euclidean_data, subgraph_k)
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# plot results
plt.hist(predicted_curvatures, bins=100, range=(-10, 10))
plt.title('Predicted Curvatures for Euclidean Space')
plt.savefig(f'{save_path}/{model_name}_model_euclidean_hist.png')
plt.show()

In [None]:
# Plot the results for the sphere

# generate new data
hyperboloid_data, X = create_hyperbolic_dataset(N, features_max_k=features_max_k)
# predict curvature
features, curvatures, predicted_curvatures = predict_curvature(model, hyperboloid_data, subgraph_k)
curvatures = torch.cat(curvatures, dim=0).numpy()
predicted_curvatures = torch.cat(predicted_curvatures, dim=0).squeeze(-1).numpy()

# plot results
# create two adjacent 3d subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
axs[0].set_title('Ground Truth')
axs[1].set_title('Predicted')
axs[0].set_zlim(-3, 3)
axs[1].set_zlim(-3, 3)
p1 = axs[0].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=curvatures.tolist())
p2 = axs[1].scatter(X[:, 0], X[:, 1], X[:, 2], 'o', c=predicted_curvatures.tolist())
fig.colorbar(p1, ax=axs.ravel().tolist())
fig.colorbar(p2, ax=axs.ravel().tolist())
plt.savefig(f'{save_path}/{model_name}_model_hyperboloid_pred.png')