# Gabor-Based Convolutional Neural Network
## trained using supervised contrastive loss
## for protein landmark classification
* [Imports and Utility Functions](#first-bullet)
* [Creating Custom Gabor2D Conv Layer](#second-bullet)
* [Model Definition](#third-bullet)
* [Contrastive Loss Definition](#fourth-bullet)
* [Model Training](#fifth-bullet)
* [Visualize Learned Representation](#sixth-bullet)
* [Confusion Matrix](#seventh-bullet)

## Imports and Utility Functions <a class="anchor" id="first-bullet"></a>

In [5]:
import matplotlib.pyplot as plt 
import tensorflow as tf
import numpy as np 
import cv2
import os
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Input, Flatten,Lambda, Dropout
from tensorflow.keras import backend as K
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import nn_ops
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.engine.input_spec import InputSpec

In [6]:
# Single cell image height and width
IMG_WIDTH = 100 
IMG_HEIGHT = 100
#training batch size
BATCH_SIZE = 4
#note batch size was small due toe memory limitation the gpu used
# GeForce GTX 1050 Ti

# utility functions
def get_tfrecords_path_list(tf_record_path):
    # get *tfrecords in given path
    _,_,fnames = next(os.walk(tf_record_path))
    file_list = []
    for f in fnames:
        file_list.append(os.path.join(tf_record_path,f))
    return file_list

def img_augment(img):
    # randomly flip image
    flip_chance = tf.random.uniform([],minval=0,maxval=4,dtype=tf.int32)
    if flip_chance == 1:
        img = tf.image.flip_left_right(img)
  
    flip_chance = tf.random.uniform([],minval=0,maxval=4,dtype=tf.int32)
    if flip_chance == 1:
        img = tf.image.flip_up_down(img)
 
    rot_chance = tf.random.uniform([],minval=0,maxval=4,dtype=tf.int32)
    if rot_chance == 1:
        img = tf.image.rot90(img)
    return img

def decode_image(image,width,height):
    # decode string bytes 
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [width,height, 3])
    return image

#define features stored in tfrecords
features={'img_raw': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)}
def _parse_function(raw_record):
    sample = tf.io.parse_single_example(raw_record, features)
    img_string = sample['img_raw']
    img = tf.io.decode_raw(img_string,tf.uint8)
    img = tf.cast(img,tf.float32)/255.0
    img = tf.reshape(img,(IMG_HEIGHT,IMG_WIDTH,1))
    img = img_augment(img)
    label = tf.cast(sample["label"],tf.float32)
#     label = tf.one_hot(label,5)
    return img ,label 

AUTOTUNE = tf.data.experimental.AUTOTUNE

def get_dataset(filenames_list,augment=False):
    raw_dataset = tf.data.TFRecordDataset(filenames_list)
    dataset = raw_dataset.map(_parse_function)
    dataset = dataset.shuffle(4096)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [7]:
#paths to tfrecords 
# 4 repeatitions for training
# 1 repeatition for validation
train_tfrecord_path1 = r'D:\Phenix\EO_20190317_mCherryColocRep1__2019-03-17T18_04_08-Measurement 1\markers_tf_records'
train_tfrecord_path2 = r'D:\Phenix\EO_20190319_mCherryColocRep2__2019-03-19T16_55_51-Measurement 1\markers_tf_records'
train_tfrecord_path3 = r'D:\Phenix\EO_20190319_mCherryColocRep2__2019-03-19T20_47_51-Measurement 2\markers_tf_records'
train_tfrecord_path4 = r'D:\Phenix\EO_20190321_mCherryColocRep3__2019-03-21T16_17_40-Measurement 1\markers_tf_records'
valid_tfrecord_path = r'D:\Phenix\EO_20190321_mCherryColocRep3__2019-03-21T21_11_35-Measurement 2\markers_tf_records'

In [8]:
file_list1 = get_tfrecords_path_list(train_tfrecord_path1)
file_list2 = get_tfrecords_path_list(train_tfrecord_path2)
file_list3 = get_tfrecords_path_list(train_tfrecord_path3)
file_list4 = get_tfrecords_path_list(train_tfrecord_path4)
#combine tfrecords paths list for all 4 reps
file_list = file_list1 + file_list2 + file_list3 + file_list4
valid_list = get_tfrecords_path_list(valid_tfrecord_path)

In [12]:
#convert lists of tfrecords into training and validation datset
train_dataset = get_dataset(file_list)
valid_dataset = get_dataset(valid_list)

## Creating Custom Gabor2D Conv Layer <a class="anchor" id="second-bullet"></a>

In [None]:
def _gabor_filter(shape, sigma=1.0, theta=20, lambd=15.0, gamma=0.5):
    """Return a gabor filter."""
    params = {
        'ksize': shape,
        'sigma': sigma,
        'theta': theta,
        'lambd': lambd,
        'gamma': gamma
    }
    gabor_filter = cv2.getGaborKernel(**params)
    return gabor_filter


class GaborConv2D(Conv2D):
    """Class GaborConv2D
       Custom Conv2D with constant gabor filter included.
    """
    def __init__(self, filters, kernel_size, **kwargs):
        super(GaborConv2D, self).__init__(filters, kernel_size, **kwargs)
        print(self.kernel_size)
        if np.size(kernel_size) == 1:
            self.kernelB_init_weight = _gabor_filter(shape=(kernel_size, kernel_size))
        else:
            self.kernelB_init_weight = _gabor_filter(kernel_size)

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        kernel_shape = self.kernel_size + (input_channel, self.filters)

        self.kernelA = self.add_weight(
            name='kernelA',
            shape=kernel_shape,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            trainable=True,
            dtype=self.dtype)

        self.kernelB = K.constant(self.kernelB_init_weight)
        self.kernel = K.transpose(K.dot(K.transpose(self.kernelA), self.kernelB))

        if self.use_bias:
            self.bias = self.add_weight(
                name='bias',
                shape=(self.filters,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                trainable=True,
                dtype=self.dtype)
        else:
            self.bias = None

        channel_axis = self._get_channel_axis()
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_channel})

        self._build_conv_op_input_shape = input_shape
        self._build_input_channel = input_channel
        self._padding_op = self._get_padding_op()
        self._conv_op_data_format = conv_utils.convert_data_format(
            self.data_format, self.rank + 2)
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=self.kernel.shape,
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self._padding_op,
            data_format=self._conv_op_data_format)
        self.built = True

## Model Definition <a class="anchor" id="third-bullet"></a>

### Encoder Definition

In [None]:
num_classes = 5  
input_shape = (100,100,1)
def create_encoder():
    input_x = Input(input_shape)
    x = GaborConv2D(128,kernel_size=(3,3),strides=(1,1),activation='relu',padding='same')(input_x)
    x = MaxPooling2D()(x)
    x = Conv2D(128,3,activation='relu',padding='same')(x)
    x = MaxPooling2D()(x)
    x = Conv2D(128,3,activation='relu',padding='same')(x)
    x = MaxPooling2D()(x)
    x = Conv2D(128,3,activation='relu',padding='same')(x)
    x = MaxPooling2D()(x)
    x = Conv2D(128,3,activation='relu',padding='same')(x)
    x = MaxPooling2D()(x)
    x = Flatten()(x)
    model = Model(input_x ,x,name='convnet_encoder')
    return model

### Encoder + Classifier Definition

In [None]:
def create_classifier(encoder, trainable=True):
    for layer in encoder.layers:
        layer.trainable = trainable

    inputs = Input(shape=input_shape)
    features = encoder(inputs)
    features = Dense(hidden_units, activation="relu")(features)
    features = Dropout(dropout_rate)(features)
    outputs = Dense(num_classes, activation="softmax")(features)

    model = Model(inputs=inputs, outputs=outputs, name="convnet-classifier")
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

## Contastive Loss Definition <a class="anchor" id="fourth-bullet"></a>

In [15]:
class SupervisedContrastiveLoss(tf.keras.losses.Loss):
    def __init__(self, temperature=1, name=None):
        super(SupervisedContrastiveLoss, self).__init__(name=name)
        self.temperature = temperature

    def __call__(self, labels, feature_vectors, sample_weight=None):
        # Normalize feature vectors
        feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
        # Compute logits
        logits = tf.divide(
            tf.matmul(
                feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
            ),
            self.temperature,
        )
        return tfa.losses.npairs_loss(tf.squeeze(labels), logits)

### Adding a projecting head

In [None]:
def add_projection_head(encoder):
    inputs = tf.keras.Input(shape=input_shape)
    features = encoder(inputs)
    outputs = tf.keras.layers.Dense(projection_units, activation="relu")(features)
    model = tf.keras.Model(
        inputs=inputs, outputs=outputs, name="cifar-encoder_with_projection-head"
    )
    return model

## Model Training <a class="anchor" id="fifth-bullet"></a>

### 1- Train encoder using Contrastive Loss

In [None]:
encoder = create_encoder()
encoder_with_projection_head = add_projection_head(encoder)
encoder_with_projection_head.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate),
    loss=SupervisedContrastiveLoss(temperature),
)
history = encoder_with_projection_head.fit(
    train_dataset, epochs=10
)

### 2- Save Encoder 

In [None]:
encoder.save('trained_encoder.tf',save_format='tf')

### 3- Create Classifier using pre-trained encoder

In [None]:
classifier = create_classifier(encoder, trainable=False)
history = classifier.fit(train_dataset, epochs=10)
# check accuracy on validation set
accuracy = classifier.evaluate(valid_dataset)

### 4- Save Fully Trained Classifying Model

In [None]:
classifier.save('trained_classifier.tf',save_format='tf')

## Visualize Learned Representation <a class="anchor" id="sixth-bullet"></a>

In [None]:
experiment_path = r'D:\Phenix\EO_20190321_mCherryColocRep3__2019-03-21T21_11_35-Measurement 2\experiment_green_tfrecords';
experiment_tf_records = get_tfrecords_path_list(experiment_path)
experiment_dataset = get_dataset(experiment_tf_records)

for batch in experiment_dataset:
    break
labels = batch[1]
outputs = encoder(batch[0])
for batch in experiment_dataset:
    current_labels = batch[1]
    current_outputs = encoder(batch[0])
    print(current_outputs.shape)
    outputs = np.vstack([outputs, current_outputs])
    labels = np.hstack([labels, current_labels])
#Represenations and labels are small enough to be saved as a numpy array
np.save('Experiment_features',outputs)
np.save('Experment_labels',labels)

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=3)
X_embeddins = tsne.fit_transform(outputs)
plt.figure(figsize=(30,30))
fig,ax = plt.subplots()
test  = plt.scatter(X_embeddins[:,0],X_embeddins[:,2],c=labels, cmap='Spectral')
plt.colorbar()

## Confusion Matrix <a class="anchor" id="seventh-bullet"></a>

In [None]:
import tensorflow_addons as tfa
import seaborn as sns

y_true = batch[1]
y_pred = np.argmax(classifier.predict(batch[0]),axis=1)

for batch in valid_dataset:
    current_labels = batch[1]
    current_preds = np.argmax(classifier.predict(batch[0]),axis=1)

    y_pred = np.hstack([y_pred, current_preds])
    y_true = np.hstack([y_true, current_labels])
    
plt.figure(figsize=(10,10))
cmap = sns.cubehelix_palette(50, hue=0.05, rot=0, light=0.9, dark=0, as_cmap=True)
keys = ['MitoTracker', 'mChActA', 'mChCb5','ChPTDSS1','mCardinal']
sns.set(font_scale=2)
ax = sns.heatmap(c, annot=True,xticklabels=keys,yticklabels=keys, fmt="g",cmap=cmap,square=True)
