In [38]:
import numpy as np
import graphtools
import pathlib
from tqdm import tqdm

def sample_non_inf_pairs(distance_matrix, num_pairs=20, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    # Identify non-infinity entries in the distance matrix
    non_inf_indices = np.argwhere(np.isfinite(distance_matrix))
    
    # Convert to list of tuples
    non_inf_pairs = [tuple(pair) for pair in non_inf_indices if pair[0] != pair[1]]  # Exclude diagonal elements
    
    # Check if there are enough pairs to sample from
    if len(non_inf_pairs) < num_pairs:
        raise ValueError("Not enough non-infinity pairs to sample from.")
    
    # Randomly sample 20 pairs using numpy
    sampled_pairs_indices = np.random.choice(len(non_inf_pairs), size=num_pairs, replace=False)
    sampled_pairs = [non_inf_pairs[i] for i in sampled_pairs_indices]
    
    return sampled_pairs


In [39]:
root_path = "../../synthetic_data4/"
data_names = ["true_1_groups_17580_3000_1_all","true_1_paths_17580_3000_1_all","true_2_groups_17580_3000_1_all","true_2_paths_17580_3000_1_all","true_3_groups_17580_3000_1_all","true_3_paths_17580_3000_1_all","true_4_groups_17580_3000_1_all","true_4_paths_17580_3000_1_all","true_5_groups_17580_3000_1_all","true_5_paths_17580_3000_1_all"]
save_path =  "../../synthetic_data4/gt/"
pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)

In [40]:
for name in tqdm(data_names):
    data = np.load(f"{root_path}/{name}.npz")
    G = graphtools.Graph(data['data'], knn=5, decay=None)
    D = G.shortest_path(distance='data')
    pairs = sample_non_inf_pairs(D, seed=432)
    ds = [D[p[0], p[1]] for p in pairs]
    np.savez(f"{save_path}/{name}.npz", pairs=pairs, ds=ds)

100%|██████████| 10/10 [01:33<00:00,  9.32s/it]
