In [1]:
import numpy as np
import tqdm

from allib.datasets import load_uci, AVAIL_DATASETS
from allib.metrics import distance
from allib.utils import ensure_path
from sklearn.metrics.pairwise import pairwise_distances, check_pairwise_arrays

In [4]:
def normalize(d):
    # normalize to 0 1
    d = (d - d.min()) / (d.max() - d.min())
    return d

def get_dist_mat(dsn: str):
    print("processing dataset ", dsn)
    ds = load_uci(dsn)
    ds.with_preprocess(steps=["sample_n", "continuous_to_categorical", "remove_constant_columns"],  params_list=[{"n": 10000, "random_state": 0}, {"encode": "ordinal"}, {}], in_place=True)
    data, label = ds._data, ds._label
    cache_dir = f"./dist_cache/{dsn}"
    # check if local cache is available
    ensure_path(cache_dir)
    nks = []
    freq = []
    N = data.shape[0]
    for col in data.columns:
        nks.append(data[col].unique().shape[0])
        freq.append(dict(data[col].value_counts()))
    prob = [{k: v/N for k, v in f.items()} for f in freq]
    prob2 = [{k: (v * (v - 1))/(N * (N-1)) for k, v in f.items()} for f in freq]
    nks = np.array(nks)
    params = {"prob": prob, "prob2": prob2, "nks": nks, "N": N, "freq": freq}
    # use tqdm to show progress
    for dm in tqdm.tqdm(distance.AVAIL_DIST_METRICS):
        fn = f"{cache_dir}/{dm}_ordinal.npy"
        if ensure_path(fn, False):
            continue
        d = pairwise_distances(data, metric=distance.get_dist_metric(dm, params))
        d = normalize(d)
        np.save(fn, d)

In [None]:
for dsn in AVAIL_DATASETS:
    if dsn == "adult":
        continue
    get_dist_mat(dsn)

processing dataset  iris


100%|██████████| 16/16 [00:00<00:00, 16120.31it/s]


processing dataset  yeast


100%|██████████| 16/16 [24:40<00:00, 92.55s/it] 


processing dataset  letter-recognition


 44%|████▍     | 7/16 [4:36:13<9:11:23, 3675.96s/it]

In [12]:
ds = load_uci("adult")
ds._data.shape

(32537, 14)

In [14]:
# estimate file have 149*149 float
# 149*149*4 = 88.5KB
10000 * 10000 * 8 / 1024 / 1024 / 1024 * 16

11.920928955078125