In [None]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

In [2]:
from typing import List
import logging
from pathlib import Path
import tempfile

import dask
import dask.array as da
from joblib import Memory
import numpy as np
from tqdm.auto import tqdm

from utils.constants import DATA_DIR, OUTPUT_DIR

# Disable logging from transformers
logging.disable(logging.CRITICAL)

# Create joblib memory
mem = Memory(OUTPUT_DIR / 'cache' / 'webtext_overlap')

In [3]:
EOS = 50256
vocab_size = EOS + 1

def load_meta(bpe_dir: Path):
    files = [file for file in bpe_dir.iterdir() if file.suffix == '.npy']
    meta = [(np.count_nonzero(array == EOS) - 1, array.dtype)
            for array 
            in tqdm(map(np.load, files), total=len(files), desc='Loading meta')]
    shapes, dtypes = zip(*meta)
    return files, shapes, dtypes[0]

# Cache calls to load_meta
load_meta = mem.cache(load_meta)

## Load metadata

In [4]:
wt_dir = DATA_DIR / 'webtext'
wt_meta = load_meta(wt_dir)

In [5]:
owtc_dir = DATA_DIR / 'openwebtext_bpe'
owtc_meta = load_meta(owtc_dir)

## Load corpus

In [5]:
def split_docs(tokens: np.array) -> np.array:
    idx = np.nonzero(tokens == EOS)[0]
    docs = np.split(tokens, idx)
    docs = [doc[1:] for doc in docs if len(doc) > 1]
    return np.array(docs)

def load_corpus_into_memory(files: List[Path]):
    corpus = []
    for shard in tqdm(map(np.load, files), total=len(files)):
        corpus.extend(split_docs(shard))
    return corpus

delayed_load = dask.delayed(lambda f: split_docs(np.load(f)))

def load_corpus(meta):
    files, shapes, dtype = meta
    
    # Create delayed arrays
    delayed_arrays = list(map(delayed_load, files))
        
    # Concatenate arrays
    corpus = da.concatenate([da.from_delayed(array, shape=(shape,), dtype=dtype) 
                             for array, shape in zip(delayed_arrays, shapes)])

    return corpus

In [6]:
# Load OWTC into memory
# owtc_corpus = []
# for shard in tqdm(map(np.load, owtc_files), total=len(owtc_files)):
#     owtc_corpus.extend(split_docs(shard))

In [7]:
wt_corpus = load_corpus(wt_meta)
wt_corpus

Unnamed: 0,Array,Chunk
Bytes,33.13 MB,1.66 MB
Shape,"(8282020,)","(414101,)"
Count,60 Tasks,20 Chunks
Type,int32,numpy.ndarray
"Array Chunk Bytes 33.13 MB 1.66 MB Shape (8282020,) (414101,) Count 60 Tasks 20 Chunks Type int32 numpy.ndarray",8282020  1,

Unnamed: 0,Array,Chunk
Bytes,33.13 MB,1.66 MB
Shape,"(8282020,)","(414101,)"
Count,60 Tasks,20 Chunks
Type,int32,numpy.ndarray


In [9]:
owtc_corpus = load_corpus(owtc_meta)
owtc_corpus

Unnamed: 0,Array,Chunk
Bytes,16.01 MB,803.01 kB
Shape,"(8003003,)","(401504,)"
Count,60 Tasks,20 Chunks
Type,uint16,numpy.ndarray
"Array Chunk Bytes 16.01 MB 803.01 kB Shape (8003003,) (401504,) Count 60 Tasks 20 Chunks Type uint16 numpy.ndarray",8003003  1,

Unnamed: 0,Array,Chunk
Bytes,16.01 MB,803.01 kB
Shape,"(8003003,)","(401504,)"
Count,60 Tasks,20 Chunks
Type,uint16,numpy.ndarray


## Features

In [9]:
doc

array([38248,    12, 32117,   402,   562, 22161,    25,   198,   198,
        2215,   262,  4196, 11745,  1628,  2067,    11,   262,  1994,
        2551,   373,   262,  3572,   286,  3788,  3113,    13, 10358,
        4196,  1949,   284,   787,   257,   705, 36890,     6,  2196,
         286,  7294,  1395,   357,   292,   340,   373,   788,  1900,
       19427,  1514,   287,   257,  3190,   649,  4571,    30,   198,
         198,  1026,  3568,   326,   257,   649,  4571,   743,   423,
         587, 29850,    13,  1629,   262,   640,   326,  4196,   338,
       11745,  1628,  2540,    11,   281,  4196,  6538,   290,  1966,
        1355, 11949,  4438, 18358,  3457,    13,   720,  7410,    42,
         329,   257,  1355,  2640,   366,  8189, 10285,     1,   532,
         655,   262,  2438,    11,   645,  1104,    11,   645, 36506,
          13,   383, 11949,   373,  4047, 14462,   329,   465,  5032,
         287, 31993,  3788,   284, 22594,  6890,    26,  1355,  2640,
         373,   257,

In [11]:
b = 'hello'.encode('utf-8')

In [12]:
b2 = doc.tobytes()

In [13]:
len(b2)

620

In [14]:
len(b2) // len(doc)

4

In [10]:
from lsh import cache, minhash

In [19]:
hasher = minhash.MinHasher(seeds=80, char_ngram=5, hashbytes=4)  # since we're using 32-byte ints, we can use 4 bytes
lshcache = cache.Cache(bands=10, hasher=hasher)

In [20]:
hasher.fingerprint(b2)

array([2150305125, 2160392694, 2158075182, 2150469502, 2152799544,
       2147858006, 2154660614, 2160704365, 2156196008, 2149272090,
       2148068488, 2152160651, 2152305272, 2171110878, 2159990308,
       2157636046, 2156996108, 2147948018, 2162669319, 2151293792,
       2147677155, 2149692153, 2156417286, 2152144853, 2158103450,
       2157017166, 2150197974, 2157353449, 2155566940, 2152023485,
       2166517483, 2155764073, 2152995929, 2148231255, 2149398974,
       2151033421, 2166144573, 2157676932, 2154821761, 2150564906,
       2148510238, 2157834002, 2157552631, 2158241615, 2154368359,
       2167131450, 2161318122, 2160361237, 2152173231, 2161572934,
       2148841526, 2157222636, 2152806184, 2161536730, 2155921600,
       2181214496, 2161523500, 2154681551, 2150251178, 2151492337,
       2154772495, 2155525684, 2156898603, 2149267936, 2156676763,
       2160752189, 2147618819, 2154071920, 2184345986, 2150133564,
       2155667299, 2166296121, 2150779422, 2152638220, 2155156

In [21]:
out = _

In [22]:
len(out)

80

In [27]:
from joblib import Parallel, delayed

# def candidate_duplicates(document_feed, bpe_ngram=3, seeds=100, bands=5, hashbytes=4, n_jobs=1):
        
        
#     for i_line, line in enumerate(document_feed):
#         line = line.decode('utf8')
#         docid, headline_text = line.split('\t', 1)
#         fingerprint = hasher.fingerprint(headline_text.encode('utf8'))
        
#         # in addition to storing the fingerpring store the line
#         # number and document ID to help analysis later on
#         lshcache.add_fingerprint(fingerprint, doc_id=(i_line, docid))

#     candidate_pairs = set()
#     for b in lshcache.bins:
#         for bucket_id in b:
#             if len(b[bucket_id]) > 1:
#                 pairs_ = set(itertools.combinations(b[bucket_id], r=2))
#                 candidate_pairs.update(pairs_)
    
#     return candidate_pairs

In [32]:
# corpus = wt_corpus.compute()

In [40]:
def train(document_feed, char_ngram=3, seeds=100, bands=5, hashbytes=4, n_jobs=1):
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=char_ngram, hashbytes=hashbytes)
    if seeds % bands != 0:
        raise ValueError('Seeds has to be a multiple of bands. {} % {} != 0'.format(seeds, bands))

    lshcache = cache.Cache(num_bands=bands, hasher=hasher)
    
    hash_bytes = lambda x: hasher.fingerprint(x.tobytes())
    fingerprints = Parallel(n_jobs=n_jobs, verbose=1)(
        delayed(hash_bytes)(doc) for doc in document_feed
    )

    # TODO: use document ids
    for i, fingerprint in enumerate(fingerprints):
        lshcache.add_fingerprint(fingerprint, doc_id=i)
    
    return hasher, lshcache

In [42]:
train(corpus, n_jobs=80)

[Parallel(n_jobs=80)]: Using backend LokyBackend with 80 concurrent workers.
[Parallel(n_jobs=80)]: Done  41 tasks      | elapsed:    0.4s
[Parallel(n_jobs=80)]: Done 422 tasks      | elapsed:    1.0s
[Parallel(n_jobs=80)]: Done 1122 tasks      | elapsed:    1.8s
[Parallel(n_jobs=80)]: Done 2022 tasks      | elapsed:    2.9s
[Parallel(n_jobs=80)]: Done 3122 tasks      | elapsed:    4.2s
[Parallel(n_jobs=80)]: Done 4422 tasks      | elapsed:    5.7s
[Parallel(n_jobs=80)]: Done 5922 tasks      | elapsed:    7.8s
[Parallel(n_jobs=80)]: Done 7622 tasks      | elapsed:    9.9s
[Parallel(n_jobs=80)]: Done 9522 tasks      | elapsed:   12.1s
[Parallel(n_jobs=80)]: Done 11622 tasks      | elapsed:   14.7s
[Parallel(n_jobs=80)]: Done 13922 tasks      | elapsed:   17.5s
[Parallel(n_jobs=80)]: Done 16422 tasks      | elapsed:   20.8s
[Parallel(n_jobs=80)]: Done 19122 tasks      | elapsed:   24.0s
[Parallel(n_jobs=80)]: Done 22022 tasks      | elapsed:   27.5s
[Parallel(n_jobs=80)]: Done 25122 task

KeyboardInterrupt: 