In [1]:
from datasketch import MinHashLSH, MinHash, LeanMinHash
from datasets import load_dataset, load_from_disk
from tqdm import tqdm 
import os 
import pickle

In [2]:
SIMILARITY_THRESHOLD = 0.8
NUM_PERMS = 128 
SHINGLE_SIZE = 4

In [3]:
# lsh = MinHashLSH(threshold=SIMILARITY_THRESHOLD, num_perm=NUM_PERMS)
lsh = MinHashLSH(
    threshold=SIMILARITY_THRESHOLD, num_perm=NUM_PERMS, storage_config={
        'type': 'cassandra',
        'basename': b'base_lsh_cassandra',
        'cassandra': {
            'seeds': ['127.0.0.1', "cassandra"],
            'keyspace': 'lsh_test',
            'replication': {
                'class': 'SimpleStrategy',
                'replication_factor': '1',
            },
            'drop_keyspace': True,
            'drop_tables': True,
        }
    }
)

In [3]:
dataset = load_dataset("./temp", split="train")

Resolving data files:   0%|          | 0/201 [00:00<?, ?it/s]

In [4]:
def get_cassandra_lsh(): 
    lsh = MinHashLSH(
        threshold=SIMILARITY_THRESHOLD, num_perm=NUM_PERMS, storage_config={
            'type': 'cassandra',
            'basename': b'base_lsh_cassandra',
            'cassandra': {
                'seeds': ['127.0.0.1', "cassandra"],
                'keyspace': 'lsh_test',
                'replication': {
                    'class': 'SimpleStrategy',
                    'replication_factor': '1',
                },
                'drop_keyspace': False,
                'drop_tables': False,
            }
        }
    )

    return lsh

def _shingle(string, shingle_size=SHINGLE_SIZE):
    shings = {
        string[i : i + shingle_size].encode("utf8")
        for i in range(len(string) - shingle_size + 1)
    }
    return set(shings)

def hash_and_insert(batch, indices, session): 
    for i, row in enumerate(batch["text"]):  
        shingles = _shingle(row, shingle_size=SHINGLE_SIZE)
        # shingles = [shing.encode("utf8") for shing in shingles]

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            minhash = LeanMinHash(minhash=minhash)
            session.insert(str(indices[i]), minhash, check_duplication=False)

    return batch

def hash_and_insert_with_basename(batch, indices): 
    lsh = get_cassandra_lsh()

    with lsh.insertion_session() as session: 
        for i, row in enumerate(batch["text"]):  
            shingles = _shingle(row, shingle_size=SHINGLE_SIZE)
            # shingles = [shing.encode("utf8") for shing in shingles]

            if len(shingles) != 0: 
                minhash = MinHash(num_perm=NUM_PERMS)
                for shing in shingles: 
                    minhash.update(shing)

                minhash = LeanMinHash(minhash=minhash)
                session.insert(str(indices[i]), minhash, check_duplication=False)

    return batch

def hash_insert_mark_with_basename(batch, indices): 
    lsh = get_cassandra_lsh()

    for i, row in enumerate(batch["text"]):  
        shingles = _shingle(row, shingle_size=SHINGLE_SIZE)
        # shingles = [shing.encode("utf8") for shing in shingles]

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            minhash = LeanMinHash(minhash=minhash)
            query = lsh.query(minhash)
            
            if len(query) == 0: 
                lsh.insert(str(indices[i]), minhash, check_duplication=False)
                batch["is_duplicate"][i] = False
            else: 
                batch["is_duplicate"][i] = True

    return batch

In [None]:
# normal iterative solution 
with lsh.insertion_session() as session: 
    for i, sample in tqdm(enumerate(dataset)): 
        shingles = _shingle(sample["text"], shingle_size=SHINGLE_SIZE)
        # shingles = [shing.encode("utf8") for shing in shingles]

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            minhash = LeanMinHash(minhash=minhash)
            session.insert(str(i), minhash, check_duplication=False)


    with open("./lsh.pkl", "wb") as f:
        pickle.dump(lsh, f)

In [5]:
# parallel solution with multiple lsh 
dataset = dataset.map(hash_and_insert_with_basename, batched=True, batch_size=10000, num_proc=os.cpu_count(), with_indices=True)

Map (num_proc=16):   0%|          | 0/214424 [00:00<?, ? examples/s]

In [None]:
# parallel solution with 1 LSH 
with lsh.insertion_session() as session: 
    dataset = dataset.map(hash_and_insert, batched=True, batch_size=2000, num_proc=os.cpu_count(), with_indices=True, fn_kwargs={"session": session})

    with open("./lsh.pkl", "wb") as f:
        pickle.dump(lsh, f)

In [10]:
lsh = get_cassandra_lsh()

In [6]:
def add_is_duplicate_column(example):
    example["is_duplicate"] = True  # Initialize to False
    return example

# Apply the function to the entire dataset
dataset = dataset.map(add_is_duplicate_column, num_proc=os.cpu_count())

Map (num_proc=16):   0%|          | 0/214424 [00:00<?, ? examples/s]

In [7]:
def marked_duplicate(batch, indices):
    lsh = get_cassandra_lsh()
    for i, row in enumerate(batch["text"]):
        try:
            lsh.__contains__(str(indices[i]))
            shingles = _shingle(row, shingle_size=SHINGLE_SIZE)

            if len(shingles) != 0:
                minhash = MinHash(num_perm=NUM_PERMS)
                for shing in shingles:
                    minhash.update(shing)

                    query = lsh.query(minhash=minhash)

                    if len(query) == 0:
                        batch["is_duplicate"][i] = False
                    else:
                        for id in query:
                            if id == indices[i]:
                                batch["is_duplicate"][i] = False
                            else:
                                lsh.remove(id)
        except Exception:
            continue

    return batch

In [19]:
test_text = dataset[0]["text"]

minhash = MinHash(num_perm=NUM_PERMS)
shingles = _shingle(test_text)

for shing in shingles: 
    minhash.update(shing) 

lean = LeanMinHash(minhash=minhash)
lsh.insert("0", lean, check_duplication=False)

lsh.__contains__("0")

True

In [None]:
dataset = dataset.map(marked_duplicate, batched=True, with_indices=True, batch_size=10000, num_proc=os.cpu_count())

Map (num_proc=16):   0%|          | 0/214424 [00:00<?, ? examples/s]

In [None]:
for i, sample in tqdm(enumerate(dataset), total=len(dataset)): 
    if sample["is_duplicate"] == False: 
        shingles = _shingle(sample["text"], shingle_size=SHINGLE_SIZE)

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            lean = LeanMinHash(minhash=minhash)

            query = lsh.query(lean) 

            for id in query: 
                id = int(id) 
                if id != i: 
                    dataset[id]["is_duplicate"] = True

In [None]:
dataset = dataset.filter(lambda x: x["is_duplicate"] == False, num_proc=os.cpu_count())

In [None]:
dataset.save_to_disk("deduplicated_data")

In [6]:
dataset = load_from_disk("deduplicated_data")

In [None]:
dataset

In [4]:
with open("./lsh.pkl", "rb") as f:
    session = pickle.load(f)

In [None]:
text = dataset[1]["text"]
print(text)

In [21]:
shingles = _shingle(text)

In [22]:
minhash = MinHash(num_perm=NUM_PERMS)

for shing in shingles: 
    minhash.update(shing) 

In [None]:
session.query(minhash)

In [None]:
session.is_empty()

In [1]:
from datasketch import MinHash, MinHashLSH
import numpy as np

minhashes = []
for i in range(100):
    m = MinHash(num_perm=128)
    m.update_batch(np.random.randint(low=0, high=30, size=10))
    minhashes.append(m)

session = MinHashLSH(threshold=0.5, num_perm=128)
with session.insertion_session() as session:
    for i, m in enumerate(minhashes):
        session.insert(i, m)

In [None]:
session.is_empty()