<a href="https://colab.research.google.com/github/PurabPatel555/SingleSequenceResnet/blob/main/SingleSequenceResnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Mount Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Paths
TRAIN_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/train/train.DSSP'
VALID_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/validation/validation.DSSP'
CHECKPOINT_PATH = '/content/drive/My Drive/Sparks/SingleSequenceResnetCheckpoint'
MODEL_PATH = '/content/drive/My Drive/Sparks/SingleSequenceResnetModel'
TEST_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/SPOT/SPOT.dssp'

In [None]:
#Imports
import os, sys
import tensorflow as tf
import tensorflow_addons as tfa
import keras
import keras.backend as K
import numpy as np
from numpy import argmax

In [None]:
#Hyperparameters
n_layers = 60
n_filters = 60
epochs = 100
bs_train=1
bs_valid=1
train_mode = False #Changes the dropout behaviour in the model

In [None]:
#Data exploration and loading
"""
The format of this data is:
>
sequence1
structure1
>
sequence2
structure2
.
.
.
"""
f = open(TRAIN_PATH, "r")
sequences = []
structures = []
while True:
  line = f.readline()
  if len(line) == 0:
    break
  if (line.find('>') != -1):
    sequence = f.readline()
    structure = f.readline()
    sequences.append(sequence)
    structures.append(structure)
f2 = open(VALID_PATH, "r")
sequences_valid = []
structures_valid = []
while True:
  line = f2.readline()
  if len(line) == 0:
    break
  if ((line).find('>') != -1):
    sequence_valid = f2.readline()
    structure_valid = f2.readline()
    sequences_valid.append(sequence_valid)
    structures_valid.append(structure_valid)

In [None]:
#Create a data generator to load data for training
class DataGenerator(keras.utils.Sequence):
    def __init__(self, sequences, structures, batch_size=1, shuffle=True):
        self.batch_size = batch_size
        self.structures = structures
        self.sequences = sequences
        self.shuffle=shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.structures) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        seq_data_formatted=[]
        str_data_formatted=[]
        aas = 'ARNDCEQGHILKMFPSTWYV'  #Amino acid symbols (input feature categories)
        aa_to_int = dict((a, i) for i, a in enumerate(aas))
        sss = 'GHITEBSCX' #Secondary structure options (output categories) 
        ss_to_int = dict((s, i) for i, s in enumerate(sss))
        seq_data = [self.sequences[i][:-1] for i in indexes]
        lengths = [len(sequence) for sequence in seq_data]  #Keep lengths of each sequence (for later padding purposes)
        max_length = max(lengths) #Find the maximum length in the batch (for padding all other sequences to this length)
        for seq in seq_data:
          padding = max_length-len(seq) #Find the number of padded elements needed in the sequence
          integer_encoded_seq = [aa_to_int[aa] for aa in seq] #Encode each amino acid as an integer
          integer_encoded_seq = np.pad(integer_encoded_seq, (0,padding), 'constant', constant_values=20)  #Pad 
          seq_one_hot = np.eye(21)[integer_encoded_seq] #One-hot encoding
          seq_data_formatted.append(seq_one_hot.astype('float32'))
        X = np.asarray(seq_data_formatted)

        str_data = [self.structures[i][:-1] for i in indexes]
        for str_ in str_data:
          padding = max_length-len(str_)  #Find the number of padded elements needed in the structure
          integer_encoded_str = [ss_to_int[ss] for ss in str_]  #Encode each secondary structure element as an integer
          integer_encoded_str = np.pad(integer_encoded_str, (0,padding), 'constant', constant_values=8) #Pad
          str_one_hot = np.eye(9)[integer_encoded_str]  #One-hot encoding
          str_one_hot = str_one_hot[:,:-1]
          str_data_formatted.append(str_one_hot.astype('float32'))
        y = np.asarray(str_data_formatted)

        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.structures))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

In [None]:
#Define the Resnet model using Keras Functional API
def block(input, n_filters, kernel, dilation):  #Helper function to create blocks
  output_res = input
  output = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=kernel, strides=1, padding='same', dilation_rate=dilation)(input)
  output = tfa.layers.InstanceNormalization()(output)
  output = tf.keras.layers.ELU(alpha=1.0)(output)
  output = tf.keras.layers.Dropout(rate=0.15)(output, training=train_mode)
  output = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=kernel, strides=1, padding='same', dilation_rate=dilation)(output)
  output = tf.keras.layers.ELU(alpha=1.0)(output_res+output)
  return output

inputs = tf.keras.Input(shape=(None,21), dtype=tf.float32)  #Input layer
x = inputs
x = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=3, strides=1, padding='same')(x)
x = output = tfa.layers.InstanceNormalization()(x)
x = tf.keras.layers.ELU(alpha=1.0)(x)

dilation = 1
for i in range(n_layers):
  x = block(input=x, n_filters=n_filters, kernel=3, dilation=dilation)
  dilation *= 2
  if(dilation == 16):
    dilation = 1
outputs = x
outputs = tf.keras.layers.Conv1D(8, kernel_size=3, strides=1, padding='same')(x)
outputs = tf.nn.softmax(outputs)  #Output layer
model = tf.keras.Model(inputs=inputs, outputs=outputs)

In [None]:
#Masked loss
"""
'X' in the structure data corresponds to an unknown label at that position, which must be ignored in the loss
"""
def loss(mask_label):
    mask_label = K.variable(mask_label)
    def masked_cce(y_true, y_pred):
        mask = K.all(K.equal(y_true, mask_label), axis=-1)
        mask = tf.math.logical_not(mask)
        y_true = tf.boolean_mask(y_true, mask)
        y_pred = tf.boolean_mask(y_pred, mask)
        loss = K.categorical_crossentropy(y_true, y_pred)
        mask = K.cast(mask, K.floatx())
        return K.sum(loss) / K.sum(mask)
    return masked_cce

masked_cce = loss(np.array([0, 0, 0, 0, 0, 0, 0, 0]))

In [None]:
#Q8 Accuracy (Masked accuracy)
def get_accuracy(mask_label):
  mask_label = K.variable(mask_label)
  def accuracy_fun(y_true, y_pred):
    mask = K.all(K.equal(y_true, mask_label), axis=-1)
    mask = 1 - K.cast(mask, K.floatx())
    m = tf.keras.metrics.CategoricalAccuracy()
    m.reset_states()
    m.update_state(y_true, y_pred, sample_weight=mask)
    accuracy = m.result()
    return accuracy
  return accuracy_fun 

accuracy = get_accuracy(np.array([0, 0, 0, 0, 0, 0, 0, 0]))

In [None]:
#Compile the model
model.compile(
    optimizer="adam",
    loss=masked_cce,
    metrics = [accuracy] #Comment this out when training to avoid eager mode slowdown
)
#model.summary()

In [None]:
#Initialize the data generators
data_generator_train = DataGenerator(sequences=sequences, structures=structures, batch_size=bs_train)
data_generator_valid = DataGenerator(sequences=sequences_valid, structures=structures_valid, batch_size=bs_valid)

In [None]:
#Create an early stopping callback and a model checkpoint callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=0)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=CHECKPOINT_PATH,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

In [None]:
#Load most recent checkpoint if applicable
model.load_weights(CHECKPOINT_PATH)

In [None]:
#Train the model
model.fit(
    x=data_generator_train,
    epochs=epochs,
    verbose=1,
    validation_data = data_generator_valid,
    callbacks = [early_stopping_callback, model_checkpoint_callback]
    )

#Save the model
model.save(MODEL_PATH)

In [None]:
#Evaluation
bs_test = 1
tf.config.run_functions_eagerly(True)

#Load test data
f3 = open(TEST_PATH, "r")
sequences_test = []
structures_test = []
while True:
  line = f3.readline()
  if len(line) == 0:
    break
  if ((line).find('>') != -1):
    sequence_test = f3.readline()
    structure_test = f3.readline()
    sequences_test.append(sequence_test)
    structures_test.append(structure_test)

#Initialize test data generator
data_generator_test = DataGenerator(sequences=sequences_test, structures=structures_test, batch_size=bs_test)

#Load model weights
model.load_weights(CHECKPOINT_PATH)

#Evaluate the model
results = model.evaluate(x=data_generator_test)
print("test loss, test accuracy: ", results)