<a href="https://colab.research.google.com/github/PurabPatel555/SingleSequenceResnetAugmentations/blob/main/SingleSequenceResnetAugmentations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
sample_names_indices = []
for _, sample_name in enumerate(tqdm(sample_names)):
  num_augmentations = get_num_lines(sample_name, TRAIN_PATH)
  augmentation_indices = np.zeros(max_augmentations, dtype=int)
  for i in range(math.floor(max_augmentations/num_augmentations)):
    augmentation_indices_batch = np.arange(num_augmentations)
    np.random.shuffle(augmentation_indices_batch)
    augmentation_indices = np.concatenate((augmentation_indices, augmentation_indices_batch))
  augmentation_indices_batch = np.arange(num_augmentations)
  np.random.shuffle(augmentation_indices_batch)
  augmentation_indices = np.concatenate((augmentation_indices, augmentation_indices_batch[:(max_augmentations%num_augmentations)]))

  sample_names_indices.append([sample_name, augmentation_indices.tolist()])

In [None]:
def alignment(seq_name, sample_num, label):
  with open(os.path.join(AUG_PATH, (seq_name + '.a3m'))) as fp:
    for i, line in enumerate(fp):
        if i == sample_num:
            augmented_seq = line
  label_aug = ""
  augmented_seq = augmented_seq[:-1]
  print(augmented_seq)
  j = 0
  for id, aa in enumerate(augmented_seq):
    if (aa == "-"):
      pass
    elif (aa.isupper()):
      label_aug += (label[id-j])
    else:
      j = j+1
      label_aug += ('X')
  return augmented_seq.replace('-', '').upper(), label_aug

In [None]:
#Mount Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def get_num_lines(sample_name, train_path):
  full_path = os.path.join(train_path, (sample_name+'.a3m'))
  with open(full_path) as file:
    lines = len(file.readlines())
  return lines

In [None]:
#Paths
TRAIN_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/train/train.DSSP'
AUG_PATH = '/content/drive/My Drive/Sparks/datasets/train_aug/a3m-sample'
VALID_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/validation/validation.DSSP'
CHECKPOINT_PATH = '/content/drive/My Drive/Sparks/SingleSequenceResnetAugmentationsCheckpoint'
MODEL_PATH = '/content/drive/My Drive/Sparks/SingleSequenceResnetAugmentationsModel'
TEST_PATH = '/content/drive/My Drive/Sparks/SPOT-1D-single/data/SPOT/SPOT.dssp'

In [None]:
#Imports
import os, sys
import tensorflow as tf
import tensorflow_addons as tfa
import keras
import keras.backend as K
import numpy as np
from numpy import argmax
import math
import pickle
from random import choice
from tqdm.notebook import tqdm

In [None]:
#Hyperparameters
n_layers = 60
n_filters = 60
epochs = 100
bs_train=1
bs_valid=1

In [None]:
#Data exploration and loading
"""
The format of this data is:
>
sequence1
structure1
>
sequence2
structure2
.
.
.
"""
f = open(TRAIN_PATH, "r")
sequences = []
structures = []
while True:
  line = f.readline()
  if len(line) == 0:
    break
  if (line.find('>') != -1):
    sequence = line[1:].split(" ")[0]
    _ = f.readline()
    structure = f.readline()
    sequences.append(sequence)
    structures.append(structure)
f2 = open(VALID_PATH, "r")
sequences_valid = []
structures_valid = []
while True:
  line = f2.readline()
  if len(line) == 0:
    break
  if ((line).find('>') != -1):
    sequence_valid = f2.readline()
    structure_valid = f2.readline()
    sequences_valid.append(sequence_valid)
    structures_valid.append(structure_valid)

In [None]:
sequence_length_exclude = [[sequence, get_num_lines(sequence, AUG_PATH), []] for sequence in tqdm(sequences)]
with open('/content/drive/My Drive/Sparks/sequence_length_exclude.pkl', 'wb') as f:
  pickle.dump(sequence_length_exclude, f)

In [None]:
with open('/content/drive/My Drive/Sparks/sequence_length_exclude.pkl', 'rb') as f:
  sequence_length_exclude = pickle.load(f)
  print(len(sequence_length_exclude))

In [None]:
#Create a data generator to load data for training (with augmentations)
class DataGeneratorTrain(keras.utils.Sequence):
    def alignment(self, seq_name, sample_num, label):
      with open(os.path.join(AUG_PATH, (seq_name + '.a3m'))) as fp:
        for i, line in enumerate(fp):
            if i == sample_num:
                augmented_seq = line
      label_aug = ""
      augmented_seq = augmented_seq[:-1]
      j = 0
      for id2, aa in enumerate(augmented_seq):
        if (aa == "-"):
          pass
        elif (aa.isupper()):
          label_aug += (label[id2-j])
        else:
          j = j+1
          label_aug += ('X')
      return augmented_seq.replace('-', '').upper(), label_aug

    def __init__(self, sequences, structures, batch_size=1, shuffle=True):
        self.batch_size = batch_size
        self.structures = structures
        self.sequences = sequences
        self.shuffle=shuffle
        self.on_epoch_end()
        for i in range(len(self.sequences)):
          self.sequences[i][2] = [] 

    def __len__(self):
        return int(np.floor(len(self.structures) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        id = indexes[0]
        print(id)
        seq_data_formatted=[]
        str_data_formatted=[]
        aas = 'ARNDCEQGHILKMFPSTWYV'  #Amino acid symbols (input feature categories)
        aa_to_int = dict((a, i) for i, a in enumerate(aas))
        sss = 'GHITEBSCX' #Secondary structure options (output categories) 
        ss_to_int = dict((s, i) for i, s in enumerate(sss))

        condition = True
        while condition:
          seq_data = self.sequences[id]
          print(seq_data[0])
          exclude = seq_data[2]
          if(seq_data[1]==len(exclude)):
            self.sequences[id][2] = [] 
            exclude = []
          aug_choice = choice([i for i in range(int(seq_data[1])) if i not in exclude])
          self.sequences[id][2].append(aug_choice) 
          seq_data, str_data = self.alignment(seq_name=seq_data[0], sample_num=aug_choice, label=self.structures[id][:-1])
          #if (((seq_data.find('X') == -1) and (seq_data.find('B') == -1) and (seq_data.find('Z') == -1) and (seq_data.find('U') == -1) and (seq_data.find('O') == -1)) and (str_data.count('X')<len(str_data))):
              #break

          condition = False
          if (str_data.count('X')>=len(str_data)):
            condition = True
          else:
            for character in seq_data:
              if aas.find(character) == -1:
                condition = True

        seq_data = [seq_data]
        str_data = [str_data]

        #print(seq_data)

        lengths = [len(sequence) for sequence in seq_data]  #Keep lengths of each sequence (for later padding purposes)
        max_length = max(lengths) #Find the maximum length in the batch (for padding all other sequences to this length)

        for seq in seq_data:
          padding = max_length-len(seq) #Find the number of padded elements needed in the sequence
          integer_encoded_seq = [aa_to_int[aa] for aa in seq] #Encode each amino acid as an integer
          integer_encoded_seq = np.pad(integer_encoded_seq, (0,padding), 'constant', constant_values=20)  #Pad 
          seq_one_hot = np.eye(21)[integer_encoded_seq] #One-hot encoding
          seq_data_formatted.append(seq_one_hot.astype('float32'))
        X = np.asarray(seq_data_formatted)

        for str_ in str_data:
          padding = max_length-len(str_)  #Find the number of padded elements needed in the structure
          integer_encoded_str = [ss_to_int[ss] for ss in str_]  #Encode each secondary structure element as an integer
          integer_encoded_str = np.pad(integer_encoded_str, (0,padding), 'constant', constant_values=8) #Pad
          str_one_hot = np.eye(9)[integer_encoded_str]  #One-hot encoding
          str_one_hot = str_one_hot[:,:-1]
          str_data_formatted.append(str_one_hot.astype('float32'))
        y = np.asarray(str_data_formatted)
        print(self.sequences[id][2])

        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.structures))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

In [None]:
#Create a data generator for the validation and testing (without augmentations)
class DataGeneratorValidTest(keras.utils.Sequence):
    def __init__(self, sequences, structures, batch_size=1, shuffle=True):
        self.batch_size = batch_size
        self.structures = structures
        self.sequences = sequences
        self.shuffle=shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.structures) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        seq_data_formatted=[]
        str_data_formatted=[]
        aas = 'ARNDCEQGHILKMFPSTWYV'  #Amino acid symbols (input feature categories)
        aa_to_int = dict((a, i) for i, a in enumerate(aas))
        sss = 'GHITEBSCX' #Secondary structure options (output categories) 
        ss_to_int = dict((s, i) for i, s in enumerate(sss))
        seq_data = [self.sequences[i][:-1] for i in indexes]
        lengths = [len(sequence) for sequence in seq_data]  #Keep lengths of each sequence (for later padding purposes)
        max_length = max(lengths) #Find the maximum length in the batch (for padding all other sequences to this length)
        for seq in seq_data:
          padding = max_length-len(seq) #Find the number of padded elements needed in the sequence
          integer_encoded_seq = [aa_to_int[aa] for aa in seq] #Encode each amino acid as an integer
          integer_encoded_seq = np.pad(integer_encoded_seq, (0,padding), 'constant', constant_values=20)  #Pad 
          seq_one_hot = np.eye(21)[integer_encoded_seq] #One-hot encoding
          seq_data_formatted.append(seq_one_hot.astype('float32'))
        X = np.asarray(seq_data_formatted)

        str_data = [self.structures[i][:-1] for i in indexes]
        for str_ in str_data:
          padding = max_length-len(str_)  #Find the number of padded elements needed in the structure
          integer_encoded_str = [ss_to_int[ss] for ss in str_]  #Encode each secondary structure element as an integer
          integer_encoded_str = np.pad(integer_encoded_str, (0,padding), 'constant', constant_values=8) #Pad
          str_one_hot = np.eye(9)[integer_encoded_str]  #One-hot encoding
          str_one_hot = str_one_hot[:,:-1]
          str_data_formatted.append(str_one_hot.astype('float32'))
        y = np.asarray(str_data_formatted)

        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.structures))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

In [None]:
#Define the Resnet model using Keras Functional API
def block(input, n_filters, kernel, dilation):  #Helper function to create blocks
  output_res = input
  output = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=kernel, strides=1, padding='same', dilation_rate=dilation)(input)
  output = tfa.layers.InstanceNormalization()(output)
  output = tf.keras.layers.ELU(alpha=1.0)(output)
  output = tf.keras.layers.Dropout(rate=0.15)(output)
  output = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=kernel, strides=1, padding='same', dilation_rate=dilation)(output)
  output = tf.keras.layers.ELU(alpha=1.0)(output_res+output)
  return output

inputs = tf.keras.Input(shape=(None,21), dtype=tf.float32)  #Input layer
x = inputs
x = tf.keras.layers.Conv1D(filters=n_filters, kernel_size=3, strides=1, padding='same')(x)
x = output = tfa.layers.InstanceNormalization()(x)
x = tf.keras.layers.ELU(alpha=1.0)(x)

dilation = 1
for i in range(n_layers):
  x = block(input=x, n_filters=n_filters, kernel=3, dilation=dilation)
  dilation *= 2
  if(dilation == 16):
    dilation = 1
outputs = x
outputs = tf.keras.layers.Conv1D(8, kernel_size=3, strides=1, padding='same')(x)
outputs = tf.nn.softmax(outputs)  #Output layer
model = tf.keras.Model(inputs=inputs, outputs=outputs)

In [None]:
class MyModel(keras.Model):
    def train_step(self, data):

        inputs, targets = data
        trainable_vars = self.trainable_variables
        with tf.GradientTape() as tape1:
            preds = self(inputs, training=True)  # Forward pass
                # Compute the loss value
                # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(targets, preds)
            # Compute first-order gradients
        dl_dw = tape1.gradient(loss, trainable_vars)

        print("Max of dl_dw[0]: %.4f" % tf.reduce_max(dl_dw[0]))
        print("Min of dl_dw[0]: %.4f" % tf.reduce_min(dl_dw[0]))
        print("Mean of dl_dw[0]: %.4f" % tf.reduce_mean(dl_dw[0]))

        # Update weights
        self.optimizer.apply_gradients(zip(dl_dw, trainable_vars))

        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(targets, preds)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

model = MyModel(inputs=inputs, outputs=outputs)
model.summary()

Model: "my_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, None, 21)]   0                                            
__________________________________________________________________________________________________
conv1d_122 (Conv1D)             (None, None, 60)     3840        input_2[0][0]                    
__________________________________________________________________________________________________
instance_normalization_61 (Inst (None, None, 60)     120         conv1d_122[0][0]                 
__________________________________________________________________________________________________
elu_121 (ELU)                   (None, None, 60)     0           instance_normalization_61[0][0]  
___________________________________________________________________________________________

In [None]:
#Masked loss
"""
'X' in the structure data corresponds to an unknown label at that position, which must be ignored in the loss
"""
def loss(mask_label):
    mask_label = K.variable(mask_label)
    def masked_cce(y_true, y_pred):
        #print(y_true)
        #print(y_pred)
        mask = K.all(K.equal(y_true, mask_label), axis=-1)
        mask = tf.math.logical_not(mask)
        y_true = tf.boolean_mask(y_true, mask)
        y_pred = tf.boolean_mask(y_pred, mask)
        loss = K.categorical_crossentropy(y_true, y_pred)
        mask = K.cast(mask, K.floatx())
        #print(loss)
        #print(mask)
        #print("----------")
        #print(K.sum(loss)/K.sum(mask))
        return K.sum(loss) / K.sum(mask)
    return masked_cce

masked_cce = loss(np.array([0, 0, 0, 0, 0, 0, 0, 0]))

In [None]:
#Q8 Accuracy (Masked accuracy)
def get_accuracy(mask_label):
  mask_label = K.variable(mask_label)
  def accuracy_fun(y_true, y_pred):
    mask = K.all(K.equal(y_true, mask_label), axis=-1)
    mask = 1 - K.cast(mask, K.floatx())
    m = tf.keras.metrics.CategoricalAccuracy()
    m.reset_states()
    m.update_state(y_true, y_pred, sample_weight=mask)
    accuracy = m.result()
    return accuracy
  return accuracy_fun 

accuracy = get_accuracy(np.array([0, 0, 0, 0, 0, 0, 0, 0]))

In [None]:
#Compile the model
model.compile(
    optimizer="adam",
    loss=masked_cce,
    metrics = [accuracy] #Comment this out when training to avoid eager mode slowdown,
)
#model.summary()

In [None]:
#Initialize the data generators
data_generator_train = DataGeneratorTrain(sequences=sequence_length_exclude, structures=structures[:len(sequence_length_exclude)], batch_size=1)
data_generator_valid = DataGeneratorValidTest(sequences=sequences_valid, structures=structures_valid, batch_size=1)

In [None]:
print(data_generator_valid[3])

In [None]:
#Create an early stopping callback and a model checkpoint callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=0)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=CHECKPOINT_PATH,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

In [None]:
#Load most recent checkpoint if applicable
model.load_weights(CHECKPOINT_PATH)

In [None]:
#Train the model
model.fit(
    x=data_generator_train,
    epochs=epochs,
    verbose=1,
    validation_data = data_generator_valid,
    callbacks = [early_stopping_callback, model_checkpoint_callback]
    )

#Save the model
model.save(MODEL_PATH)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1OVX_1_A
[38, 76, 50, 60]
1ZGO_1_A
[39, 83, 22, 31, 47]
1JER_1_A
[25, 9, 6, 63, 85]
1NPS_1_A
[82, 51, 87, 74]
2AZW_1_A
[92, 90, 74, 31]
1R9C_1_A
[51, 66, 87, 4]
1LJ2_1_A
[2, 12, 6, 9]
2BZW_2_B
[13, 10, 0, 8]
3U80_1_A
[84, 92, 71, 46]
2KZ9_1_A
[74, 23, 7, 93]
2P1M_d2p1ma2
[7, 24, 32, 15]
4GY7_1_A
[37, 8, 81, 71]
2KSL_1_A
[0]
1K4T_d1k4ta1
[11, 70, 95, 46]
4KUJ_1_A
[51, 31, 57, 87]
1X2I_1_A
[33, 4, 44, 27]
4ZAC_1_A
[90, 82, 88, 79]
3EAB_2_G
[78, 85, 98, 28]
2HY5_2_B
[69, 89, 57, 27]
3FOC_1_A
[24, 5, 13, 28]
1VP7_d1vp7b-
[71, 85, 41, 23]
3TVJ_1_A
[13, 34, 45, 73]
4WFB_5_C
[54, 96, 89, 99]
3BO6_1_A
[6, 23, 18, 7]
2ROR_1_A
[75, 96, 19, 28]
3K9C_1_A
[90, 63, 77, 85]
3ARX_1_A
[27, 63, 94, 32]
3MDF_1_A
[78, 75, 32, 79]
4GHU_1_A
[94, 46, 65, 64]
2C6U_1_A
[57, 66, 71, 19]
2Z0X_1_A
[86, 46, 67, 12]
2ZU6_d2zu6f2
[38, 30, 48, 29]
1FVI_d1fvia1
[78, 10, 63, 66]
2LPB_2_B
[8, 25, 23, 30]
3ECQ_1_A
3ECQ_1_A
3ECQ_1_A
[8, 59, 34, 50, 85, 63, 7

In [None]:
#Evaluation
bs_test = 1
tf.config.run_functions_eagerly(True)

#Load test data
f3 = open(TEST_PATH, "r")
sequences_test = []
structures_test = []
while True:
  line = f3.readline()
  if len(line) == 0:
    break
  if ((line).find('>') != -1):
    sequence_test = f3.readline()
    structure_test = f3.readline()
    sequences_test.append(sequence_test)
    structures_test.append(structure_test)

#Initialize test data generator
data_generator_test = DataGeneratorValidTest(sequences=sequences_test, structures=structures_test, batch_size=bs_test)

#Load model weights
model.load_weights(CHECKPOINT_PATH)

#Evaluate the model
results = model.evaluate(x=data_generator_test)
print("test loss, test accuracy: ", results)