# Concrete Autoencoders dMRI for Tensorflow

In [7]:
import project_path # Always import this first

In [8]:
from pathlib import Path

import tensorflow as tf
import pandas as pd
import os
import numpy as np
import h5py
from datetime import datetime

from utils.env import DATA_PATH
from utils.logger import logger

In [9]:
logger.info('Tensorflow versionL: %s', tf.__version__)
logger.info("Num GPUs Available: %s", len(tf.config.list_physical_devices('GPU')))

[38;21m2021-06-12 21:02:29,332 - geometric-dl - INFO - Tensorflow versionL: 2.6.0-dev20210611 (<ipython-input-9-9638a48e88ea>:1)[0m
[38;21m2021-06-12 21:02:29,437 - geometric-dl - INFO - Num GPUs Available: 1 (<ipython-input-9-9638a48e88ea>:2)[0m


In [10]:
PROJECT_PATH = Path().cwd().parent

## MUDI data

In [22]:
# split dataset into batches
class MRISelectorSubjDataset(tf.keras.utils.Sequence):
    """MRI dataset to select features from."""

    def __init__(self, root_dir, dataf, headerf, subj_list, batch_size=100, shuffle=False):
        """
        Args:
            root_dir (string): Directory with the .csv files
            data (string): Data .csv file
            header (string): Header .csv file
            subj_list (list): list of all the subjects to include
        """
        
        self.root_dir = root_dir
        self.dataf = dataf
        self.headerf = headerf
        self.subj_list = subj_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # load the header
        subj = self.subj_list[0]
        self.header = pd.read_csv(os.path.join(self.root_dir,
                                             self.headerf), index_col=0).to_numpy()
        self.ind = self.header[np.isin(self.header[:,1],self.subj_list),0]
#         print(self.ind)
        
        self.indexes = np.arange(len(self.ind)) 

    def __len__(self):
        return int(np.ceil(len(self.ind) / float(self.batch_size)))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
#         print(indexes)

        # Find list of IDs
        list_IDs_temp = [self.ind[k] for k in indexes]
#         print(list_IDs_temp)

        # Generate data
        signals = self.__data_generation(list_IDs_temp)

        return signals, signals

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.ind)) 
        'Updates indexes after each epoch'
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Generate data
        # X = pd.read_csv(os.path.join(self.root_dir, self.dataf), index_col=0, skiprows=lambda x: x not in list_IDs_temp).to_numpy()
        h5f = h5py.File(os.path.join(self.root_dir, self.dataf), 'r')
        X = h5f.get('data1')
        X = X[list_IDs_temp,:]

        return X

## Autoencoder

In [23]:
from tensorflow.keras.layers import Dense, LeakyReLU
import numpy as np

def decoder(x):
    x = Dense(1344)(x)
    return x

def decoder_2l(x):
    x = Dense(800)(x)
    x = LeakyReLU(0.2)(x)
#    x = Dropout(0.1)(x)
    x = Dense(1000)(x)
    x = LeakyReLU(0.2)(x)
#    x = Dropout(0.1)(x)
    x = Dense(1344)(x)
    return x

def decoder_1l(x):
    x = Dense(1000)(x)
    x = LeakyReLU(0.2)(x)
#    x = Dropout(0.1)(x)
    x = Dense(1344)(x)
    return x

In [38]:
import math
from tensorflow.keras import backend as K
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer, Softmax, Input
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.initializers import Constant, glorot_normal
from tensorflow.keras.optimizers import Adam

class ConcreteSelect(Layer):
    
    def __init__(self, output_dim, start_temp = 10.0, min_temp = 0.1, alpha = 0.99999, **kwargs):
        self.output_dim = output_dim
        self.start_temp = start_temp
        self.min_temp = K.constant(min_temp)
        self.alpha = K.constant(alpha)
        super(ConcreteSelect, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.temp = self.add_weight(name = 'temp', shape = [], initializer = Constant(self.start_temp), trainable = False)
        self.logits = self.add_weight(name = 'logits', shape = [self.output_dim, input_shape[1]], initializer = glorot_normal(), trainable = True)
        super(ConcreteSelect, self).build(input_shape)
        
    def call(self, X, training = None):
        uniform = K.random_uniform(self.logits.shape, K.epsilon(), 1.0)
        gumbel = -K.log(-K.log(uniform))
        temp = K.update(self.temp, K.maximum(self.min_temp, self.temp * self.alpha))
        noisy_logits = (self.logits + gumbel) / temp
        samples = K.softmax(noisy_logits)
        
        discrete_logits = K.one_hot(K.argmax(self.logits), self.logits.shape[1])
        
        self.selections = K.in_train_phase(samples, discrete_logits, training)
        Y = K.dot(X, K.transpose(self.selections))
        
        return Y
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'output_dim': self.output_dim,
            'start_temp': self.start_temp,
            'min_temp': self.min_temp.numpy(),
            'alpha': self.alpha.numpy()
        })
        return config

class StopperCallback(EarlyStopping):
    
    def __init__(self, mean_max_target = 0.998):#, writer=None):
        self.mean_max_target = mean_max_target
        #self.writer = writer
        super(StopperCallback, self).__init__(monitor = '', patience = float('inf'), verbose = 1, mode = 'max')#, baseline = self.mean_max_target)
    
    def on_epoch_begin(self, epoch, logs = None):
        logger.info('mean max of probabilities: %.8f %s %.8f', self.get_monitor_value(logs), '- temperature', K.get_value(self.model.get_layer('concrete_select').temp))
        #s1 = self.get_monitor_value(logs)
        #print(s1)
        #s2 = K.get_value(self.model.get_layer('concrete_select').temp)
        #s1 = tf.summary.scalar('mean_max', s1)
        #s2 = tf.summary.scalar('temperature', s2)
        #print(s1)
        #self.writer.add_summary(s1)
        #self.writer.add_summary(s2)
        #print( K.get_value(K.max(K.softmax(self.model.get_layer('concrete_select').logits), axis = -1)))
        #print(K.get_value(K.max(self.model.get_layer('concrete_select').selections, axis = -1)))
    
    def get_monitor_value(self, logs):
        monitor_value = K.get_value(K.mean(K.max(K.softmax(self.model.get_layer('concrete_select').logits), axis = -1)))
        return monitor_value

class ConcreteAutoencoderFeatureSelector():
    
    def __init__(self, K, output_function, num_epochs = 100, learning_rate = 0.001, start_temp = 10.0, min_temp = 0.1, tryout_limit = 5, input_dim = 1344, callback=[]):#, writer=None): #batch_size = None, 
        self.K = K
        self.output_function = output_function
        self.num_epochs = num_epochs
#         self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.start_temp = start_temp
        self.min_temp = min_temp
        self.tryout_limit = tryout_limit
        self.input_dim = input_dim
        self.callback = callback
        #self.writer = writer
        
    def fit(self, X, val_X=None):
#         if self.batch_size is None:
#             self.batch_size = max(len(X) // 256, 16)
        
        num_epochs = self.num_epochs
        steps_per_epoch = X.__len__()#(len(X) + self.batch_size - 1) // self.batch_size
        
        for i in range(self.tryout_limit):
            K.set_learning_phase(1)
            
            inputs = Input(shape = (self.input_dim,))#X.shape[1:])
            
            alpha = math.exp(math.log(self.min_temp / self.start_temp) / (num_epochs * steps_per_epoch))
            
            self.concrete_select = ConcreteSelect(self.K, self.start_temp, self.min_temp, alpha, name = 'concrete_select')
            
            selected_features = self.concrete_select(inputs)
            outputs = self.output_function(selected_features)

            self.model = Model(inputs, outputs)
            self.model.compile(Adam(self.learning_rate), loss = 'mean_squared_error')
    
            logger.info('%s', self.model.summary())
            
            stopper_callback = StopperCallback()#writer=self.writer)
                        
            # hist = self.model.fit(X, Y, self.batch_size, num_epochs, verbose = 1, callbacks = [stopper_callback, tensorboard_callback], validation_data = validation_data)#, validation_freq = 10)
            hist = self.model.fit_generator(X, epochs=num_epochs, callbacks = [stopper_callback] + self.callback, validation_data=val_X, verbose=0)#, validation_freq = 10) workers=8, use_multiprocessing=True, 
            #fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, 
            #              class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
            
            if K.get_value(K.mean(K.max(K.softmax(self.concrete_select.logits, axis = -1)))) >= stopper_callback.mean_max_target:
                break
            
            num_epochs *= 2
        
        self.probabilities = K.get_value(K.softmax(self.model.get_layer('concrete_select').logits))
        self.indices = K.get_value(K.argmax(self.model.get_layer('concrete_select').logits))
            
        return self
    
    def get_indices(self):
        return K.get_value(K.argmax(self.model.get_layer('concrete_select').logits))
    
    def get_mask(self):
        return K.get_value(K.sum(K.one_hot(K.argmax(self.model.get_layer('concrete_select').logits), self.model.get_layer('concrete_select').logits.shape[1]), axis = 0))
    
#     def transform(self, X):
#         return X[self.get_indices()]
    
#     def fit_transform(self, X, y):
#         self.fit(X, y)
#         return self.transform(X)
    
    def get_support(self, indices = False):
        return self.get_indices() if indices else self.get_mask()
    
    def get_params(self):
        return self.model

## Experiments

In [39]:
def run_model(train_subject, 
              test_subject, 
              n_features = 500, 
              batch_size = 256, 
              num_epochs = 2000, 
              learning_rate = .001,
              decoder_architecture = decoder_2l, 
              decoder_name = 'l2'):
    """
    Trains the ConcreteAutoencoderFeatureSelector
    
    Parameters:
        train_subject (List): subjects to train on
        test_subject (List): subjects to test on
        n_features (int): number of features for the encoder output/decoder input
        batch_size (int): batch size for the ConcreteAutoencoderFeatureSelector
        num_epochs (int): number of epochs
        learning_rate (int): learning rate
        decoder_architecture (keras layer): decoder network architecture for the ConcreteAutoencoderFeatureSelector
        decoder_name (str): name to describe the `decoder_architecture`. Used for logging
        
    Returns:
        indices (List[int]): indices of the features selected by the autoencoder
    """
    data_file = 'data_.hdf5'
    header_file = 'header_.csv'
    date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_info_str = f'{date_str}_K={n_features}_epoch={num_epochs}_test={test_subject[0]}_dec={decoder_name}'
 
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=Path(PROJECT_PATH, 'logs', run_info_str))
    monitor_callback = tf.keras.callbacks.ModelCheckpoint(Path(PROJECT_PATH, 'runs', 'models', f'{run_info_str}_runtime.h5'), 
                                                          monitor='val_loss', 
                                                          verbose=0)
    
    trainset = MRISelectorSubjDataset(root_dir=DATA_PATH, 
                                      dataf=data_file, 
                                      headerf ='header_.csv',
                                      subj_list=np.array(train_subject), 
                                      batch_size=batch_size)
    
    testset = MRISelectorSubjDataset(root_dir=DATA_PATH, 
                                     dataf=data_file, 
                                     headerf=header_file,
                                     subj_list=np.array(test_subject), 
                                     batch_size=batch_size)

    selector = ConcreteAutoencoderFeatureSelector(K=n_features, 
                                                  output_function=decoder_architecture, 
                                                  num_epochs=num_epochs, 
                                                  learning_rate=learning_rate, 
                                                  start_temp=10.0, 
                                                  min_temp=0.1, 
                                                  tryout_limit=5, 
                                                  input_dim=1344, 
                                                  callback=[tensorboard_callback, monitor_callback])
    
    selector.fit(X=trainset, val_X=testset)
    
    model = selector.get_params()
    indices = selector.get_indices()

    model.save(Path(PROJECT_PATH, 'runs', 'models', f'{run_info_str}.h5'))
    
    indices_save_path = Path(PROJECT_PATH, 'runs', 'textfiles', f'{run_info_str}.txt')
    indices_save_path.parent.mkdir(parents=True, exist_ok=True) 
    np.savetxt(indices_save_path, np.array(indices, dtype=int), fmt='%d') 
    
    return indices

### Test Subject 15, default values
Training on subjects 11,12,13,14 and testing on 15. Default values for the `run_model` function.

In [40]:
result_run_0 = run_model([11, 12, 13, 14], [15])

[38;21m2021-06-11 22:53:25,521 - geometric-dl - INFO - None (<ipython-input-38-0e838eedc6e0>:118)[0m
[38;21m2021-06-11 22:53:25,574 - geometric-dl - INFO - mean max of probabilities: 0.00080113 - temperature 10.00000000 (<ipython-input-38-0e838eedc6e0>:60)[0m


Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 1344)]            0         
_________________________________________________________________
concrete_select (ConcreteSel (None, 500)               672001    
_________________________________________________________________
dense_15 (Dense)             (None, 800)               400800    
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 800)               0         
_________________________________________________________________
dense_16 (Dense)             (None, 1000)              801000    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 1000)              0         
_________________________________________________________________
dense_17 (Dense)             (None, 1344)              1345

[38;21m2021-06-11 22:53:33,415 - geometric-dl - INFO - mean max of probabilities: 0.00085819 - temperature 9.97203064 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:53:40,438 - geometric-dl - INFO - mean max of probabilities: 0.00096117 - temperature 9.94406128 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:53:48,573 - geometric-dl - INFO - mean max of probabilities: 0.00110152 - temperature 9.91609192 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:53:55,490 - geometric-dl - INFO - mean max of probabilities: 0.00130152 - temperature 9.88812256 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:54:03,433 - geometric-dl - INFO - mean max of probabilities: 0.00160927 - temperature 9.86015320 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:54:11,075 - geometric-dl - INFO - mean max of probabilities: 0.00238046 - temperature 9.83218384 (<ipython-input-38-0e838eedc6e0>:60)[0m
[38;21m2021-06-11 22:54:18,925 - geomet

FileNotFoundError: [Errno 2] No such file or directory: '/home/maarten/Workspace/School/uu/thesis/geometric-dl-dmri/runs/textfiles/20210611-22:53:25_K=500_epoch=2000_test=15_dec=l2.txt'

In [None]:
logger.info('%s', result_run_0)