In [1]:
#expand cell width to 100%
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [140]:
import tensorflow as tf

def refactor_block_name(num_):
    if num_ == 1:
        return f"{num_}st"
    elif num_ ==2 :
        return f"{num_}nd"
    elif num_ == 3:
        return f"{num_}rd"
    else:
        return f"{num_}th"

def estimateAbundances(inputNodes):
    sampleWiseSums=tf.keras.backend.sum(inputNodes,
                                     axis=1,keepdims=True)
    return inputNodes/sampleWiseSums

def pearson_correlation(x,y):
    x_mean=tf.reduce_mean(x)
    y_mean=tf.reduce_mean(y)
    x_diff=x-x_mean
    y_diff=y-y_mean
    covariance=tf.reduce_mean(tf.multiply(x_diff,y_diff))
    x_std=tf.sqrt(tf.reduce_mean(tf.square(x_diff)))
    y_std=tf.sqrt(tf.reduce_mean(tf.square(y_diff)))
    correlation=covariance/(x_std*y_std)
    return correlation

def load_abundance_models(inputLayer, trainable = False, mineral_list = ["K", "U", "Th"],
                            base_dirs = "/Users/msd/Desktop/coderepo/sandeepan/Ml4Sci_GRS_abundance_estimation-Multi_Task_Models"):
    vars = locals()
    mineral_model_list = []
    for mineral_ in mineral_list:
        var_name = f"{mineral_.lower()}AbundanceEstimator"
        vars[var_name] = tf.keras.models.load_model(f"{base_dirs}/Models/untrained_{mineral_}_Abundance_Model.h5",
                                               custom_objects={'pearson_correlation':pearson_correlation},
                                               compile=True)
        vars[var_name]._name = var_name ## check the name of all the model as it seems two model share the same name
        if trainable == False:
            vars[var_name].trainable = False
        mineral_model_list.append(vars[var_name](inputLayer))

    return mineral_model_list

def create_convolution_blocks(inputs, num_filters, num_block, global_seed = 23):
    block_name = refactor_block_name(num_block)
    convolutionLayer = tf.keras.layers.Conv1D(filters=num_filters,kernel_size=5,strides=3,
                                                padding='valid',name=f"{block_name}_Conv_layer")(inputs)
    reluActivation = tf.keras.layers.Activation('relu',name=f"{block_name}_Activation")(convolutionLayer)
    batchNormalization=tf.keras.layers.BatchNormalization(name=f"{block_name}_Batch_Norm")(reluActivation)
    dropoutLayer=tf.keras.layers.Dropout(0.5, noise_shape=None,seed=global_seed,name=f"{block_name}_Dropout")(batchNormalization)
    return dropoutLayer

def create_transposed_blocks(inputs, num_filters, num_block, global_seed  = 23):
    block_name = refactor_block_name(num_block)
    transposedConvolutionalLayer=tf.keras.layers.Conv1DTranspose(filters=16,kernel_size=5,strides=1,
                                                activation=None,name=f"{block_name}_Trans_Conv_Layer")(inputs)
    decoderReluActivation=tf.keras.layers.Activation('relu', name=f"{block_name}_Decoder_Activation")(transposedConvolutionalLayer)
    decoderBatchNormalization=tf.keras.layers.BatchNormalization(name=f"{block_name}_Decoder_Batch_Norm")(decoderReluActivation)
    decoderDropoutLayer=tf.keras.layers.Dropout(0.5,noise_shape=None,seed=global_seed, name=f"{block_name}_Decoder_Dropout")(decoderBatchNormalization)
    return decoderDropoutLayer

def create_unmixing_autoencoder(input_shape):
    inputLayer=tf.keras.Input(shape=input_shape, name= "input_layer")

    firstConvBlock = create_convolution_blocks(inputLayer, 32, 1)
    secondConvBlock = create_convolution_blocks(firstConvBlock, 16, 2)
    thirdConvBlock = create_convolution_blocks(secondConvBlock, 16, 3)

    flattenedFeatures=tf.keras.layers.Flatten()(thirdConvBlock)
    firstFullyConnectedLayer=tf.keras.layers.Dense(112,activation='relu',name="1st_Fully_Connected_Layer")(flattenedFeatures)
    secondFullyConnectedLayer=tf.keras.layers.Dense(7,activation='relu', name="7_Element_Abundance_Embedding")(firstFullyConnectedLayer)

    abundanceEmbedding=tf.keras.layers.Lambda(estimateAbundances)(secondFullyConnectedLayer)
    abundance_mineral_list = load_abundance_models(inputLayer, False)

    allAbundanceValues=tf.keras.layers.concatenate([*abundance_mineral_list, abundanceEmbedding], axis=-1)

    firstFullyConnectedLayerDecoder=tf.keras.layers.Dense(112,activation='relu',name="1st_Fully_Connected_Layer_Decoder")(allAbundanceValues)
    firstFullyConnectedLayerDecoder = tf.expand_dims(firstFullyConnectedLayerDecoder,axis=-1)

    firstTransConvBlock = create_transposed_blocks(firstFullyConnectedLayerDecoder, 16, 1)
    secondTransConvBlock = create_transposed_blocks(firstTransConvBlock, 16, 2)
    thirdTransConvBlock = create_transposed_blocks(secondTransConvBlock, 32, 3)

    decoderFlattenedFeatures = tf.keras.layers.Flatten()(thirdTransConvBlock)
    secondFullyConnectedLayerDecoder=tf.keras.layers.Dense(2048,activation='relu',name="2nd_Fully_Connected_Layer_Decoder")(decoderFlattenedFeatures)
    thirdFullyConnectedLayerDecoder=tf.keras.layers.Dense(noOfChannels,activation='relu',name="3rd_Fully_Connected_Layer_Decoder")(secondFullyConnectedLayerDecoder)

    kAbundanceEstimator=tf.keras.Model(inputs=inputLayer,
                                   outputs=[abundanceEmbedding,thirdFullyConnectedLayerDecoder],name="Unmixing_Autoencoder")
    
    embeddingLossFunction=tf.keras.losses.MeanSquaredError()
    decoderLossFunction=tf.keras.losses.CosineSimilarity()
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)

    kAbundanceEstimator.compile(optimizer= optimizer, loss = [embeddingLossFunction, decoderLossFunction], metrics = ['accuracy', 'mse'])

    return kAbundanceEstimator

import tensorflow.keras.backend as K
import numpy as np
import gc
K.clear_session()

noOfChannels = 421

unmixing_autoencoder = create_unmixing_autoencoder((noOfChannels, 1))
print(unmixing_autoencoder.summary())

Model: "Unmixing_Autoencoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_layer (InputLayer)       [(None, 421, 1)]     0           []                               
                                                                                                  
 1st_Conv_layer (Conv1D)        (None, 139, 32)      192         ['input_layer[0][0]']            
                                                                                                  
 1st_Activation (Activation)    (None, 139, 32)      0           ['1st_Conv_layer[0][0]']         
                                                                                                  
 1st_Batch_Norm (BatchNormaliza  (None, 139, 32)     128         ['1st_Activation[0][0]']         
 tion)                                                                         