In [1]:
import sys
import os
from nbfinder import NotebookFinder
sys.meta_path.append(NotebookFinder())
import numpy as np
import threading

    

importing Jupyter notebook from ../../../notebooks_src/load_data/configs.ipynb


In [3]:
class GenThreadSafe(object):
    def __init__(self, dataset, shape=None, batch_size=128, typ="tr", tf_mode=False, num_ex=-1, make_label_fxn=None):
        self.data = dataset
        self.tf_mode = tf_mode
        self.batch_size = batch_size
        # create a lock
        self.lock = threading.Lock()
        self.make_label_fxn = make_label_fxn
        if shape:
            self.input_shape = shape
        else:
            self.input_shape = self.data.images.shape
    def __iter__(self):
        return self
    @property
    def num_ims(self):
        return self.data.num_examples
    
    def next(self):
        nch, xdim,ydim = self.input_shape
        
        # acquire/release the lock when updating self.i
        with self.lock:
            ims,lbls = self.data.next_batch(batch_size=self.batch_size)
            
            #figure out how much smaller we want to make im
            raw_nch, raw_xdim,raw_ydim = ims.shape[1:]
            xscale = xdim / float(raw_xdim)
            yscale = ydim / float(raw_ydim)
            
            #crop to size specified
            ims = ims[:,:nch,:xdim,:ydim]
            
            if self.tf_mode:
                ims = np.transpose(ims,axes=(0,2,3,1))
            if self.make_label_fxn:
                    lbls = self.make_label_fxn(lbls)
            
            #scale down label accordingly
            lblx, lbly, lblch = lbls.shape[1:]
            lbls = lbls[:,:int(lblx*xscale),:int(lbly*yscale)]
            
            return ims, lbls

In [4]:
#thread safe
class SemisupWrapper(object):
    def __init__(self,generator):
        self.generator = generator
        self.lock = threading.Lock()
        
    def __iter__(self):
        return self
    
    @property
    def num_ims(self):
        return self.generator.num_ims
    
    def next(self):
        with self.lock:
            ims, lbls = self.generator.next()
            return ims, {"box_score":lbls,"reconstruction":ims}
        