In [1]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

/homes/gws/sgehman/language-model-toxicity


In [38]:
import logging
from typing import List
import pickle
import re

from joblib import Memory, Parallel, delayed, dump, load
from lsh import cache, minhash
import numpy as np
from itertools import chain, islice
from tqdm.auto import tqdm
import multiprocessing as mp

from utils.constants import DATA_DIR, OUTPUT_DIR

# Load webtext fingerprints and add to LSHCache

In [4]:
dups_dir = OUTPUT_DIR / 'lsh_duplicates' / 'char_ngram_5_seeds_100_bands_20_str'

fingerprints = np.load(dups_dir / 'fingerprints.npy')
with open(dups_dir / 'doc_ids.pkl', 'rb') as f:
    doc_ids = pickle.load(f)

In [5]:
assert len(fingerprints) == len(doc_ids)

In [6]:
webtext_len = 8_282_020

In [7]:
wt_fingerprints, wt_doc_ids = fingerprints[:webtext_len], doc_ids[:webtext_len]

In [10]:
seeds = 100
char_ngram = 5
hashbytes = 4
bands = 10
random_state = 42

In [40]:
hasher = minhash.MinHasher(seeds=seeds, char_ngram=char_ngram, random_state=random_state, hashbytes=hashbytes)
if seeds % bands != 0:
    raise ValueError('Seeds has to be a multiple of bands. {} % {} != 0'.format(seeds, bands))

lshcache = cache.Cache(num_bands=bands, hasher=hasher)

In [41]:
def add_fingerprints(cache, doc_ids, fingerprints):
    for doc_id, fingerprint in zip(tqdm(doc_ids), fingerprints):
        cache.add_fingerprint(fingerprint, doc_id=doc_id)
        
add_fingerprints(lshcache, wt_doc_ids, wt_fingerprints)

HBox(children=(FloatProgress(value=0.0, max=8282020.0), HTML(value='')))




# Fingerprint Generations and check for near-duplicates in WebText

In [44]:
gpt2_generations = load(DATA_DIR / 'gpt2' / 'gpt2_generations.joblib')

In [48]:
document_feed = zip(map(lambda i: (i, 'gen'), range(len(gpt2_generations))), 
                    gpt2_generations)

In [16]:
class Fingerprinter:
    def __init__(self, hasher):
        self.hasher = hasher

    def fingerprint(self, item):
        doc_id, doc = item
        return doc_id, self.hasher.fingerprint(doc)

In [49]:
fingerprinter = Fingerprinter(hasher)
gen_doc_ids = []
gen_fingerprints = []
with mp.Pool(processes=96) as pool:
    for doc_id, fingerprint in tqdm(pool.imap(fingerprinter.fingerprint, document_feed, chunksize=1_000),
                                    desc='Fingerprinting', dynamic_ncols=True, total=len(gpt2_generations)):
        gen_doc_ids.append(doc_id)
        gen_fingerprints.append(fingerprint)

HBox(children=(FloatProgress(value=0.0, description='Fingerprinting', layout=Layout(flex='2'), max=2500000.0, …




In [51]:
gen_fingerprints = np.stack(gen_fingerprints)

In [55]:
# np.save(OUTPUT_DIR / 'gpt2_generation_fingerprints.npy', gen_fingerprints)

In [58]:
# with open(OUTPUT_DIR / 'gpt2_generation_doc_ids.pkl', 'wb') as f:
#     pickle.dump(gen_doc_ids, f)

In [61]:
def get_duplicates_of(cache, fingerprint):
    candidates = set()
    for bin_i, bucket in cache.bins_(fingerprint):
        bucket_id = hash(tuple(bucket))
        candidates.update(cache.bins[bin_i][bucket_id])
    return candidates

In [69]:
duplicate_candidates = set()
pbar = tqdm(gen_fingerprints)
for f in pbar:
    c = get_duplicates_of(lshcache, f)
    duplicate_candidates.update(c)
    pbar.set_description(f'Duplicates ({len(duplicate_candidates)} found)')

HBox(children=(FloatProgress(value=0.0, max=2500000.0), HTML(value='')))

KeyboardInterrupt: 