In [1]:
from itertools import cycle
import numpy as np
import os
import pandas as pd
from scipy import misc

from warnings import filterwarnings
filterwarnings('ignore')

%run 'multilabel.ipynb'

from keras.utils import to_categorical
from numpy import argmax

In [4]:
class DataSet(object):
    
    def __init__(self, datapath, dataset_type = 'micro_2D', val_size=0.3, batch_size=16, test=False):
        
        self.datapath = datapath
        self.val_size = val_size
        self.batch_size = batch_size
        self.dataset_type = dataset_type
        
        # boolean for test mode
        self.test = test
        
        # params based on dataset type
        if self.dataset_type == 'nano_2D':
            self.height = 56
            self.width = 56
        elif self.dataset_type == 'micro_2D':
            self.height = 128
            self.width = 128
        else:
            raise NotImplementedError("Please set dataset_type as raw, micro, or nano.")
        
        
        # for tracking errors
        self.bad_images = []
        
        # training and validation
        self.X_train, self.X_val, self.y_train, self.y_val = self.split_training_into_validation()
    
        # params of data based on training data
        self.num_samples = self.y_train.shape[0]
        self.num_batches = self.num_samples // self.batch_size
        
        # test paths and prediction matrix
        self.X_test_ids, self.predictions = self.prepare_test_data_and_prediction()
        
        # variables to make batch generating easier
        self.batch_idx = cycle(range(self.num_batches))
        self.batch_num = next(self.batch_idx)
        
        self.num_val_samples = self.X_val.shape[0]
        self.num_val_batches = self.y_val.shape[0] // self.batch_size
        self.val_batch_idx = cycle(range(self.num_val_batches))
        self.val_batch_num = next(self.val_batch_idx)
        
        self.num_test_samples = self.X_test_ids.shape[0]
        self.num_test_batches = self.num_test_samples // self.batch_size
        self.test_batch_idx = cycle(range(self.num_test_batches))
        self.test_batch_num = next(self.test_batch_idx)
    
        # for testing iterator in test_mode
        self.train_data_seen = pd.DataFrame(data={'seen': 0}, index=self.y_train.index)
        
        # test the generator
        if test:
            self._test_batch_generator()
    def rescaling(self,x):
        '''realy needs to be reworked!!!'''
        return x/255#-244.709)/28.2702
    
    def binerizer(self,y):
        '''realy needs to be reworked!!!'''
        Y = []
        y = np.int16(y/200)
        for train in y:
            Y.append(to_categorical(train,num_classes=2))
    
        return np.asarray(Y)
    
    def reconstruct(self,y):
        y_back = []
        for sample in y:
            y_back.append(sample.argmax(1))

        y_back = np.asarray(y_back)
        y_back = y_back.reshape(y.shape[0], self.height, self.width) 
        
        return y_back 
        
    def prepare_test_data_and_prediction(self):
        """
        Returns paths to test data indexed by subject_id 
        and preallocates prediction dataframe.
        """
        
        predpath = os.path.join(self.datapath, 'submission_format.csv')
        predictions = pd.read_csv(predpath, index_col='filename')
        test_idx = predictions.index
        subjpath = os.path.join(self.datapath, self.dataset_type)
        #subject_ids = pd.read_csv(subjpath, index_col=0)
        subject_ids = pd.DataFrame(data=subjpath, columns=['filepath'], index=test_idx)
        for row in subject_ids.itertuples():
            subject_ids.loc[row.Index] = os.path.join(row.filepath, str(row.Index)) 
        
        return test_idx, predictions
  
    
    def split_training_into_validation(self):
        """
        Uses the multilabel_train_test_split function 
        to load dataframe with filenames for train and validation.
        """

        datapath = self.datapath
        dataset_type = self.dataset_type
        val_size = self.val_size
        
        # load training labels
        labelpath = os.path.join(datapath, 'infos.csv')
        labels = pd.read_csv(labelpath,index_col = 'nr')#index_col='filename'       
        
        # split
        X_train, X_val, y_train, y_val = multilabel_train_test_split(X=labels['scatter'],Y=labels['real'],
                                                                     size=val_size, seed=42)
        
        return X_train, X_val, y_train, y_val

    def batches(self):
        """This method yields the next batch of images for training."""

        batch_size = self.batch_size
        num_train = self.y_train.shape[0]
        
        while 1:
            # get videos
            start = self.batch_size*self.batch_num
            stop = self.batch_size*(self.batch_num + 1)
            
            # print batch rangesrain if testing
            if self.test:
                print("batch {0}:\t{1} --> {2}".format(self.batch_num,start, stop-1))
            
            failed = []
            x_paths = self.X_train.iloc[start:stop]
            x = self._get_image_batch(x_paths)
            
            # get labels
            y_paths = self.y_train.iloc[start:stop]
            y = self._get_image_batch(y_paths)
            
            # check match for labels and videos
            assert (x_paths.index==y_paths.index).all()
            assert x.shape[0] == y.shape[0]

            # report failures if verbose
            if len(failed) != 0 and verbose==True:
                print("\t\t\t*** ERROR FETCHING BATCH {0}/{1} ***".format(self.batch_num,self.num_batches))
                print("Dropped {0} videos:".format(len(failed)))
                for failure in failed:
                    print("\t{0}\n\n".format(failure))

            # increment batch number
            self.batch_num = next(self.batch_idx)
            
            # update dataframe of seen training indices for testing
            self.train_data_seen.loc[y_paths.index.values] = 1
            
            yield (self.rescaling(x), self.binerizer(y))
            
    def val_batches(self):
        """This method yields the next batch of images for validation."""
        
        batch_size = self.batch_size
        num_train = self.y_val.shape[0]
        failed = []
        
        
        while 1:
            # get videos
            start = self.batch_size*self.val_batch_num
            stop = self.batch_size*(self.val_batch_num + 1)
            
            x_paths = self.X_val.iloc[start:stop]
            x = self._get_image_batch(x_paths)
            
            # get labels
            y_paths = self.y_val.iloc[start:stop]
            y = self._get_image_batch(y_paths)

            # check match for labels and videos
            assert (x_paths.index==y_paths.index).all()
            assert x.shape[0] == y.shape[0]

            # report failures if verbose
            if len(failed) != 0 and verbose==True:
                print("\t\t\t*** ERROR FETCHING BATCH {0}/{1} ***".format(self.batch_num,self.num_batches))
                print("Dropped {0} videos:".format(len(failed)))
                for failure in failed:
                    print("\t{0}\n\n".format(failure))


            # increment batch number
            self.val_batch_num = next(self.val_batch_idx)
            
            yield (self.rescaling(x), self.binerizer(y))

    def test_batches(self):
        """This method yields the next batch of images for testing."""
        
        batch_size = self.batch_size
        num_test = self.num_test_samples
        
        test_dir = os.path.join(self.datapath, self.dataset_type)
        
        
        while 1:
            # get videos
            start = self.batch_size*self.test_batch_num
            stop = self.batch_size*(self.test_batch_num + 1)
            
            x_ids = self.X_test_ids[start:stop]
            x_paths = pd.DataFrame(data=[os.path.join(test_dir, "{0}".format(filename)) for filename in x_ids], 
                                   columns=['filepath'],
                                   index=x_ids)
            #print(x_paths)
            x = self._get_image_batch(x_paths)
            
            self.test_batch_ids = x_ids.values

            # increment batch number
            self.test_batch_num = next(self.test_batch_idx)
            
            yield self.rescaling(x)

    def _get_image_batch(self, x_paths, as_grey=True):
        """
        Returns ndarray of shape (batch_size, width, height, channels).
        If as_grey, then channels dimension is squeezed out.
        """

        images = []
        
        for filepath in x_paths:
            images.append(misc.imread(filepath))
        
        if as_grey:
            images = np.asarray(images)[:,:,:,0].reshape(self.batch_size,self.width, self.height,1)
        else:
            images = np.asarray(images)
        return images

    def _test_batch_generator(self):
        
        print('Testing train batch generation...')
        
        for i in range(self.num_batches):
            if self.batch_num % 10 == 0:
                print("\t\t\t*** ERROR FETCHING BATCH {0}/{1} ***".format(self.batch_num,self.num_batches))
                
            batch = self.batches()
            x,y = next(batch)
        
            # same batches for videos and labels
            assert x.shape[0] == y.shape[0]
            
            # square videos
            assert x.shape[2] == x.shape[3]
            
            # black and white
            assert x.shape[4] == 1
            
        
        # assert we've seen all data up to remainder of a batch
        assert (self.y_train.shape[0] - self.train_data_seen.sum().values[0]) < self.batch_size
        
        # check that batch_num is reset
        assert self.batch_num == 0
        
        # turn off test mode
        if self.test == True:
            self.test = False
        
        print('Test passed.')
        
    def update_predictions(self, results):
        self.predictions.loc[self.test_batch_ids] = results