## Setup

In [1]:
# Imports 
import torch
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
import functools
import torch_geometric as tg
import networkx as nx
import pyarrow
import pyarrow.parquet as pq
import pyarrow.compute as pc

from src.dataset_utils import theta_ds_create
from src.dataset_utils import S_ds_compute

from src.phi import JTFS_forward
from src.jacobian import M_factory
from src.distances import distance_factory
from src.ftm import rectangular_drum
from src.ftm import constants as FTM_constants

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# KNN-G Parameters

n_hubs = 2000
k = 50

## Creating KNN-G

### Dataset

In [3]:
# Choose to read or create the parameters graphset and set the path according to it
read_dataset = True

if read_dataset:
    DatasetPath = "data/precompute_S/param_dataset.csv"
    S_DatasetPath = "data/precompute_S/S_dataset_full.parquet"
else:
    DatasetPath = "data/default_parameters.csv"

In [None]:
# Reading/Creating the dataset
logscale = True
if read_dataset:
    DF = torch.from_numpy(pd.read_csv(DatasetPath).to_numpy()).to(torch.float)
else:
    bounds = [['omega', 'tau', 'p', 'd', 'alpha'],[(2.4, 3.8),(0.4, 3),(-5, -0.7),(-5, -0.5),(10e-05, 1)]]
    DF = torch.from_numpy(theta_ds_create(bounds=bounds, subdiv=5, path='data/default_parameters.csv').to_numpy()).to(torch.float)

### Hubs

In [None]:
# Choosing the initial hubs

n_dataset = DF.size(dim=0)
Id_hub = torch.linspace(0, n_dataset-1, steps=n_hubs).long()

Id_hub

tensor([    0,    50,   100,  ..., 99898, 99948, 99999], device='cuda:0')

In [None]:
# Read/Compute the S(hubs)  

def S_hub_from_dataset(ds_path, Id_hub):
        id_hub_list = Id_hub.tolist()
        parquet_file = pq.ParquetFile(ds_path)
        S_hub = []
        for i in tqdm.tqdm(range(parquet_file.num_row_groups), desc="Reading in batch"):
            table = parquet_file.read_row_group(i)
            mask = pc.is_in(table["row_id"], pyarrow.array(id_hub_list))
            filtered_table = table.filter(mask)
            S_batch_cpu = torch.from_numpy(np.array(filtered_table.drop(["row_id"])))
            S_hub.append(S_batch_cpu)

        return torch.cat(S_hub)


if read_dataset :
    S_hub = S_hub_from_dataset(S_DatasetPath, Id_hub)
else:
    phi = JTFS_forward
    def S(theta):
        return phi(rectangular_drum(theta, logscale, **FTM_constants))
    S_hub = S_ds_compute(DF,Id_hub,S)

S_hub

Reading in batch: 100%|██████████| 20/20 [01:48<00:00,  5.41s/it]


tensor([[6.1219, 5.9250, 5.3349,  ..., 3.5671, 2.7955, 2.1095],
        [6.2008, 6.0038, 5.4136,  ..., 3.6870, 2.9022, 2.2149],
        [6.2217, 6.0247, 5.4344,  ..., 3.6131, 2.8381, 2.1647],
        ...,
        [2.6765, 2.4917, 1.9642,  ..., 2.2821, 2.3521, 2.5454],
        [2.2671, 2.0901, 1.5930,  ..., 1.9963, 1.7524, 1.7293],
        [2.9475, 2.7602, 2.2185,  ..., 2.2782, 2.1937, 2.2938]])

In [None]:
# Compute the M(hub) with multiprocessing

from src.M_multiprocessing import init_worker, compute_task

phi = JTFS_forward

def run_parallel():
    num_tasks = Id_hub.size(0)
    num_processes = 4  
    
    # Prepare task arguments
    tasks = [(i, DF[Id_hub[i], :], device) for i in range(num_tasks)]

    M_hub = torch.zeros(num_tasks, DF.size(1), DF.size(1))

    ctx = torch.multiprocessing.get_context('spawn')
    
    with ctx.Pool(
        processes=num_processes,
        initializer=init_worker,
        initargs=(M_factory, logscale, phi, device)
    ) as pool:
        
        for idx, result in tqdm.tqdm(pool.imap_unordered(compute_task, tasks), total=num_tasks, desc="Computing M"):
            M_hub[idx] = result
            
    return M_hub

if __name__ == '__main__':
    M_hub = run_parallel()

M_hub

### Allocation

In [None]:
distance_PNP = distance_factory('PNP')

def F(i,j,h):
    """
    i dans [0,DF.size(dim=0)-1]
    j dans [0,DF.size(dim=0)-1]
    h dans [0,Id_hub.size(dim=0)-1]
    """
    return distance_PNP(DF[i,:],DF[j,:],M_hub[h,:,:])

In [None]:
# Allocation of each point

Allocation = torch.zeros(DF.size(dim=0)).to(int).to(device)

for i in tqdm.tqdm(range(DF.size(dim=0)),desc='Allocating',leave=True):

    dmin = torch.inf
    argmin = None

    for k in range(Id_hub.size(dim=0)):
        d = F(i,Id_hub[k],k)
        if d<dmin:
            dmin = d
            argmin = k
        
    Allocation[i] = argmin

Allocation

### Graph from KNN

In [None]:
# Distance with hubs

distance_PNP = distance_factory('PNP')

def D(theta_c,M_c,theta_r,M_r):
    """
    Average of both ways to solve problem with distance across hubs
    """
    return (distance_PNP(theta_c,theta_r,M_r)+distance_PNP(theta_r,theta_c,M_c))/2

def all_D(i):
    """
    i in [0,DF.size(dim=0)]
    """
    M_i = M_hub[Allocation[i]]
    T_M_j = torch.zeros(DF.size(dim=0),DF.size(dim=1),DF.size(dim=1)).to(device)
    for j in range(DF.size(dim=0)):
        T_M_j[j] = M_hub[Allocation[j]]
 
    all_D_vmap = torch.func.vmap(functools.partial(D,theta_r=DF[i,:],M_r=M_i),in_dims=(0,0))

    return all_D_vmap(DF, T_M_j)

In [None]:
# Edge index 

def Knn_edge():
    """
    edge_index : 1 if one point in neighbour of the other OR vice-versa
    edge_attr : w(i,j) = 1/D(i,j)
    """
    COO_edges = [[],[]]

    edge_attr = []

    for i in tqdm.tqdm(range(DF.size(dim=0)),desc='Finding neighbours',leave=True):

        T_dist = all_D(i)

        T_dist_sorted,i_c_sorted = torch.sort(T_dist)

        for j in i_c_sorted[0:k+1]:
            if i!=j: #No self connections 
                COO_edges[0].append(i)
                COO_edges[1].append(j)
                COO_edges[1].append(i)
                COO_edges[0].append(j)
        
                edge_attr.append(1/T_dist_sorted[j])
                edge_attr.append(1/T_dist_sorted[j])

    return torch.tensor(COO_edges, dtype=torch.long).to(device),edge_attr



In [None]:
# Create Graph Data object

edge_index,edge_attr = Knn_edge()

graph = tg.data.Data(x=DF, edge_index=edge_index, edge_attr=edge_attr)

## Visualisation of the graph

In [None]:
# Gather some statistics about the graph.
print(f'Number of nodes: {graph.num_nodes}')
print(f'Number of edges: {graph.num_edges}')
print(f'Average node degree: {graph.num_edges / graph.num_nodes:.2f}')
print(f'Has isolated nodes: {graph.has_isolated_nodes()}')
print(f'Has self-loops: {graph.has_self_loops()}')
print(f'Is undirected: {graph.is_undirected()}')

In [None]:
def visualize_graph(G,pos):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos=pos,with_labels=False,node_size=3)
    plt.show()

In [None]:
G = tg.utils.to_networkx(graph, to_undirected=True)

nx.write_graphml(G, 'data/Knn-G/knnG.graphml')

pos = nx.forceatlas2_layout(G,gravity=0.1)
#pos = nx.spring_layout(G)
visualize_graph(G,pos)