In [1]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

/homes/gws/sgehman/language-model-toxicity


In [2]:
import logging
from typing import List

from joblib import Memory, Parallel, delayed, dump
from lsh import cache, minhash
import numpy as np
from itertools import chain, islice

from utils.constants import DATA_DIR, OUTPUT_DIR
from utils.webtext import load_meta, delayed_corpus, split_docs

# Create joblib memory
mem = Memory(OUTPUT_DIR / 'cache' / 'webtext_overlap')

In [3]:
cached_meta = mem.cache(load_meta)

wt_meta = cached_meta(DATA_DIR / 'webtext')
wt_files = wt_meta[0]
owtc_meta = cached_meta(DATA_DIR / 'openwebtext_bpe')
owtc_files = owtc_meta[0]

## Find duplicates with LSH

In [4]:
def train(document_feed, char_ngram=3, seeds=100, bands=5, hashbytes=4, n_jobs=1):
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=char_ngram, hashbytes=hashbytes)
    if seeds % bands != 0:
        raise ValueError('Seeds has to be a multiple of bands. {} % {} != 0'.format(seeds, bands))
    
    out = Parallel(n_jobs=n_jobs, verbose=1, backend='threading')(
        delayed(lambda doc_id, doc: (doc_id, hasher.fingerprint(doc)))(doc_id, doc) 
        for doc_id, doc in document_feed
    )

    lshcache = cache.Cache(num_bands=bands, hasher=hasher)
    for doc_id, fingerprint in out:
        lshcache.add_fingerprint(fingerprint, doc_id=doc_id)
    
    return hasher, lshcache

In [5]:
def corpus_iter(files: List[Path], name: str):
    i = 0
    for file in files:
        print("Loading file:", file)
        shard = np.load(file)
        docs = split_docs(shard)
        for doc in docs:
            # Yield name and doc as 4-byte
            yield (i, name), doc.astype(np.int32).tobytes()
            i += 1

In [None]:
corpus = chain(
    corpus_iter(wt_files, name='wt'),
    corpus_iter(owtc_files, name='owtc')
)
hasher, cache = train(corpus, n_jobs=96)

[Parallel(n_jobs=96)]: Using backend ThreadingBackend with 96 concurrent workers.


Loading file: /homes/gws/sgehman/data/language-model-toxicity/data/webtext/x00_tokens.npy


[Parallel(n_jobs=96)]: Done  11 tasks      | elapsed:    4.6s
[Parallel(n_jobs=96)]: Done 261 tasks      | elapsed:    4.7s
[Parallel(n_jobs=96)]: Done 611 tasks      | elapsed:    4.8s
[Parallel(n_jobs=96)]: Done 1061 tasks      | elapsed:    5.1s
[Parallel(n_jobs=96)]: Done 1611 tasks      | elapsed:    5.2s
[Parallel(n_jobs=96)]: Done 2261 tasks      | elapsed:    5.4s
[Parallel(n_jobs=96)]: Done 3011 tasks      | elapsed:    5.7s
[Parallel(n_jobs=96)]: Done 3861 tasks      | elapsed:    6.0s
[Parallel(n_jobs=96)]: Done 4811 tasks      | elapsed:    6.2s
[Parallel(n_jobs=96)]: Done 5861 tasks      | elapsed:    6.6s
[Parallel(n_jobs=96)]: Done 7011 tasks      | elapsed:    7.1s
[Parallel(n_jobs=96)]: Done 8261 tasks      | elapsed:    7.5s
[Parallel(n_jobs=96)]: Done 9611 tasks      | elapsed:    8.1s
[Parallel(n_jobs=96)]: Done 11061 tasks      | elapsed:    8.5s
[Parallel(n_jobs=96)]: Done 12611 tasks      | elapsed:    9.0s
[Parallel(n_jobs=96)]: Done 14261 tasks      | elapsed: 

In [None]:
all_duplicates = cache.get_all_duplicates()

In [None]:
dump(all_duplicates, 'webtext_dups.joblib')

## Filter candidates

In [None]:
candidate_duplicates = [(x, y) for x, y in duplicates if x[1] != y[1]]
filtered_duplicates = cache.filter_candidates(not_same_corpus, min_jaccard=0.9)

In [44]:
wt_corpus = delayed_corpus(wt_meta)
owtc_corpus = delayed_corpus(owtc_meta)

In [49]:
def load_example(x):
    idx = x[0]
    if x[1] == 'wt':
        return wt_corpus[idx].compute()
    elif x[1] == 'owtc':
        return owtc_corpus[idx].compute()
    else:
        raise RuntimeError

for x, y in filtered_duplicates:
    print(x, y)
    x = load_example(x)
    y = load_example(y)
    print(x[:20], '\n', y[:20])
    break

(67320, 'owtc') (94085, 'wt')
[  464 16009   286  2947  6736   874    11   543  5839   262  8919    12
  3106 16009   286 23729    12    46    12    44] 
 [  464 16009   286  2947  6736   874    11   543  5839   262  8919    12
  3106 16009   286 23729    12    46    12    44]
