In [77]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist

def generate_df(join_key_domain, max_rows_per_key,
                num_attributes, min_rows_per_key=0, random_seed=None, prefix='V'):
    if random_seed is not None:
        np.random.seed(random_seed)

    rows_per_key = np.random.randint(
        min_rows_per_key, max_rows_per_key + 1, size=len(join_key_domain))

    all_keys = []
    all_attributes = []

    for key, num_rows in zip(join_key_domain, rows_per_key):
        if num_rows > 0:
            all_keys.extend([key] * num_rows)
            attributes = np.random.randn(num_rows, num_attributes)
            all_attributes.append(attributes)

    if all_keys:
        attributes_array = np.vstack(all_attributes)
        df = pd.DataFrame(attributes_array, columns=[
                          f'{prefix}{i+1}' for i in range(num_attributes)])
        df.insert(0, 'join_key', all_keys)
        return df, [f'{prefix}{i+1}' for i in range(num_attributes)]
    else:
        columns = ['join_key'] + \
            [f'{prefix}{i+1}' for i in range(num_attributes)]
        return pd.DataFrame(columns=columns), [
            f'{prefix}{i+1}' for i in range(num_attributes)]


def find_k_closest_points(df, k=3):
    features = df[['V1', 'F1']].values
    distances = cdist(features, features, metric='chebyshev')

    results = {}
    for i in range(len(df)):
        # Get the distances from point i to all other points
        dist_to_i = distances[i]
        
        # Get indices of k+1 closest points (including self)
        closest_indices = np.argsort(dist_to_i)[:k+1]
        
        # Remove self (which should be at index 0 with distance 0)
        if closest_indices[0] == i:
            closest_indices = closest_indices[1:k+1]
        else:
            closest_indices = closest_indices[:k]
        
        # Get the corresponding rows and distances
        closest_rows = df.iloc[closest_indices]
        closest_distances = dist_to_i[closest_indices]
        
        # Store the results
        results[i] = {
            'point': df.iloc[i].to_dict(),
            'closest_points': [
                {
                    'join_key': df.iloc[idx]['join_key'],
                    'V1': df.iloc[idx]['V1'],
                    'F1': df.iloc[idx]['F1'],
                    'distance': dist_to_i[idx]
                }
                for idx in closest_indices
            ]
        }
    
    return results


def compute_pairwise_distances(df):
    key_to_f1 = dict(zip(df['join_key'], df['F1']))
    distances = {}
    for i in key_to_f1:
        for j in key_to_f1:
            distance = abs(key_to_f1[i] - key_to_f1[j])
            distances[(i, j)] = distance
    
    return distances


def compute_custom_distance(x_min, x_max, y_min, y_max):
    # Define the y midpoint
    y_mid = (y_min + y_max) / 2

    if x_min < y_min:
        part1 = abs(y_min - x_min)
    else:
        part1 = 0

    if x_min <= y_max and x_max >= y_min:  # Check if there's overlap
        # Find the closest point to y_mid within [x_min, x_max]
        if x_min <= y_mid <= x_max:
            closest_to_mid = y_mid
        elif y_mid < x_min:
            closest_to_mid = x_min
        else:  # y_mid > x_max
            closest_to_mid = x_max

        if closest_to_mid < y_mid:
            part2 = abs(closest_to_mid - y_min)
        else:
            part2 = abs(y_max - closest_to_mid)
    else:
        part2 = 0

    if x_max > y_max:
        part3 = abs(x_max - y_max)
    else:
        part3 = 0

    return max(part1, part2, part3)

In [78]:
domain = np.arange(20)
max_rows_per_key = 30
df1, _ = generate_df(domain, max_rows_per_key, 1)
df2, _ = generate_df(domain, 1, 1, min_rows_per_key=1, prefix='F')
df1_min_max = df1.groupby('join_key')['V1'].agg(['min', 'max'])
pairwise_dist = compute_pairwise_distances(df2)

In [79]:
distances = {}
min_distances = {}
join_keys = df1_min_max.index.tolist()
for i in join_keys:
    for j in join_keys:
        x_min, x_max = df1_min_max.loc[i, 'min'], df1_min_max.loc[i, 'max']
        y_min, y_max = df1_min_max.loc[j, 'min'], df1_min_max.loc[j, 'max']
        distance = compute_custom_distance(x_min, x_max, y_min, y_max)
        if x_max <= y_min:
            min_distances[(i, j)] = y_min - x_max
        elif y_max <= x_min:
            min_distances[(i, j)] = x_min - y_max
        else:
            min_distances[(i, j)] = 0
        distances[(i, j)] = distance

In [80]:
max_dist = {}
min_dist = {}
for key in distances.keys():
    max_dist[key] = distances[key] + pairwise_dist[key]
    min_dist[key] = pairwise_dist[key] + min_distances[key]

In [81]:
for key, val in min_dist.items():
    if key[0] == 1:
        print(f"{key}: {val}")

In [82]:
merge_df = df1.merge(df2, on='join_key')

In [83]:
df1

Unnamed: 0,join_key,V1
0,0,1.346975
1,0,1.304603
2,0,0.990835
3,0,-1.261790
4,0,0.210198
...,...,...
291,19,0.159372
292,19,0.581100
293,19,0.589475
294,19,-1.353387


In [84]:
def compute_histogram_by_join_key(df):
    unique_join_keys = df['join_key'].unique()
    histograms = {}
    for key in unique_join_keys:
        subset = df[df['join_key'] == key]
        counts = np.bincount(subset['V1_ind'])
        index_count_pairs = [
            (index, count) for index, count in enumerate(counts) if count > 0]
        histograms[key] = index_count_pairs
    return histograms

In [85]:
x_width = 3.49 * (len(df1)**(-1/3))
# might need to adjust y_width
y_width = 3.49 * (len(df2)**(-1/3))
x_data = df1['V1'].values
x_data = (x_data - np.mean(x_data)) / np.std(x_data, ddof=1)
x_inds = ((x_data - np.min(x_data)) // x_width).astype(int)
df1['V1_ind'] = x_inds

In [86]:
merge_df = df1.merge(df2, on='join_key')
merged_y = merge_df['F1'].values
y_mean, y_std = np.mean(merged_y), np.std(merged_y, ddof=1)
y_data = df2['F1'].values
y_data = (y_data - y_mean) / y_std
y_inds = ((y_data - np.min(y_data)) // x_width).astype(int)
merged_y_std = (merged_y - y_mean) / y_std
merge_df['F1_ind'] = (merged_y_std - np.min(merged_y_std) // x_width).astype(int)

In [87]:
df1['V1_std'] = x_data
buyer_data_std = df1[['join_key', 'V1_std']]
buyer_data_std

Unnamed: 0,join_key,V1_std
0,0,1.212680
1,0,1.169741
2,0,0.851767
3,0,-1.431043
4,0,0.060670
...,...,...
291,19,0.009163
292,19,0.436543
293,19,0.445029
294,19,-1.523868


In [88]:
marg_hist_x = compute_histogram_by_join_key(df1)
marg_hist_y = {i: (ele, 1) for i, ele in enumerate(y_inds)}

In [89]:
merge_df

Unnamed: 0,join_key,V1,V1_ind,F1,F1_ind
0,0,1.346975,8,0.057519,4
1,0,1.304603,7,0.057519,4
2,0,0.990835,7,0.057519,4
3,0,-1.261790,3,0.057519,4
4,0,0.210198,5,0.057519,4
...,...,...,...,...,...
291,19,0.159372,5,-0.106661,4
292,19,0.581100,6,-0.106661,4
293,19,0.589475,6,-0.106661,4
294,19,-1.353387,2,-0.106661,4


In [90]:
merge_df[(merge_df['V1_ind'] == 2) & (merge_df['F1_ind'] == 4)]

Unnamed: 0,join_key,V1,V1_ind,F1,F1_ind
12,2,-1.298469,2,0.392107,4
17,2,-1.576322,2,0.392107,4
204,12,-1.504519,2,0.101605,4
294,19,-1.353387,2,-0.106661,4


In [91]:
joint_hist = {}
for key, hist in marg_hist_x.items():
    y_hist = marg_hist_y[key]
    for bin_ind, count in hist:
        if (bin_ind, y_hist[0]) not in joint_hist:
            joint_hist[(bin_ind, y_hist[0])] = count
        else:
            joint_hist[(bin_ind, y_hist[0])] += count
joint_hist

{(3, 4): 5,
 (4, 4): 8,
 (5, 4): 17,
 (6, 4): 12,
 (7, 4): 11,
 (8, 4): 8,
 (1, 4): 1,
 (2, 4): 3,
 (3, 8): 2,
 (4, 8): 6,
 (5, 8): 4,
 (6, 8): 4,
 (7, 8): 5,
 (9, 8): 1,
 (11, 8): 1,
 (2, 2): 7,
 (3, 2): 9,
 (4, 2): 17,
 (5, 2): 16,
 (6, 2): 12,
 (7, 2): 6,
 (8, 2): 11,
 (9, 2): 2,
 (3, 0): 3,
 (4, 0): 3,
 (5, 0): 6,
 (6, 0): 7,
 (7, 0): 1,
 (8, 0): 2,
 (2, 0): 1,
 (10, 0): 1,
 (9, 4): 3,
 (11, 4): 1,
 (2, 1): 4,
 (3, 1): 11,
 (4, 1): 8,
 (5, 1): 15,
 (6, 1): 11,
 (7, 1): 9,
 (8, 1): 4,
 (2, 3): 2,
 (3, 3): 1,
 (4, 3): 3,
 (5, 3): 5,
 (6, 3): 6,
 (7, 3): 4,
 (8, 3): 1,
 (10, 3): 1,
 (1, 1): 1,
 (9, 1): 1,
 (1, 2): 2,
 (3, 5): 1,
 (4, 5): 4,
 (5, 5): 1,
 (6, 5): 1,
 (7, 5): 1,
 (9, 5): 2,
 (0, 2): 1}

In [92]:
buyer_data_std[buyer_data_std['join_key'] == 0]

Unnamed: 0,join_key,V1_std
0,0,1.21268
1,0,1.169741
2,0,0.851767
3,0,-1.431043
4,0,0.06067
5,0,-0.740841
6,0,0.351733
7,0,-0.234329


In [93]:
buyer_data_std[buyer_data_std['join_key'] == 1]

Unnamed: 0,join_key,V1_std


In [94]:
import bisect
import math
import numpy as np
import concurrent.futures


class SortedJoinKeyStructure:
    def __init__(self, df, col, k=3):
        self.data = df
        self.treat = col
        self.k = k

    def batch_find_k_closest(self, sorted_array, query_values, exclude_self=False, threshold=None):
        query_values = np.expand_dims(query_values, 1)
        sorted_array = np.expand_dims(sorted_array, 0)
        
        diffs = np.abs(query_values - sorted_array)
        if exclude_self:
            k = self.k + 1
        else:
            k = self.k
        k = min(k, sorted_array.shape[1])
        if k == 0:
            return np.zeros((len(query_values), 0))

        indices = np.argsort(diffs, axis=1)[:, :k]
        selected_diffs = np.take_along_axis(diffs, indices, axis=1)

        if exclude_self:
            selected_diffs = selected_diffs[:, 1:k+1]
        
        if threshold is not None:
            mask = selected_diffs >= threshold.reshape(-1, 1)
            selected_diffs[mask] = np.nan
        return selected_diffs
    
    def sort_join_key(self, join_key):
        grouped = self.data.groupby(join_key)[self.treat].apply(np.array)
        join_keys = list(grouped.index)
        partitioned_data = list(grouped.values)
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            sorted_arrays = list(executor.map(np.sort, partitioned_data))
        
        result_dict = {key: sorted_array for key, sorted_array in zip(join_keys, sorted_arrays)}
        
        all_results = {}
        nn_jk_inds = {}
        thresholds = {}
        
        for key, values in grouped.items():
            if len(result_dict[key]) >= self.k + 1:
                distances = self.batch_find_k_closest(result_dict[key], values, exclude_self=True)
                thresholds[key] = distances[:, -1]
            else:
                thresholds[key] = None
        
        for key, values in grouped.items():
            for target_key, sorted_array in result_dict.items():
                if key == target_key:
                    distances = self.batch_find_k_closest(sorted_array, values, exclude_self=True)
                else:
                    distances = self.batch_find_k_closest(sorted_array, values, threshold=thresholds[key])
                if key in all_results:
                    nn_jk_inds[key] = np.concatenate([nn_jk_inds[key], np.full(distances.shape[1], target_key)])
                    all_results[key] = np.concatenate([all_results[key], distances], axis=1)
                else:
                    nn_jk_inds[key] = np.full(distances.shape[1], target_key)
                    all_results[key] = distances
                    
        print('------------------------------Micro Benchmark------------------------------')
        # Count non-NaN values
        total_non_nan = 0
        for key, distances in all_results.items():
            non_nan_count = np.count_nonzero(~np.isnan(distances))
            total_non_nan += non_nan_count
        
        print(f"Total non-NaN values: {total_non_nan}")

        min_vals, max_vals = math.inf, -math.inf
        for _, jks in nn_jk_inds.items():
            if len(jks) < min_vals:
                min_vals = len(jks)
            elif len(jks) > max_vals:
                max_vals = len(jks)

        print(f"Min cols in extended treat: {min_vals}, Max: {max_vals}")
        print('------------------------------Micro Benchmark------------------------------')

        return all_results, nn_jk_inds


ss = SortedJoinKeyStructure(buyer_data_std, 'V1_std')
all_results, nn_jk_inds = ss.sort_join_key('join_key')

------------------------------Micro Benchmark------------------------------
Total non-NaN values: 9053
Min cols in extended treat: 51, Max: 51
------------------------------Micro Benchmark------------------------------


In [95]:
# TODO: we have thresholds and nans representing the values 
# that are guaranteed to be pruned out from the top k values, how to leverage that to optimize runtime?
# Do this last

# TODO: all_results so far is a dictionary of the format: join key i: array of shape n_i * m_i where n_i is the
# number of tuples at join key i and m_i is the sum of top-k closest values over all join keys.
# Use this all_results to derive the radius array, which has size n * m.

# TODO: next we use this radiu vector to query the treatment vector
# How we want to do that is we compute [treatment - radius, treatment + radius], 
# then sort the values in treatment vector, so you can 
# 1. find the index (upper_ind) querying each treatment + radius in the sorted treatment vector
# 2. find the index (lower_ind) querying each treatment - radius in the sorted treatment vector
# 3. the number upper_ind - lower_ind + 1 is n_x in computing mutual information
# 4. do something similar for n_y

In [96]:
all_results

{0: array([[4.29395179e-02, 3.60912595e-01, 8.60947073e-01, 3.93377992e-01,
         6.56766185e-01, 6.70875070e-01, 2.33963122e-01, 3.46074381e-01,
         4.59177762e-01, 2.03587380e-02, 3.53277539e-02, 1.42486946e-01,
         2.03798085e-01, 4.16296423e-01,            nan, 1.87225608e-01,
         2.26531710e-01, 4.23274058e-01, 7.59017562e-01,            nan,
                    nan, 8.17532316e-03, 1.97118839e-01, 3.89021839e-01,
         7.84759717e-02, 9.72907770e-02, 1.96579361e-01, 2.59559491e-01,
         3.77475070e-01, 3.95173271e-01, 2.08712449e-02, 1.20718302e-01,
         1.59953854e-01, 5.85362393e-02, 9.65991985e-02, 1.15558640e-01,
         5.50534509e-03, 8.43874402e-02, 3.25929319e-01, 3.27445078e-01,
         5.43010409e-01, 7.19745599e-01, 5.17572267e-01,            nan,
                    nan, 5.79104101e-02, 7.31497290e-02, 1.48554789e-01,
         1.70420517e-01, 3.51599292e-01, 7.67650959e-01],
        [4.29395179e-02, 3.17973077e-01, 8.18007555e-01, 4.3631

In [97]:
nn_jk_inds[0]

array([ 0,  0,  0,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  6,  6,
        6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11, 12,
       12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 19, 19, 19])

In [98]:
def get_radius(all_results,k):
    n = 0
    for _,vals in all_results.items():
        n += vals.shape[0]
    radius = np.zeros((n,)) 

    a = 0
    for jk,dists in all_results.items():
        for i in range(dists.shape[0]):
            radius[a] = np.sort(dists[i,:])[k-1]
            a += 1
    return radius

radius = get_radius(all_results,3)
radius.shape

(296,)

In [99]:
def get_n_y(radius, Y):
    lower = Y - radius
    upper = Y + radius
    
    Y_sorted = np.sort(Y)
    lower_ind = np.searchsorted(Y_sorted, lower)
    upper_ind = np.searchsorted(Y_sorted, upper)
    return upper_ind - lower_ind + 1

def get_n_x(radius, X):
    lower = X - radius
    upper = X + radius
    
    X_sorted = np.sort(X)
    lower_ind = np.searchsorted(X_sorted, lower)
    upper_ind = np.searchsorted(X_sorted, upper)
    return upper_ind - lower_ind + 1

    

In [100]:
print(f'n_y:{get_n_y(radius,merged_y_std)} \n')
print(f'n_x:{get_n_x(radius,x_data)}')

n_y:[  9   9   9   9   9   9   9   9  17  17  17  17  17  17  17  17  17  17
  17  17  17  17  17  17  24  24  24  24  24  24  24  24  24  24  24  24
  24  24  24  24  24  24  24  24  24  24  24  30  50  30  30  30  30  63
  50  50  50  50  30  30  50  30  50  30  50  30  30  30  30  30  30  30
  30  30  30  50   6   6   6   6   6  20  20  20  20  20  20  20  20  20
  20  20  20  20  20  20  20  20  20  20   6   6   6   6   6  20  20  20
  20  46  20 152  20  20  20  20  20  20  20  20  20  20  20  62  29  29
  29  29  29  29  29  29  29  29  29  29  29  29  29  29  29  29  29  29
  29  29  29  29  29  29  29  29  14  14  14  14  14  14  14  14  14  14
  14  14  14  24  24  24  24  24  24  24  24  24  24  24  89  24  24  24
  24  24  24  24  24  24  24  24  27  27  27  27  27  27  35  27  27  27
  27  27  27  27  35  35  27  27  27  27  27  27  27  27  27  27  30  30
  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30  30
  30  30  30  30  30  30  30  30  30  11  11  1

In [101]:
# ADD IMPROVEMENTS AND GROUND TESTING!