In [1]:
from datasketch import MinHashLSH, MinHash, LeanMinHash
from datasets import load_dataset, load_from_disk, Dataset
from tqdm import tqdm 
import os 
import pickle
import multiprocessing as mp

In [2]:
SIMILARITY_THRESHOLD = 0.9
NUM_PERMS = 128 
SHINGLE_SIZE = 4

# Build MinHash LSH

In [12]:
# for resetting cassandra
lsh = MinHashLSH(
    threshold=SIMILARITY_THRESHOLD, num_perm=NUM_PERMS, storage_config={
        'type': 'cassandra',
        'basename': b'base_lsh_cassandra',
        'cassandra': {
            'seeds': ['127.0.0.1', "cassandra"],
            'keyspace': 'lsh_test',
            'replication': {
                'class': 'SimpleStrategy',
                'replication_factor': '1',
            },
            'drop_keyspace': True,
            'drop_tables': True,
        }
    }
)

In [3]:
dataset = load_dataset("./temp", split="train")

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

In [4]:
dataset

Dataset({
    features: ['text'],
    num_rows: 21245
})

In [4]:
def get_cassandra_lsh(): 
    lsh = MinHashLSH(
        threshold=SIMILARITY_THRESHOLD, num_perm=NUM_PERMS, storage_config={
            'type': 'cassandra',
            'basename': b'base_lsh_cassandra',
            'cassandra': {
                'seeds': ['127.0.0.1', "cassandra"],
                'keyspace': 'lsh_test',
                'replication': {
                    'class': 'SimpleStrategy',
                    'replication_factor': '1',
                },
                'drop_keyspace': False,
                'drop_tables': False,
            }
        }
    )

    return lsh

def _shingle(string, shingle_size=SHINGLE_SIZE):
    shings = {
        string[i : i + shingle_size].encode("utf8")
        for i in range(len(string) - shingle_size + 1)
    }
    return set(shings)

def hash_and_insert(batch, indices, session): 
    for i, row in enumerate(batch["text"]):  
        shingles = _shingle(row, shingle_size=SHINGLE_SIZE)
        # shingles = [shing.encode("utf8") for shing in shingles]

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            minhash = LeanMinHash(minhash=minhash)
            session.insert(str(indices[i]), minhash, check_duplication=False)

    return batch

def hash_and_insert_with_basename(batch, indices): 
    lsh = get_cassandra_lsh()

    with lsh.insertion_session() as session: 
        for i, row in enumerate(batch["text"]):  
            shingles = _shingle(row, shingle_size=SHINGLE_SIZE)
            # shingles = [shing.encode("utf8") for shing in shingles]

            if len(shingles) != 0: 
                minhash = MinHash(num_perm=NUM_PERMS)
                for shing in shingles: 
                    minhash.update(shing)

                minhash = LeanMinHash(minhash=minhash)
                session.insert(str(indices[i]), minhash, check_duplication=False)

    return batch

def hash_insert_mark_with_basename(batch, indices): 
    lsh = get_cassandra_lsh()

    for i, row in enumerate(batch["text"]):  
        shingles = _shingle(row, shingle_size=SHINGLE_SIZE)

        if len(shingles) != 0: 
            minhash = MinHash(num_perm=NUM_PERMS)
            for shing in shingles: 
                minhash.update(shing)

            minhash = LeanMinHash(minhash=minhash)
            query = lsh.query(minhash)
            
            if len(query) == 0: 
                lsh.insert(str(indices[i]), minhash, check_duplication=False)
                batch["is_duplicate"][i] = False
            else: 
                batch["is_duplicate"][i] = True

    return batch

In [12]:
# insert and query at the same time 
import time 

start_time = time.time() 
dataset = dataset.map(hash_insert_mark_with_basename, batched=True, batch_size=3500, num_proc=os.cpu_count(), with_indices=True)
end_time = time.time() 

print(f"method taking {end_time - start_time} seconds")

Map (num_proc=16):   0%|          | 0/21245 [00:00<?, ? examples/s]

method taking 25.935439825057983 seconds


In [None]:
# parallel solution with multiple lsh 
dataset = dataset.map(hash_and_insert_with_basename, batched=True, batch_size=2000, num_proc=os.cpu_count(), with_indices=True)

# Add duplicate status column to dataset 

In [11]:
def add_is_duplicate_column(example):
    example["is_duplicate"] = False  # Initialize to False
    return example

# Apply the function to the entire dataset
dataset = dataset.map(add_is_duplicate_column, num_proc=os.cpu_count())

Map (num_proc=16):   0%|          | 0/21245 [00:00<?, ? examples/s]

# Remove duplicate

In [9]:
def marked_duplicate(batch, indices):
    lsh = get_cassandra_lsh()
    for i, row in enumerate(batch["text"]):
        if batch["is_duplicate"][i] == False: 
            shingles = _shingle(row, shingle_size=SHINGLE_SIZE)

            if len(shingles) != 0:
                minhash = MinHash(num_perm=NUM_PERMS)
                for shing in shingles:
                    minhash.update(shing)

                query = lsh.query(minhash=minhash)

                if len(query) != 0:
                    min_id = min([int(id) for id in query])
                    if min_id != indices[i]: 
                        batch["is_duplicate"][i] = True

    return batch


In [None]:
dataset = dataset.map(marked_duplicate, batched=True, with_indices=True, batch_size=2000, num_proc=os.cpu_count())

In [14]:
lsh = get_cassandra_lsh()

In [9]:
# clean duplicate from dataset 
dataset = dataset.filter(lambda x: x["is_duplicate"] == False, num_proc=os.cpu_count())

Filter (num_proc=16):   0%|          | 0/21245 [00:00<?, ? examples/s]

In [10]:
# log datset after cleaning 
dataset

Dataset({
    features: ['text', 'is_duplicate'],
    num_rows: 18447
})

In [None]:
# save cleaned dataset to disk 
dataset.save_to_disk("deduplicated_data")