In [1]:
import threading
from time import sleep
import numpy
import numpy as np

test_fold = 1

import time as time_simple
    
import Queue
from collections import defaultdict

from sklearn.cluster import MiniBatchKMeans, KMeans
    
    
import multiprocessing
from multiprocessing import Pool
from multiprocessing import Process

import multiprocessing
from itertools import product
from contextlib import contextmanager

@contextmanager
def poolcontext(*args, **kwargs):
    pool = multiprocessing.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
from scipy.special import logsumexp

def softmax(x,axis=1):
    return np.exp(x - logsumexp(x, axis=axis, keepdims=True))

def clustering(model,X,i,idx):
    dist = model.fit_transform(X)
    prob = softmax(-dist)[:,0]
    th = np.median(prob)
    label = prob>=th
#     return i,idx[label==0],idx[label==1],model.cluster_centers_[0],model.cluster_centers_[1]
    return i,idx[label==0],idx[label==1],np.mean(X[label==0,:],0),np.mean(X[label==1,:],0)

def clustering_unpack(args):
    return clustering(*args)

from sklearn.metrics import pairwise_distances
    
def int2bin(a,bin_size):
    return bin(int(a))[2:].zfill(bin_size)

def idx_in_subset(subset,y):
    fullset = np.unique(y)
    # diffset = np.setdiff1d(fullset,subset)
    dic_binary = {e:0 for e in fullset}
    for e in subset:
        dic_binary[e] = 1
    idx = np.array(map(lambda x: dic_binary[x],y))
    return idx == 1

class Reader(threading.Thread):
    """ This class is designed to automatically feed mini-batches.
        The reader constantly monitors the state of the variable 'data_buffer'.
        When finding the 'data_buffer' is None, the reader will fill a mini-batch into it.
        This is done in the backend, i.e., the reader is in an independent thread.
        For users, they only need to call iterate_batch() to get a new mini-batch.
    """

    # Initialize the super class
    def __init__(self, x_train, y_train,ratio_subset = 1., rng_seed=123, batch_size=32,n_sample_per_class=2, flag_shuffle=True,flag_x_float32=True,x_train_clustering=None,kmeans_batch_size=45,flag_minibatch_kmeans=True,tree_depth=10,num_core=10,supp_pos_batch_size=32):
        """

        Parameters:
        ----------
        :param test_fold   : int, for testing
        :param valid_folds : list of ints, validation set
        :param rng_seed    : random seed to make our code more controllable
        :param n_nearby    : number of time stamps to consider
        :param batch_size  : mini-batch
        :param center_pad  : center pad the mls feature with 0's
                             Actually, we should pad with -80.
        :param valid_fold  : int, for validation
        """

        threading.Thread.__init__(self)

        # Settings
        self.rng = numpy.random.RandomState(seed=rng_seed)
        self.batch_size = batch_size
   

        # train_data : {'data', 'label', 'file', 'fold', 'salience', 'sound'}
        # train_data['data'] is an ndarray, T x D,
        # each row a sample for a certain time-stamp

        self.x_train = x_train
        self.y_train = y_train
        self.dim_time = len(x_train[0])
        self.dim_class_num = len(y_train[0])
        self.ratio_subset = ratio_subset
        self.flag_shuffle = flag_shuffle
        self.n_sample = y_train.shape[0]
        self.n_sample_per_class = n_sample_per_class
        self.flag_x_float32 = flag_x_float32
        self.x_train_clustering = x_train_clustering
        self.kmeans_batch_size = kmeans_batch_size
        self.flag_minibatch_kmeans = flag_minibatch_kmeans
        self.num_core = num_core
        self.supp_pos_batch_size = supp_pos_batch_size
        self.tree_depth = tree_depth
        

        # Shuffle the data
        # We just need to shuffle 'query_index'
        # at each beginning of a new epoch
        # 'shuffle_index' is a list indicating
        # all the positions in 'query_index'
        
#         start_time = time_simple.time()
        self.shuffle_index = self.gen_idx_by_h_clustering()#range(len(self.x_train))
#         self.shuffle_index = self.gen_idx()#range(len(self.x_train))
#         end_time = time_simple.time()
#         print 'process time: %.3f' % (end_time - start_time)

        if self.flag_shuffle:
            self.rng.shuffle(self.shuffle_index)

        self.index_start = 0

        # Initialization
        self.running = True
        self.data_buffer = None
        self.lock = threading.Lock()

        # Start thread
        self.start()

    def run(self):
        """ Overwrite the 'run' method of threading.Thread
        """
        while self.running:
            if self.data_buffer is None:
                if self.index_start + self.batch_size <= len(self.shuffle_index):
                    # This case means we are still in this epoch
                    batch_index = self.shuffle_index[self.index_start: self.index_start + self.batch_size]
                    self.index_start += self.batch_size

                elif self.index_start < len(self.shuffle_index):
                    # This case means we've come to the
                    # end of this epoch, take all the rest data
                    # and shuffle the training data again
                    batch_index = self.shuffle_index[self.index_start:]

                    # Now, we've finished this epoch
                    # let's shuffle it again.
                    self.shuffle_index = self.gen_idx()#range(len(self.x_train))
                    if self.flag_shuffle:
                        self.rng.shuffle(self.shuffle_index)
                    self.index_start = 0
                    
                else:
                    # This case means index_start == len(shuffle_index)
                    # Thus, we've finished this epoch
                    # let's shuffle it again.
                    self.shuffle_index = self.gen_idx()#range(len(self.x_train))
                    if self.flag_shuffle:
                        self.rng.shuffle(self.shuffle_index)
                    batch_index = self.shuffle_index[0: self.batch_size]
                    self.index_start = self.batch_size
 
#                 data = self.x_train[:len(batch_index)].copy()
#                 label = self.y_train[:len(batch_index)].copy()
#                 for i in range(len(batch_index)):
               

#                     data[i] = self.x_train[batch_index[i]] 

#                     label[i] = self.y_train[batch_index[i]]

                batch_index_supp_pos = self.supp_pos_idx(batch_index)
                batch_index = np.hstack([batch_index,batch_index_supp_pos[:self.supp_pos_batch_size]])
   
                data = self.x_train[batch_index] 
                label = self.y_train[batch_index] 
 

                with self.lock:
                    self.data_buffer = data, label
            sleep(0.0001)

    def iterate_batch(self):
        while self.data_buffer is None:
            sleep(0.0001)

        data, label = self.data_buffer
        if self.flag_x_float32:
            data = numpy.asarray(data, dtype=numpy.float32)
#         label = numpy.asarray(label, dtype=numpy.int32)
        with self.lock:
            self.data_buffer = None

        return data, label

    def close(self):
        self.running = False
        self.join()
        
    def change_x_train_clustering(self,x_train_clustering):
        self.x_train_clustering = x_train_clustering
        self.shuffle_index = self.gen_idx_by_h_clustering()#range(len(self.x_train))
        self.index_start = 0
        
    def change_dataset(self,x_train,y_train,x_train_clustering):
        self.x_train = x_train
        self.y_train = y_train
        self.dim_time = len(x_train[0])
        self.dim_class_num = len(y_train[0])
        self.n_sample = y_train.shape[0]
        self.x_train_clustering = x_train_clustering
        self.shuffle_index = self.gen_idx_by_h_clustering()#range(len(self.x_train))
        self.index_start = 0
        
    def gen_idx_by_h_clustering(self):
        tree_depth = self.tree_depth
        
        clustering_model_list = []
        for l in range(tree_depth):
            clustering_model_list_sub = []
            for split in range(2**l):
                if self.flag_minibatch_kmeans:  
                    clustering_model = MiniBatchKMeans(init='k-means++', n_clusters=2, batch_size=self.kmeans_batch_size,
                                  n_init=10, max_no_improvement=10, verbose=0)
                else:
                    clustering_model = KMeans(init='k-means++', n_clusters=2, n_init=10)
                clustering_model_list_sub.append(clustering_model)
            clustering_model_list.append(clustering_model_list_sub)
            
        sample_indices_list = [[np.arange(self.x_train_clustering.shape[0])]]
        label_code = np.empty((self.x_train_clustering.shape[0],),dtype='|S%d' % tree_depth)
        label_code[:]=''
        for l in range(tree_depth):
 
            clustering_model_list_sub = clustering_model_list[l]
            sample_indices_list_sub = sample_indices_list[l]
            num_core = self.num_core
            
#             results = []
#             for i_split in range(len(clustering_model_list_sub)):
# #                 print(len(sample_indices_list_sub[i_split]))
#                 result_i = clustering(clustering_model_list_sub[i_split],self.x_train_clustering[sample_indices_list_sub[i_split]],i_split,sample_indices_list_sub[i_split])
#                 results.append(result_i)    
                
            with poolcontext(processes=num_core) as pool:
                results = pool.map(clustering_unpack, [(clustering_model_list_sub[i_split],self.x_train_clustering[sample_indices_list_sub[i_split]],i_split,sample_indices_list_sub[i_split]) for i_split in range(len(clustering_model_list_sub))])
                
            dic = {}
            for e in results:
                dic[e[0]] = [e[1],e[2],e[3],e[4]]
                
            center_list_sub = []
            sample_indices_list_sub_next = []
            
            for i_split in range(len(clustering_model_list_sub)):
                e = dic[i_split]
                sample_indices_list_sub_next.append(e[0])
                sample_indices_list_sub_next.append(e[1])
                center_list_sub.append(e[2])
                center_list_sub.append(e[3])
                
            sample_indices_list.append(sample_indices_list_sub_next)
            
            if l == 0:
                center_list = [center_list_sub]
                code_list_sub = np.arange(2)
                code_list = [code_list_sub]
            else:
                center_list.append(center_list_sub)
                
                code_list_sub_pre = code_list[l-1]
                code_list_sub = np.zeros(len(center_list_sub)).astype(int)
                for i_split in range(len(center_list_sub)/4):
                    c01,c23 = np.array([center_list_sub[i_split*4],center_list_sub[i_split*4+1]]),np.array([center_list_sub[i_split*4+2],center_list_sub[i_split*4+3]])
                    dist = pairwise_distances(c01,c23)
                    min_idx = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
                    code_list_sub[i_split*4+min_idx[0]] = 1 - code_list_sub_pre[i_split*2]
                    code_list_sub[i_split*4+1-min_idx[0]] = code_list_sub_pre[i_split*2]
                    code_list_sub[i_split*4+2+min_idx[1]] = 1 - code_list_sub_pre[i_split*2+1]
                    code_list_sub[i_split*4+2+1-min_idx[1]] = code_list_sub_pre[i_split*2+1]
                code_list.append(code_list_sub)
            
            cont_len = 0
            for i_indices,e_indices in enumerate(sample_indices_list_sub_next):
                cont_len += len(e_indices)
                label_code[e_indices] = np.core.defchararray.add(label_code[e_indices], str(code_list_sub[i_indices]))
        idx_ret = []
        for e_int in range(2**tree_depth):
            code = int2bin(e_int,tree_depth)
            idx = list(np.where(label_code == code)[0])
            idx_ret.extend(idx)
            
        return idx_ret
    
    def supp_pos_idx(self,indices):
        y_train = self.y_train[:,0]
        n = len(y_train)
        idx_all = np.arange(n)
        uniid = np.unique(y_train[indices])
        idx_uniid = idx_in_subset(uniid,y_train)
        idx_all = idx_all[idx_uniid]
        y_train = y_train[idx_uniid]
        idx_in_idx = self.gen_idx_fixpk(y_train)
        idx_ret = idx_all[idx_in_idx]
        
        return idx_ret
            
    def gen_idx_fixpk(self,y_train):
        
        n = len(y_train)
 
        idx = np.arange(n)
        class_set,class_num = np.unique(y_train,return_counts=True)
        dic_user2sample_num = {class_set[i]: class_num[i] for i in range(len(class_set))}
        dic_user2indices = defaultdict(list)

        n_sample_per_class = self.n_sample_per_class
  
        idx_ret = []
        self.rng.shuffle(idx)
        for i in range(n):
            item = idx[i]
            cur_id = y_train[item]
            
            if dic_user2sample_num[cur_id] == 1:
                idx_ret.append(item)
            elif dic_user2sample_num[cur_id]<n_sample_per_class:
                cur_list = dic_user2indices[cur_id]
                if len(cur_list) == dic_user2sample_num[cur_id]-1:
                    idx_ret.extend(cur_list)
                    idx_ret.append(item)
                    dic_user2indices[cur_id] = []
                else:
                    cur_list.append(item)
            else:
                cur_list = dic_user2indices[cur_id]
                if len(cur_list) == n_sample_per_class-1:
                    idx_ret.extend(cur_list)
                    idx_ret.append(item)
                    dic_user2indices[cur_id] = []
                    dic_user2sample_num[cur_id] -= n_sample_per_class
                else:
                    cur_list.append(item)
 
        return idx_ret
        
    def gen_idx(self):
        
        y_train = self.y_train[:,0]
        n = len(y_train)
 
        idx = np.arange(n)
        class_set,class_num = np.unique(y_train,return_counts=True)
        dic_user2sample_num = {class_set[i]: class_num[i] for i in range(len(class_set))}
        dic_user2indices = defaultdict(list)

        n_sample_per_class = self.n_sample_per_class
  
        idx_ret = []
        self.rng.shuffle(idx)
        for i in range(n):
            item = idx[i]
            cur_id = y_train[item]
            
            if dic_user2sample_num[cur_id] == 1:
                idx_ret.append(item)
            elif dic_user2sample_num[cur_id]<n_sample_per_class:
                cur_list = dic_user2indices[cur_id]
                if len(cur_list) == dic_user2sample_num[cur_id]-1:
                    idx_ret.extend(cur_list)
                    idx_ret.append(item)
                    dic_user2indices[cur_id] = []
                else:
                    cur_list.append(item)
            else:
                cur_list = dic_user2indices[cur_id]
                if len(cur_list) == n_sample_per_class-1:
                    idx_ret.extend(cur_list)
                    idx_ret.append(item)
                    dic_user2indices[cur_id] = []
                    dic_user2sample_num[cur_id] -= n_sample_per_class
                else:
                    cur_list.append(item)
 
        return idx_ret
            
if __name__ == '__main__':
    """ Let's just write the test function
        for reader in the same file
    """
    test_fold = 1
    valid_folds = [2]
    rng_seed = 123
    n = 10000
    x_train = np.random.randn(n, 191)
    y_train = np.random.randint(0,10,size=(n,1))
    x_train_clustering = np.random.randn(n,128)


    model = Reader(x_train, y_train,ratio_subset = 1., rng_seed=123, batch_size=32,n_sample_per_class=4, 
                   flag_shuffle=False,flag_x_float32=False,x_train_clustering=x_train_clustering,kmeans_batch_size=10,
                   flag_minibatch_kmeans=True,tree_depth=10,num_core=10,supp_pos_batch_size=16)
      


    for i in range(10):
        print('Loading %d-th mini-batch ...' % i)
        data, label = model.iterate_batch()

        print('max: %f, min: %f' % (data.max(), data.min()))
        print (label.shape)
        print (data.shape)

    model.close()






Loading 0-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 1-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 2-th mini-batch ...
max: 5.087900, min: -5.041768
(8078, 1)
(8078, 191)
Loading 3-th mini-batch ...
max: 5.073649, min: -5.041768
(9019, 1)
(9019, 191)
Loading 4-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 5-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 6-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 7-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 8-th mini-batch ...
max: 5.087900, min: -5.041768
(10032, 1)
(10032, 191)
Loading 9-th mini-batch ...
max: 5.087900, min: -5.041768
(9011, 1)
(9011, 191)
