## Setup

In [1]:
# Imports 
import torch
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
import functools
import torch_geometric as tg
import networkx as nx
import pyarrow
import pyarrow.parquet as pq
import pyarrow.compute as pc

from src.dataset_utils import theta_ds_create
from src.dataset_utils import S_ds_compute

from src.phi import JTFS_forward
from src.jacobian import M_factory
from src.distances import distance_factory
from src.ftm import rectangular_drum
from src.ftm import constants as FTM_constants

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# KNN-G Parameters

n_hubs = 2000
k = 50

## Creating KNN-G

### Dataset

In [10]:
# Choose to read or create the parameters graphset and set the path according to it
read_dataset = True

if read_dataset:
    DatasetPath = "data/precompute_S/param_dataset.csv"
    S_DatasetPath = "data/precompute_S/S_dataset_full.parquet"
else:
    DatasetPath = "data/default_parameters.csv"

In [11]:
# Reading/Creating the dataset
logscale = True
if read_dataset:
    DF = torch.from_numpy(pd.read_csv(DatasetPath).to_numpy()).to(torch.float)
else:
    bounds = [['omega', 'tau', 'p', 'd', 'alpha'],[(2.4, 3.8),(0.4, 3),(-5, -0.7),(-5, -0.5),(10e-05, 1)]]
    DF = torch.from_numpy(theta_ds_create(bounds=bounds, subdiv=5, path='data/default_parameters.csv').to_numpy()).to(torch.float)

### Hubs

In [None]:
# Choosing the initial hubs

n_dataset = DF.size(dim=0)
Id_hub = torch.linspace(0, n_dataset-1, steps=n_hubs).long()

Id_hub

In [None]:
# Read/Compute the S(hubs)  

def S_hub_from_dataset(ds_path, Id_hub):
        id_hub_list = Id_hub.tolist()
        parquet_file = pq.ParquetFile(ds_path)
        S_hub = []
        for i in tqdm.tqdm(range(parquet_file.num_row_groups), desc="Reading S"):
            table = parquet_file.read_row_group(i)
            mask = pc.is_in(table["row_id"], pyarrow.array(id_hub_list))
            filtered_table = table.filter(mask)
            S_batch_cpu = torch.from_numpy(np.array(filtered_table.drop(["row_id"])))
            S_hub.append(S_batch_cpu)

        return torch.cat(S_hub)


if read_dataset :
    S_hub = S_hub_from_dataset(S_DatasetPath, Id_hub)
else:
    phi = JTFS_forward
    def S(theta):
        return phi(rectangular_drum(theta, logscale, **FTM_constants))
    S_hub = S_ds_compute(DF,Id_hub,S)

S_hub

In [None]:
# Compute the M(hub) with multiprocessing

from src.M_multiprocessing import init_worker_M, compute_task_M

phi = JTFS_forward

def run_parallel():
    num_tasks = Id_hub.size(0)
    num_processes = 2  
    
    # Prepare task arguments
    tasks = [(i, DF[Id_hub[i], :], device) for i in range(num_tasks)]

    M_hub = torch.zeros(num_tasks, DF.size(1), DF.size(1))

    ctx = torch.multiprocessing.get_context('spawn')
    
    with ctx.Pool(
        processes=num_processes,
        initializer=init_worker_M,
        initargs=(M_factory, logscale, phi, device)
    ) as pool:
        
        for idx, result in tqdm.tqdm(pool.imap_unordered(compute_task_M, tasks), total=num_tasks, desc="Computing M"):
            M_hub[idx] = result
            
    return M_hub

if __name__ == '__main__':
    M_hub = run_parallel()

M_hub

In [None]:
torch.save(M_hub,'data/Knn-G/M_hub.pt')

### Allocation

In [None]:
distance_PNP = distance_factory('PNP')

def F(i,j,h):
    """
    i dans [0,DF.size(dim=0)-1]
    j dans [0,DF.size(dim=0)-1]
    h dans [0,Id_hub.size(dim=0)-1]
    """
    return distance_PNP(DF[i,:],DF[j,:],M_hub[h,:,:])

In [None]:
# Allocation of each point

Allocation = torch.zeros(DF.size(dim=0)).to(int).to(device)

for i in tqdm.tqdm(range(DF.size(dim=0)),desc='Allocating',leave=True):

    dmin = torch.inf
    argmin = None

    for k in range(Id_hub.size(dim=0)):
        d = F(i,Id_hub[k],k)
        if d<dmin:
            dmin = d
            argmin = k
        
    Allocation[i] = argmin

Allocation

In [None]:
torch.save(Allocation,'data/Knn-G/Allocation.pt')

### Graph from KNN

In [8]:
#If reloading stuff

M_hub = torch.load('data/Knn-G/M_hub_full.pt')
Allocation = torch.load('data/Knn-G/Allocation_full.pt')
print(M_hub.size())
print(Allocation.size())

torch.Size([2000, 5, 5])
torch.Size([100000])


In [12]:
distance_PNP = distance_factory('PNP')

DF = DF.to(device)
M_hub = M_hub.to(device)
Allocation = Allocation.to(device)

M_all = M_hub[Allocation]  

def D_vmap(theta_c, M_c, theta_r, M_r):
    d1 = distance_PNP(theta_c, theta_r, M_r)
    d2 = distance_PNP(theta_r, theta_c, M_c)
    return (d1 + d2) / 2

compute_row = torch.vmap(
    lambda tc, mc, tr, mr: D_vmap(tc, mc, tr, mr),
    in_dims=(0, 0, None, None) 
)

def Knn_edge(k, batch_size=128):
    """
    k-NN graph construction.
    """
    num_nodes = DF.size(0)
    sources_list = []
    targets_list = []
    weights_list = []

    for i in tqdm.tqdm(range(0, num_nodes, batch_size), desc='Computing Edges'):
        start = i
        end = min(i + batch_size, num_nodes)
        
        batch_theta = DF[start:end]   
        batch_M = M_all[start:end]     
        
        dists_batch = []
        for b in range(end - start):

            d_row = compute_row(DF, M_all, batch_theta[b], batch_M[b])
            dists_batch.append(d_row)
        
        dists_batch = torch.stack(dists_batch)

        vals, cols = torch.topk(dists_batch, k=k+1, dim=1, largest=False)

        rows = torch.arange(start, end, device=device).unsqueeze(1).repeat(1, k+1)
        
        mask = rows != cols
        
        valid_rows = rows[mask]
        valid_cols = cols[mask]
        valid_vals = vals[mask]

        # Both ways to get a symmetric graph
        sources_list.append(valid_rows)
        targets_list.append(valid_cols)
        weights_list.append(valid_vals) 
        
        sources_list.append(valid_cols)
        targets_list.append(valid_rows)
        weights_list.append(valid_vals)

    all_sources = torch.cat(sources_list)
    all_targets = torch.cat(targets_list)
    all_weights = torch.cat(weights_list)
    
    edge_index = torch.stack([all_sources, all_targets], dim=0)
    
    return edge_index, all_weights

In [13]:
# Create Graph Data object

edge_index,edge_attr = Knn_edge(k)

graph = tg.data.Data(x=DF, edge_index=edge_index, edge_attr=edge_attr)
graph

Computing Edges: 100%|██████████| 782/782 [01:24<00:00,  9.30it/s]


Data(x=[100000, 5], edge_index=[2, 10000000], edge_attr=[10000000])

In [None]:
torch.save(graph, 'data/Knn-G/tgGraph.pt')

## Writing the graph

In [None]:
#If reloading stuff

#graph = torch.load('data/Knn-G/tgGraph_full.pt',weights_only=False)
#graph

Data(x=[100000, 5], edge_index=[2, 10000000], edge_attr=[10000000])

In [None]:
# Gather some statistics about the graph.
print(f'Number of nodes: {graph.num_nodes}')
print(f'Number of edges: {graph.num_edges}')
print(f'Average node degree: {graph.num_edges / graph.num_nodes:.2f}')
#print(f'Has isolated nodes: {graph.has_isolated_nodes()}')
#print(f'Has self-loops: {graph.has_self_loops()}')

Number of nodes: 100000
Number of edges: 10000000
Average node degree: 100.00


[DataEdgeAttr(edge_type=None, layout=<EdgeLayout.COO: 'coo'>, is_sorted=False, size=None)]

In [None]:
#Only for small graph, else OOM very fast

#G = tg.utils.to_networkx(graph, to_undirected=True, edge_attrs=["edge_attr"], node_attrs=["x"])
#G
#PathGraph = 'data/Knn-G/knnG.gml'
#nx.write_gml(G, PathGraph)

In [None]:
def save_to_graphml_gephi(data, filename):
    edge_index = data.edge_index.cpu().numpy()
    num_nodes = data.num_nodes
    
    # Prepare attributes
    edge_weights = data.edge_attr.cpu().numpy() if hasattr(data, 'edge_attr') else None
    node_features = data.x.cpu().numpy() if hasattr(data, 'x') else None

    with open(filename, 'w') as f:
        # 1. Header and Schema Definitions
        f.write('<?xml version="1.0" encoding="UTF-8"?>\n')
        f.write('<graphml xmlns="http://graphml.graphdrawing.org/xmlns" \n')
        f.write('         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" \n')
        f.write('         xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">\n')
        
        # Define Attributes (Crucial for Gephi)
        f.write('  <key id="v_feat" for="node" attr.name="features" attr.type="string"/>\n')
        f.write('  <key id="e_weight" for="edge" attr.name="weight" attr.type="double"/>\n')
        
        f.write('  <graph id="G" edgedefault="directed">\n')

        # 2. Write Nodes
        for i in tqdm.tqdm(range(num_nodes),desc='Nodes'):
            f.write(f'    <node id="n{i}">\n')
            if node_features is not None:
                feat_str = ",".join(map(str, node_features[i]))
                f.write(f'      <data key="v_feat">{feat_str}</data>\n')
            f.write('    </node>\n')

        # 3. Write Edges
        sources = edge_index[0]
        targets = edge_index[1]
        
        for idx in tqdm.tqdm(range(len(sources)),desc='Edges'):
            # Write edge with optional weight
            if edge_weights is not None:
                w = edge_weights[idx].item() if hasattr(edge_weights[idx], "item") else edge_weights[idx]
                f.write(f'    <edge source="n{sources[idx]}" target="n{targets[idx]}">\n')
                f.write(f'      <data key="e_weight">{w}</data>\n')
                f.write('    </edge>\n')
            else:
                f.write(f'    <edge source="n{sources[idx]}" target="n{targets[idx]}"/>\n')

        f.write('  </graph>\n')
        f.write('</graphml>\n')

In [25]:
save_to_graphml_gephi(graph,'data/Knn-G/KnnG_full.graphml')

Nodes: 100%|██████████| 100000/100000 [00:00<00:00, 248690.32it/s]
Edges: 100%|██████████| 10000000/10000000 [00:36<00:00, 270904.38it/s]
