In [4]:
import pandas as pd
import numpy as np

import lightgbm as lgb
import xgboost as xgb

from sklearn.model_selection import train_test_split  # type: ignore
from sklearn.linear_model import LinearRegression  # type: ignore
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

import time
import matplotlib.pyplot as plt

import json

from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans

import dask.dataframe as dd
import os

dataset_params = {
        "SIFT100M": {
            "M": 32,
            "efC": 500,
            "efS": 500,
            "li": 1,
            "label": "SIFT100M",
        },
        "GIST1M": {
            "M": 32,
            "efC": 500,
            "efS": 1000,
            "li": 1,
            "label": "GIST1M",
        },
        "GLOVE100": {
            "M": 16,
            "efC": 500,
            "efS": 500,
            "li": 1,
            "label": "GLOVE1M",
        },
        "DEEP100M":{
            "M": 32,
            "efC": 500,
            "efS": 750,
            "li": 1,
            "label": "DEEP100M",
        },
        "T2I100M":{
            "M": 80,
            "efC": 1000,
            "efS": 2500,
            "li": 2,
            "label": "T2I100M",
        }
    }

SEED = 42

index_metric_feats = ["step", "dists", "inserts"]
neighbor_distances_feats = ["first_nn_dist", "nn_dist", "furthest_dist"]
neighbor_stats_feats = ["avg_dist", "variance", "percentile_25", "percentile_50", "percentile_75"]
data_dims_feats = ["dim_l2_norm", "dim_l1_norm", "dim_mean", "dim_median", "dim_std", "dim_var", "dim_min", "dim_max", "dim_range", "dim_energy", "dim_skewness", "dim_kurtosis", "dim_perc_25", "dim_perc_75", "dim_perc_95"]
neighbor_stats_feats_new = ["std", "range", "energy", "skewness", "kurtosis", "percentile_95"]

In [5]:
ds_name = "SIFT100M"
k = 50
li = 5
queries = 10000
M = dataset_params[ds_name]["M"]
efC = dataset_params[ds_name]["efC"]
efS = dataset_params[ds_name]["efS"]

columns_to_load = ["qid", "elaps_ms"] + index_metric_feats + neighbor_distances_feats + neighbor_stats_feats + data_dims_feats + neighbor_stats_feats_new + ["r", "feats_collect_time_ms"]

datapath = f"/data/mchatzakis/et_training_data/early-stop-training/{ds_name}/k{k}/M{M}_efC{efC}_efS{efS}_qs{queries}_li{li}_imp.txt"
all_queries_dask = dd.read_csv(datapath, usecols=columns_to_load)
all_queries_data = all_queries_dask.compute()

In [6]:
all_queries_data.head()

Unnamed: 0,qid,step,dists,elaps_ms,inserts,first_nn_dist,nn_dist,avg_dist,furthest_dist,percentile_25,...,dim_max,dim_range,dim_energy,dim_skewness,dim_kurtosis,dim_perc_25,dim_perc_75,dim_perc_95,feats_collect_time_ms,r
0,0,0,50,0.451088,51,105106.0,102107.0,135078.984375,165557.0,124807.0,...,157.0,157.0,258253.0,2.075178,3.977462,0.0,34.0,101.0,0.377893,0.0
1,0,0,55,0.872135,56,105106.0,97570.0,132397.875,153798.0,123862.0,...,157.0,157.0,258253.0,2.075178,3.977462,0.0,34.0,101.0,0.099897,0.0
2,0,1,60,1.021147,60,105106.0,97570.0,129950.15625,149828.0,123205.0,...,157.0,157.0,258253.0,2.075178,3.977462,0.0,34.0,101.0,0.094891,0.0
3,0,1,65,1.155138,62,105106.0,97570.0,129732.9375,149222.0,123205.0,...,157.0,157.0,258253.0,2.075178,3.977462,0.0,34.0,101.0,0.095844,0.0
4,0,1,70,1.287937,65,105106.0,97534.0,127782.460938,147794.0,118932.0,...,157.0,157.0,258253.0,2.075178,3.977462,0.0,34.0,101.0,0.094891,0.0


In [8]:
all_queries_data.sample(10)

Unnamed: 0,qid,step,dists,elaps_ms,inserts,first_nn_dist,nn_dist,avg_dist,furthest_dist,percentile_25,...,dim_max,dim_range,dim_energy,dim_skewness,dim_kurtosis,dim_perc_25,dim_perc_75,dim_perc_95,feats_collect_time_ms,r
98783,8307,9,565,17.302036,206,82328.0,47425.0,66128.078125,72886.0,63738.0,...,117.0,117.0,257445.0,1.203448,0.478845,5.0,44.0,102.0,0.101089,0.42
124437,144,43,1640,46.103001,254,42975.0,28616.0,33626.859375,36121.0,32784.0,...,137.0,137.0,258871.0,1.858197,2.337068,0.0,26.0,136.0,0.093937,0.86
44402,9480,485,8180,217.138052,177,5248.0,3970.0,4956.240234,5415.0,4702.0,...,206.0,206.0,260949.0,3.374529,11.116528,0.0,7.0,111.0,0.08893,1.0
120431,253,234,11615,318.962097,309,99760.0,60975.0,70772.320312,74953.0,69202.0,...,121.0,121.0,257721.0,1.436238,1.451099,6.0,47.0,121.0,0.095844,0.98
12314,1096,93,3810,103.178978,242,40509.0,30693.0,41181.019531,44167.0,39512.0,...,125.0,125.0,258295.0,1.604243,1.554233,2.0,37.0,125.0,0.091076,0.98
117015,630,187,9610,258.877039,268,85451.0,55664.0,71444.257812,76161.0,69020.0,...,128.0,128.0,258432.0,1.477446,1.293972,4.0,45.0,113.0,0.097036,0.86
71426,8893,451,14325,321.671963,207,43948.0,31908.0,43908.558594,49109.0,42698.0,...,136.0,136.0,258572.0,1.8125,2.287206,0.0,34.0,136.0,0.056028,1.0
69875,8180,300,8720,234.745026,273,40096.0,22423.0,28330.779297,30477.0,27489.0,...,182.0,182.0,258762.0,2.578554,6.937395,0.0,29.0,95.0,0.090837,1.0
127178,3376,216,5210,141.470194,185,20926.0,13026.0,16736.679688,18228.0,16136.0,...,148.0,148.0,259921.0,2.243611,4.005967,1.0,27.0,148.0,0.09203,1.0
99311,5782,29,1245,36.614895,200,47796.0,31868.0,37836.761719,41553.0,36325.0,...,134.0,134.0,259419.0,1.776762,2.221609,3.0,32.0,134.0,0.09799,0.76


In [7]:
all_queries_data.describe()

Unnamed: 0,qid,step,dists,elaps_ms,inserts,first_nn_dist,nn_dist,avg_dist,furthest_dist,percentile_25,...,dim_max,dim_range,dim_energy,dim_skewness,dim_kurtosis,dim_perc_25,dim_perc_75,dim_perc_95,feats_collect_time_ms,r
count,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,...,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0,31996160.0
mean,4964.227,234.5386,8810.88,239.5019,235.2847,63658.77,40685.68,52771.59,56956.05,50914.49,...,130.1244,130.1213,258658.2,1.66567,1.973213,2.678172,37.28851,120.9748,0.0932484,0.9206981
std,2974.072,146.0507,5693.646,153.1011,46.2511,24348.73,15317.18,17727.82,18787.97,17309.84,...,14.99027,14.99319,663.7837,0.3784789,1.570998,2.658499,8.612538,11.81026,0.1021507,0.1705124
min,0.0,0.0,50.0,0.247955,50.0,1213.0,775.0,1443.9,1625.0,1323.0,...,105.0,105.0,256682.0,0.81953,-0.484514,0.0,3.0,76.0,0.050783,0.0
25%,2276.0,105.0,4045.0,112.1721,205.0,47526.0,30287.0,41743.32,45507.0,40032.0,...,120.0,120.0,258177.0,1.407713,0.971851,1.0,32.0,113.0,0.09203,0.94
50%,5101.0,227.0,8155.0,222.6391,235.0,64898.0,41638.0,54845.62,59195.0,52872.0,...,127.0,127.0,258619.0,1.602661,1.573328,2.0,38.0,120.0,0.093937,1.0
75%,7603.0,360.0,12940.0,349.2939,266.0,79825.0,51698.0,65415.44,70146.0,63336.0,...,137.0,137.0,259110.0,1.849003,2.49886,4.0,43.0,128.0,0.096083,1.0
max,9999.0,523.0,27005.0,971.365,486.0,165738.0,162061.0,208576.2,271922.0,188834.0,...,222.0,222.0,261515.0,3.874847,14.72042,21.0,58.0,177.0,50.46582,1.0


In [11]:
# Average collection time:
all_queries_data["feats_collect_time_ms"].mean()

0.09324840195345321