In [1]:
import numpy as np

# add parent directory to path
import os, sys
sys.path.append('..')

from utils.mmap_dataset import MMapIndexedDataset
import dask
import dask.array as da
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
from dask.diagnostics import ProgressBar
from dask.distributed import Lock
from tqdm import tqdm
from dask.diagnostics import ProgressBar
from numpy.lib.stride_tricks import sliding_window_view

In [2]:
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(cores=4,
                       processes=2,
                       memory="32GB",
                       walltime="48:00:00",
                       # project="fiete",
                       queue="normal")
cluster.scale(jobs=128)

In [3]:
from dask.distributed import Client
client = Client(cluster)

In [4]:
cluster

0,1
Dashboard: http://172.16.20.133:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://172.16.20.133:46795,Workers: 0
Dashboard: http://172.16.20.133:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [5]:
dataset = MMapIndexedDataset('/om/user/sunnyd/data/datasets--EleutherAI--pile-standard-pythia-preshuffled-merged/document', skip_warmup = True)

    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...


In [6]:
# indices = np.load('0-1-43-idx.npy')
# emergent = []
# for idx in tqdm(indices):
#     emergent.append(dataset[int(idx)][:64])
# emergent = da.from_array(np.stack(emergent, axis=0), chunks=(100, 64))
# da.to_npy_stack(
#         'matching_data/',
#         emergent,
#         axis=0)

In [7]:

@dask.delayed
def load_chunk(path, ptr, total_size, dtype):
    bin_buffer_mmap = np.memmap(path, mode="r", order="C")
    bin_buffer = memoryview(bin_buffer_mmap)
    data = np.frombuffer(bin_buffer, 
                         dtype=dtype, 
                         count=total_size, 
                         offset=ptr).reshape(-1, 2049)
    return data
    

def mmap_dask_array(blocksize=1000, offset=0, max=50000):
    load = dask.delayed(load_chunk)
    chunks = []
    max_idx = min(max, len(dataset))
    for index in tqdm(range(offset, max_idx, blocksize)):
        chunk_size = min(blocksize, max_idx - index)
        path = '/om/user/sunnyd/data/datasets--EleutherAI--pile-standard-pythia-preshuffled-merged/document.bin'
        ptr = dataset._index._pointers[index]
        dtype = dataset._index.dtype
        count = np.sum(dataset._index._sizes[index:index+chunk_size])
        # Truncate the last chunk if necessary
        chunk = dask.array.from_delayed(
            load(path, ptr, count, dtype),
            shape=(chunk_size, 2049),
            dtype=dataset[0].dtype
        )
        chunks.append(chunk)
    return da.concatenate(chunks, axis=0)

def match(a, b):
    # matches = np.empty([a.shape[0], b.shape[0]], dtype=bool)
    # matches.fill(False)
    # # return np.dot(a,b.T).reshape(1, 1, 1, 1)
    # for i in range(b.shape[1] - a.shape[1]):
    #     matches = np.logical_or(matches,
    #                             np.max(np.abs(np.expand_dims(a, axis=1) - b[:, i:i+a.shape[1]]), axis=-1) == 0)
    # return np.expand_dims(np.expand_dims(matches, -1), -1)
    return np.expand_dims(np.expand_dims(
        np.max(np.sum(np.abs(sliding_window_view(b, (100, 64)).squeeze() - a.reshape(-1, 1, 1, 64)) == 0, axis=-1), axis=1),
        -1), -1)


emergent = da.from_npy_stack('matching_data/')

In [8]:
emergent.shape[0]

11904

## Implement search

In [10]:

# with ProgressBar():
#     res = da.blockwise(match, 'ijab', emergent, 'ia', x[:2000], 'jb', dtype=int, 
#                        adjust_chunks={'a': 1, 'b': 1}).squeeze().compute()
job_size = 200000
job_size_emergent = 2000
for j in range(emergent.shape[0] // job_size_emergent):
    for i in range(43000000 // job_size): 
        res_path = f"matches-count/{i}-{j}"
        if os.path.exists(res_path):
            print("skipping "+ res_path)
            continue
        x = mmap_dask_array(100, i * job_size, (i+1) * job_size)
        da.to_npy_stack(
                res_path,
                da.blockwise(match, 'ijab', emergent[j*job_size_emergent:min(emergent.shape[0], (j+1)*job_size_emergent)],
                             'ia', x, 'jb', dtype=int, 
                                   adjust_chunks={'a': 1, 'b': 1}).squeeze(),
                axis=1)

skipping matches-count/0-0
skipping matches-count/1-0
skipping matches-count/2-0
skipping matches-count/3-0
skipping matches-count/4-0
skipping matches-count/5-0
skipping matches-count/6-0
skipping matches-count/7-0
skipping matches-count/8-0
skipping matches-count/9-0
skipping matches-count/10-0
skipping matches-count/11-0
skipping matches-count/12-0
skipping matches-count/13-0
skipping matches-count/14-0
skipping matches-count/15-0
skipping matches-count/16-0
skipping matches-count/17-0
skipping matches-count/18-0
skipping matches-count/19-0
skipping matches-count/20-0
skipping matches-count/21-0
skipping matches-count/22-0
skipping matches-count/23-0
skipping matches-count/24-0
skipping matches-count/25-0
skipping matches-count/26-0
skipping matches-count/27-0
skipping matches-count/28-0
skipping matches-count/29-0
skipping matches-count/30-0
skipping matches-count/31-0
skipping matches-count/32-0
skipping matches-count/33-0
skipping matches-count/34-0
skipping matches-count/35-0
sk

100%|██████████| 2000/2000 [00:00<00:00, 2660.79it/s]
  da.blockwise(match, 'ijab', emergent[j*job_size_emergent:min(emergent.shape[0], (j+1)*job_size_emergent)],


skipping matches-count/0-1
skipping matches-count/1-1
skipping matches-count/2-1
skipping matches-count/3-1
skipping matches-count/4-1
skipping matches-count/5-1
skipping matches-count/6-1
skipping matches-count/7-1
skipping matches-count/8-1


100%|██████████| 2000/2000 [00:00<00:00, 6659.46it/s]
  da.blockwise(match, 'ijab', emergent[j*job_size_emergent:min(emergent.shape[0], (j+1)*job_size_emergent)],


KeyboardInterrupt: 

In [None]:
from pathlib import Path
for i in np.arange(0, len(dataset), 500000):
    Path(f"matches/{i}").mkdir(parents=True, exist_ok=True)
    da.to_npy_stack(
        f"matches/{i}",
        da.blockwise(match, 'ijab', emergent, 'ia', x[i:min(len(dataset), i+500000)], 'jb', dtype=int, 
                           adjust_chunks={'a': 1, 'b': 1}).squeeze(),
        axis=1)

In [None]:
# Print out repeats
# tokenizer.decode(x[6111].compute())
# tokenizer.decode(x[1768].compute())

## Use tokenizer to decode

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="/om/user/sunnyd/transformers_cache",
)

inputs = tokenizer("Hello, I am", return_tensors="pt")
# tokens = model.generate(**inputs)
# tokenizer.decode(tokens[0])

## TODO:
1. Find a way to convert back to natural text
2. Load data into dask


In [23]:
tokenizer.decode(dataset[15])

