In [1]:
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform
import numba
import argparse
import time
from numba import jit
from tqdm import tqdm

In [2]:
data = pd.read_parquet(f'../../data/features/MASH_remove_21.parquet', engine='pyarrow')  # You can use 'fastparquet' as the engine

targets = data["Target"]
train = data["Train"]

pair_data = data[data["Train"] == 0]

data = data.drop(columns=["Train", "Target"]).to_numpy().astype(np.float32)
pair_data = pair_data.drop(columns=["Train", "Target"]).to_numpy().astype(np.float32)

In [3]:
@jit(nopython=True)
def calculate_jaccard_index(sketch1, sketch2):
    """
    Efficiently calculate the Jaccard index between two sorted MinHash sketches.
    
    :param sketch1: First sorted MinHash sketch as a numpy array.
    :param sketch2: Second sorted MinHash sketch as a numpy array.
    :return: The estimated Jaccard index.
    """
    shared_hashes = 0  # Intersection
    total_hashes = 0   # Union
    i, j = 0, 0
    
    while i < len(sketch1) and j < len(sketch2):
        if sketch1[i] == sketch2[j]:
            shared_hashes += 1
            i += 1
            j += 1
        elif sketch1[i] < sketch2[j]:
            i += 1
        else:
            j += 1
        total_hashes += 1

    # Include any remaining hashes from both sketches
    total_hashes += len(sketch1[i:]) + len(sketch2[j:])

    # Calculate the Jaccard index
    jaccard_index = shared_hashes / total_hashes if total_hashes > 0 else 0
    
    return jaccard_index

In [4]:
def calculate_mash_distance(sketch1, sketch2, k):
    """
    Calculate the Mash distance between two MinHash sketches.

    :param sketch1: First MinHash sketch as a numpy array.
    :param sketch2: Second MinHash sketch as a numpy array.
    :param k: k-mer size used to create the MinHash sketches.
    :return: The Mash distance.
    """
    start = time.time()
    
    jaccard_estimate = calculate_jaccard_index(sketch1, sketch2)
    # Calculate the Mash distance using the formula
    # Guard against log(0) by maxing jaccard_estimate with a very small number
    jaccard_estimate = max(jaccard_estimate, 1e-10)
    mash_distance = - (1 / k) * np.log((2 * jaccard_estimate) / (1 + jaccard_estimate))

    return mash_distance

# Example usage:
# Assuming we have two sketches, sketch1 and sketch2, and a k-mer size k
sketch1 = pair_data[0]  # Replace with actual MinHash sketch values
sketch2 = pair_data[1]  # Replace with actual MinHash sketch values
k = 21  # Example k-mer size

distance = calculate_mash_distance(sketch1, sketch2, k)
print(f"The Mash distance is: {distance}")

The Mash distance is: 0.002392438877940318


In [17]:
import dask.array as da
from dask import delayed, compute
import numpy as np
from dask.diagnostics import ProgressBar

def calculate_distances_for_sketch(sketch1, pair_data, k):
    """
    Calculate the Mash distances between a single sketch and an array of sketches.

    :param sketch1: The single MinHash sketch to compare against pair_data.
    :param pair_data: Array of MinHash sketches to compare with sketch1.
    :param k: k-mer size used to create the MinHash sketches.
    :return: A list of Mash distances.
    """
    return [calculate_mash_distance(sketch1, sketch2, k) for sketch2 in pair_data]

def calculate_distance_matrix(data, pair_data, k):
    """
    Calculate a matrix of Mash distances between two arrays of sketches using Dask for parallel computation.

    :param data: First array of MinHash sketches.
    :param pair_data: Second array of MinHash sketches to which to compare the first array.
    :param k: k-mer size used to create the MinHash sketches.
    :return: A Dask array of Mash distances.
    """
    delayed_rows = []  # Initialize outside the loop

    # Create a list of delayed computations for each row
    for sketch1 in data:
        delayed_row = delayed(calculate_distances_for_sketch)(sketch1, pair_data, k)
        delayed_rows.append(delayed_row)

    progress_bar = ProgressBar()
    progress_bar.register()

    # Compute all rows in parallel and concatenate the results
    with progress_bar:
        distance_matrix = compute(*delayed_rows, scheduler='processes')

    return np.vstack(distance_matrix)

distance_matrix = calculate_distance_matrix(data, pair_data, k)

[########################################] | 100% Completed | 16m 47s


In [19]:
distance_matrix_np = distance_matrix.compute()

# Now, create a pandas DataFrame from the NumPy array
distance_df = pd.DataFrame(distance_matrix_np)

In [20]:
distance_df = pd.DataFrame(distance_matrix_np)

In [22]:
distance_df["Target"] = targets.tolist()
distance_df["Train"] = train.tolist()

In [23]:
distance_df.to_parquet(f'../../data/features/mash_distance.parquet', engine='pyarrow')

  table = self.api.Table.from_pandas(df, **from_pandas_kwargs)
