## Setup

In [1]:
# Imports 

import torch
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt

from src.dataset_utils import theta_ds_create
from src.dataset_utils import S_ds_read_given_rows_batch, S_ds_compute

from src.phi import JTFS_forward
from src.ftm import rectangular_drum
from src.ftm import constants as FTM_constants

device = "cuda" if torch.cuda.is_available() else "cpu"

## KNN-G

In [2]:
# Choose to read or create the parameters dataset and set the path according to it

read_dataset = False

if read_dataset:
    DatasetPath = "data/precompute_S/param_dataset.csv"
    S_DatasetPath = "data/precompute_S/S_dataset_full.parquet"
else:
    DatasetPath = "data/default_parameters.csv"

In [3]:
# Reading/Creating the dataset

if read_dataset:
    DF = torch.from_numpy(pd.read_csv(DatasetPath).to_numpy()).to(device).to(torch.float)
else:
    bounds = [['omega', 'tau', 'p', 'd', 'alpha'],[(2.4, 3.8),(0.4, 3),(-5, -0.7),(-5, -0.5),(10e-05, 1)]]
    logscale = True
    DF = torch.from_numpy(theta_ds_create(bounds=bounds, subdiv=5, path='data/default_parameters.csv').to_numpy()).to(device).to(torch.float)

In [4]:
# Choosing the initial hubs

n_hubs = 100
n_dataset = DF.size(dim=0)
Id_hub = torch.linspace(0, n_dataset-1, steps=n_hubs, device=device).long()

Id_hub

tensor([   0,   31,   63,   94,  126,  157,  189,  220,  252,  284,  315,  347,
         378,  410,  441,  473,  504,  536,  568,  599,  631,  662,  694,  725,
         757,  788,  820,  852,  883,  915,  946,  978, 1009, 1041, 1072, 1104,
        1136, 1167, 1199, 1230, 1262, 1293, 1325, 1356, 1388, 1420, 1451, 1483,
        1514, 1546, 1577, 1609, 1640, 1672, 1704, 1735, 1767, 1798, 1830, 1861,
        1893, 1924, 1956, 1988, 2019, 2051, 2082, 2114, 2145, 2177, 2208, 2240,
        2272, 2303, 2335, 2366, 2398, 2429, 2461, 2492, 2524, 2556, 2587, 2619,
        2650, 2682, 2713, 2745, 2776, 2808, 2840, 2871, 2903, 2934, 2966, 2997,
        3029, 3060, 3092, 3124], device='cuda:0')

In [None]:
# Read/Compute the S(hubs)  

#Can't read too many indices at once or else the RAM will be filled very fast, so we do it in batch (1 batch = 35s so less than 70 per batch = slower than computing)
enough_ram = False 

if read_dataset and enough_ram:
    batch_size = 70
    S_hub = S_ds_read_given_rows_batch(S_DatasetPath, Id_hub, batch_size)
else:
    phi = JTFS_forward
    def S(theta):
        return phi(rectangular_drum(theta, logscale, **FTM_constants))
    S_hub = S_ds_compute(DF,Id_hub,S)

S_hub

Computing S: 100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


tensor([[6.1219, 5.9250, 5.3349,  ..., 3.5671, 2.7955, 2.1095],
        [5.8615, 5.6647, 5.0753,  ..., 3.1641, 2.4342, 1.7679],
        [5.4886, 5.2921, 4.7040,  ..., 2.9272, 2.2112, 1.5357],
        ...,
        [3.0145, 2.8241, 2.2784,  ..., 2.0083, 2.1885, 2.4690],
        [2.5644, 2.3827, 1.8630,  ..., 2.2431, 2.3357, 2.5162],
        [2.9475, 2.7602, 2.2185,  ..., 2.2782, 2.1937, 2.2938]],
       device='cuda:0')

In [None]:
# Compute the M(hubs)