Test Steps:
1. Get list of all text files. This will probably be in the millions
2. Grab one text file and run it through python text_processing

In [131]:
import sys
import os
import time
sys.path.append(os.path.abspath("/home/arxiv/doc_intel_etl"))
import config
import src.blob_data_transfer as blob_transfer
from torch.utils.data import Dataset, IterableDataset, DataLoader
from transformers import DistilBertTokenizer
'''
When loading large data via CPU in the DataLoader and need to push it
to GPU during training then we should set pin_memory to True as it will
speed up the host device transfer
source: https://discuss.pytorch.org/t/when-to-set-pin-memory-to-true/19723
'''
from torch.utils.data._utils.pin_memory import pin_memory
import threading
from torch._six import queue, container_abcs, string_classes
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from itertools import cycle, islice, chain
import random
import torch
import multiprocessing as mp

In [132]:
prefix = 'arxiv_training_data/pdfplumber/text/1991'

In [133]:
def get_text_list():
    return blob_transfer.get_blob_list(prefix)
def get_text_stream(file):
    return blob_transfer.stream_blob(file).decode()

In [112]:
class TextIterableDataset(IterableDataset):
    '''
    TO DO:
    1. Create another file that has the tokenizer. This should probably
    be the model.py file itself
    '''
    def __init__(self, data_list, seq_length=256, batch_size=2):
        self.data_list = data_list
        self.seq_length = seq_length
        self.batch_size = batch_size
        # set max sequence length to 30,0000 so that we can tokenize
        # all text in a document and not have to run it more than
        # once per document
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased',
            model_max_length=30000)
     
    @property
    def shuffle_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
    
    def get_text_stream(self, file):
        return blob_transfer.stream_blob(file).decode()
    
    def tokenize_stream(self, stream):
        return self.tokenizer.tokenize(stream.replace('\n',''))
    
    def encode_seq(self, tokens):
        '''
        Take input tokens and convert it into form needed for model
        '''
        return tokenizer(tokens,
                         max_length=self.seq_length,
                         padding=True,
                         truncation=True,
                         is_pretokenized=True,
                         return_tensors='pt',
                        )
    
    def parse_file(self, file):
        '''
        1. stream the given text file from blob
        2. run the text file through pre-processing then tokenizer
        3. break it up into seq_length sizes and yield those
        '''
        stream = self.get_text_stream(file)
        worker = torch.utils.data.get_worker_info()
        worker_id = worker.id if worker is not None else -1
        
        out = []
        token_stream = self.tokenize_stream(stream)
        
        for token_num, token in enumerate(token_stream):
            if token is None:
                break
            out.append(token)
            # seq_length - 2 so we can add cls and eos tokens
            if len(out) == (self.seq_length - 2):
                y = self.encode_seq(out)
                out = []
                yield y

    def get_stream(self, data_list):
        '''
        This function will continue to pull tokens from the opened file
        via parse files until it has filled a full batch size
        '''
        print("number of files to stream: ", len(data_list))
        tmp = map(self.parse_file, iter(data_list))
        out = chain.from_iterable(tmp)
        return out

    def __iter__(self):
        chunk_size = len(self.data_list) // self.batch_size
        return zip(
                *[self.get_stream(self.data_list[i*chunk_size:
                                                 (i+1)*chunk_size]) 
                  for i in range(self.batch_size)]
        )
    
    @classmethod
    def split_datasets(cls, data_list, seq_length, batch_size,
                       max_workers):
        
        for n in range(max_workers, 0, -1):
            if batch_size % n == 0:
                num_workers = n
                break
                
        split_size = batch_size // num_workers
        num_files_per_worker = len(data_list) // num_workers
        out = []
        for i in range(num_workers):
            start = i * num_files_per_worker
            end = (i + 1) * num_files_per_worker
            stream_files = data_list[start:end]
            item = cls(stream_files, batch_size=split_size,
                       seq_length=seq_length)
            out.append(item)
        return out

            

class MultiStreamDataLoader:
    
    def __init__(self, datasets, pin_memory=True):
        self.datasets = datasets
        self.pin_memory = pin_memory
            
    def get_stream_loaders(self):
        dataloaders = [
            DataLoader(dataset, num_workers=1, batch_size=None,
                       pin_memory=True) 
            for dataset in self.datasets
        ]
        return zip(*dataloaders)
            
    def join_streams_thread(self, out_queue, device_id, done_event):
        '''
        additional thread putting data into a queue to be collected
        from __iter__
        '''
        torch.set_num_threads(1)
        torch.cuda.set_device(device_id)
        
        for idx, batch_parts in enumerate(self.get_stream_loaders()):
            data = list(chain(*batch_parts))
            
            data = torch.cat([item[:, None] for item in data], dim=1)
            if (
                not done_event.is_self()
                and not isinstance(data, ExceptionWrapper)
            ):
                data = pin_memory(data)
            
            out_queue.put(data, timeout=MP_STATUS_CHECK_INTERVAL)
            
        self._join_memory_thread_done_even.set()
            
    def __iter__(self):
        # thread for collation and memory pinning
        if self.pin_memory:
            self._join_memory_thread_done_event = threading.Event()
            self._data_queue = queue.Queue()
            self.join_memory_thread = threading.Thread(
                target=self.join_streams_thread,
                args=(
                    self._data_queue,
                    torch.cuda.current_device(),
                    self._join_memory_thread_done_event,
                ),
            )
            self.join_memory_thread.daemon = True
            self.join_memory_thread.start()
            
            while not self._join_memory_thread_done_event.is_set():
                batch = self._data_queue.get(timeout=100000)
                batch = {'data':batch}
                yield batch
            self.join_memory_thread.join()
        else:
            # Single process
            for batch_parts in self.get_stream_loaders():
                batch = list(chain(*batch_parts))
                yield batch

In [5]:
text_list = get_text_list()

In [6]:
test_list = text_list[:30]

In [113]:
datasets = TextIterableDataset.split_datasets(test_list, seq_length=500, batch_size=4, max_workers=4)
loader = MultiStreamDataLoader(datasets, pin_memory=False)
test = []
for batch in islice(loader, 30):
    test.append(batch)

number of files to stream: 7 
number of files to stream:  7
number of files to stream:  7
number of files to stream:  7


In [116]:
tokenizer.decode(test[18][1]['input_ids'][0])

'[CLS] # # al ( so z # # φ = λ # # φ for some λ ) such that ∆ # # φ = 0 and φ is a solution of β = 0, then w ( φ ) is also a solution of β = 0 for any function w. the simplest non # # con # # stan # # tf # # un # # ction obey # # ing the stated conditions is φ = ( a a + a a ) / 2. consequently ( since # # 1 2 3 4 # # in fact φ = a a on q ) 1 2 # # ∞ # # x # # n # # φ = − o ( a a ) ( 2. 49 ) n 1 2 # # n = 0 # # is a per # # tur # # bation with vanishing second order beta function, for any values of the # # con # # stan # # ts o. na # # cco # # rdi # # ng to our an # # sat # # z for the meaning of φ, the marginal operator ( 2. 49 ) corresponds to def # # or # # ming the quad # # ric q to a more general hyper # # sur # # face # # ∞ # # x # # na a = a a + o ( a a ). ( 2. 50 ) 3 4 1 2 n 1 2 # # n = 0 # # in comparing to matrix models, as we will see, the e # # igen # # val # # ue phase space will # # cor # # res # # pon # # d to the a − a plane, and the curve a a = 0 in the a − a plane will

Take the text file and run it through pytorch custom iterabledataset and dataloader. A Dataset in our instance is an entire arxiv paper which can have multiple batches.
We can't use the built in dataloader multiprocessing becuse it will then just send back multiple copies of the same dataset. We want to make the dataloader itself be parallized so probably just use python multiprocessing

In [128]:
doc = text_list[0]
output = parse_file(doc)

In [130]:
doc

{'name': 'arxiv_training_data/pdfplumber/text/1991/hep-lat9107001.txt', 'container': 'arxiv', 'snapshot': None, 'blob_type': <BlobType.BlockBlob: 'BlockBlob'>, 'metadata': {}, 'encrypted_metadata': None, 'last_modified': datetime.datetime(2020, 7, 15, 13, 20, 11, tzinfo=datetime.timezone.utc), 'etag': '0x8D828C1CD6CC37C', 'size': 28866, 'content_range': None, 'append_blob_committed_block_count': None, 'page_blob_sequence_number': None, 'server_encrypted': True, 'copy': {'id': None, 'source': None, 'status': None, 'progress': None, 'completion_time': None, 'status_description': None, 'incremental_copy': None, 'destination_snapshot': None}, 'content_settings': {'content_type': 'application/octet-stream', 'content_encoding': None, 'content_language': None, 'content_md5': bytearray(b'W \x16\xbab\xd5\xdd\xa3.\xcc\xc3\xcb\xe4\xf0\x05k'), 'content_disposition': None, 'cache_control': None}, 'lease': {'status': 'unlocked', 'state': 'available', 'duration': None}, 'blob_tier': 'Hot', 'blob_tier

In [78]:
text

['–',
 '91',
 '–',
 '31',
 '##fs',
 '##u',
 '-',
 'sc',
 '##ri',
 '-',
 '91',
 '-',
 '94',
 '§',
 'how',
 'to',
 'put',
 'a',
 'heavier',
 'hi',
 '##ggs']