In [4]:
#!pip install datasets

In [8]:
import numpy as np
import re

In [9]:
SEED = 42 # is the answer
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
RNG = np.random.RandomState(SEED)
MAX_HASH = np.uint64((1 << 32) - 1)
MERSENNE_PRIME = np.uint64((1 << 61) - 1)

In [10]:
# max uint32 value
MAX_HASH

4294967295

In [11]:
2**32

4294967296

In [12]:
## Compute the total number of documents from all the datasets

import json 
with open('groups.json') as fd:
    config = json.load(fd)
dataset_list = config['group_1']

In [13]:
#!pip install tqdm

In [14]:
from tqdm.auto import tqdm
import datasets
import os
dataset_description_collection = {}

basedir =  '/fsx/shared/pilev2_local_deduped'
total = 0
for dataset_name in tqdm(dataset_list, desc="loading datasets"):
    dataset_path = os.path.join(basedir, dataset_name)
    ds = datasets.load_from_disk(dataset_path, fs=None, keep_in_memory=None)
    dataset_description_collection[dataset_name] = ds
    total += len(ds)


  from .autonotebook import tqdm as notebook_tqdm
loading datasets: 100%|██████████| 11/11 [01:07<00:00,  6.13s/it]


In [25]:
total

37203245

In [26]:
total **2

1384081438530025

In [17]:
datasets.config.IN_MEMORY_MAX_SIZE

0.0

In [23]:
# Break down the Embed func

In [24]:
from typing import List, Tuple, Any, Dict, Iterable

from itertools import tee


def ngrams(sequence: List[str], n: int) -> Iterable:
    """
    Directly taken from nltk package to avoid dependency.

    Parameters
    ----------
    sequence : list
        The sequence of items to be n-grammed.
    n : int
        The order of the n-grams to be extracted.

    Returns
    -------
    Iterable
        The n-grams generated from the sequence.
    """
    iterables = tee(sequence, n)
    for i, sub_iterable in enumerate(iterables):
        for _ in range(i):
            next(sub_iterable, None)
    return zip(*iterables)

In [None]:
# !pip install nltk
import nltk

In [39]:
# An n-gram is a contiguous sequence of n items
# An n-gram of size 1 is referred to as a unigram
list(nltk.ngrams('1234', 5))

[]

In [40]:
list(nltk.ngrams('123456', 5))

[('1', '2', '3', '4', '5'), ('2', '3', '4', '5', '6')]

In [28]:
def ngrams(text, n):
    n-=1
    return [text[i-n:i+1] for i,char in enumerate(text)][n:] 

In [19]:
keys = list(dataset_description_collection.keys())

In [22]:
keys[0]

'arXiv_ver2'

In [72]:
# WARNING THIS LOADS 130GB IN MEMORY
# content = dataset_description_collection[keys[0]]['text'][0]
ngram_size = 5
import hashlib
import struct

def sha1_hash32(data: bytes) -> int:
    """
    Compute hash32 (int) value of a sequence of bytes

    Parameters
    ----------
    data : bytes

    Returns
    -------
    int
    """
    # In cryptography, SHA-1 (Secure Hash Algorithm 1) is a cryptographically broken 
    # but still widely used hash function which takes an input and 
    # produces a 160-bit (20-byte) hash value known as a message digest 
    # typically rendered as a hexadecimal number, 40 digits long. 
    digest = hashlib.sha1(data).digest()
    top4bytes = digest[:4]
    # given 4 bytes pack it to a single uint32
    hash32bit , *_ = struct.unpack("<I", top4bytes)
    return hash32bit


def convert_chunks_to_hash_list(chunks):
   return [sha1_hash32(chunk) for chunk in chunks]


for ix, row in enumerate(dataset_description_collection['arXiv_ver2']):
    if ix == 3:
        break
    # don't use empty/only whitespace documents 
    document = row['text'].strip()
    # TODO: apply utf8 cleanup step before
    if not document:
        continue
    chunks = NON_ALPHA.split(document)
    content_as_ngram = set(" ".join(ngram).encode("utf-8") # TODO: NORMALIZE UTF8 
                                    for ngram in ngrams(chunks, ngram_size))
    # for each document we have a bunch of chunks. each chunk is UNIQUE within the chunk_set
    hash_values = np.array(convert_chunks_to_hash_list(content_as_ngram), dtype=np.uint64)  # noqa: E501
    # vector of hash32 values of the chunked document
    break

In [73]:
num_perm = 256
SEED = 42
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
MERSENNE_TWISTER_RNG = np.random.RandomState(SEED)
# Container for the slow Mersenne Twister pseudo-random number generator. 
# Consider using a different BitGenerator with the Generator container instead.
# TODO: USE NEW RANDOM GENERATION CODE FOR *SPEED*

MAX_HASH = np.uint64((1 << 32) - 1)
MERSENNE_PRIME = np.uint64((1 << 61) - 1)
DATA_SIZE = len(ds)

In [74]:
def np64array(int_array):
    return np.array(int_array, dtype=np.uint64)

In [75]:
MERSENNE_PRIME

2305843009213693951

In [76]:
# Return random integers from low (inclusive) to high (exclusive).
# Return random integers from the “discrete uniform” distribution of the specified dtype in the “half-open” interval [low, high). If high is None (the default), then results are from [0, low).

In [77]:
partitions = [
    (
        MERSENNE_TWISTER_RNG.randint(1, MERSENNE_PRIME, dtype=np.uint64),
        MERSENNE_TWISTER_RNG.randint(0, MERSENNE_PRIME, dtype=np.uint64),
    )
    for _ in range(num_perm)
]

PERMUTATIONS = np64array(partitions)

In [78]:
PERMUTATIONS.shape

(256, 2)

In [79]:
permutations = PERMUTATIONS.T

In [85]:
a, b = permutations

In [87]:
len(document), len(NON_ALPHA.split(document)), *a.shape, hash_values.size

(63260, 19132, 256, 10186)

In [92]:
table_hash = np.tile(a, (len(hash_values), 1))

In [94]:
table_hash.shape
# for each hash we have a random number

(10186, 256)

In [96]:
b.shape

(256,)

In [99]:
X = (hash_values * table_hash.T)

In [100]:
phv = np.bitwise_and( X % MERSENNE_PRIME, MAX_HASH) 
# np.bitwise_and: Compute the bit-wise AND of two arrays element-wise.


In [103]:
phv.shape
# Number of partitions X number of unique chunks (as hash32)

(256, 10186)

In [108]:
hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH

In [109]:
hashvalues.shape, phv.shape

((256,), (256, 10186))

In [None]:
threshold = 0.7
num_perm = 256
B, R = optimal_param(threshold, num_perm)

In [None]:
HASH_RANGES = [(i * R, (i + 1) * R) for i in range(B)]

In [114]:
hash_value_for_each_partition = np.vstack([phv.T, hashvalues]).min(axis=0)

# Hs = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges]
# return {"__signatures__": Hs, "__id__": idx}

In [116]:
!pip install scipy

Collecting scipy
  Downloading scipy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[K     |████████████████████████████████| 34.5 MB 33.5 MB/s eta 0:00:01
Installing collected packages: scipy
Successfully installed scipy-1.10.0


In [119]:
import scipy
integrate = scipy.integrate.quad

In [122]:
integrate?

[0;31mSignature:[0m
[0mintegrate[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfunc[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0ma[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mb[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margs[0m[0;34m=[0m[0;34m([0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfull_output[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mepsabs[0m[0;34m=[0m[0;36m1.49e-08[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mepsrel[0m[0;34m=[0m[0;36m1.49e-08[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlimit[0m[0;34m=[0m[0;36m50[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpoints[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mweight[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwvar[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwopts[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmaxp1[0m[0;34m=[0m[0;36m50[0m[0;34

In [120]:
def probability(b, x, r):
    return (1 - 
        (1 - (x ** r))
            ** b)

How to determine the B, R optimal params given a 
SIMILARITY threshold and a number of permutations

This code defines a function false_positive_probability(threshold: float, b: int, r: int) that calculates the probability of a false positive in the context of Locality-Sensitive Hashing (LSH).

This function takes in three arguments:

- threshold is a float representing a similarity threshold
- b is an integer representing the number of hash tables used
- r is an integer representing the number of hash functions per table

It uses an inner function proba(s) which calculates the probability that any two items that are similar to each other by at least s will be hashed to the same bucket. The inner function returns a probability value, and this probability is used in the outer function.

The outer function then use integrate function which is not provided here and calculates the area under the curve of this probability function between 0.0 and the given threshold and returns the result as the probability of false positive.

This code is based on datasketch library, The datasketch library provides a variety of algorithms for performing approximate nearest neighbor search in high-dimensional spaces, including LSH.

In [126]:
r = 1
def proba_original(s):
    return 1 - (1 - s ** float(r)) ** float(b)

In [127]:
threshold = 0.7

In Locality-Sensitive Hashing (LSH), the idea is to hash similar items to the same "bucket" with high probability.

- To increase the chances of this happening, multiple hash tables are used. 
- Each table uses a different hash function, and when an item is hashed, it is hashed to each of the tables using the corresponding hash function.

- The parameter b in the false_positive_probability function represents the number of hash tables used.

- The larger the value of b, the more hash tables are used, and the higher the probability that similar items will be hashed to the same bucket in at least one of the tables. 

- This means that using a larger value of b will increase the chances of correctly identifying similar items, but it will also increase the number of hash tables needed to be searched, which can increase the running time.

- It is worth noting that it is a trade-off between recall and precision when it comes to the number of hash tables. 

- In general, more hash tables would increase the recall but decrease the precision of the search and vice versa.

In [128]:
integrate(proba_original, 0.0, threshold)

TypeError: only size-1 arrays can be converted to Python scalars

In [None]:
min_error = float("inf")
opt = (0, 0)

for hash_table in range(1, num_perm + 1):
    max_r = int(num_perm / b)
    for r in range(1, max_r + 1):
        pass

a, _ = integrate(probability, 0.0, threshold)

In [115]:
import scipy
# Integrate func from a to b (possibly infinite interval) using a technique from the Fortran library QUADPACK.

ModuleNotFoundError: No module named 'scipy'

In [112]:
hashvalues[:3]

array([896308, 690133,   6451], dtype=uint64)

In [None]:
chunks = NON_ALPHA.split(content)

In [None]:
def embed_func(
    content: str,
    idx: int,
    *,
    num_perm: int,
    ngram_size: int,
    hashranges: List[Tuple[int, int]],
    permutations: np.ndarray,
) -> Dict[str, Any]:
    """
    Combined with some datasketch code to avoid dependency.

    Parameters
    ----------
    content : str
        The content to be embedded.
    idx : int
        The index of the content.
    num_perm : int
        The number of permutations.
    ngram_size : int
        The size of n-grams.
    hashranges : List[Tuple[int, int]]
        The ranges of hash values.
    permutations : np.ndarray
        The permutations for the minhash.

    Returns
    -------
    Dict[str, Any]
        The hash values in each range and the index.
    """
    hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH
    # [MAX_HASH, MAX_HASH, .... , num_perm]
    # 1. split the content with non alpha numeric char
    # 2. create an ngram from the content_chunks to create a tokens SET
    # 3. FOR EACH token in the SET create a hash (hv)
    # 4.
    chunks = NON_ALPHA.split(content)
    content_as_ngram = {" ".join(ngram) # TODO: are we losing data (punctuaction) here ??
        for ngram in ngrams(chunks, ngram_size)}
    hv = np.array([sha1_hash32(token.encode("utf-8")) for token in content_as_ngram], dtype=np.uint64)  # noqa: E501
    a, b = permutations
    phv = np.bitwise_and(((hv * np.tile(a, (len(hv), 1)).T).T + b) % MERSENNE_PRIME, MAX_HASH)  # noqa: E501
    hashvalues = np.vstack([phv, hashvalues]).min(axis=0)
    Hs = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges]
    return {"__signatures__": Hs, "__id__": idx}

In [129]:
class UnionFind:
    def __init__(self):
        self.parent: Dict[int, int] = {}

    def find(self, x):
        if x not in self.parent:
            self.parent[x] = x
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        px = self.find(x)
        py = self.find(y)
        self.parent[px] = self.parent[py] = min(px, py)

A Union-find data structure is an algorithm that keeps track of a set of elements partitioned into a number of disjoint (non-overlapping) subsets. The Union-find algorithm is used to keep track of which elements are in the same subset and quickly perform union and find operations on these subsets.

In the provided code, the UnionFind class is used to keep track of subsets of integers, which are represented by instances of the UnionFind class. The class has three methods:

The __init__() method creates an empty dictionary named parent which is used to store the parent-child relationship of the elements in the subsets.

The find(x) method takes an integer x as an input and returns the unique identifier of the subset which element x belongs to. This is done by following the chain of parent pointers up the tree until the parent pointer of x points to itself, which indicates that x is the root element of its subset.

The union(x, y) method takes two integers x and y as input and unite two subsets which the element x and y respectively belong to. The method performs find on x and y to find the unique identifier of the subsets, and then set the root of the smaller set to point to the root of the larger set, effectively merging the two subsets into one.

The parent dictionary is used to store the parent-child relationship of the elements in the subsets, and the find and union method use this dictionary to keep track of which elements are in the same subset and perform union and find operations on these subsets. The parent is an important part of Union-find data structure, it is used to maintain the disjoint sets of element in memory, to perform the union operation and find operation efficiently.