Tutorial4: Building a Model with Cross-Attention Interaction Using Two Transformer Encoders

In [1]:
import os
os.chdir('../')

from FlexMol.dataset.loader import load_DAVIS
from FlexMol.encoder import FlexMol
from FlexMol.task import BinaryTrainer

In [6]:
# Load the DAVIS dataset
# We are using a subset of the data (first 20 rows) for demonstration purposes
train = load_DAVIS("data/DAVIS/train.txt").head(20)
val = load_DAVIS("data/DAVIS/val.txt").head(20)
test = load_DAVIS("data/DAVIS/test.txt").head(20)

In [4]:
# Initialize FlexMol instance
FM = FlexMol()

# Initialize two Transformer encoders for drugs and proteins
# Drug encoder: Transformer
drug_encoder_transformer = FM.init_drug_encoder("Transformer", pooling=False)

# Protein encoder: Transformer
protein_encoder_transformer = FM.init_prot_encoder("Transformer", pooling=False)

# Set up the cross_attention interaction layer
# The cross_attention interaction layer takes 2D embeddings as input
# It will fail if pooling is set to True
interaction_output = FM.set_interaction(
    [drug_encoder_transformer, protein_encoder_transformer], 
    "cross_attention"
)
# Apply a Multi-Layer Perceptron (MLP) to the interaction output
output = FM.apply_mlp(interaction_output, head=1)
# Build the model
FM.build_model()

In [7]:
trainer = BinaryTrainer(
    FM,
    task="DTI",
    test_metrics=["accuracy", "precision", "recall", "f1"],
    device="cpu",
    early_stopping="roc-auc",
    epochs=30,
    patience=10,
    lr=0.0001,
    batch_size=128
)
# Prepare the datasets for training, validation, and testing
train_data, val_data, test_data = trainer.prepare_datasets(train_df=train, val_df=val, test_df=test)
# Train the model
trainer.train(train_data, val_data)
# Test the model
trainer.test(test_data)

Start training...


Epoch 0: 100%|██████████| 1/1 [00:13<00:00, 13.56s/batch, loss=0.695]


Epoch: 0 	Training Loss: 0.694806
Epoch: 0 	Validation Loss: 0.786196
Epoch: 0 	Validation roc-auc: 0.5000


Epoch 1: 100%|██████████| 1/1 [00:14<00:00, 14.05s/batch, loss=0.684]


Epoch: 1 	Training Loss: 0.684389
Epoch: 1 	Validation Loss: 0.772495
Epoch: 1 	Validation roc-auc: 0.5000


Epoch 2: 100%|██████████| 1/1 [00:15<00:00, 15.37s/batch, loss=0.679]


Epoch: 2 	Training Loss: 0.679326
Epoch: 2 	Validation Loss: 0.756629
Epoch: 2 	Validation roc-auc: 0.5000


Epoch 3: 100%|██████████| 1/1 [00:14<00:00, 14.64s/batch, loss=0.664]


Epoch: 3 	Training Loss: 0.664419
Epoch: 3 	Validation Loss: 0.743121
Epoch: 3 	Validation roc-auc: 0.5278


Epoch 4: 100%|██████████| 1/1 [00:13<00:00, 13.77s/batch, loss=0.679]


Epoch: 4 	Training Loss: 0.678608
Epoch: 4 	Validation Loss: 0.727770
Epoch: 4 	Validation roc-auc: 0.4722


Epoch 5: 100%|██████████| 1/1 [00:14<00:00, 14.86s/batch, loss=0.68]


Epoch: 5 	Training Loss: 0.679626
Epoch: 5 	Validation Loss: 0.716419
Epoch: 5 	Validation roc-auc: 0.4722


Epoch 6: 100%|██████████| 1/1 [00:12<00:00, 12.85s/batch, loss=0.694]


Epoch: 6 	Training Loss: 0.694460
Epoch: 6 	Validation Loss: 0.703493
Epoch: 6 	Validation roc-auc: 0.4722


Epoch 7: 100%|██████████| 1/1 [00:13<00:00, 13.97s/batch, loss=0.668]


Epoch: 7 	Training Loss: 0.668321
Epoch: 7 	Validation Loss: 0.689191
Epoch: 7 	Validation roc-auc: 0.4444


Epoch 8: 100%|██████████| 1/1 [00:12<00:00, 12.16s/batch, loss=0.672]


Epoch: 8 	Training Loss: 0.671593
Epoch: 8 	Validation Loss: 0.679974
Epoch: 8 	Validation roc-auc: 0.4167


Epoch 9: 100%|██████████| 1/1 [00:12<00:00, 12.40s/batch, loss=0.698]


Epoch: 9 	Training Loss: 0.697948
Epoch: 9 	Validation Loss: 0.667212
Epoch: 9 	Validation roc-auc: 0.3889


Epoch 10: 100%|██████████| 1/1 [00:13<00:00, 13.06s/batch, loss=0.681]


Epoch: 10 	Training Loss: 0.681225
Epoch: 10 	Validation Loss: 0.656406
Epoch: 10 	Validation roc-auc: 0.4167


Epoch 11: 100%|██████████| 1/1 [00:14<00:00, 14.98s/batch, loss=0.679]


Epoch: 11 	Training Loss: 0.678944
Epoch: 11 	Validation Loss: 0.645484
Epoch: 11 	Validation roc-auc: 0.3611


Epoch 12: 100%|██████████| 1/1 [00:15<00:00, 15.56s/batch, loss=0.664]


Epoch: 12 	Training Loss: 0.663980
Epoch: 12 	Validation Loss: 0.637467
Epoch: 12 	Validation roc-auc: 0.4167


Epoch 13: 100%|██████████| 1/1 [00:11<00:00, 11.59s/batch, loss=0.685]


Epoch: 13 	Training Loss: 0.684512
Epoch: 13 	Validation Loss: 0.629846
Epoch: 13 	Validation roc-auc: 0.3889
Early stopping triggered after 13 epochs.
Start testing...
Test Loss: 0.622106
accuracy: 0.950000
precision: 0.000000
recall: 0.000000
f1: 0.000000


  _warn_prf(average, modifier, msg_start, len(result))
