This notebook utilizes ProteinBERT to create embeddings for protein sequences and classify them into seven distinct families. The embeddings are then used to train a classifier, significantly improving performance compared to previous methods.

In [None]:
!pip install git+https://github.com/nadavbra/protein_bert.git

Collecting git+https://github.com/nadavbra/protein_bert.git
  Cloning https://github.com/nadavbra/protein_bert.git to /tmp/pip-req-build-l3aejpns
  Running command git clone --filter=blob:none --quiet https://github.com/nadavbra/protein_bert.git /tmp/pip-req-build-l3aejpns
  Resolved https://github.com/nadavbra/protein_bert.git to commit 168a4db5aac281ff14165d00e50f862d780a8966
  Running command git submodule update --init --recursive -q
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
import os
import pandas as pd
from IPython.display import display
from tensorflow import keras
from sklearn.model_selection import train_test_split
from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len, log
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs


In [None]:
BENCHMARK_NAME = 'scop'
OUTPUT_TYPE = OutputType(False, 'categorical')

In [None]:
settings = {
    'max_dataset_size': None,
    'max_epochs_per_stage': 40,
    'seq_len': 512,
    'batch_size': 32,
    'final_epoch_seq_len': 1024,
    'initial_lr_with_frozen_pretrained_layers': 1e-02,
    'initial_lr_with_all_layers': 1e-04,
    'final_epoch_lr': 1e-05,
    'dropout_rate': 0.5,
    'training_callbacks': [
        keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1),
        keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True),
    ],
}

In [None]:
def load_benchmark_dataset():
    train_set_url = 'https://raw.githubusercontent.com/nadavbra/protein_bert/master/protein_benchmarks/scop.train.csv'
    test_set_url = 'https://raw.githubusercontent.com/nadavbra/protein_bert/master/protein_benchmarks/scop.test.csv'

    train_set = pd.read_csv(train_set_url).dropna().drop_duplicates()
    test_set = pd.read_csv(test_set_url).dropna().drop_duplicates()

    train_set, valid_set = train_test_split(train_set, stratify=train_set['label'], test_size=0.1, random_state=0)

    return train_set, valid_set, test_set


In [None]:
train_set, valid_set, test_set=load_benchmark_dataset()

In [None]:
train_set


Unnamed: 0,seq,label
7511,LPASIFRAYDIRGVVGDTLTAETAYWIGRAIGSESLARGEPCVAVG...,c
4727,LVPRGSHMNTSELRICRINKESGPCTGGEELYLLCDKVQKEDISVV...,b
6950,MNPDLRKERASATFNPELITHILDGSPENTRRRREIENLILNDPDF...,e
2250,GYFDAHALAMDYRSLGFRECLAEVARYLSIIEGLDASDPLRVRLVS...,a
8075,AKNVVLDHDGNLDDFVAMVLLASNTEKVRLIGALCTDADCFVENGF...,c
...,...,...
7600,PDVKCVCCTEGKECACFGQDCCVTGECCKDGTCCGI,g
8088,MKIISKEFTVKTRSRFDSIDITEQVSEAIKGINNGIAHVIVKHTTC...,d
6254,PHDPLDDIQADPWALWLSGRTRTALELPHYRRAAVLVALTREADPR...,d
15448,MFGNLQGKFIIATPEMDDEYFDRTVIYICEHNDNGTIGVIINTPTD...,d


In [None]:
valid_set

Unnamed: 0,seq,label
11559,QPLEGYTLFSHRSAPNGFKVAIVLSELGFHYNTIFLDFNLGEHRAP...,c
11712,GICAPFTIPDVALEPGQQVTVPVAVTNQSGIAVPKPSLQLDASPDW...,b
10191,DDRRDALLERINLDIPAAVAQALREDLGGEVDAGNDITAQLLPADT...,d
1721,DIGQVIHPDDFDKAAADDYVLHEDGEKIYFLIKSKTDEYCFTNLAL...,b
5742,GPLSLSVDAFKILEDPKWEFPRKNLVLGKTLGEGEFGKVVKATAFH...,d
...,...,...
15461,EVVASNETLYQVVKEVKPGGLVQIADGTYKDVQLIVSNSGKSGLPI...,b
14052,QKSVLEQLKQVTMVVADTGDFELIKKYKPVDATTNPSLILKAVKEQ...,c
6884,SKMPQVNLRWPREVLDLVRKVAEENGRSVNSEIYQRVMESFKKEGRIGA,a
13128,PRLVALVKGRVQGVGYRAFAQKKALELGLSGYAENLPDGRVEVVAE...,d


In [None]:
test_set

Unnamed: 0,seq,label
0,DPMTCEQAMASCEHTMCGYCQGPLYMTCIGITTDPECGLP,a
1,GSDKIHHHHHHMNIFEAIENRHSVRDFLERKMPERVKDDIENLLVK...,d
2,MPDLNSSTDSAASASAASDVSVESTAEATVCTVTLEKMSAGLGFSL...,b
3,MDFHIRKATNSDAEAIQHVATTSWHHTYQDLIPSDVQDDFLKRFYN...,d
4,DQSGYERGLTLPLRHPSGLFDGETEAVWGLNTAYSVVEKSVSTRDY...,b
...,...,...
3916,MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVG...,d
3917,MFEARLVQGSILKKVLEALKDLINEACWDISSSGVNLQSMDSSHVS...,d
3918,MEGEIDIAKRIEDGINQVQCSVAEYPEAITYLLEQYNRVEAEEARL...,a
3919,KESCKMFIGGLNWDTTEDNLREYFGKYGTVTDLKIMKDPATGRSRG...,d


Model

In [None]:
def run_benchmark(benchmark_name, pretraining_model_generator, input_encoder, pretraining_model_manipulation_function = None):

    log('========== %s ==========' % benchmark_name)

    output_type = OutputType(False, 'categorical')
    log('Output type: %s' % output_type)

    train_set, valid_set, test_set = load_benchmark_dataset()
    log(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')

    if settings['max_dataset_size'] is not None:
        log('Limiting the training, validation and test sets to %d records each.' % settings['max_dataset_size'])
        train_set = train_set.sample(min(settings['max_dataset_size'], len(train_set)), random_state = 0)
        valid_set = valid_set.sample(min(settings['max_dataset_size'], len(valid_set)), random_state = 0)
        test_set = test_set.sample(min(settings['max_dataset_size'], len(test_set)), random_state = 0)

    if output_type.is_seq or output_type.is_categorical:
        train_set['label'] = train_set['label'].astype(str)
        valid_set['label'] = valid_set['label'].astype(str)
        test_set['label'] = test_set['label'].astype(str)
    else:
        train_set['label'] = train_set['label'].astype(float)
        valid_set['label'] = valid_set['label'].astype(float)
        test_set['label'] = test_set['label'].astype(float)

    if output_type.is_categorical:

        if output_type.is_seq:
            unique_labels = sorted(set.union(*train_set['label'].apply(set)) | set.union(*valid_set['label'].apply(set)) | \
                    set.union(*test_set['label'].apply(set)))
        else:
            unique_labels = sorted(set(train_set['label'].unique()) | set(valid_set['label'].unique()) | set(test_set['label'].unique()))

        log('%d unique lebels.' % len(unique_labels))
    elif output_type.is_binary:
        unique_labels = [0, 1]
    else:
        unique_labels = None

    output_spec = OutputSpec(output_type, unique_labels)
    model_generator = FinetuningModelGenerator(pretraining_model_generator, output_spec, pretraining_model_manipulation_function = \
            pretraining_model_manipulation_function, dropout_rate = settings['dropout_rate'])
    finetune(model_generator, input_encoder, output_spec, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], \
            seq_len = settings['seq_len'], batch_size = settings['batch_size'], max_epochs_per_stage = settings['max_epochs_per_stage'], \
            lr = settings['initial_lr_with_all_layers'], begin_with_frozen_pretrained_layers = True, lr_with_frozen_pretrained_layers = \
            settings['initial_lr_with_frozen_pretrained_layers'], n_final_epochs = 1, final_seq_len = settings['final_epoch_seq_len'], \
            final_lr = settings['final_epoch_lr'], callbacks = settings['training_callbacks'])

    for dataset_name, dataset in [('Training-set', train_set), ('Validation-set', valid_set), ('Test-set', test_set)]:

        log('*** %s performance: ***' % dataset_name)
        results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, output_spec, dataset['seq'], dataset['label'], \
                start_seq_len = settings['seq_len'], start_batch_size = settings['batch_size'])

        with pd.option_context('display.max_rows', None, 'display.max_columns', None):
            display(results)

        if confusion_matrix is not None:
            with pd.option_context('display.max_rows', 16, 'display.max_columns', 10):
                log('Confusion matrix:')
                display(confusion_matrix)

    return model_generator

In [None]:
pretrained_model_generator, input_encoder = load_pretrained_model()


In [None]:
run_benchmark('scop', pretrained_model_generator, input_encoder, pretraining_model_manipulation_function=get_model_with_hidden_layers_as_outputs)


[2024_07_19-11:48:47] Output type: global categorical
[2024_07_19-11:48:47] 14112 training set records, 1568 validation set records, 3921 test set records.
[2024_07_19-11:48:47] 7 unique lebels.
[2024_07_19-11:48:47] Training set: Filtered out 6 of 14112 (0.0%) records of lengths exceeding 510.
[2024_07_19-11:48:48] Validation set: Filtered out 1 of 1568 (0.1%) records of lengths exceeding 510.
[2024_07_19-11:48:48] Training with frozen pretrained layers...
Epoch 1/40



Epoch 2/40



Epoch 3/40



Epoch 4/40




Epoch 4: ReduceLROnPlateau reducing learning rate to 0.0024999999441206455.
Epoch 5/40



Epoch 6/40




Epoch 6: ReduceLROnPlateau reducing learning rate to 0.0006249999860301614.
Epoch 7/40



Epoch 8/40



Epoch 9/40



Epoch 10/40




Epoch 10: ReduceLROnPlateau reducing learning rate to 0.00015624999650754035.
Epoch 11/40



Epoch 12/40



Epoch 13/40



Epoch 14/40




Epoch 14: ReduceLROnPlateau reducing learning rate to 3.9062499126885086e-05.
Epoch 15/40



Epoch 16/40



Epoch 17/40




Epoch 17: ReduceLROnPlateau reducing learning rate to 1e-05.
Epoch 18/40



Epoch 19/40



Epoch 20/40



[2024_07_19-12:00:29] Training the entire fine-tuned model...
[2024_07_19-12:00:57] Incompatible number of optimizer weights - will not initialize them.
Epoch 1/40



Epoch 2/40




Epoch 2: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 3/40




Epoch 3: ReduceLROnPlateau reducing learning rate to 1e-05.
[2024_07_19-12:05:14] Training on final epochs of sequence length 1024...
[2024_07_19-12:05:14] Training set: Filtered out 0 of 14112 (0.0%) records of lengths exceeding 1022.
[2024_07_19-12:05:16] Validation set: Filtered out 0 of 1568 (0.0%) records of lengths exceeding 1022.



[2024_07_19-12:09:11] *** Training-set performance: ***


Unnamed: 0_level_0,# records,Accuracy
Model seq len,Unnamed: 1_level_1,Unnamed: 2_level_1
512,14106,0.948745
1024,6,1.0
All,14112,0.948767


[2024_07_19-12:10:01] Confusion matrix:


Unnamed: 0,a,b,c,d,e,f,g
a,2242,6,25,75,3,1,2
b,4,2814,22,115,4,1,1
c,10,7,4134,42,4,0,0
d,41,171,144,2988,14,2,1
e,2,1,3,1,254,0,0
f,0,0,0,0,0,196,0
g,3,7,1,10,0,0,761


[2024_07_19-12:10:01] *** Validation-set performance: ***


Unnamed: 0_level_0,# records,Accuracy
Model seq len,Unnamed: 1_level_1,Unnamed: 2_level_1
512,1567,0.90619
1024,1,1.0
All,1568,0.90625


[2024_07_19-12:10:13] Confusion matrix:


Unnamed: 0,a,b,c,d,e,f,g
a,248,0,4,9,1,0,0
b,0,306,1,19,2,1,0
c,2,2,448,11,2,1,0
d,4,27,35,304,3,0,0
e,2,0,4,3,20,0,0
f,3,4,0,0,0,15,0
g,3,2,0,1,1,0,80


[2024_07_19-12:10:13] *** Test-set performance: ***


Unnamed: 0_level_0,# records,Accuracy
Model seq len,Unnamed: 1_level_1,Unnamed: 2_level_1
512,3919,0.88594
1024,2,0.5
All,3921,0.885743


[2024_07_19-12:10:30] Confusion matrix:


Unnamed: 0,a,b,c,d,e,f,g
a,585,2,21,37,6,0,3
b,1,731,21,62,4,2,2
c,3,9,1124,27,3,0,0
d,22,91,80,732,5,3,0
e,4,4,4,4,56,1,0
f,4,2,0,1,0,47,1
g,6,5,0,8,0,0,198


<proteinbert.model_generation.FinetuningModelGenerator at 0x7dc9514aae30>

In [None]:
log('Done.')


[2024_07_19-12:10:30] Done.


In [None]:
import pandas as pd

train_counts = train_set['label'].value_counts()
valid_counts = valid_set['label'].value_counts()
test_counts = test_set['label'].value_counts()

print("Train Set Value Counts:\n", train_counts)
print("\nValidation Set Value Counts:\n", valid_counts)
print("\nTest Set Value Counts:\n", test_counts)


Train Set Value Counts:
 label
c    4197
d    3361
b    2961
a    2354
g     782
e     261
f     196
Name: count, dtype: int64

Validation Set Value Counts:
 label
c    466
d    373
b    329
a    262
g     87
e     29
f     22
Name: count, dtype: int64

Test Set Value Counts:
 label
c    1166
d     933
b     823
a     654
g     217
e      73
f      55
Name: count, dtype: int64


ProteinBERT provides a substantial boost in classification accuracy, achieving 0.88 compared to the 0.29 accuracy obtained with ProtT5. This improvement underscores ProteinBERT's effectiveness and suitability for protein family classification, demonstrating its advantage in capturing the complex patterns in protein sequences.