In [None]:
#| default_exp experimental.multi_processing

In [None]:
#| export
from __future__ import print_function, division, annotations
from jax_dataloader.imports import *
from jax_dataloader.datasets import ArrayDataset, JAXDataset
from jax_dataloader.loaders import BaseDataLoader
from jax_dataloader.utils import get_config, asnumpy
from jax_dataloader.tests import *
import jax_dataloader as jdl
from threading import Thread, Event
from queue import Queue, Full
import multiprocessing as mp
try:
    import numba
except ModuleNotFoundError:
    numba = None

In [None]:
class Compiler:

    @classmethod
    def set_enabled(cls, is_enable):
        cls.is_enabled = is_enable

    @classmethod
    def set_num_threads(cls, n):
        if n < 1 :
            n = mp.cpu_count()
        cls.num_threads = n
        numba.set_num_threads(n)

    @classmethod
    def compile(cls, code, signature=None):
        parallel = False
        if hasattr(code, 'is_parallel'):
            parallel = code.is_parallel and cls.num_threads > 1
        
        if cls.is_enabled:
            return numba.njit(
                signature, fastmath=True, nogil=True, 
                error_model='numpy',parallel=parallel)(code)
        return code

    @classmethod
    def get_iterator(cls):
        if cls.num_threads > 1:
            return numba.prange
        else:
            return range


In [None]:
#| export
def chunk(seq: Sequence, size: int) -> List[Sequence]:
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]  

In [None]:
#| export
class EpochIterator(Thread):
    """[WIP] Multiprocessing Epoch Iterator"""
    
    def __init__(self, data, batch_size: int, indices: Sequence[int]):
        super().__init__()
        self.data = data
        batches = chunk(indices, batch_size)
        self.iter_idx = iter(batches)
        self.output_queue = Queue(5) # TODO: maxsize
        self.terminate_event = Event()
        self.start()

    def run(self):
        try:
            while True:
                # get data
                result = self.get_data()
                # put result in queue
                while True:
                    try: 
                        self.output_queue.put(result, block=True, timeout=0.5)
                        break
                    except Full: pass
                
                    if self.terminate_event.is_set(): return                

        except StopIteration:
            self.output_queue.put(None)

    def __next__(self):
        result = self.output_queue.get()
        if result is None:
            self.close()
            raise StopIteration()
        return result
    
    def __iter__(self):
        return self
    
    def __del__(self):
        self.close()

    def close(self):
        self.terminate_event.set()

    def get_data(self):
        batch_idx = next(self.iter_idx)
        batch = self.data[batch_idx]
        return batch
