In [8]:
import pandas as pd
import numpy as np
import keras
from keras.models import Sequential
from sklearn.utils import shuffle
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from Bio import SeqIO
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Embedding, Conv1D, Dropout, MaxPooling1D, Flatten, Dense
import os
import random

In [9]:
def plot(history):
    # learning curves of model accuracy
    plt.plot(history.history['accuracy'], label='train_acc')
    plt.plot(history.history['val_accuracy'], label='val_acc')
    plt.plot(history.history['loss'], label='train_loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.legend()
    plt.show()

In [10]:
from sklearn.metrics import accuracy_score, matthews_corrcoef, roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
def evaluate_model(model, X_val = None, X_val_pt5 = None, y_val=None):
    y_true = y_val
    # Predict probabilities (or logits if using `from_logits=True`).
    if X_val_pt5 is None:
        y_pred_probs = model.predict(X_val)
    elif X_val is None:
        y_pred_probs = model.predict(X_val_pt5)
    else:
        y_pred_probs = model.predict([X_val, X_val_pt5])

    # Convert probabilities/logits to binary predictions (threshold = 0.5).
    y_pred = (y_pred_probs > 0.5).astype(int)

    # If y_true is one-hot encoded, convert it to binary format
    if len(y_true.shape) > 1 and y_true.shape[1] > 1:  # Check if y_true is one-hot encoded
        y_true = np.argmax(y_true, axis=1)  # Convert one-hot encoded y_true to binary labels

    # Ensure y_pred is also 1D
    if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
        y_pred = np.argmax(y_pred, axis=1)  # Convert y_pred to binary labels if necessary

    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    auc = roc_auc_score(y_true, y_pred)
    auprc = average_precision_score(y_true, y_pred_probs)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # Compute Specificity
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    specificity = tn / (tn + fp)

    # Print the results
    print(f'Accuracy: {accuracy}')
    print(f'MCC: {mcc}')
    print(f'AUC: {auc}')
    print(f'AUPRC: {auprc}')
    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'Specificity: {specificity}')
    print(f'F1: {f1}')

    return accuracy, mcc, auc, auprc, precision, recall, specificity, f1

In [12]:
from keras.layers import Input, Embedding, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Lambda
from keras.models import Model
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy

def create_conv_branch(input_shape_conv):
    conv_input = Input(shape=input_shape_conv, name='conv_input')

    # Embedding layer
    x = Embedding(input_dim=256, output_dim=21, input_length=input_shape_conv[0])(conv_input)

    x = Lambda(lambda x: tf.expand_dims(x, 3))(x)

    # Convolutional layers
    x = Conv2D(32, kernel_size=(17, 3), activation='relu',
               kernel_initializer='he_normal', padding='VALID')(x)
    x = Dropout(0.2)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)

    x = Dense(16, activation='relu', kernel_initializer='he_normal')(x)
    x = Dropout(0.2)(x)

    # Output of convolutional branch
    conv_output = Dense(16, activation='relu', name='conv_output')(x)

    conv_output = Dense(1, activation='sigmoid')(conv_output)

    model = Model(inputs=conv_input, outputs=conv_output, name='conv_branch')

    model.compile(optimizer=Adam(learning_rate=0.001),
                  loss=BinaryCrossentropy(),
                  metrics=['accuracy'])

    return model
# # Instantiate the convolutional branch
# conv_branch = create_conv_branch((33,))

# # Train the convolutional branch
# conv_history = conv_branch.fit(
#     X_train, y_train,
#     epochs=100,
#     batch_size=256,
#     verbose=1,
#     validation_data=(X_val, y_val),
#     callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
# )

# # Optionally, save the trained weights
# conv_branch.save_weights('conv_branch.weights.h5')


In [5]:
from keras.layers import Input, Dense, Dropout
from keras.models import Model

def create_ann_branch(input_shape_ann):
    ann_input = Input(shape=(input_shape_ann,), name='ann_input')

    x = Dense(256, activation='relu')(ann_input)
    x = Dropout(0.4)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.4)(x)

    # Output of ANN branch
    ann_output = Dense(1, activation='sigmoid', name='ann_output')(x)

    model = Model(inputs=ann_input, outputs=ann_output, name='ann_branch')

    model.compile(optimizer=Adam(learning_rate=0.001),
                  loss=BinaryCrossentropy(),
                  metrics=['accuracy'])

    return model

# # Instantiate the ANN branch
# ann_branch = create_ann_branch(1024)

# # Train the ANN branch
# ann_history = ann_branch.fit(
#     X_train_pt5, y_train,
#     epochs=100,
#     batch_size=256,
#     verbose=1,
#     validation_data=(X_val_pt5, y_val),
#     callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
# )

# # Optionally, save the trained weights
# ann_branch.save_weights('ann_branch.weights.h5')


In [6]:
from keras.layers import Concatenate, Dense
from keras.models import Model

def create_combined_model(conv_branch, ann_branch):
    # Freeze the branches if you don't want to train them initially
    conv_branch.trainable = False
    ann_branch.trainable = False

    # Define inputs
    conv_input = conv_branch.input
    ann_input = ann_branch.input

    # Get outputs from the branches
    conv_output = conv_branch.get_layer(index=6).output
    ann_output = ann_branch.get_layer(index=4).output

    # Concatenate the outputs
    combined = Concatenate()([conv_output, ann_output])

    # Add combined layers
    x = Dense(16, activation='relu')(combined)
    x = Dense(4, activation='relu')(x)
    output = Dense(1, activation='sigmoid', name='output')(x)

    # Define the combined model
    combined_model = Model(inputs=[conv_input, ann_input], outputs=output, name='combined_model')

    # Compile the combined model
    combined_model.compile(optimizer=Adam(learning_rate=0.001),
                           loss=BinaryCrossentropy(),
                           metrics=['accuracy'])

    return combined_model

# # If you saved the weights separately, load them
# conv_branch.load_weights('conv_branch.weights.h5')
# ann_branch.load_weights('ann_branch.weights.h5')

# # Create the combined model
# combined_model = create_combined_model(conv_branch, ann_branch)

# # View the summary
# combined_model.summary()

# # Train the combined model
# combined_history = combined_model.fit(
#     [X_train, X_train_pt5], y_train,
#     epochs=100,
#     batch_size=256,
#     verbose=1,
#     validation_data=([X_val, X_val_pt5], y_val),
#     callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
# )


# # Evaluate the model
# evaluate_model(combined_model, X_val, X_val_pt5, y_val)
# evaluate_model(combined_model, X_test, X_test_pt5, y_test)




Load Data, Train and Save Final Model:

In [17]:
train = pd.read_csv('../Embeddings/Prot_t5/train_t5.csv')
val = pd.read_csv('../Embeddings/Prot_t5/val_t5.csv')
test = pd.read_csv('../Embeddings/Prot_t5/test_t5.csv')

print(train.shape)
print(val.shape)
print(test.shape)

# Convert the embedding strings to numpy arrays
X_train_embeddings = train['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))
X_val_embeddings = val['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))
X_test_embeddings = test['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))

# Convert to a numpy array if needed
X_train_embeddings = np.stack(X_train_embeddings.values)
X_val_embeddings = np.stack(X_val_embeddings.values)
X_test_embeddings = np.stack(X_test_embeddings.values)

# Extract sequences
X_train = train['sequence'].values
X_val = val['sequence'].values
X_test = test['sequence'].values

# Extract labels
y_train = train['label'].values
y_val = val['label'].values
y_test = test['label'].values

# Create a dictionary to map amino acids to integers
amino_acids = 'ACDEFGHIKLMNPQRSTVWY-'

aa_to_int = {aa: i for i, aa in enumerate(amino_acids)}

# Convert the sequences to a numerical format and convert to numpy arrays
X_train_num = [[aa_to_int[aa] for aa in seq] for seq in X_train]
X_val_num = [[aa_to_int[aa] for aa in seq] for seq in X_val]
X_test_num = [[aa_to_int[aa] for aa in seq] for seq in X_test]

X_train_num = np.array(X_train_num)
X_val_num = np.array(X_val_num)
X_test_num = np.array(X_test_num)

set_seed(4)

conv_branch = create_conv_branch((33,))

# Train the convolutional branch
conv_history = conv_branch.fit(
    X_train_num, y_train,
    epochs=100,
    batch_size=256,
    verbose=0,
    validation_data=(X_val_num, y_val),
    callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
)

# # Instantiate the ANN branch
ann_branch = create_ann_branch(1024)

# Train the ANN branch
ann_history = ann_branch.fit(
    X_train_embeddings, y_train,
    epochs=100,
    batch_size=256,
    verbose=0,
    validation_data=(X_val_embeddings, y_val),
    callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
)

# Create the combined model
combined_model = create_combined_model(conv_branch, ann_branch)

# View the summary
# combined_model.summary()

# Train the combined model
combined_history = combined_model.fit(
    [X_train_num, X_train_embeddings], y_train,
    epochs=100,
    batch_size=256,
    verbose=1,
    validation_data=([X_val_num, X_val_embeddings], y_val),
    callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
)

# Evaluate the model
evaluate_model(combined_model, X_val=X_val_num, X_val_pt5=X_val_embeddings, y_val=y_val)
evaluate_model(combined_model, X_val=X_test_num, X_val_pt5=X_test_embeddings, y_val=y_test)

#save the model
combined_model.save('Models/LMSuccSite.h5')
combined_model.save_weights('Models/LMSuccSite.weights.h5')

(8411, 5)
(935, 5)
(3226, 5)




Epoch 1/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 33ms/step - accuracy: 0.6297 - loss: 0.6589 - val_accuracy: 0.7668 - val_loss: 0.5667
Epoch 2/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 14ms/step - accuracy: 0.7879 - loss: 0.5236 - val_accuracy: 0.7701 - val_loss: 0.4994
Epoch 3/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.8115 - loss: 0.4320 - val_accuracy: 0.7754 - val_loss: 0.4804
Epoch 4/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8162 - loss: 0.4008 - val_accuracy: 0.7786 - val_loss: 0.4752
Epoch 5/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8257 - loss: 0.3846 - val_accuracy: 0.7807 - val_loss: 0.4717
Epoch 6/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8362 - loss: 0.3724 - val_accuracy: 0.7872 - val_loss: 0.4697
Epoch 7/100
[1m33/33[0m [




MCC: 0.33446819290404073
AUC: 0.7807686836335983
AUPRC: 0.3155356056713083
Precision: 0.22089227421109903
Recall: 0.8023715415019763
Specificity: 0.7591658257652203
F1: 0.3464163822525597


Permutation test :

In [7]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

#read train and test datasets

train = pd.read_csv('../Embeddings/Prot_t5/train_t5.csv')
val = pd.read_csv('../Embeddings/Prot_t5/val_t5.csv')
test = pd.read_csv('../Embeddings/Prot_t5/test_t5.csv')

print(train.shape)
print(val.shape)
print(test.shape)

# Convert the embedding strings to numpy arrays
X_train_embeddings = train['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))
X_val_embeddings = val['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))
X_test_embeddings = test['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))

# Convert to a numpy array if needed
X_train_embeddings = np.stack(X_train_embeddings.values)
X_val_embeddings = np.stack(X_val_embeddings.values)
X_test_embeddings = np.stack(X_test_embeddings.values)

# Extract sequences
X_train = train['sequence'].values
X_val = val['sequence'].values
X_test = test['sequence'].values

# Extract labels
y_train = train['label'].values
y_val = val['label'].values
y_test = test['label'].values

# Create a dictionary to map amino acids to integers
amino_acids_perm = [
''.join(np.random.permutation(list('ACDEFGHIKLMNPQRSTVWY-'))) for _ in range(20)
]

accuracys = []
mccs = []
aucs = []
auprcs = []
precisions = []
recalls = []
specificitys = []
f1s = []

for amino_acids in amino_acids_perm:
    aa_to_int = {aa: i for i, aa in enumerate(amino_acids)}

    # Convert the sequences to a numerical format and convert to numpy arrays
    X_train_num = [[aa_to_int[aa] for aa in seq] for seq in X_train]
    X_val_num = [[aa_to_int[aa] for aa in seq] for seq in X_val]
    X_test_num = [[aa_to_int[aa] for aa in seq] for seq in X_test]

    X_train_num = np.array(X_train_num)
    X_val_num = np.array(X_val_num)
    X_test_num = np.array(X_test_num)

    set_seed(4)

    # Instantiate the convolutional branch
    conv_branch = create_conv_branch((33,))

    # Train the convolutional branch
    conv_history = conv_branch.fit(
        X_train_num, y_train,
        epochs=100,
        batch_size=256,
        verbose=0,
        validation_data=(X_val_num, y_val),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # # Instantiate the ANN branch
    ann_branch = create_ann_branch(1024)

    # Train the ANN branch
    ann_history = ann_branch.fit(
        X_train_embeddings, y_train,
        epochs=100,
        batch_size=256,
        verbose=0,
        validation_data=(X_val_embeddings, y_val),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # Create the combined model
    combined_model = create_combined_model(conv_branch, ann_branch)

    # View the summary
    # combined_model.summary()

    # Train the combined model
    combined_history = combined_model.fit(
        [X_train_num, X_train_embeddings], y_train,
        epochs=100,
        batch_size=256,
        verbose=1,
        validation_data=([X_val_num, X_val_embeddings], y_val),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # Evaluate the model
    acc, mcc, auc, auprc, precision, recall, specificity, f1 = evaluate_model(combined_model, X_val=X_val_num, X_val_pt5=X_val_embeddings, y_val=y_val)

    evaluate_model(combined_model, X_val=X_test_num, X_val_pt5=X_test_embeddings, y_val=y_test)

    accuracys.append(acc)
    mccs.append(mcc)
    aucs.append(auc)
    auprcs.append(auprc)
    precisions.append(precision)
    recalls.append(recall)
    specificitys.append(specificity)
    f1s.append(f1)

# Print the results mean and std
print(f'Accuracy: {np.mean(accuracys)} +/- {np.std(accuracys)}')
print(f'MCC: {np.mean(mccs)} +/- {np.std(mccs)}')
print(f'AUC: {np.mean(aucs)} +/- {np.std(aucs)}')
print(f'AUPRC: {np.mean(auprcs)} +/- {np.std(auprcs)}')
print(f'Precision: {np.mean(precisions)} +/- {np.std(precisions)}')
print(f'Recall: {np.mean(recalls)} +/- {np.std(recalls)}')
print(f'Specificity: {np.mean(specificitys)} +/- {np.std(specificitys)}')
print(f'F1: {np.mean(f1s)} +/- {np.std(f1s)}')

(8411, 5)
(935, 5)
(3226, 5)





Epoch 1/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.6262 - loss: 0.6860 - val_accuracy: 0.7668 - val_loss: 0.6512
Epoch 2/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.7829 - loss: 0.6158 - val_accuracy: 0.7786 - val_loss: 0.5484
Epoch 3/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.7929 - loss: 0.4949 - val_accuracy: 0.7743 - val_loss: 0.4895
Epoch 4/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8019 - loss: 0.4331 - val_accuracy: 0.7754 - val_loss: 0.4842
Epoch 5/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8146 - loss: 0.4101 - val_accuracy: 0.7754 - val_loss: 0.4853
Epoch 6/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8190 - loss: 0.3995 - val_accuracy: 0.7690 - val_loss: 0.4860
Epoch 7/100
[1m33/33[0m 



Epoch 1/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 20ms/step - accuracy: 0.6391 - loss: 0.6855 - val_accuracy: 0.7583 - val_loss: 0.6536
Epoch 2/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7700 - loss: 0.6252 - val_accuracy: 0.7679 - val_loss: 0.5617
Epoch 3/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.7845 - loss: 0.5155 - val_accuracy: 0.7668 - val_loss: 0.4958
Epoch 4/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7968 - loss: 0.4479 - val_accuracy: 0.7690 - val_loss: 0.4854
Epoch 5/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.8110 - loss: 0.4197 - val_accuracy: 0.7668 - val_loss: 0.4881
Epoch 6/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8174 - loss: 0.4075 - val_accuracy: 0.7722 - val_loss: 0.4898
Epoch 7/100
[1m33/33[0m [



Epoch 1/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 20ms/step - accuracy: 0.6492 - loss: 0.6864 - val_accuracy: 0.7594 - val_loss: 0.6582
Epoch 2/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 14ms/step - accuracy: 0.7815 - loss: 0.6298 - val_accuracy: 0.7679 - val_loss: 0.5696
Epoch 3/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.7891 - loss: 0.5204 - val_accuracy: 0.7690 - val_loss: 0.4983
Epoch 4/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8015 - loss: 0.4479 - val_accuracy: 0.7690 - val_loss: 0.4845
Epoch 5/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8090 - loss: 0.4194 - val_accuracy: 0.7690 - val_loss: 0.4862
Epoch 6/100
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8188 - loss: 0.4082 - val_accuracy: 0.7701 - val_loss: 0.4859
Epoch 7/100
[1m33/33[0m [



KeyboardInterrupt: 

Cross validation :

In [13]:
# 10 fold cross validation
kf = KFold(n_splits=10, shuffle=True, random_state=4)

accuracys = []
mccs = []
aucs = []
auprcs = []
precisions = []
recalls = []
specificitys = []
f1s = []

for train_index, val_index in kf.split(X_train_num):
    X_train_num_fold, X_val_num_fold = X_train_num[train_index], X_train_num[val_index]
    X_train_embeddings_fold, X_val_embeddings_fold = X_train_embeddings[train_index], X_train_embeddings[val_index]
    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]

    set_seed(4)

    # Instantiate the convolutional branch
    conv_branch = create_conv_branch((33,))

    # Train the convolutional branch
    conv_history = conv_branch.fit(
        X_train_num_fold, y_train_fold,
        epochs=100,
        batch_size=256,
        verbose=0,
        validation_data=(X_val_num_fold, y_val_fold),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # # Instantiate the ANN branch
    ann_branch = create_ann_branch(1024)

    # Train the ANN branch
    ann_history = ann_branch.fit(
        X_train_embeddings_fold, y_train_fold,
        epochs=100,
        batch_size=256,
        verbose=0,
        validation_data=(X_val_embeddings_fold, y_val_fold),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # Create the combined model
    combined_model = create_combined_model(conv_branch, ann_branch)

    # View the summary
    # combined_model.summary()

    # Train the combined model
    combined_history = combined_model.fit(
        [X_train_num_fold, X_train_embeddings_fold], y_train_fold,
        epochs=100,
        batch_size=256,
        verbose=1,
        validation_data=([X_val_num_fold, X_val_embeddings_fold], y_val_fold),
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]
    )

    # Evaluate the model
    acc, mcc, auc, auprc, precision, recall, specificity, f1 = evaluate_model(combined_model, X_val=X_val_num_fold, X_val_pt5=X_val_embeddings_fold, y_val=y_val_fold)

    evaluate_model(combined_model, X_val=X_test_num, X_val_pt5=X_test_embeddings, y_val=y_test)

    accuracys.append(acc)
    mccs.append(mcc)
    aucs.append(auc)
    auprcs.append(auprc)
    precisions.append(precision)
    recalls.append(recall)
    specificitys.append(specificity)
    f1s.append(f1)

# Print the results mean and std
print(f'Accuracy: {np.mean(accuracys)} +/- {np.std(accuracys)}')
print(f'MCC: {np.mean(mccs)} +/- {np.std(mccs)}')
print(f'AUC: {np.mean(aucs)} +/- {np.std(aucs)}')
print(f'AUPRC: {np.mean(auprcs)} +/- {np.std(auprcs)}')
print(f'Precision: {np.mean(precisions)} +/- {np.std(precisions)}')
print(f'Recall: {np.mean(recalls)} +/- {np.std(recalls)}')
print(f'Specificity: {np.mean(specificitys)} +/- {np.std(specificitys)}')
print(f'F1: {np.mean(f1s)} +/- {np.std(f1s)}')




Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.5915 - loss: 0.6690 - val_accuracy: 0.7055 - val_loss: 0.6183
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7626 - loss: 0.5626 - val_accuracy: 0.7185 - val_loss: 0.5582
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.7878 - loss: 0.4844 - val_accuracy: 0.7328 - val_loss: 0.5256
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7985 - loss: 0.4509 - val_accuracy: 0.7399 - val_loss: 0.5187
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7929 - loss: 0.4384 - val_accuracy: 0.7280 - val_loss: 0.5165
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8094 - loss: 0.4233 - val_accuracy: 0.7352 - val_loss: 0.5119
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 19ms/step - accuracy: 0.6343 - loss: 0.6882 - val_accuracy: 0.7432 - val_loss: 0.6644
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.7935 - loss: 0.6386 - val_accuracy: 0.7432 - val_loss: 0.5886
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.8016 - loss: 0.5338 - val_accuracy: 0.7491 - val_loss: 0.5205
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.8099 - loss: 0.4450 - val_accuracy: 0.7562 - val_loss: 0.5109
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8142 - loss: 0.4087 - val_accuracy: 0.7574 - val_loss: 0.5179
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8280 - loss: 0.3920 - val_accuracy: 0.7586 - val_loss: 0.5237
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 21ms/step - accuracy: 0.5824 - loss: 0.6899 - val_accuracy: 0.6849 - val_loss: 0.6699
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7435 - loss: 0.6462 - val_accuracy: 0.7170 - val_loss: 0.6006
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 22ms/step - accuracy: 0.7788 - loss: 0.5509 - val_accuracy: 0.7265 - val_loss: 0.5375
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7891 - loss: 0.4739 - val_accuracy: 0.7408 - val_loss: 0.5233
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.7974 - loss: 0.4437 - val_accuracy: 0.7467 - val_loss: 0.5257
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8058 - loss: 0.4290 - val_accuracy: 0.7491 - val_loss: 0.5278
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 17ms/step - accuracy: 0.6136 - loss: 0.6894 - val_accuracy: 0.7134 - val_loss: 0.6694
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7574 - loss: 0.6470 - val_accuracy: 0.7384 - val_loss: 0.6081
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7800 - loss: 0.5538 - val_accuracy: 0.7396 - val_loss: 0.5441
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7847 - loss: 0.4704 - val_accuracy: 0.7360 - val_loss: 0.5268
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8018 - loss: 0.4354 - val_accuracy: 0.7408 - val_loss: 0.5254
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8055 - loss: 0.4217 - val_accuracy: 0.7432 - val_loss: 0.5240
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.5920 - loss: 0.6876 - val_accuracy: 0.7479 - val_loss: 0.6579
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7897 - loss: 0.6261 - val_accuracy: 0.7432 - val_loss: 0.5667
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7974 - loss: 0.5133 - val_accuracy: 0.7551 - val_loss: 0.5025
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8153 - loss: 0.4340 - val_accuracy: 0.7669 - val_loss: 0.4891
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8189 - loss: 0.4075 - val_accuracy: 0.7669 - val_loss: 0.4866
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8225 - loss: 0.3940 - val_accuracy: 0.7646 - val_loss: 0.4868
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.5908 - loss: 0.6641 - val_accuracy: 0.7360 - val_loss: 0.5918
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7960 - loss: 0.5230 - val_accuracy: 0.7467 - val_loss: 0.5288
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8090 - loss: 0.4403 - val_accuracy: 0.7503 - val_loss: 0.5230
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8224 - loss: 0.4044 - val_accuracy: 0.7539 - val_loss: 0.5220
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8221 - loss: 0.3943 - val_accuracy: 0.7586 - val_loss: 0.5110
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8295 - loss: 0.3845 - val_accuracy: 0.7622 - val_loss: 0.5067
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.6090 - loss: 0.6523 - val_accuracy: 0.7622 - val_loss: 0.5777
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7978 - loss: 0.5194 - val_accuracy: 0.7669 - val_loss: 0.5135
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8175 - loss: 0.4287 - val_accuracy: 0.7669 - val_loss: 0.4997
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8273 - loss: 0.3921 - val_accuracy: 0.7693 - val_loss: 0.4988
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8364 - loss: 0.3721 - val_accuracy: 0.7693 - val_loss: 0.4945
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8385 - loss: 0.3591 - val_accuracy: 0.7729 - val_loss: 0.4936
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 17ms/step - accuracy: 0.6328 - loss: 0.6454 - val_accuracy: 0.7348 - val_loss: 0.5949
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8116 - loss: 0.4989 - val_accuracy: 0.7562 - val_loss: 0.5330
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8307 - loss: 0.4064 - val_accuracy: 0.7574 - val_loss: 0.5240
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8424 - loss: 0.3635 - val_accuracy: 0.7634 - val_loss: 0.5301
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8430 - loss: 0.3539 - val_accuracy: 0.7622 - val_loss: 0.5315
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8512 - loss: 0.3446 - val_accuracy: 0.7622 - val_loss: 0.5304
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.6099 - loss: 0.6628 - val_accuracy: 0.7170 - val_loss: 0.5840
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7941 - loss: 0.5258 - val_accuracy: 0.7301 - val_loss: 0.5248
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.8103 - loss: 0.4393 - val_accuracy: 0.7444 - val_loss: 0.5163
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8202 - loss: 0.4050 - val_accuracy: 0.7408 - val_loss: 0.5154
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8188 - loss: 0.3976 - val_accuracy: 0.7444 - val_loss: 0.5076
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8336 - loss: 0.3822 - val_accuracy: 0.7515 - val_loss: 0.5066
Epoch 7/100
[1m30/30[0m [



Epoch 1/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.6037 - loss: 0.6905 - val_accuracy: 0.7479 - val_loss: 0.6754
Epoch 2/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7726 - loss: 0.6595 - val_accuracy: 0.7527 - val_loss: 0.6216
Epoch 3/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7768 - loss: 0.5848 - val_accuracy: 0.7515 - val_loss: 0.5471
Epoch 4/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.7875 - loss: 0.5013 - val_accuracy: 0.7515 - val_loss: 0.5075
Epoch 5/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7933 - loss: 0.4558 - val_accuracy: 0.7539 - val_loss: 0.5004
Epoch 6/100
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7977 - loss: 0.4408 - val_accuracy: 0.7527 - val_loss: 0.5000
Epoch 7/100
[1m30/30[0m [

20 random balanced test set:

In [18]:
from sklearn.preprocessing import PowerTransformer

#20 randonly balanced test set
test = pd.read_csv('../Embeddings/Prot_t5/test_t5_pssm.csv')
test_pos = test[test['label'] == 1]
test_neg = test[test['label'] == 0]

accuracies = []
mccs = []
aucs = []
auprcs = []
precisions = []
recalls = []
specificities = []
f1s = []

for i in range(20):
    test_neg = test_neg.sample(n=253, random_state=np.random.randint(0, 10000))

    test= pd.concat([test_pos, test_neg], axis=0)


    X_test_embeddings = test['embedding'].apply(lambda x: np.array([float(i) for i in x.strip('[]').split()]))
    X_test_PSSM = test['PSSM'].apply(lambda x: np.array([float(i) for i in x.strip("[]").split()]))
    X_test_PSSM = np.stack(X_test_PSSM.values)
    X_test_embeddings = np.stack(X_test_embeddings.values)
    scaler = PowerTransformer()
    X_test_PSSM = scaler.fit_transform(X_test_PSSM)
    X_test = test['sequence'].values
    y_test = test['label'].values

    # Create a dictionary to map amino acids to integers
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY-'
    aa_to_int = {aa: i for i, aa in enumerate(amino_acids)}
    X_test_num = [[aa_to_int[aa] for aa in seq] for seq in X_test]
    X_test_num = np.array(X_test_num)

    # Evaluate the models
    acc, mcc, auc, auprc, precision, recall, specificity, f1 = evaluate_model(combined_model, X_test_num, X_test_embeddings, y_val=y_test)
    accuracies.append(acc)
    mccs.append(mcc)
    aucs.append(auc)
    auprcs.append(auprc)
    precisions.append(precision)
    recalls.append(recall)
    specificities.append(specificity)
    f1s.append(f1)

print("Results for Residual Model:")
print(f'Accuracy: {np.mean(accuracies)} +- {np.std(accuracies)}')
print(f'MCC: {np.mean(mccs)} +- {np.std(mccs)}')
print(f'AUC: {np.mean(aucs)} +- {np.std(aucs)}')
print(f'AUPRC: {np.mean(auprcs)} +- {np.std(auprcs)}')
print(f'Precision: {np.mean(precisions)} +- {np.std(precisions)}')
print(f'Recall: {np.mean(recalls)} +- {np.std(recalls)}')
print(f'Specificity: {np.mean(specificities)} +- {np.std(specificities)}')
print(f'F1: {np.mean(f1s)} +- {np.std(f1s)}')   

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Accuracy: 0.7865612648221344
MCC: 0.573409265656776
AUC: 0.7865612648221343
AUPRC: 0.8301515093010355
Precision: 0.7777777777777778
Recall: 0.8023715415019763
Specificity: 0.7707509881422925
F1: 0.7898832684824902
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Accuracy: 0.7865612648221344
MCC: 0.573409265656776
AUC: 0.7865612648221343
AUPRC: 0.8301515093010355
Precision: 0.7777777777777778
Recall: 0.8023715415019763
Specificity: 0.7707509881422925
F1: 0.7898832684824902
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Accuracy: 0.7865612648221344
MCC: 0.573409265656776
AUC: 0.7865612648221343
AUPRC: 0.8301515093010355
Precision: 0.7777777777777778
Recall: 0.8023715415019763
Specificity: 0.7707509881422925
F1: 0.7898832684824902
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Accuracy: 0.7865612648221344
MCC: 0.573409265656776
AUC: 0.786