In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
#import visualkeras
import read_data_tensorflow as read_data
import threading
import tqdm.notebook as tqdm
%matplotlib inline

config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
sess = tf.compat.v1.Session(config=config)

#############################################
#load data, fashion MNIST
org_train_images, train_labels, test_images, test_labels = read_data.load_fashion_mnist_dataset()

#load data, MNIST
#org_train_images, train_labels, test_images, test_labels = read_data.load_mnist_dataset()

train_images = np.expand_dims(org_train_images, 3) #train x
train_images = train_images.astype(np.float32)
org_train_images = org_train_images.astype(np.float32)
test_images = np.expand_dims(test_images, 3)       #test x
test_images = test_images.astype(np.float32)

train_labels = train_labels.astype(np.float32)     #train y
test_labels = test_labels.astype(np.float32)       #test y

#normalize all data
train_images = train_images /255.0
test_images = test_images   /255.0


In [None]:
#use multiple 1 filter CNN layer to comprise a multi filter CNN layer

#ideas for reducing resolution by factor of 2
#1. use 4x4 filter and stride 2
#2. use 2x2 filter and stride 2
#3. use a flatten layer for computation (?)

#customized activation function
def bell_tanh_activation(x):
    #shapes like a bell
    #when x in 0 to 1, y approaches 1
    #and decreases on both sides and approaches -1
    #use 5+5x tanh
    x1 = 10-5*x
    x2 = 5*x
    return tf.minimum(tf.tanh(x1), tf.tanh(x2))
    
def gaussian_activation(x, tilt_level = 0.85): #last best -- 1.1
    x = tilt_level-tilt_level*x
    return 2*tf.exp(-(x**2)) - 1 #range -1 to 1


def gaussian_activation_01(x, tilt_level = 1):#0.85): #last best -- 1.1
    #x = tilt_level-tilt_level*x
    return tf.exp(-(x**2)) #range 0 to 1


#sigmoid
#1/(1+e^-x)

#======================================================================
# Notes:
# proves that sigmoid is better, for unknown reason
# probably because the gradient is non-zero for all values of x
# In contrast, bell has gradient approaching 0 not only on two ends, but in middle, and hence is harder to train
#======================================================================



class extensible_CNN_layer_multi_module_3D(tf.keras.Model):
    #growth model
    #input: nxnx1
    #output: 1*n
    #filter number can increase
    #last parameter:
    # kernel_size = (4,4), stride = 2, activation = 'gaussian_bell', padding = 'valid', optimizer = 'adam'
    #best: gaussian bell of 1 tilt level, with reg on weight and bias, and 3x3 filter with stride 1
    def __init__(self, kernel_size = (4,4,1), stride = 2, activation = 'sigmoid', padding = 'valid', optimizer = 'adam'): #best -- gaussian_bell, 4x4, stride 2
        super(extensible_CNN_layer_multi_module_3D, self).__init__()
        self.filter_list = []
        self.bias_list = []
        self.kernel_size = kernel_size
        self.stride = stride
        
        self.activation = activation
        self.channels = 1
        self.threshold = 0.5
        if self.activation == "gaussian_bell":
            self.threshold = 0.4 #0.34 final -- 0.4 one layer
        if self.activation == "sigmoid":
            self.threshold = 0.8 #from 0.5 to 0.7, unstable
        if self.activation == "relu":
            self.threshold = 0.5
        if self.activation == "gaussian_bell_01":
            self.threshold = 0.5 #last 0.5    
        
        if (activation == 'bell_tanh'):
            self.activation = bell_tanh_activation
        if (activation == 'gaussian_bell'):
            self.activation = gaussian_activation
        if (activation == 'gaussian_bell_01'):
            self.activation = gaussian_activation_01
         

        
        
        self.padding = padding
        self.optimizer = optimizer
        if optimizer == 'adam':
            #self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
            #use legacy optimizer
            self.optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.1)
            
        else:
            print("\noptimizer not implemented\n")
            return
        
        self.sample_space = {}
        self.filter_list = [] 
        #initialized with 1 filter
        self.aggregated_conv = None
    
    def call(self, input_x):
        #input
        feature_maps = []
        for filter_i in self.filter_list:
            feature_maps.append(filter_i(input_x))
        #feature map of a conv3d layer is: (batch, height, width, depth, channels)
        #ignore batch and channels
        #concatenate the 3D feature maps on the depth axis
        #print("feature_maps shape: ", feature_maps[0].shape)
        
        #print(len(feature_maps), feature_maps[0].shape)
        feature_maps = tf.concat(feature_maps, axis = 3)
        #print("feature_maps shape: ", feature_maps.shape)
        #Final shape -- (batch, height, width, depth, channels), where batch and channels are 1
        return feature_maps
    
    def update_depth(self, new_depth):
        #because with change of previous layer's units, the depth of inputs to this layer changes
        #So update the depth in all filters and sample space of this layer
        #filter_list -- list
        #sample_space -- dict -- key: filter_index, value: sample
        print("original depth: ", self.filter_list[0].get_weights()[0].shape)
        self.kernel_size = (self.kernel_size[0], self.kernel_size[1], new_depth) #update kernel size for new depth
        for filter_index in range(len(self.filter_list)):
            filter_depth = self.filter_list[filter_index].kernel_size[2] #filter depth
            if filter_depth > new_depth:
                #this might be caused by an error
                print("filter depth is larger than new depth\nOnly increase is allowed at this version")
                return
            if filter_depth == new_depth:
                #no need to update
                continue
            
            #create a new filter
            new_filter = tf.keras.layers.Conv3D(filters = 1, kernel_size = self.kernel_size, strides = self.stride, padding = self.padding, activation = self.activation)
            #weights
            new_filter.build(input_shape = (1, self.kernel_size[0], self.kernel_size[1], new_depth, 1))
            new_filter_weights = np.zeros((self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1, 1)).astype('float32') #initialize with 0
            #reduce the newly added weights by 1/sum(newly added weights)
            count_new_weights = self.kernel_size[0]*self.kernel_size[1]*(self.kernel_size[2]-filter_depth)
            reduce_map = np.zeros((self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1, 1)).astype('float32') #initialize with 0
            reduce_map[:, :, :, :, :] -= 2/(new_depth - filter_depth) #reduce by 1/sum(newly added weights) for the newly added weights
            
            #a different reduce map will result in different mode. Current mode --> inclusive, may recognize some new patterns as its own class
            #add the old weights
            old_weights = self.filter_list[filter_index].get_weights()[0]
            new_filter_weights[:, :, :filter_depth, 0, 0] += old_weights[:, :, :, 0, 0] #reduce_map[:, :, filter_depth:, 0, 0]
            #add the reduce map
            new_filter_weights[:, :, filter_depth:, 0, 0] += reduce_map[:, :, filter_depth:, 0, 0]
            #add the old weights to the new filter and reduced weights to the new filter
            new_bias = self.filter_list[filter_index].get_weights()[1]
            
            #set the new weights
            new_filter.set_weights([new_filter_weights, new_bias])
            #update the filter
            self.filter_list[filter_index] = new_filter
            
            #update the sample space
            sample_i = self.sample_space[filter_index]
            #add 0s to the end of the sample
            zero_sample = np.zeros((1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1)).astype('float32')
            zero_sample[0, :, :, :filter_depth, 0] += sample_i[0, :, :, :, 0]
            #update the sample space
            self.sample_space[filter_index] = zero_sample
        #at the end, print a message to show the update is done
        #take the first filter's shape as example
        first_filter_shape = self.filter_list[0].get_weights()[0].shape
        
        print("update of layer's depth done, new depth: ", first_filter_shape)
        print("new sample space shape: ", self.sample_space[0].shape)
        
        
            
            
            
    
    
    def add_filter(self, x, epochs = 10, refit = False, regularization = True, image_x = None):
        #use autoencoder to generate a new filter which accepts x
        #x: nxnxnx1
        #check if x is fit for filter size
        
        if x.shape != (self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1):
            print("x is not fit for filter size")
            print("x shape: ", x.shape)
            print("expected shape: ", (self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1))
            return
        
        #reshape x to (1, nxnxn, 1)
        x = x.reshape(1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1) #for 3D conv
        #check if x is equivalent to any existing filter's sample space
        for sample_i in self.sample_space.keys():
            if np.sum(abs(self.sample_space[sample_i] - x)) < 0.2:
                print("x is already in sample space")
                return
        
        
        #initialize a new filter with decoder

        new_filter = tf.keras.layers.Conv3D(1, self.kernel_size, 
                                            padding=self.padding, activation=self.activation,strides=self.stride)
        decoder = tf.keras.layers.Conv3DTranspose(1, self.kernel_size, 
                                                  padding=self.padding, activation=self.activation,strides=self.stride)
        
        #set weights of new filter to be the same as x
        print("new filter initialized, id = ", len(self.filter_list))
        
        #initialize new filter
        
        #get the 1 matrix of x
        #TODO: use a better way to get the 1 matrix of x
        #  by giving a suitable threshold -- example, 0.5
        zero_pixel_threshold = 0.1#self.threshold
        
        #matrix_1 = (x > zero_pixel_threshold).astype(np.float64)
        '''
        x_mean = np.mean(x)
        x_max = np.max(x)
        x_threshold = x_mean + (x_max - x_mean)/2 #With the prior knowledge that inputs are sparse, this threshold can separate most of the 1s and 0s
        
        zero_pixel_threshold = x_mean
        matrix_1 = (x > x_mean).astype(np.float64)
        matrix_1 = matrix_1 + matrix_1 - 1
        print("x_mean: ", x_mean, " matrix_1 mean: ", np.mean(matrix_1), " matrix_1 sum: ", np.sum(matrix_1))
        matrix_1 = matrix_1 - np.mean(matrix_1) - 1/(self.kernel_size[0]*self.kernel_size[1]*self.kernel_size[2])
       
        matrix_1 = matrix_1/(np.sum(matrix_1))
        characterization = x/np.max(x)
        '''
        
        #make the matrix_1, when multiplied by x, the output is 1
        #assigning top n to 1, and rest to -1, to make the output of the dot product to be 1
        #find the threshold, where the values below and above sum to same value
        #because the range of x is 0 to 1, the sum of the values below and above the threshold is 0.5
        random_filter = np.random.rand(self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1)/5
        threshold_same_sum = 0
        for i in range(100):
            threshold = i/100
            x_above = x[x > threshold]
            x_below = x[x <= threshold]
            if (np.sum(x_above) - np.sum(x_below)) <= 1:
                threshold_same_sum = threshold
                break
        print("threshold_same_sum: ", threshold_same_sum)
        #x_above = x[x > threshold_same_sum] + np.zeros(x.shape)
        #x_below = x[x <= threshold_same_sum] + np.zeros(x.shape)
        matrix_1 = np.zeros(x.shape)
        matrix_1[x > threshold_same_sum] = 1
        matrix_1[x <= threshold_same_sum] = -1
        
        #mode 0
        #if (np.sum(matrix_1) != 0):
        #    matrix_1 = matrix_1/np.sum(matrix_1)
        
        #mode 1    
        #matrix_1_positive_sum = np.sum(matrix_1[x > threshold_same_sum])
        #matrix_1_negative_sum = np.sum(matrix_1[x <= threshold_same_sum])
        #matrix_1[x <= threshold_same_sum] = matrix_1[x <= threshold_same_sum]/(matrix_1_negative_sum / matrix_1_positive_sum)
        
        #mode 2
        matrix_1 = np.multiply(matrix_1, x)
        
        #mode 3
        
        #matrix_1 = x - threshold_same_sum
        
        #matrix_1 = (x - np.mean(x))
        #square
        #matrix_1 = np.multiply(matrix_1, abs(matrix_1))
        #matrix_1 = matrix_1/np.max(abs(matrix_1))
        
        
        
        bias = np.array([-1]) #0
        
        if (np.max(x) > zero_pixel_threshold):
            #characterization =(x/np.max(x))
            #equals to the 1 matrix of x
            #random_filter = np.random.rand(3,3,1)/5
            
            #avoid a random_filter value on pixel <= 0
            #multiply by dot product of matrix_1 and random_filter
            random_filter = np.multiply(matrix_1, random_filter)
            
            #devide by 5 so that the maximum output is 0.2*16 - 1 = 2.2
            #characterization = characterization/5  #/np.average(characterization)
            
            #or divide by the size of the filter, setting bias to 0, so the output is 1
            
            #bias = np.array([-np.sum(matrix_1)])
            bias = np.array([-(np.sum(np.multiply(matrix_1, x)))])
            '''          
            characterization = (matrix_1)# + x) - random_filter
            matrix_n1 = matrix_1 - 1 #if is 0, then -1
            #matrix_n1 *= (np.sum(characterization))
            characterization = characterization + matrix_n1
            
            characterization -= random_filter
            characterization -= np.mean(characterization)
            '''
            
            #V3
            characterization = (matrix_1)# - zero_pixel_threshold
            print("characterization shape = ", characterization.shape, "bias = ", float(bias))
            #characterization -= random_filter
            #print("characterization = ", characterization)
            print(np.sum(characterization))
            #characterization = characterization/np.sum(characterization) #so when product with x, the output is 1
            
            
        else:
            bias = np.array([1])
            characterization = x - 2*(bias)    #/(np.max(x)+0.01)
            #print(characterization)
            
        #print("characterization = ", characterization, "matrix_1 = ", matrix_1)
        #characterization = np.asarray(characterization - 1 + matrix_1).reshape(1,self.kernel_size[0],self.kernel_size[1],1).astype(np.float64)
        characterization = np.asarray(characterization).reshape(1,self.kernel_size[0],self.kernel_size[1], self.kernel_size[2], 1).astype(np.float64)
        
        #normalize
        #characterization = characterization/np.max(characterization)
        new_filter.build(input_shape = (1,self.kernel_size[0],self.kernel_size[1],self.kernel_size[2],1))
        weight = characterization.reshape(self.kernel_size[0],self.kernel_size[1],self.kernel_size[2],1,1)
        new_filter.set_weights([weight, bias])
        decoder.build(input_shape = (1,1,1,1,1))
        decoder.set_weights([weight, bias])
        #the weight of this filter is characterized by x

        #train the new filter
        #new filter must reject all other x in the sample space
        self.filter_list.append(new_filter)
        
        def calc_reg(weight, bias):
            reg = 0
            #reg = tf.reduce_sum((tf.square(weight)))# + tf.reduce_sum(tf.square(bias))
            #reg = tf.reduce_sum((weight))
            #reg += tf.reduce_sum(bias)
            #reg += tf.reduce_sum(tf.square(bias))
            #reg = tf.square(reg)
            
            return reg
            
        #==============================================
        #sub-functions that can be reused in this function
        def call_autoencoder(x):
            y = new_filter(x)
            y = decoder(y)
            #because decoder has activation, so y is in range [0,1]
            #hence magnify y by max in x
            y = y * np.max(x)
            return y
        
        
        def fit_autoencoder(x):
            '''
            autoencoder = tf.keras.Model(new_filter.input, call_autoencoder(new_filter.input))
            autoencoder.compile(optimizer=self.optimizer, loss=self.loss)
            autoencoder.fit(x, x, epochs=epochs, verbose=0)
            '''

            #use gradient descent to train the new filter
            #consider the sample space and the autoencoder 
            
            loss = 0
            with tf.GradientTape(persistent=True) as tape:
                y = call_autoencoder(x)
                #use sum of square error as loss function
                loss = tf.reduce_sum(tf.square(y - x))
                #a regularization by calculating sum of weights
                if regularization:
                    loss += calc_reg(new_filter.weights[0], new_filter.weights[1])
                    #loss += tf.reduce_sum((new_filter.weights[0]))# + new_filter.weights[1]))
            grad_filter = tape.gradient(loss, new_filter.trainable_variables)
            grad_decoder = tape.gradient(loss, decoder.trainable_variables)
            self.optimizer.apply_gradients(zip(grad_filter, new_filter.trainable_variables))
            self.optimizer.apply_gradients(zip(grad_decoder, decoder.trainable_variables))
            #clear tape
            tf.keras.backend.clear_session()
            return loss
        
        def fit_filter(negative_samples):
            loss = 0
            with tf.GradientTape(persistent=True) as tape:
                y = new_filter(negative_samples)
                #expect y to be close to -1
                loss = tf.reduce_sum(tf.square(y) - -1)
                if regularization:
                    #make weights close to a sum of 1
                    loss += calc_reg(new_filter.weights[0], new_filter.weights[1])
                    #loss += tf.reduce_sum((new_filter.weights[0]))# + new_filter.weights[1]))
                #the square makes loss reduce faster, so more tolerant to negative samples
            grad_filter = tape.gradient(loss, new_filter.trainable_variables)
            self.optimizer.apply_gradients(zip(grad_filter, new_filter.trainable_variables))
            #clear tape
            tf.keras.backend.clear_session()
            return loss
        
        def combined_fit(target, negative_samples):
            loss = 0
            with tf.GradientTape(persistent=True) as tape:
                y = call_autoencoder(target)
                #use sum of square error as loss function
                loss = tf.reduce_sum(tf.square(y - target))
                #expect y to be close to 0
                z = new_filter(negative_samples)
                target_z = -1#-1
                #target_z = z - self.threshold
                loss += tf.reduce_sum(tf.square(z - target_z))# (0 - (1-self.threshold)))) #far from threshold
                #loss += tf.reduce_sum(tf.square(new_filter(negative_samples))) #targetting for 0
                loss += (tf.reduce_sum(tf.square(new_filter(target) - 1))) #targetting for 1

                #the above calculations can be improved
                if regularization:
                    loss += calc_reg(new_filter.weights[0], new_filter.weights[1])
                    #loss += tf.reduce_sum((new_filter.weights[0]))# + new_filter.weights[1]))
            grad_filter = tape.gradient(loss, new_filter.trainable_variables)
            grad_decoder = tape.gradient(loss, decoder.trainable_variables)
            self.optimizer.apply_gradients(zip(grad_filter, new_filter.trainable_variables))
            #last best not update decoder
            self.optimizer.apply_gradients(zip(grad_decoder, decoder.trainable_variables))
            #clear tape
            tf.keras.backend.clear_session()
            return loss
        #==============================================
        
        def combined_fit_v2(target, negative_samples):
            loss = 0
            
            #the negative samples in this version is a list of all negative samples
            with tf.GradientTape(persistent=True) as tape:
                y = call_autoencoder(target)
                #use sum of square error as loss function
                loss = tf.reduce_sum(tf.square(y - target))
                #expect y to be close to 0
                #fetch the loss of each negative sample
                for neg_sample in negative_samples:
                    z = new_filter(neg_sample)
                    target_z = -1#-1
                    loss += tf.reduce_sum(tf.square(z - target_z))/(len(negative_samples)+1) #the averaged loss
                    loss += (tf.reduce_sum(tf.square(new_filter(target) - 1)))#/(len(negative_samples)+1)
                
                if regularization:
                    loss += calc_reg(new_filter.weights[0], new_filter.weights[1])
                    #loss += tf.reduce_sum((new_filter.weights[0]))# + new_filter.weights[1]))
                    
            grad_filter = tape.gradient(loss, new_filter.trainable_variables)
            grad_decoder = tape.gradient(loss, decoder.trainable_variables)
            self.optimizer.apply_gradients(zip(grad_filter, new_filter.trainable_variables))
            #last best not update decoder
            self.optimizer.apply_gradients(zip(grad_decoder, decoder.trainable_variables))
            #clear tape
            tf.keras.backend.clear_session()
            return loss
        
        
        def combined_fit_v3(target, negative_samples):
            loss = 0
            
            #the negative samples in this version is a list of all negative samples
            with tf.GradientTape(persistent=True) as tape:
                y = call_autoencoder(target)
                #use sum of square error as loss function
                loss = tf.reduce_sum(tf.square(y - target))
                #expect y to be close to 0
                #fetch the loss of each negative sample
                loss += (tf.reduce_sum(tf.square(new_filter(target) - 1)))
                
                for neg_sample in negative_samples:
                    z = new_filter(neg_sample)
                    target_z = 0#-1, make the negative samples always exist in loss back propagation instead of ignoring it during training
                    loss += tf.reduce_sum(tf.square(z - target_z))/(len(negative_samples)+1) #the averaged loss
                    #/(len(negative_samples)+1)
                #if the depth is greater than 3, then it is after another layer. Use regularization
                if self.kernel_size[2] > 3:
                    #regularization -- make the weights close to 0
                    loss += np.sum(np.square(new_filter.weights[0]))
                
                
                    
            grad_filter = tape.gradient(loss, new_filter.trainable_variables)
            grad_decoder = tape.gradient(loss, decoder.trainable_variables)
            self.optimizer.apply_gradients(zip(grad_filter, new_filter.trainable_variables))
            #last best not update decoder
            self.optimizer.apply_gradients(zip(grad_decoder, decoder.trainable_variables))
            #clear tape
            tf.keras.backend.clear_session()
            return loss
        
        #
        progress_bar = tf.keras.utils.Progbar(epochs)
        

        '''
        print("\ncombined training")
        for epoch in (range(epochs)):
            loss = 0
            for i in range(len(self.filter_list) - 1):
                loss = combined_fit(x, self.sample_space[i])
            progress_bar.update(epoch, values=[("loss", loss)])
        ''' 
        
        print("\ncombined training v3")
        neg_samples = []
        for i in range(len(self.filter_list) - 1):
            neg_samples.append(self.sample_space[i])
        for epoch in (range(epochs)):
            loss = 0
            loss = combined_fit_v3(x, neg_samples)
            progress_bar.update(epoch, values=[("loss", loss)])
        
            
        
        self.sample_space[len(self.filter_list) - 1] = x #.reshape(1, self.kernel_size[0], self.kernel_size[1], 1) #add to sample space

        
        self.filter_list[len(self.filter_list) - 1] = new_filter #replace the filter
        
        
        if image_x is not None:
            print("\nUpdating the sample space")
            #get the feature map for different filters
            for filter_i in range(max(len(self.filter_list) - 5,0), len(self.filter_list)):
                feature_map = self.filter_list[filter_i](image_x)
                #get one of the max value's location
                max_loc_fm = np.unravel_index(np.argmax(feature_map, axis=None), feature_map.shape)#[0]
                #print(max_loc_fm, feature_map.shape, image_x.shape)
                #map back to the image_x
                org_x = max_loc_fm[1]*self.stride
                org_y = max_loc_fm[2]*self.stride
                org_z = max_loc_fm[3]*self.stride
                #find the patch from image_x
                self.sample_space[filter_i] = image_x[:, 
                                                      (org_x):(org_x + self.kernel_size[0]), 
                                                      (org_y):(org_y + self.kernel_size[1]), 
                                                        (org_z):(org_z + self.kernel_size[2]),
                                                      :].reshape(1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1)

        if refit:
            #fit all the filters
            self.refit_all(epochs = epochs)
        
        return new_filter
    
    
    def refit_all(self, epochs = 100, image_x = None):
        #refit all the filters with their own sample (1) and other filters' samples (-1)
        #update the sample space
                
        def refit_filter(idx, epochs):
            print("\nrefitting filter ", idx)
            progress_bar = tf.keras.utils.Progbar(epochs)
            for epoch in range(epochs):
                filter_optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
                loss = 0
                with tf.GradientTape(persistent=True) as tape:
                    
                    #fit the filter
                    for i in range(len(self.filter_list)):
                        if (i == idx):
                            loss += tf.reduce_sum(tf.square(self.filter_list[i](self.sample_space[i]) - 1))
                        else:
                            loss += tf.reduce_sum(tf.square(self.filter_list[i](self.sample_space[i]) - -1))
                grad_filter = tape.gradient(loss, self.filter_list[idx].trainable_variables)
                filter_optimizer.apply_gradients(zip(grad_filter, self.filter_list[idx].trainable_variables))
                progress_bar.update(epoch, values=[("loss", loss)])
            #clear tape
            
            tf.keras.backend.clear_session()
            
        
        for i in range(max(int(len(self.filter_list)/2)-1,0), len(self.filter_list)):
            refit_filter(i, epochs)
    
    #TODO: following functions are needed potentially when the learning covers more images
    def collapse_check(self, example_img):
        #check if the filters are overlapping
        return
    
    def collapse_overlapping(self, example_img):
        #collapse the overlapping filters
        return
    
    
    
    def get_index_map(self, x):
        feature_maps = self.call(x)
        return np.argmax(feature_maps, axis = 3).reshape(np.shape(feature_maps)[1:-1])
    
    
    def call_seperated_fm(self, x):
        feature_maps = []
        for i in range(len(self.filter_list)):
            feature_maps.append(self.filter_list[i](x))
        return feature_maps
            
    def get_aggregated_conv(self):
        #return the aggregated convolutional layer by building a conv layer with the weights in the filter list
        #to set layer weights
        #kernel_size[0], kernel_size[1], kernel_size[2], channel, len(self.filter_list)
        weights = np.zeros((self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1, len(self.filter_list)))
        biases = np.zeros((len(self.filter_list)))
        for i in tqdm.tqdm(range(len(self.filter_list))):
            weight_i = self.filter_list[i].get_weights()[0].reshape(self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1, 1)
            bias_i = self.filter_list[i].get_weights()[1].reshape(1)
            weights[:,:,:,0,i] = weight_i[:,:,:,0,0]
            biases[i] = bias_i[0]
        new_conv = tf.keras.layers.Conv3D(len(self.filter_list), self.kernel_size, strides = self.stride, padding = "valid", activation =self.activation)
        new_conv.build((None, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], 1))
        self.aggregated_conv = new_conv
        return new_conv
    
    
        
            
#new_filter = customized_CNN_kernel(4, 2, 'relu')

def show_indedx_act_map_2D(model, img):
    threshold = model.threshold
    
    #show the index map and activation map
    
    feature_maps = model.call(img.reshape(1, img.shape[0], img.shape[1], 1))
    index_map = np.argmax(feature_maps, axis = 3).reshape(np.shape(feature_maps)[1:-1])
    #print("!")
    fig = plt.figure(figsize=(10, 10))
    
    ax = fig.add_subplot(1, 2, 1)
    ax.imshow(index_map, cmap = "gist_ncar")
    ax.set_title("index map")
    
    #print("index map shape", index_map.shape)
    activated_points = np.zeros(np.shape(index_map))
    mapping = np.max(feature_maps, axis = 3).reshape(np.shape(feature_maps)[1:-1])
    activated_points[mapping > 0.5] = 1
    
    ax2 = fig.add_subplot(1, 2, 2)
    plt.imshow(mapping, cmap = "gray")
    print("min and max", np.min(mapping), np.max(mapping))
    #plt.set_title("activated points")
    
    fig.show()
    plt.show()
    
def show_all_index_maps_2D(model, data_0_9):
    #show the heatmap of 10 images
    #show 20 images
    figure = plt.figure(figsize=(20, 20))


    #show the heatmap of 10 images and 10 original images
    #5 images one row

    for i in range(len(images_x_0_9)):
        feature_maps = tf.convert_to_tensor(my_model.call(images_x_0_9[i]).numpy().astype('float64'))
        plt.subplot((len(images_x_0_9) * 2) // 5, 5, i+1 + (i // 5) * 5)
        plt.imshow(images_x_0_9[i].reshape(28,28), cmap='gray')
        plt.title("label "+ str(i))
        plt.subplot((len(images_x_0_9) * 2) // 5, 5, i+6 + (i // 5) * 5)
        index_map = np.argmax(feature_maps, axis = 3).reshape(np.shape(feature_maps)[1:-1])
        plt.imshow(index_map, cmap='gist_ncar')

    plt.show()

def show_all_index_maps_3D(model, data_0_9):
    figure = plt.figure(figsize=(20, 20))
    
    for i in range(len(data_0_9)):
        
        feature_maps = tf.convert_to_tensor(model.call(data_0_9[i]).numpy().astype('float64'))
        plt.subplot((len(data_0_9) * 2) // 5, 5, i+1 + (i // 5) * 5)
        plt.imshow(images_x_0_9[i].reshape(28,28), cmap='gray')
        plt.title("label "+ str(i))
        plt.subplot((len(data_0_9) * 2) // 5, 5, i+6 + (i // 5) * 5)
        index_map = np.argmax(feature_maps, axis = 3).reshape(np.shape(feature_maps)[1:-2])
        plt.imshow(index_map, cmap='gist_ncar')
    plt.show()
    

images_x_0_9 = []
for i in range(10):
    images_x_0_9.append(train_images[np.where(train_labels == i)[0][0]])
    
fm_0_9 = []
#give the feature maps processed by first layer




def generate_filter(image, model, image_shape, fm_0_9):
    if image_shape[2] > model.kernel_size[2]:
        model.update_depth(image_shape[2])
    print("image shape", image_shape)
    kernel_size = model.kernel_size
    stride = model.stride
    #image_shape ---> (x, y, z)
    image = image.reshape(1, image_shape[0], image_shape[1], image_shape[2], 1)
    feature_map_size_xy = (image_shape[0] - kernel_size[0]) // stride + 1
    feature_map_size_z = (image_shape[2] - kernel_size[2]) // stride + 1
    feature_map_size = (feature_map_size_xy, feature_map_size_xy, feature_map_size_z)

    if len(model.filter_list) == 0:
        null_input = np.zeros((kernel_size[0], kernel_size[1], kernel_size[2], 1))
        new_filter = model.add_filter(null_input, epochs = 100)
        
        if new_filter is None:
            print("ERROR null filter")
    
    feature_maps = model.call(image.reshape(1, image_shape[0], image_shape[1], image_shape[2], 1))
    #feature map size --  1, x, y, z, 1
    #separate to multiple feature maps corresponding back to different filters
    feature_maps_of_filters = []
    
    for i in range(len(model.filter_list)):
        depth_start = i * feature_map_size[2]
        depth_end = (i+1) * feature_map_size[2]
        feature_maps_of_filters.append(feature_maps[:, :, :, depth_start:depth_end, :])
    #add together and divide by the number of filters
    
    #find the maximum values within the feature maps
    #print(model)
    max_values_on_maps = np.zeros(feature_map_size)
    for fm in feature_maps_of_filters:
        fm = (fm>model.threshold).numpy().astype(np.float32)
        max_values_on_maps += fm.reshape(feature_map_size)
    
    
    inactive_points = np.where(max_values_on_maps == 0)
    inactive_ratio_0 = len(inactive_points[0])/np.prod(np.shape(max_values_on_maps))
    #check how many points are inactive
    print ("inactive ratio ", inactive_ratio_0)
    
    if len(inactive_points[0]) == 0:
        return model, inactive_ratio_0
    
    point = np.random.choice(len(inactive_points[0]))
    selected = point
    
    x = inactive_points[0][selected]
    y = inactive_points[1][selected]
    z = inactive_points[2][selected]
    
    print("selected point", x, y, z)
    
    #get the patch of the image that corresponds to the selected point
    org_x_start = x * stride
    org_y_start = y * stride
    org_z_start = z * stride
    org_x_end = org_x_start + kernel_size[0]
    org_y_end = org_y_start + kernel_size[1]
    org_z_end = org_z_start + kernel_size[2]
    
    patch = image[:, org_x_start:org_x_end, org_y_start:org_y_end, org_z_start:org_z_end, :]
    patch = patch.reshape(kernel_size[0], kernel_size[1], kernel_size[2], 1)
    new_filter = model.add_filter(patch, epochs = 100, image_x = image)
    
    separated_new_fms = model.call_seperated_fm(image)
    max_map = np.zeros(separated_new_fms[0].shape)
    for fm in separated_new_fms:
        fm = (fm>model.threshold).numpy().astype(np.float32)
        max_map += fm.reshape(separated_new_fms[0].shape)
        
    inactive_points_check = np.where(max_map == 0)
    inactive_ratio_1 = len(inactive_points_check[0])/np.prod(np.shape(max_map))
    print("inactive ratio after adding filter", inactive_ratio_1)
    
    
    
    show_all_index_maps_3D(model, data_0_9=fm_0_9)
    return model, inactive_ratio_1
    



def generate_model(img_x, model = None):
    image_x = img_x.reshape(1, img_x.shape[0], img_x.shape[1], 1)
    last_selected = [-1, -1]

    
    my_model = model
    if model == None:
        my_model = extensible_CNN_layer_multi_module_3D()
    
    filter_size = my_model.kernel_size
    stride = my_model.stride
    
    threshold = my_model.threshold
    

    print("activation and threshold")
    print(my_model.activation, threshold)
    for i in range(20): #generate a series of filters
        generate_filter(image_x, my_model, image_x.shape[1:])

    #print(my_model.activation, len(my_model.filter_list))
    return my_model
    

def get_inactive_ratio(model, image):
    
    fm_i = model.call_seperated_fm(image)
    max_map = np.zeros(fm_i[0].shape)
    
    for fm in fm_i:
        fm = (fm>model.threshold).numpy().astype(np.float32)
        max_map += fm.reshape(fm_i[0].shape)
        
    inactive_ratio = len(np.where(max_map == 0)[0])/np.prod(np.shape(max_map))
    return inactive_ratio
        
    
def get_inactive_ratio_list(model, images):
    inactive_ratios = []
    total = len(images)
    counter = 0
    tqdm_bar = tqdm.tqdm(total=total)
    for img in images:
        inactive_ratios.append(get_inactive_ratio(model, img))
        counter += 1
        #a progress bar
        #tqdm
        tqdm_bar.update(1)
        
    
    return inactive_ratios

def generate_model_on_images(images, model, images_0_9, inactive_ratio_threshold = 0.1, n = 3):
    #initialize by generating 2 filters on the first image
    if model == None:
        model = extensible_CNN_layer_multi_module_3D()
    
    #model = generate_filter(images[0], model, threshold = model.threshold)
    
    #loop through the images, generate filters
    
    inactive_ratios = [0 for i in range(len(images))]
    '''
    for i in range(len(images)):
        print("\nimage ", i)
        model, ratio = generate_filter(images[i], model, images[i].shape[1:], fm_0_9=images_0_9)
        inactive_ratios[i] = ratio
    '''
    
    #If the model has no filter, first round, generate 1 filter on randomly selected n images
    if len(model.filter_list) == 0:
        for i in range(n):
            image_idx = np.random.randint(len(images))
            print("\nimage ", image_idx)
            model, ratio = generate_filter(images[image_idx], model, images[image_idx].shape[1:], fm_0_9=images_0_9)
            inactive_ratios[image_idx] = ratio
        
    #then generate 1 filter on the image with the highest inactive ratio
    
    inactive_ratios = get_inactive_ratio_list(model, images)
    
    mean_inactive_ratio = np.mean(inactive_ratios)
    max_inactive_ratio = np.max(inactive_ratios)
    print("inactive ratios mean", mean_inactive_ratio, "max", max_inactive_ratio)
    
    if max_inactive_ratio < inactive_ratio_threshold:
        return model, False

    #get top n images with highest inactive ratio
    top_n_inactive_ratio_idx = np.argsort(inactive_ratios)[-n:]
    for i in top_n_inactive_ratio_idx:
        print("\nimage ", i)
        model, ratio = generate_filter(images[i], model, images[i].shape[1:], fm_0_9=images_0_9)
    
    
    print(model.activation, len(model.filter_list))
    
    return model, True
        

def examin_aggregated_conv(images_x_0_9, layer):
    figure = plt.figure(figsize=(10, 10))
    
    fms = layer(np.array(images_x_0_9).reshape(10, image_size, image_size, 1, 1))
    for i in range(10):
        #print(fms[i].shape)
        fm_i = fms[i].numpy().reshape(fms[i].shape[0], fms[i].shape[1], fms.shape[-1])
        index_map = np.argmax(fm_i, axis = 2)
        #print(fm_i.shape, index_map.shape)
        ax_i = figure.add_subplot(2, 5, i+1)
        ax_i.imshow(index_map, cmap = "gist_ncar")
    #show_all_index_maps_3D(data_0_9=images_y_0_9, model=my_model)
    
      

#get the images 0 - 9
images_x_0_9 = []
images_y_0_9 = []
image_size = len(train_images[0])
n_shot = 1
for label_idx in range(10):
    for i in range(n_shot):
        image_label_i = np.where(train_labels == label_idx)[0]
        image_x_i = train_images[image_label_i][i].reshape(1,image_size,image_size,1, 1).astype('float64')
        image_y_i = label_idx
        one_hot_y_i = tf.one_hot(image_y_i, 10).numpy().reshape(1,10)
        #train_one_shot(image_x_i, one_hot_y_i)
        images_x_0_9.append(image_x_i)
        images_y_0_9.append(one_hot_y_i)
    
print(len(images_x_0_9))

      

In [None]:
def get_images_class_n(images, labels, class_count = 5, labelled_img_count = 1000):
    """_summary_
    
    Description:
        This function returns a list of images and labels in the order of class 0 * 1000, class 1 * 1000, ... class n * 1000

    Args:
        images (_type_): _description_
        labels (_type_): _description_
        class_count (int, optional): _description_. Defaults to 5.
        labelled_img_count (int, optional): _description_. Defaults to 1000.

    Returns:
        _type_: _description_
    """
    images_class_n = []
    labels_class_n = []
    for i in range(class_count):
        for j in range(labelled_img_count):
            img_j = images[np.where(labels == i)[0][j]]
            images_class_n.append(img_j.reshape(1, image_size, image_size, 1, 1))
            labels_class_n.append(i)
    return images_class_n, labels_class_n


images_1000_0_4, labels_1000_0_4 = get_images_class_n(train_images, train_labels, class_count = 5, labelled_img_count = 1000)
images_1000_5_9, labels_1000_5_9 = get_images_class_n(train_images, train_labels, class_count = 5, labelled_img_count = 1000)

In [None]:
#generate l0 model on images_1000_0_4
epochs = 10
model_l0 = None
continue_generation = True
for epoch in range(epochs):
    if not continue_generation:
        break
    print("epoch", epoch)
    model_l0, continue_generation = generate_model_on_images(images_1000_0_4, model_l0, images_x_0_9, inactive_ratio_threshold = 0.01, n = 3)

In [None]:
#generate l1 model on images_1000_5_9

epochs = 10
continue_generation = True
for epoch in range(epochs):
    if not continue_generation:
        break
    print("epoch", epoch)
    model_l0, continue_generation = generate_model_on_images(images_1000_5_9, model_l0, images_x_0_9, inactive_ratio_threshold = 0.01, n = 3)

In [None]:
#get the mnist dataset
mnist = tf.keras.datasets.mnist
(train_images_mnist, train_labels_mnist), (test_images_mnist, test_labels_mnist) = mnist.load_data()
train_images_mnist = train_images_mnist.reshape(train_images_mnist.shape[0], 28, 28, 1, 1).astype('float64')
train_images_mnist = train_images_mnist / 255.0
test_images_mnist = test_images_mnist.reshape(test_images_mnist.shape[0], 28, 28, 1, 1).astype('float64')
test_images_mnist = test_images_mnist / 255.0
#visualize the 0-9 images of mnist
mnist_x_0_9 = []
for i in range(10):
    mnist_x_0_9.append(train_images_mnist[np.where(train_labels_mnist == i)[0][0]].reshape(1,28,28,1,1).astype('float64'))
show_all_index_maps_3D(data_0_9=mnist_x_0_9, model=model_l0)

In [None]:
#generate more filters on mnist
continue_generation = True
for epoch in range(epochs):
    if not continue_generation:
        break
    print("epoch", epoch)
    model_l0, continue_generation = generate_model_on_images(mnist_x_0_9, model_l0, images_x_0_9, inactive_ratio_threshold = 0.01, n = 3)

In [None]:
#build a multi-module model for classification
