In [None]:
!pip install --use-deprecated legacy-resolver proteinbert-with-pretrain-model/Markdown-3.4.1-py3-none-any.whl" 2>/dev/null
!pip install --ignore-requires-python proteinbert-with-pretrain-model/google_auth-1.6.3-py2.py3-none-any.whl 2>/dev/null
!pip install --ignore-requires-python proteinbert-with-pretrain-model/tensorboard-2.6.0-py3-none-any.whl 2>/dev/null
!pip install --use-deprecated legacy-resolver proteinbert-with-pretrain-model/pyfaidx-0.7.1-py3-none-any.whl 2>/dev/null
!pip install --use-deprecated legacy-resolver proteinbert-with-pretrain-model/typing_extensions-3.10.0.2-py3-none-any.whl 2>/dev/null
!pip install --use-deprecated legacy-resolver proteinbert-with-pretrain-model/protein_bert-1.0.1-py3-none-any.whl 2>/dev/null

In [None]:
import os

import pandas as pd
from IPython.display import display

from tensorflow import keras

from sklearn.model_selection import train_test_split

from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len, log, conv_and_global_attention_model
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve,  confusion_matrix, f1_score, auc
import matplotlib.pyplot as plt
import numpy as np
from proteinbert.model_generation import load_pretrained_model_from_dump

In [None]:
BENCHMARKS_DIR = 'Dataset/biogrid_human/biogrid_human_dataset'
BENCHMARKS = [
    # name, output_type
    ('Biogrid_human_1', OutputType(False, 'binary')),
    ('Biogrid_human_2', OutputType(False, 'binary')),
    ('Biogrid_human_3', OutputType(False, 'binary')),
    ('Biogrid_human_4', OutputType(False, 'binary')),
    ('Biogrid_human_5', OutputType(False, 'binary')),
    ('Biogrid_human_6', OutputType(False, 'binary')),
    ('Biogrid_human_7', OutputType(False, 'binary')),
    ('Biogrid_human_8', OutputType(False, 'binary')),
    ('Biogrid_human_9', OutputType(False, 'binary')),
    ('Biogrid_human_10', OutputType(False, 'binary')),
]

settings = {
    'max_dataset_size': None,
    'max_epochs_per_stage': 40,
    'seq_len': 256,
    'batch_size': 32,
    'final_epoch_seq_len': 1024,
    'initial_lr_with_frozen_pretrained_layers': 1e-02,
    'initial_lr_with_all_layers': 1e-04,
    'final_epoch_lr': 1e-05,
    'dropout_rate': 0.5,
    'training_callbacks': [
        keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1),
        keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True),
    ],
}

####### Uncomment for debug mode
# settings['max_dataset_size'] = 500
# settings['max_epochs_per_stage'] = 1

def run_benchmark(benchmark_name, pretraining_model_generator, input_encoder, pretraining_model_manipulation_function = None):

    log('========== %s ==========' % benchmark_name)

    output_type = get_benchmark_output_type(benchmark_name)
    log('Output type: %s' % output_type)

    train_set, valid_set = load_benchmark_dataset(benchmark_name)
    test_set = pd.read_csv("Dataset/biogrid_human/biogrid_human_dataset/Biogrid_human.test.csv")
    log(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')

    if settings['max_dataset_size'] is not None:
        log('Limiting the training, validation and test sets to %d records each.' % settings['max_dataset_size'])
        train_set = train_set.sample(min(settings['max_dataset_size'], len(train_set)), random_state = 0)
        valid_set = valid_set.sample(min(settings['max_dataset_size'], len(valid_set)), random_state = 0)
        test_set = test_set.sample(min(settings['max_dataset_size'], len(test_set)), random_state = 0)

    if output_type.is_seq or output_type.is_categorical:
        train_set['label'] = train_set['label'].astype(str)
        valid_set['label'] = valid_set['label'].astype(str)
        test_set['label'] = test_set['label'].astype(str)
    else:
        train_set['label'] = train_set['label'].astype(float)
        valid_set['label'] = valid_set['label'].astype(float)
        test_set['label'] = test_set['label'].astype(float)

    if output_type.is_categorical:

        if output_type.is_seq:
            unique_labels = sorted(set.union(*train_set['label'].apply(set)) | set.union(*valid_set['label'].apply(set)) | \
                    set.union(*test_set['label'].apply(set)))
        else:
            unique_labels = sorted(set(train_set['label'].unique()) | set(valid_set['label'].unique()) | set(test_set['label'].unique()))

        log('%d unique lebels.' % len(unique_labels))
    elif output_type.is_binary:
        unique_labels = [0, 1]
    else:
        unique_labels = None

    output_spec = OutputSpec(output_type, unique_labels)


    model_generator = FinetuningModelGenerator(pretraining_model_generator, output_spec, pretraining_model_manipulation_function = \
            pretraining_model_manipulation_function, dropout_rate = settings['dropout_rate'])
    finetune(model_generator, input_encoder, output_spec, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], \
            seq_len = settings['seq_len'], batch_size = settings['batch_size'], max_epochs_per_stage = settings['max_epochs_per_stage'], \
            lr = settings['initial_lr_with_all_layers'], begin_with_frozen_pretrained_layers = True, lr_with_frozen_pretrained_layers = \
            settings['initial_lr_with_frozen_pretrained_layers'], n_final_epochs = 1, final_seq_len = settings['final_epoch_seq_len'], \
            final_lr = settings['final_epoch_lr'], callbacks = settings['training_callbacks'])

    for dataset_name, dataset in [('Training-set', train_set), ('Validation-set', valid_set), ('Test-set', test_set)]:

        log('*** %s performance: ***' % dataset_name)
        results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, output_spec, dataset['seq'], dataset['label'], \
                start_seq_len = settings['seq_len'], start_batch_size = settings['batch_size'])

        with pd.option_context('display.max_rows', None, 'display.max_columns', None):
            display(results)

        if confusion_matrix is not None:
            with pd.option_context('display.max_rows', 16, 'display.max_columns', 10):
                log('Confusion matrix:')
                display(confusion_matrix)
        X = input_encoder.encode_X(test_set['seq'], 256)
        model = model_generator.create_model(256)
        y_pred = model.predict(X, batch_size = 32)


        test_set['pred_label'] = y_pred.flatten()
        test_set[['Protein','label', 'pred_label']].to_csv(f'Result/{benchmark_name}.csv', index=False)

    return model_generator

def load_benchmark_dataset(benchmark_name):

    train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.train.csv' % benchmark_name)
    valid_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.valid.csv' % benchmark_name)
    #test_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.test.csv' % benchmark_name)

    train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()
    #test_set = pd.read_csv(test_set_file_path).dropna().drop_duplicates()

    if os.path.exists(valid_set_file_path):
        valid_set = pd.read_csv(valid_set_file_path).dropna().drop_duplicates()
    else:
        log(f'Validation set {valid_set_file_path} missing. Splitting training set instead.')
        train_set, valid_set = train_test_split(train_set, stratify = train_set['label'], test_size = 0.1, random_state = 0)

    return train_set, valid_set, #test_set

def get_benchmark_output_type(benchmark_name):
    for name, output_type in BENCHMARKS:
        if name == benchmark_name:
            return output_type

pretrained_model_generator, input_encoder = load_pretrained_model_from_dump(
    "proteinbert-with-pretrain-model/epoch_92400_sample_23500000.pkl",
    conv_and_global_attention_model.create_model)

for benchmark_name, _ in BENCHMARKS:
    run_benchmark(benchmark_name, pretrained_model_generator, input_encoder, pretraining_model_manipulation_function = \
            get_model_with_hidden_layers_as_outputs)

log('Done.')