In [None]:
import numpy as np
import torch 
from framework.SMS import import_dataset_fromSMS
import networkx as nx

distance_mode = "linear_interpolation" # "linear_interpolation" or "dijkstra"

datasetName = "swiss1"
datasetSuffix = "-500"
datasetPath = "data/SMS/" + datasetName  + datasetSuffix + "/"
datasetSMS = import_dataset_fromSMS(datasetPath)

CORRUPTED_NODES = 70

sim = list(datasetSMS.keys())[0]
mat = datasetSMS[sim]['adjacency_matrix']
num_nodes = mat.shape[0]
p_vectors_array = datasetSMS[sim]['p_array']
dimP = p_vectors_array.shape[1]

torch_points_labels=torch.tensor([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES)


#plot_graph_from_adjacency_matrix(mat, node_color_scalars=np.sum(p_vectors_array, axis=1), cmap='plasma')

def read_matrix_from_csv_loadtxt(filepath, delimiter=','):
  """
  Reads a NumPy matrix from a CSV file using np.loadtxt().

  Args:
    filepath (str): The path to the CSV file.
    delimiter (str): The character separating values in the CSV file (default is comma).

  Returns:
    numpy.ndarray: The matrix read from the CSV file.
  """
  try:
    matrix = np.loadtxt(filepath, delimiter=delimiter)
    print(f"Successfully loaded matrix from {filepath} using np.loadtxt().")
    return matrix
  except FileNotFoundError:
    print(f"Error: The file '{filepath}' was not found.")
    return None
  except Exception as e:
    print(f"An error occurred while loading the file: {e}")
    return None


path = "data/SMS/" + datasetName  + datasetSuffix +"/sim_" + str(sim)+ "/" + datasetName
p_vectors_array = (read_matrix_from_csv_loadtxt(path + "_p_matrix.csv"))
true_p_vectors_array = (read_matrix_from_csv_loadtxt(path + "_true_p_matrix.csv"))
dimP = p_vectors_array.shape[1]
for x in datasetSMS.values():
    x["p_array"]=p_vectors_array


In [None]:
from framework.trainFct import *
from torch_geometric.data import Data
from framework.visuals import *
from scipy.sparse.csgraph import shortest_path


latent_dim = 2
input_dim = dimP
batch_size = 16

encoder_hidden_dims=[128, 64, 32]
adj_decoder_hidden_dims=[64, 64, 32]
node_decoder_hidden_dims=[64, 64, 32]
gcn_layers=3
fc_layers=2

dataset = []
for x in datasetSMS.values():
    # Create PyG data object
    x["distance_matrix"] = shortest_path(x["adjacency_matrix"], directed=False, unweighted=False)
    data = Data(x=torch.tensor(x["p_array"], dtype=torch.float), 
                edge_index=adj_matrix_to_edge_index(x["distance_matrix"])[0], 
                edge_labels=adj_matrix_to_edge_index(x["distance_matrix"])[1],
                adjacency_matrix=torch.tensor(x["distance_matrix"]))
    dataset.append(data)

# Select a single graph to train on
single_graph = dataset[0]

# Wrap in list for compatibility with DataLoader-like expectations
single_graph_list = [single_graph]


dist_mat = shortest_path(x["distance_matrix"], directed=False, unweighted=False)

G = nx.from_numpy_array(dist_mat)
G.remove_edges_from(nx.selfloop_edges(G)) # Remove self-loops


In [None]:
phase1_epochs = 1500
phase2_epochs = 200
lr_phase1 = 0.005
latent_dim = 2
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = MLPEncoder(
    input_dim=input_dim,
    hidden_dims=[16, 16],
    latent_dim=latent_dim,
    mlp_layers=2,
    dropout=0.2,
    activation=nn.ELU()
)


node_decoder = NodeAttributeVariationalDecoder(
    latent_dim=latent_dim,
    output_dim=input_dim,
    #hidden_dims=[5000, 128],
    #hidden_dims=[2000, 128],
    hidden_dims=[16],
    dropout=0,
    activation=nn.ELU(),
)

# Create KL annealing scheduler
kl_scheduler = KLAnnealingScheduler(
    anneal_start=0.0,
    #anneal_end=0.001,
    #anneal_end=0.8,
    anneal_end=2,
    anneal_steps=phase1_epochs * len(single_graph_list),
    anneal_type='sigmoid',
)

# Create initial model with only node decoder
model_phase1 = GraphVAE(
    encoder=encoder,
    decoders=[node_decoder],
    kl_scheduler=kl_scheduler,
    compute_latent_manifold=False,
)


In [None]:
import os

if os.path.exists("model_phase1_swissSMS.pth"):
    print("Loading pretrained model")
    model_phase1.load_state_dict(torch.load('model_phase1_swissSMS.pth'))
else:
    print("=== Starting Phase 1: Training encoder with node feature reconstruction ===")

    # Phase 1 training
    history_phase1 = train_phase1(
        model=model_phase1,
        data_loader=single_graph_list,
        num_epochs=phase1_epochs,
        lr=lr_phase1,
        weight_decay=1e-5,
        verbose=True,
        device=device,
        loss_coefficient=1
    )

    print("\n=== Phase 1 Complete ===")

    torch.save(model_phase1.state_dict(), 'model_phase1_swissSMS.pth')
    print("\n=== Phase 1 Saved ===")

    visualize_training(history_phase1)
    visualize_node_features_reconstruction(model_phase1, single_graph, sample_features=dimP)
    visualize_latent_space(model_phase1, [single_graph])


In [None]:
import copy
from framework.torchVersions.distanceApproximations import DistanceApproximations
from framework.boundedManifold import BoundedManifold

model_phase1 = model_phase1.to('cpu')
model_phase1.encoder.eval()
model_phase2 = copy.deepcopy(model_phase1)

model_phase1.eval()

with torch.no_grad():
    x = single_graph.x.to(device)
    edge_index = single_graph.edge_index.to(device)
    latent_mu = model_phase1.encode(x, edge_index=edge_index)

latent_points = latent_mu[0]

model_phase1.set_compute_latent_manifold(True)
model_phase1.construct_latent_manifold(bounds=BoundedManifold.hypercube_bounds(latent_points, margin=0.1, relative=True), force=True)
model_phase1.get_latent_manifold().compute_full_grid_metric_tensor()
model_phase1.get_latent_manifold().visualize_manifold_curvature(data_points=latent_points, labels=torch.tensor([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))

with torch.no_grad():
    if distance_mode == "linear_interpolation":
        dists_phase1 = model_phase1.get_latent_manifold().create_riemannian_distance_matrix(latent_points, 
                                                                                            DistanceApproximations.linear_interpolation_distance, num_points=20)
    elif distance_mode == "dijkstra":
        dists_phase1 = model_phase1.get_latent_manifold().get_grid_as_graph().compute_shortest_paths(
                            latent_points,
                            weight_type="geodesic",  # Uses your metric tensors
                            max_grid_neighbors=8,     # Connect to up to 8 nearest grid nodes
                            num_threads=6
                        )


In [None]:
from framework.synthetic_manifold import get_metric

true_metric = get_metric("soft_swiss", seed=42)

def get_theoretical_metric_tensor(point: torch.tensor):
    return torch.tensor(true_metric(point[0], point[1], dimP))
TheoreticalManifold = BoundedManifold(get_theoretical_metric_tensor, bounds=BoundedManifold.hypercube_bounds(torch.tensor(true_p_vectors_array), margin=0.1, relative=True))
TheoreticalManifold.visualize_manifold_curvature(data_points=torch.tensor(true_p_vectors_array), labels=torch_points_labels)

In [None]:
from framework.synthetic_manifold import ManifoldConfig, generate_manifold_data, approx_jacobian_rank, IMMERSIONS, load_dataset, metric_pullback_alignment_error, geodesic_benchmark

X, Z_true, meta, Z_vae, immersion, true_metric = load_dataset(datasetPath + "soft_swiss")

def VAE_metric_phase1(point : np.ndarray):
    return model_phase1.get_latent_manifold().metric_tensor(torch.tensor(point)).detach().numpy()

summary = metric_pullback_alignment_error(
    seeds_uv=true_p_vectors_array,
    codes_z=latent_points.detach().numpy(),
    true_metric_func=true_metric,
    vae_metric_func=VAE_metric_phase1,
    D=20,
    k=12,
)

geo = geodesic_benchmark(
    seeds_uv=true_p_vectors_array,
    codes_z=latent_points.detach().numpy(),
    true_metric_func=true_metric,
    vae_metric_func=VAE_metric_phase1,
    D=20,
    k_graph=12,
    subsample=500,  # speed
)
print("Geodesic distance correlation:", geo["pairwise_corr"])
print("Mean pullback metric rel. error:", summary["mean_rel_error"])
print("Volume element corr (sqrt(det G)):", summary["volume_corr"])


In [None]:
true_p_vectors_tensor = torch.tensor(true_p_vectors_array)
# theoretical_distances = TheoreticalManifold.create_riemannian_distance_matrix(true_p_vectors_tensor,
#                                                                             DistanceApproximations.linear_interpolation_distance, 
#                                                                             batch_size=8, num_points=20)

if distance_mode == "linear_interpolation":
    theoretical_distances = TheoreticalManifold.create_riemannian_distance_matrix(true_p_vectors_tensor, 
                                                                        DistanceApproximations.linear_interpolation_distance, num_points=20)
elif distance_mode == "dijkstra":
    theoretical_distances = TheoreticalManifold.get_grid_as_graph().compute_shortest_paths(
                        true_p_vectors_tensor,
                        weight_type="geodesic",  # Uses your metric tensors
                        max_grid_neighbors=8,     # Connect to up to 8 nearest grid nodes
                        num_threads=6
                    )

In [None]:
plot_correlogram(theoretical_distances/torch.max(theoretical_distances), dists_phase1/torch.max(dists_phase1),
                titles=["Theoretical Distances (Linear estimation)", "Phase 1 Distances (Linear estimation)"])

In [None]:
def compute_curvature_change(model1: GraphVAE, model2: GraphVAE):
    resolution = 30
    bounds_np = model1.get_latent_manifold().get_bounds().cpu().numpy()
    plot_z1 = np.linspace(bounds_np[0, 0], bounds_np[0, 1], resolution)
    plot_z2 = np.linspace(bounds_np[1, 0], bounds_np[1, 1], resolution)

    Z1_np, Z2_np = np.meshgrid(plot_z1, plot_z2)
    Z1, Z2 = torch.from_numpy(Z1_np), torch.from_numpy(Z2_np)
                
    curvature_phase1 = torch.zeros((resolution, resolution))
    curvature_phase2 = torch.zeros((resolution, resolution))

    for i in range(resolution):
        for j in range(resolution):
            z = torch.stack([Z1[i, j], Z2[i, j]])
            clamped_z = model1.get_latent_manifold()._clamp_point_to_bounds(z)
            try:
                curv_val_1 = model1.get_latent_manifold().compute_gaussian_curvature(model1.get_latent_manifold().metric_tensor(clamped_z))
                curvature_phase1[i, j] = curv_val_1
            except (ValueError, RuntimeError) as e:
                print(f"Error computing curvature at point {z}: {e}. Setting to NaN.")
                curvature_phase1[i, j] = torch.nan

            try:
                curv_val_2 = model2.get_latent_manifold().compute_gaussian_curvature(model2.get_latent_manifold().metric_tensor(clamped_z))
                curvature_phase2[i, j] = curv_val_2
            except (ValueError, RuntimeError) as e:
                print(f"Error computing curvature at point {z}: {e}. Setting to NaN.")
                curvature_phase2[i, j] = torch.nan
            
    curvature_diff = (curvature_phase1 - curvature_phase2)/(curvature_phase1 + 1e-8)
    curvature_diff = curvature_diff.detach().numpy()

    return curvature_diff, Z1_np, Z2_np

def intermediary_diagnostics(model, data_loader, epoch):
    with torch.no_grad():
        model.encoder.eval()
        x = data_loader[0].x.to(device)
        edge_index = data_loader[0].edge_index.to(device)
        latent_mu = model.encode(x, edge_index=edge_index)

    latent_points = latent_mu[0]
    curvature_decoder = model.get_decoder("adj_decoder")
    distances = curvature_decoder.compute_distance_matrix(latent_points)
    L_manifold = compute_manifold_laplacian(distances=distances,
                                            sigma=curvature_decoder.sigma_ema,
                                            laplacian_regularization=curvature_decoder.laplacian_regularization)
    K_manifold = compute_heat_kernel_from_laplacian(L_manifold, curvature_decoder.heat_times)

    #divergence = compute_heat_kernel_divergence(K_manifold, curvature_decoder.K_graph)
    batch_size = 5
    diffs_mat = None
    diffs_mat_norm = None
    for i in range(0, len(K_manifold), batch_size):
        batch_K = K_manifold[i:i+batch_size]
        batch_G = curvature_decoder.K_graph[i:i+batch_size]
        batch_times = curvature_decoder.heat_times[i:i+batch_size]

        # compute differences
        diffs = []
        names = []
        for M, G, t in zip(batch_K, batch_G, batch_times):
            trace_manifold = torch.trace(M)
            trace_graph    = torch.trace(G)

            if trace_manifold > 1e-8:
                K_manifold_norm = M * (trace_graph / trace_manifold)
            else:
                K_manifold_norm = M

            # Compute Frobenius norm of difference
            diff = (K_manifold_norm - G)**2
            if diffs_mat is None:
                diffs_mat = diff
                diffs_mat_norm = diff / torch.max(diff)
            else:
                diffs_mat += diff
                diffs_mat_norm += diff / torch.max(diff)
            diffs.append(diff / torch.max(diff))
            names.append(f"Time={t:.2f}")

        plot_correlogram(*diffs, titles=names, cmap="gist_yarg")
    plot_correlogram(diffs_mat/ torch.max(diffs_mat), diffs_mat_norm/ torch.max(diffs_mat_norm), cmap="gist_yarg", 
                        titles=["Total loss", "Total loss (normalized at each step)"])
    

    curvature_diff, Z1_np, Z2_np = compute_curvature_change(model_phase1, model)
    model.get_latent_manifold()._plot_manifold_grid(curvature_diff, Z1_np, Z2_np, latent_points=latent_points, labels=torch_points_labels, name="Curvature variation btwn Phases 1 and 2")
    #model.get_latent_manifold().visualize_manifold_curvature(data_points=latent_points)

In [None]:
from framework.KLAnnealingScheduler import NoKLScheduler

lr_phase2 = 0.0005

print("=== Starting Phase 2: Freezing encoder and adding adjacency decoder ===")

model_phase2.set_compute_latent_manifold(True)
model_phase2.construct_latent_manifold(bounds=BoundedManifold.hypercube_bounds(latent_points, margin=0.1, relative=True), force=True)
model_phase2.set_encoder_freeze(True)


distance_decoder = ManifoldHeatKernelDecoder(
    distance_mode=distance_mode,
    latent_dim=latent_dim,
    num_eigenvalues=500,
    num_integration_points=20,
    name="adj_decoder",
    ema_lag_factor=0.1,
    num_heat_time=50,
)

# Add to your GraphVAE model
model_phase2.add_decoder(distance_decoder)

# Set reference decoder (the node attribute decoder)
#model_phase2.get_decoder("adj_decoder").giveManifoldInstance(model_phase2.get_latent_manifold())
model_phase2.get_decoder("adj_decoder").giveVAEInstance(model_phase2)

# Reset KL scheduler for phase 2
model_phase2.kl_scheduler = NoKLScheduler()

# Phase 2 training
history_phase2 = train_phase2(
    model=model_phase2,
    data_loader=single_graph_list,
    latent_points=latent_points,
    num_epochs=phase2_epochs,
    lr=lr_phase2,
    weight_decay=1e-5,
#    decoder_weights={"adj_decoder": -1, "node_attr_decoder":-1 },
    decoder_weights={"adj_decoder": 1, "node_attr_decoder":0 },
    verbose=True,
    device=device,
    intermediary_diagnostics=intermediary_diagnostics,
)

print("\n=== Phase 2 Complete ===")



In [None]:
visualize_node_features_reconstruction(model_phase2, single_graph, sample_features=dimP)
with torch.no_grad():
    x = single_graph.x.to(device)
    edge_index = single_graph.edge_index.to(device)
    latent_mu2 = model_phase2.encode(x, edge_index=edge_index)

model_phase1.get_latent_manifold().visualize_manifold_curvature(data_points=latent_mu2[0], labels=torch_points_labels)
model_phase2.get_latent_manifold().visualize_manifold_curvature(data_points=latent_mu2[0], labels=torch_points_labels)

In [None]:
# Combine histories
combined_history = {
    "phase1": history_phase1,
    "phase2": history_phase2
}
visualize_training(history_phase2)

merged_history = {}
for key in history_phase1.keys():
  if isinstance(history_phase1[key], list):
    merged_history[key] = history_phase1[key] + history_phase2[key]
  else:
     merged_history[key] = dict()
     #for key2 in history_phase1[key].keys():
        #merged_history[key][key2] = history_phase1[key][key2] + history_phase2[key][key2]
      
     for key2 in history_phase2[key].keys():
        if key2 not in history_phase1[key].keys():
          merged_history[key][key2] = [np.nan]*len(history_phase1["kl_loss"]) + history_phase2[key][key2]

visualize_training(merged_history)


In [None]:
with torch.no_grad():
    if distance_mode == "linear_interpolation":
        dists_phase2 = model_phase2.get_latent_manifold().create_riemannian_distance_matrix(latent_points, 
                                                                                            DistanceApproximations.linear_interpolation_distance, num_points=20)
    elif distance_mode == "dijkstra":
        dists_phase2 = model_phase2.get_latent_manifold().get_grid_as_graph().compute_shortest_paths(
                            latent_points,
                            weight_type="geodesic",  # Uses your metric tensors
                            max_grid_neighbors=8,     # Connect to up to 8 nearest grid nodes
                            num_threads=6
                        )
        

In [None]:
curvature_diff, Z1_np, Z2_np = compute_curvature_change(model_phase1, model_phase2)
model_phase2.get_latent_manifold()._plot_manifold_grid(curvature_diff, Z1_np, Z2_np, latent_points=latent_points, labels=torch_points_labels, name="Curvature variation btwn Phases 1 and 2")

In [None]:
if isinstance(dists_phase1, torch.Tensor):
    dists_phase1 = dists_phase1.detach().numpy()
if isinstance(dists_phase2, torch.Tensor):
    dists_phase2 = dists_phase2.detach().numpy()

local_dists_phase1 = dists_phase1
res = np.abs((local_dists_phase1 - dists_phase2))
res3 = np.abs((local_dists_phase1 - dists_phase2)/local_dists_phase1)

res2 = np.abs((local_dists_phase1/np.max(np.where(local_dists_phase1 > 0, local_dists_phase1, 0)) - dists_phase2/np.max(np.where(dists_phase2 > 0, dists_phase2, 0))))

_ = plot_correlogram(local_dists_phase1/np.max(np.where(local_dists_phase1 > 0, local_dists_phase1, 0)), dists_phase2/np.max(np.where(dists_phase2 > 0, dists_phase2, 0)), remove_diagonal=True, triangular=True)

_ = plot_correlogram((res/np.max(res))**2, (res2/np.max(res2))**2)
_ = plot_correlogram(res3, cmap="inferno", titles=["Relative variation of pairwise geodesic distances between phases 1 phase 2"] )

In [None]:
plt.figure(figsize=(10, 6)) # Set the size of the plot for better readability

plt.rcParams['font.family'] = 'sans-serif'
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.sans-serif'] = ['Helvetica', 'Arial', 'DejaVu Sans']
#plt.plot(np.sum(res, axis=0)/np.max(np.sum(res, axis=0), axis=0), marker='o', linestyle='')
#plt.plot(np.sum(res2, axis=0)/np.max(np.sum(res2, axis=0), axis=0), marker='o', linestyle='')
plt.plot(np.sum(np.nan_to_num(res, 0.0), axis=0), marker='o', linestyle='')
# Add a horizontal line at y=0

# --- Customize the Plot ---
#plt.savefig("dists_total_var_phase1_2.png", dpi=500, bbox_inches='tight')

plt.xlabel('Nodes') # Label for the x-axis
plt.ylabel('Residuals') # Label for the y-axis, updated to reflect the ratio
plt.title('Absolute variation at node level between phase 1 and 2 in distances') # Title of the plot, updated
plt.grid(True, linestyle='--', alpha=0.7) # Add a grid for easier reading
plt.tight_layout() # Adjust plot to ensure everything fits without overlapping
plt.show()

In [None]:
correction_matrix = np.where(mat > 0, res, 0)
_ = plot_correlogram(mat, correction_matrix/np.max(correction_matrix))

In [None]:
np.corrcoef(res.flatten(), dist_mat.flatten())

In [None]:
from sklearn.feature_selection import mutual_info_regression
mutual_info_regression(res.flatten().reshape(-1, 1), dist_mat.flatten())

In [None]:
from scipy.stats import spearmanr
spearmanr(res.flatten(), dist_mat.flatten())

In [None]:
plt.figure(figsize=(10, 6))
plt.scatter(res.flatten(), dist_mat.flatten(), alpha=0.6, s=10) # alpha for transparency, s for dot size
plt.title('Scatter Plot of res vs. dist_mat', fontsize=16)
plt.xlabel('Values from res (flattened)', fontsize=14)
plt.ylabel('Values from dist_mat (flattened)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout() # Adjust layout to prevent labels from overlapping
plt.show()

In [None]:
# --- 2. Identify the indices for the last 70 rows/columns ---
num_rows = res.shape[0]
num_cols = res.shape[1]
rows_to_color = 70
cols_to_color = 70


red_mask = np.zeros(res.shape, dtype=bool)
red_mask[num_rows - rows_to_color:, num_cols - cols_to_color:] = True

brown_mask = np.zeros(res.shape, dtype=bool)
brown_mask[num_rows - rows_to_color:, :] = True
brown_mask[:, num_cols - cols_to_color:] = True

# Flatten all data
res_flat = res.flatten()
dist_mat_flat = dist_mat.flatten()
red_mask_flat = red_mask.flatten()
brown_mask_flat = brown_mask.flatten()

# Separate data based on the mask
res_red = res_flat[red_mask_flat]
dist_mat_red = dist_mat_flat[red_mask_flat]
res_brown = res_flat[brown_mask_flat]
dist_mat_brown = dist_mat_flat[brown_mask_flat]

res_blue = res_flat[~red_mask_flat] # Elements NOT in the red region
dist_mat_blue = dist_mat_flat[~red_mask_flat] # Elements NOT in the red region

# --- 3. Create the scatter plot with different colors ---
plt.figure(figsize=(10, 10))

# Plot the 'blue' (non-red) points first
plt.scatter(res_blue, dist_mat_blue, alpha=0.5, s=10, color='blue', label='Other Values')
# Plot the 'red' (last 70 rows/cols) points second, so they are on top
plt.scatter(res_brown, dist_mat_brown, alpha=0.7, s=15, color='grey', label=f'Last {rows_to_color} rows | last {cols_to_color} cols')
plt.scatter(res_red, dist_mat_red, alpha=0.7, s=15, color='red', label=f'Last {rows_to_color} rows & {cols_to_color} cols')

plt.title('Scatter Plot of res vs. dist_mat with Highlighted Region', fontsize=16)
plt.xlabel('Values from res (flattened)', fontsize=14)
plt.ylabel('Values from dist_mat (flattened)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import scipy.sparse as sp

def keep_k_links(A, k=3, mode='or', selection='top', keep_self_loops=False, return_sparse=True):
    """
    Keep only k strongest or weakest links per node in a weighted symmetric adjacency matrix.

    Parameters
    ----------
    A : (n,n) array-like
        Weighted adjacency matrix. Should be symmetric (function does not enforce).
    k : int
        Number of links to keep per node (default 3).
    mode : {'or', 'and'}
        'or'  -> keep edge (i,j) if i kept j OR j kept i
        'and' -> keep edge only if i kept j AND j kept i
    selection : {'top', 'bottom'}
        'top'    -> keep largest weights (strongest links)
        'bottom' -> keep smallest weights (weakest links)
    keep_self_loops : bool
        If False, diagonal entries are ignored/cleared before selection.
    return_sparse : bool
        If True, returns a scipy.sparse.csr_matrix. Otherwise, returns a dense numpy array.

    Returns
    -------
    A_pruned : scipy.sparse.csr_matrix or np.ndarray
        Pruned symmetric adjacency matrix.
    """

    A = np.asarray(A, dtype=float)
    n = A.shape[0]
    if n != A.shape[1]:
        raise ValueError("A must be square")

    # Ignore self-loops if requested
    if not keep_self_loops:
        np.fill_diagonal(A, np.inf if selection == 'bottom' else -np.inf)

    k_eff = min(k, n - 1 if not keep_self_loops else n)
    if k_eff <= 0:
        np.fill_diagonal(A, 0.0)
        return sp.csr_matrix((n, n)) if return_sparse else np.zeros((n, n))

    # Choose argpartition direction based on selection
    if selection == 'top':
        idx_part = np.argpartition(-A, k_eff - 1, axis=1)[:, :k_eff]
    elif selection == 'bottom':
        idx_part = np.argpartition(A, k_eff - 1, axis=1)[:, :k_eff]
    else:
        raise ValueError("selection must be 'top' or 'bottom'")

    # Build mask for selected entries
    mask = np.zeros_like(A, dtype=bool)
    rows = np.repeat(np.arange(n)[:, None], k_eff, axis=1)
    mask[rows, idx_part] = True

    # Restore diagonal to 0 if we temporarily modified it
    if not keep_self_loops:
        np.fill_diagonal(A, 0.0)

    # Combine masks symmetrically
    if mode == 'or':
        sym_mask = np.logical_or(mask, mask.T)
    elif mode == 'and':
        sym_mask = np.logical_and(mask, mask.T)
    else:
        raise ValueError("mode must be 'or' or 'and'")

    # Apply mask
    pruned = np.zeros_like(A)
    pruned[sym_mask] = A[sym_mask]

    # Enforce perfect symmetry
    pruned = np.maximum(pruned, pruned.T)
    pruned = (pruned + pruned.T) / 2.0

    if return_sparse:
        return sp.csr_matrix(pruned)
    else:
        return pruned


dist1_pruned = keep_k_links(dists_phase1, selection='bottom', k=5, mode='or', return_sparse=False)
dist2_pruned = keep_k_links(dists_phase2, selection='bottom', k=5, mode='or', return_sparse=False)
flop_k_dist = keep_k_links(dist_mat, selection='bottom', k=10, mode='or', return_sparse=False)
flop_k_dist_mask = np.where(flop_k_dist > 0, 1, 0)

In [None]:
flop_k_dist_mask

In [None]:
plot_correlogram(flop_k_dist_mask)

In [None]:
flop_k_dist_mask

In [None]:
flop_res = np.where(flop_k_dist > 0, res, 0)
plot_correlogram(flop_res)

In [None]:
plt.figure(figsize=(10, 6)) # Set the size of the plot for better readability

plt.rcParams['font.family'] = 'sans-serif'
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.sans-serif'] = ['Helvetica', 'Arial', 'DejaVu Sans']
#plt.plot(np.sum(res, axis=0)/np.max(np.sum(res, axis=0), axis=0), marker='o', linestyle='')
#plt.plot(np.sum(res2, axis=0)/np.max(np.sum(res2, axis=0), axis=0), marker='o', linestyle='')
plt.plot(np.sum(np.nan_to_num(flop_res, 0.0), axis=0), marker='o', linestyle='')
# Add a horizontal line at y=0

# --- Customize the Plot ---
#plt.savefig("dists_total_var_phase1_2.png", dpi=500, bbox_inches='tight')

plt.xlabel('Nodes') # Label for the x-axis
plt.ylabel('Residuals') # Label for the y-axis, updated to reflect the ratio
plt.title('Absolute variation at node level between phase 1 and 2 in distances') # Title of the plot, updated
plt.grid(True, linestyle='--', alpha=0.7) # Add a grid for easier reading
plt.tight_layout() # Adjust plot to ensure everything fits without overlapping
plt.show()

In [None]:
res4 = (np.where(dist1_pruned > 0, 1, 0) - np.where(dist2_pruned > 0, 1, 0))**2
res5 = (dist1_pruned - dist2_pruned)**2

plot_correlogram(res4)

In [None]:
plt.figure(figsize=(10, 6)) # Set the size of the plot for better readability

plt.rcParams['font.family'] = 'sans-serif'
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.sans-serif'] = ['Helvetica', 'Arial', 'DejaVu Sans']
#plt.plot(np.sum(res, axis=0)/np.max(np.sum(res, axis=0), axis=0), marker='o', linestyle='')
#plt.plot(np.sum(res2, axis=0)/np.max(np.sum(res2, axis=0), axis=0), marker='o', linestyle='')
plt.plot(np.sum(np.nan_to_num(res5, 0.0), axis=0), marker='o', linestyle='')
# Add a horizontal line at y=0

# --- Customize the Plot ---
#plt.savefig("dists_total_var_phase1_2.png", dpi=500, bbox_inches='tight')

plt.xlabel('Nodes') # Label for the x-axis
plt.ylabel('Residuals') # Label for the y-axis, updated to reflect the ratio
plt.title('Absolute variation at node level between phase 1 and 2 in distances') # Title of the plot, updated
plt.grid(True, linestyle='--', alpha=0.7) # Add a grid for easier reading
plt.tight_layout() # Adjust plot to ensure everything fits without overlapping
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse.csgraph import laplacian
from scipy.linalg import eigh  # For symmetric matrices
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    confusion_matrix,
    f1_score
)
import networkx as nx
import community as community_louvain  # pip install python-louvain


def evaluate_clustering(pred_labels, true_labels):
    """
    Evaluate clustering predictions with several robust metrics.
    Assumes binary classification.
    """

    # Ensure binary and integers
    pred_labels = np.asarray(pred_labels).astype(int)
    true_labels = np.asarray(true_labels).astype(int)

    # Handle label permutation: flip if needed
    acc1 = np.mean(pred_labels == true_labels)
    acc2 = np.mean(1 - pred_labels == true_labels)
    if acc2 > acc1:
        pred_labels = 1 - pred_labels

    # Confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)

    # Metrics
    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels)

    # Report
    print("Evaluation of Spectral Clustering")
    print("--------------------------------")
    print(f"Accuracy               : {max(acc1, acc2):.4f}")
    print(f"Adjusted Rand Index    : {ari:.4f}")
    print(f"Normalized Mutual Info : {nmi:.4f}")
    print(f"F1 Score               : {f1:.4f}")
    print("\nConfusion Matrix:")
    print(cm)
    
    return {
        "accuracy": max(acc1, acc2),
        "ari": ari,
        "nmi": nmi,
        "f1": f1,
        "confusion_matrix": cm
    }

def analyze_laplacian(adj_matrix, true_labels, plot_fiedler=False):
    """
    Performs eigenvalue analysis of the Laplacian matrix of a graph.

    Parameters:
    - adj_matrix: np.ndarray, symmetric adjacency matrix
    - plot_fiedler: bool, whether to plot the Fiedler vector (2nd smallest eigenvector)
    """

    # Ensure the matrix is symmetric
    if not np.allclose(adj_matrix, adj_matrix.T):
        raise ValueError("Adjacency matrix must be symmetric")

    # Compute unnormalized Laplacian
    L = laplacian(adj_matrix, normed=False)

    # Compute eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(L)

    # Plot eigenvalue spectrum
    plt.figure(figsize=(10, 4))
    plt.plot(np.sort(eigenvals), marker='o')
    plt.title("Eigenvalue Spectrum of Laplacian")
    plt.xlabel("Index")
    plt.ylabel("Eigenvalue")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Fiedler vector (2nd smallest eigenvector)
    if plot_fiedler and len(eigenvals) >= 2:
        fiedler_vec = eigenvecs[:, 1]
        plt.figure(figsize=(10, 4))
        plt.plot(fiedler_vec, marker='.')
        plt.title("Fiedler Vector (2nd Smallest Eigenvector)")
        plt.xlabel("Node Index")
        plt.ylabel("Value")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        pred_labels = (fiedler_vec > 0).astype(int)
        print(f"Clustering accuracy (Fiedler): ")
        evaluate_clustering(pred_labels, true_labels)

    sc = SpectralClustering(n_clusters=2, affinity='precomputed', assign_labels='kmeans', random_state=0)
    pred_labels = sc.fit_predict(adj_matrix)
    print(f"Clustering accuracy (SpectralClustering): ")
    evaluate_clustering(pred_labels, true_labels)

    print(f"Clustering accuracy (Louvain): ")
    G = nx.from_numpy_array(adj_matrix)  # Use thresholded graph here!
    partition = community_louvain.best_partition(G)
    pred_labels = np.array([partition[i] for i in range(len(adj_matrix))])
    # Cluster Louvain labels into 2 meta-classes
    kmeans = KMeans(n_clusters=2, n_init=10)
    meta_labels = kmeans.fit_predict(pred_labels.reshape(-1, 1))

    evaluate_clustering(meta_labels, true_labels)

    # Useful output
    num_components = np.sum(np.isclose(eigenvals, 0))
    print(f"Number of connected components (zero eigenvalues): {num_components}")
    print(f"Smallest eigenvalues: {eigenvals[:5]}")
    return eigenvals, eigenvecs


print("==="*20)
print("BASELINE")
eigenvals, eigenvecs = analyze_laplacian(mat, np.array([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))
eigenvals, eigenvecs = analyze_laplacian(dist_mat, np.array([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))
print("==="*20)
print("TEST")
eigenvals, eigenvecs = analyze_laplacian(res4, np.array([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))
eigenvals, eigenvecs = analyze_laplacian(correction_matrix, np.array([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))
eigenvals, eigenvecs = analyze_laplacian(flop_res, np.array([0]*(num_nodes-CORRUPTED_NODES) + [1]*CORRUPTED_NODES))

In [None]:
plot_correlogram(correction_matrix**2)