In [28]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist

def generate_df(join_key_domain, max_rows_per_key,
                num_attributes, min_rows_per_key=0, random_seed=None, prefix='V'):
    if random_seed is not None:
        np.random.seed(random_seed)

    rows_per_key = np.random.randint(
        min_rows_per_key, max_rows_per_key + 1, size=len(join_key_domain))

    all_keys = []
    all_attributes = []

    for key, num_rows in zip(join_key_domain, rows_per_key):
        if num_rows > 0:
            all_keys.extend([key] * num_rows)
            attributes = np.random.randn(num_rows, num_attributes)
            all_attributes.append(attributes)

    if all_keys:
        attributes_array = np.vstack(all_attributes)
        df = pd.DataFrame(attributes_array, columns=[
                          f'{prefix}{i+1}' for i in range(num_attributes)])
        df.insert(0, 'join_key', all_keys)
        return df, [f'{prefix}{i+1}' for i in range(num_attributes)]
    else:
        columns = ['join_key'] + \
            [f'{prefix}{i+1}' for i in range(num_attributes)]
        return pd.DataFrame(columns=columns), [
            f'{prefix}{i+1}' for i in range(num_attributes)]


def find_k_closest_points(df, k=3):
    features = df[['V1', 'F1']].values
    distances = cdist(features, features, metric='chebyshev')

    results = {}
    for i in range(len(df)):
        # Get the distances from point i to all other points
        dist_to_i = distances[i]
        
        # Get indices of k+1 closest points (including self)
        closest_indices = np.argsort(dist_to_i)[:k+1]
        
        # Remove self (which should be at index 0 with distance 0)
        if closest_indices[0] == i:
            closest_indices = closest_indices[1:k+1]
        else:
            closest_indices = closest_indices[:k]
        
        # Get the corresponding rows and distances
        closest_rows = df.iloc[closest_indices]
        closest_distances = dist_to_i[closest_indices]
        
        # Store the results
        results[i] = {
            'point': df.iloc[i].to_dict(),
            'closest_points': [
                {
                    'join_key': df.iloc[idx]['join_key'],
                    'V1': df.iloc[idx]['V1'],
                    'F1': df.iloc[idx]['F1'],
                    'distance': dist_to_i[idx]
                }
                for idx in closest_indices
            ]
        }
    
    return results


def compute_pairwise_distances(df):
    key_to_f1 = dict(zip(df['join_key'], df['F1']))
    distances = {}
    for i in key_to_f1:
        for j in key_to_f1:
            distance = abs(key_to_f1[i] - key_to_f1[j])
            distances[(i, j)] = distance
    
    return distances


def compute_custom_distance(x_min, x_max, y_min, y_max):
    # Define the y midpoint
    y_mid = (y_min + y_max) / 2

    if x_min < y_min:
        part1 = abs(y_min - x_min)
    else:
        part1 = 0

    if x_min <= y_max and x_max >= y_min:  # Check if there's overlap
        # Find the closest point to y_mid within [x_min, x_max]
        if x_min <= y_mid <= x_max:
            closest_to_mid = y_mid
        elif y_mid < x_min:
            closest_to_mid = x_min
        else:  # y_mid > x_max
            closest_to_mid = x_max

        if closest_to_mid < y_mid:
            part2 = abs(closest_to_mid - y_min)
        else:
            part2 = abs(y_max - closest_to_mid)
    else:
        part2 = 0

    if x_max > y_max:
        part3 = abs(x_max - y_max)
    else:
        part3 = 0

    return max(part1, part2, part3)

In [29]:
domain = np.arange(20)
max_rows_per_key = 30
df1, _ = generate_df(domain, max_rows_per_key, 1)
df2, _ = generate_df(domain, 1, 1, min_rows_per_key=1, prefix='F')
df1_min_max = df1.groupby('join_key')['V1'].agg(['min', 'max'])
pairwise_dist = compute_pairwise_distances(df2)

In [30]:
distances = {}
min_distances = {}
join_keys = df1_min_max.index.tolist()
for i in join_keys:
    for j in join_keys:
        x_min, x_max = df1_min_max.loc[i, 'min'], df1_min_max.loc[i, 'max']
        y_min, y_max = df1_min_max.loc[j, 'min'], df1_min_max.loc[j, 'max']
        distance = compute_custom_distance(x_min, x_max, y_min, y_max)
        if x_max <= y_min:
            min_distances[(i, j)] = y_min - x_max
        elif y_max <= x_min:
            min_distances[(i, j)] = x_min - y_max
        else:
            min_distances[(i, j)] = 0
        distances[(i, j)] = distance

In [31]:
max_dist = {}
min_dist = {}
for key in distances.keys():
    max_dist[key] = distances[key] + pairwise_dist[key]
    min_dist[key] = pairwise_dist[key] + min_distances[key]

In [32]:
for key, val in min_dist.items():
    if key[0] == 1:
        print(f"{key}: {val}")

(1, 0): 2.224342641451588
(1, 1): 0.0
(1, 2): 2.8781205609311398
(1, 3): 1.561203958092374
(1, 4): 0.3967435898627698
(1, 5): 1.4483967917736378
(1, 6): 2.076859722652179
(1, 7): 2.0071385071147727
(1, 8): 1.8865871151618894
(1, 9): 1.5303247982521762
(1, 10): 0.9820910311095141
(1, 11): 2.841897525681551
(1, 12): 2.500087493165169
(1, 13): 3.3133795486964583
(1, 14): 2.966223343365937
(1, 15): 2.091276896833634
(1, 16): 1.9036916867931641
(1, 17): 1.6805177470145893
(1, 18): 1.5301489960526127
(1, 19): 1.6060839609441244


In [33]:
merge_df = df1.merge(df2, on='join_key')

In [34]:
df1

Unnamed: 0,join_key,V1
0,0,-1.535231
1,0,-0.350493
2,0,-0.695557
3,0,-0.974278
4,0,-0.991246
...,...,...
263,19,0.373415
264,19,-0.967687
265,19,0.589122
266,19,-1.747895


In [35]:
def compute_histogram_by_join_key(df):
    unique_join_keys = df['join_key'].unique()
    histograms = {}
    for key in unique_join_keys:
        subset = df[df['join_key'] == key]
        counts = np.bincount(subset['V1_ind'])
        index_count_pairs = [
            (index, count) for index, count in enumerate(counts) if count > 0]
        histograms[key] = index_count_pairs
    return histograms

In [36]:
x_width = 3.49 * (len(df1)**(-1/3))
# might need to adjust y_width
y_width = 3.49 * (len(df2)**(-1/3))
x_data = df1['V1'].values
x_data = (x_data - np.mean(x_data)) / np.std(x_data, ddof=1)
x_inds = ((x_data - np.min(x_data)) // x_width).astype(int)
df1['V1_ind'] = x_inds

In [37]:
merge_df = df1.merge(df2, on='join_key')
merged_y = merge_df['F1'].values
y_mean, y_std = np.mean(merged_y), np.std(merged_y, ddof=1)
y_data = df2['F1'].values
y_data = (y_data - y_mean) / y_std
y_inds = ((y_data - np.min(y_data)) // x_width).astype(int)
merged_y_std = (merged_y - y_mean) / y_std
merge_df['F1_ind'] = (merged_y_std - np.min(merged_y_std) // x_width).astype(int)

In [38]:
df1['V1_std'] = x_data
buyer_data_std = df1[['join_key', 'V1_std']]
buyer_data_std

Unnamed: 0,join_key,V1_std
0,0,-1.638656
1,0,-0.440980
2,0,-0.789812
3,0,-1.071577
4,0,-1.088731
...,...,...
263,19,0.290834
264,19,-1.064914
265,19,0.508896
266,19,-1.853643


In [39]:
marg_hist_x = compute_histogram_by_join_key(df1)
marg_hist_y = {i: (ele, 1) for i, ele in enumerate(y_inds)}

In [40]:
merge_df

Unnamed: 0,join_key,V1,V1_ind,F1,F1_ind
0,0,-1.535231,1,0.431854,4
1,0,-0.350493,3,0.431854,4
2,0,-0.695557,2,0.431854,4
3,0,-0.974278,2,0.431854,4
4,0,-0.991246,2,0.431854,4
...,...,...,...,...,...
263,19,0.373415,4,-0.186405,3
264,19,-0.967687,2,-0.186405,3
265,19,0.589122,5,-0.186405,3
266,19,-1.747895,0,-0.186405,3


In [41]:
merge_df[(merge_df['V1_ind'] == 2) & (merge_df['F1_ind'] == 4)]

Unnamed: 0,join_key,V1,V1_ind,F1,F1_ind
2,0,-0.695557,2,0.431854,4
3,0,-0.974278,2,0.431854,4
4,0,-0.991246,2,0.431854,4
59,6,-1.106113,2,0.284371,4
71,6,-0.795193,2,0.284371,4
86,8,-0.70348,2,0.094098,4
87,8,-0.842542,2,0.094098,4
195,15,-0.67858,2,0.298788,4
219,16,-0.904439,2,0.111203,4


In [42]:
joint_hist = {}
for key, hist in marg_hist_x.items():
    y_hist = marg_hist_y[key]
    for bin_ind, count in hist:
        if (bin_ind, y_hist[0]) not in joint_hist:
            joint_hist[(bin_ind, y_hist[0])] = count
        else:
            joint_hist[(bin_ind, y_hist[0])] += count
joint_hist

{(1, 4): 5,
 (2, 4): 6,
 (3, 4): 10,
 (1, 0): 4,
 (2, 0): 2,
 (3, 0): 9,
 (4, 0): 9,
 (5, 0): 4,
 (6, 0): 2,
 (7, 0): 1,
 (8, 0): 1,
 (1, 5): 3,
 (6, 5): 7,
 (0, 3): 6,
 (1, 3): 7,
 (2, 3): 15,
 (3, 3): 15,
 (4, 3): 20,
 (6, 3): 12,
 (7, 3): 4,
 (11, 0): 1,
 (3, 2): 1,
 (4, 4): 14,
 (5, 4): 9,
 (6, 4): 5,
 (7, 4): 2,
 (5, 3): 16,
 (0, 1): 1,
 (2, 1): 1,
 (3, 1): 3,
 (4, 1): 2,
 (5, 1): 2,
 (6, 1): 3,
 (7, 1): 1,
 (0, 5): 2,
 (2, 5): 6,
 (3, 5): 5,
 (4, 5): 11,
 (5, 5): 9,
 (7, 5): 4,
 (8, 5): 1,
 (1, 6): 1,
 (2, 6): 4,
 (3, 6): 6,
 (4, 6): 4,
 (5, 6): 6,
 (6, 6): 2,
 (8, 6): 1,
 (0, 4): 1,
 (8, 3): 2}

In [43]:
buyer_data_std[buyer_data_std['join_key'] == 0]

Unnamed: 0,join_key,V1_std
0,0,-1.638656
1,0,-0.44098
2,0,-0.789812
3,0,-1.071577
4,0,-1.088731
5,0,-0.454549


In [44]:
buyer_data_std[buyer_data_std['join_key'] == 1]

Unnamed: 0,join_key,V1_std
6,1,-1.465314
7,1,0.009587
8,1,-1.285032
9,1,-1.278044
10,1,-0.605588
11,1,-0.648766
12,1,0.070713
13,1,-0.254896
14,1,2.246071
15,1,-1.145493


In [45]:
import bisect
import math
import numpy as np
import concurrent.futures


class SortedJoinKeyStructure:
    def __init__(self, df, col, k=3):
        self.data = df
        self.treat = col
        self.k = k

    def batch_find_k_closest(self, sorted_array, query_values, exclude_self=False, threshold=None):
        query_values = np.expand_dims(query_values, 1)
        sorted_array = np.expand_dims(sorted_array, 0)
        
        diffs = np.abs(query_values - sorted_array)
        if exclude_self:
            k = self.k + 1
        else:
            k = self.k
        k = min(k, sorted_array.shape[1])
        if k == 0:
            return np.zeros((len(query_values), 0))

        indices = np.argsort(diffs, axis=1)[:, :k]
        selected_diffs = np.take_along_axis(diffs, indices, axis=1)

        if exclude_self:
            selected_diffs = selected_diffs[:, 1:k+1]
        
        if threshold is not None:
            mask = selected_diffs >= threshold.reshape(-1, 1)
            selected_diffs[mask] = np.nan
        return selected_diffs
    
    def sort_join_key(self, join_key):
        grouped = self.data.groupby(join_key)[self.treat].apply(np.array)
        join_keys = list(grouped.index)
        partitioned_data = list(grouped.values)
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            sorted_arrays = list(executor.map(np.sort, partitioned_data))
        
        result_dict = {key: sorted_array for key, sorted_array in zip(join_keys, sorted_arrays)}
        
        all_results = {}
        nn_jk_inds = {}
        thresholds = {}
        
        for key, values in grouped.items():
            if len(result_dict[key]) >= self.k + 1:
                distances = self.batch_find_k_closest(result_dict[key], values, exclude_self=True)
                thresholds[key] = distances[:, -1]
            else:
                thresholds[key] = None
        
        for key, values in grouped.items():
            for target_key, sorted_array in result_dict.items():
                if key == target_key:
                    distances = self.batch_find_k_closest(sorted_array, values, exclude_self=True)
                else:
                    distances = self.batch_find_k_closest(sorted_array, values, threshold=thresholds[key])
                if key in all_results:
                    nn_jk_inds[key] = np.concatenate([nn_jk_inds[key], np.full(distances.shape[1], target_key)])
                    all_results[key] = np.concatenate([all_results[key], distances], axis=1)
                else:
                    nn_jk_inds[key] = np.full(distances.shape[1], target_key)
                    all_results[key] = distances
                    
        print('------------------------------Micro Benchmark------------------------------')
        # Count non-NaN values
        total_non_nan = 0
        for key, distances in all_results.items():
            non_nan_count = np.count_nonzero(~np.isnan(distances))
            total_non_nan += non_nan_count
        
        print(f"Total non-NaN values: {total_non_nan}")

        min_vals, max_vals = math.inf, -math.inf
        for _, jks in nn_jk_inds.items():
            if len(jks) < min_vals:
                min_vals = len(jks)
            elif len(jks) > max_vals:
                max_vals = len(jks)

        print(f"Min cols in extended treat: {min_vals}, Max: {max_vals}")
        print('------------------------------Micro Benchmark------------------------------')

        return all_results, nn_jk_inds


ss = SortedJoinKeyStructure(buyer_data_std, 'V1_std')
all_results, nn_jk_inds = ss.sort_join_key('join_key')

------------------------------Micro Benchmark------------------------------
Total non-NaN values: 8034
Min cols in extended treat: 54, Max: 55
------------------------------Micro Benchmark------------------------------


In [46]:
# TODO: we have thresholds and nans representing the values 
# that are guaranteed to be pruned out from the top k values, how to leverage that to optimize runtime?
# Do this last

# TODO: all_results so far is a dictionary of the format: join key i: array of shape n_i * m_i where n_i is the
# number of tuples at join key i and m_i is the sum of top-k closest values over all join keys.
# Use this all_results to derive the radius array, which has size n * m.

# TODO: next we use this radiu vector to query the treatment vector
# How we want to do that is we compute [treatment - radius, treatment + radius], 
# then sort the values in treatment vector, so you can 
# 1. find the index (upper_ind) querying each treatment + radius in the sorted treatment vector
# 2. find the index (lower_ind) querying each treatment - radius in the sorted treatment vector
# 3. the number upper_ind - lower_ind + 1 is n_x in computing mutual information
# 4. do something similar for n_y

In [47]:
all_results

{0: array([[0.54992582, 0.56707904, 0.84884402, 0.04089139, 0.17334225,
         0.35362424, 0.35596577,        nan, 0.09771848, 0.31005548,
         0.54279816,        nan,        nan,        nan,        nan,
         0.04693924, 0.11651447, 0.25287441,        nan,        nan,
                nan, 0.13936941, 0.19287672, 0.35558957,        nan,
         0.5515648 ,        nan,        nan, 0.32782352, 0.37357211,
         0.56430737, 0.05324419,        nan,        nan, 0.35280646,
         0.47132881, 0.49725858, 0.58736212, 0.67505619, 0.68058909,
         0.348366  , 0.38598734,        nan, 0.63768125,        nan,
                nan, 0.06207579,        nan,        nan, 0.27005453,
         0.30259158, 0.42012272, 0.00264474, 0.17008178, 0.21498655],
        [0.01356909, 0.34883208, 0.63059707, 0.16460732, 0.1698015 ,
         0.17532103,        nan,        nan, 0.09587841, 0.32364101,
         0.33818606, 0.21642129, 0.34488631,        nan, 0.22418156,
         0.06203776, 0.0803170

In [48]:
nn_jk_inds[0]

array([ 0,  0,  0,  1,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  6,  6,
        6,  7,  7,  7,  8,  8,  8,  9, 10, 10, 10, 11, 11, 11, 12, 12, 12,
       13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18,
       18, 19, 19, 19])

In [49]:
import torch 

def get_radius(n, all_results, k):
    radius = np.zeros((n,))
    
    a = 0
    for jk, dists in all_results.items():
        dists_tensor = torch.from_numpy(dists)
        kth_values = torch.topk(dists_tensor, k=k, dim=1, largest=False)[0][:,-1]
        
        radius[a:a+len(kth_values)] = kth_values.numpy()
        a += len(kth_values)
        
    return radius


In [50]:
def get_radius_ground(n,all_results,k):
    radius = np.zeros((n,)) 

    a = 0
    for jk,dists in all_results.items():
        for i in range(dists.shape[0]):
            radius[a] = np.sort(dists[i,:])[k-1]
            a += 1
    return radius

radius = get_radius_ground(len(x_data), all_results,4)
radius_torch = get_radius(len(x_data), all_results, 4)
(radius == radius_torch).all()

True

In [51]:
def get_counts(radius, X):
    lower = X - radius
    upper = X + radius
    
    X_sorted = np.sort(X)
    lower_ind = np.searchsorted(X_sorted, lower)
    upper_ind = np.searchsorted(X_sorted, upper)
    return upper_ind - lower_ind + 1

def get_counts_ground(radius, X):
    n_x = np.zeros((len(radius),))
    for i in range(len(radius)):
        n_x[i] = np.sum(X <= radius[i])
    return n_x

get_counts_ground(radius, x_data) == get_counts(radius, x_data)

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [52]:
print(f'n_y:{get_counts(radius,merged_y_std)} \n')
print(f'n_x:{get_counts(radius,x_data)}')

n_y:[  7   7   7   7   7   7  30  30  30  30  30  30  30  30  30  30  30  30
  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30   3
   3  18  18  37  18  18  67  19  67  18  63  63  18  18  18  18  18  18
   5   5   5 145   2  38  38  47  38  38  38  38  16  38  16  38  47  38
  38  38  10  10  10  10  10  10  10  10  10  28  33  33  28  33  28  33
  33  33  33  33  33  33  33  28  28  33  33  33  28  33  28  33  33  33
  33  33  37  14  14  14  14  14  14  14  14  14  14  14  14  14  33  31
  31  31  31  31  31  33  31  31  31  31  31  31  31  31  31  43  33  31
  31  33  33  31  31  43  33  31  31  31   7   7   7   7   7   7  25  25
  25  25  25  25  25  25  25  25  25  25  25  25  25  25  25  25  25  25
  25  25  25  25  11  11  11  11  11  11  11  11  43  11  23  38  38  23
  38  38  38  23  38  38  47  38  38  23  38  38  38  23  23  23  38  23
  82   6   6  33   6   4   4   4  64  20  37  20  67  20  37  37  20  20
  20  37  37  20  20  67  20  20  27  27  27  2

In [53]:
# ADD IMPROVEMENTS AND GROUND TESTING!