Tutorial 2: Building a Simple Model Using a 2-Encoder + MLP Structure

In [1]:
import os
os.chdir('../')

from FlexMol.dataset.loader import load_DAVIS
from FlexMol.encoder import FlexMol
from FlexMol.task import BinaryTrainer

In [2]:
# Load the DAVIS dataset
# We are using a subset of the data (first 50 rows) for demonstration purposes
train = load_DAVIS("data/DAVIS/train.txt").head(50)
val = load_DAVIS("data/DAVIS/val.txt").head(50)
test = load_DAVIS("data/DAVIS/test.txt").head(50)

In [3]:
# Initialize FlexMol instance
FM = FlexMol()
# Initialize drug and protein encoders
drug_encoder = FM.init_drug_encoder("GCN")
protein_encoder = FM.init_prot_encoder("AAC")
# Concatenate the encoders' outputs
combined_output = FM.cat([drug_encoder, protein_encoder])
# Apply a Multi-Layer Perceptron (MLP) to the concatenated outputs
output = FM.apply_mlp(combined_output, head=1)
# Build the model
FM.build_model()

In [4]:
# Set up the trainer with specified parameters and metrics
trainer = BinaryTrainer(
    FM,
    task = "DTI",
    test_metrics=["accuracy", "precision", "recall", "f1"],
    device="cpu",
    early_stopping="roc-auc",
    epochs=30,
    patience=10,
    lr=0.0001,
    batch_size=128
)
# Prepare the datasets for training, validation, and testing
train_data, val_data, test_data = trainer.prepare_datasets(train_df=train, val_df=val, test_df=test)

In [5]:
# Train the model
trainer.train(train_data, val_data)

Start training...


Epoch 0: 100%|██████████| 1/1 [00:02<00:00,  2.00s/batch, loss=0.672]


Epoch: 0 	Training Loss: 0.671772
Epoch: 0 	Validation Loss: 2.133007
Epoch: 0 	Validation roc-auc: 0.7872


Epoch 1: 100%|██████████| 1/1 [00:01<00:00,  1.59s/batch, loss=0.652]


Epoch: 1 	Training Loss: 0.652205
Epoch: 1 	Validation Loss: 1.683086
Epoch: 1 	Validation roc-auc: 0.7801


Epoch 2: 100%|██████████| 1/1 [00:01<00:00,  1.69s/batch, loss=0.643]


Epoch: 2 	Training Loss: 0.642842
Epoch: 2 	Validation Loss: 1.383644
Epoch: 2 	Validation roc-auc: 0.7660


Epoch 3: 100%|██████████| 1/1 [00:01<00:00,  1.19s/batch, loss=0.641]


Epoch: 3 	Training Loss: 0.641445
Epoch: 3 	Validation Loss: 1.163450
Epoch: 3 	Validation roc-auc: 0.7305


Epoch 4: 100%|██████████| 1/1 [00:01<00:00,  1.69s/batch, loss=0.641]


Epoch: 4 	Training Loss: 0.641228
Epoch: 4 	Validation Loss: 1.007522
Epoch: 4 	Validation roc-auc: 0.6738


Epoch 5: 100%|██████████| 1/1 [00:01<00:00,  1.74s/batch, loss=0.641]


Epoch: 5 	Training Loss: 0.641124
Epoch: 5 	Validation Loss: 0.912583
Epoch: 5 	Validation roc-auc: 0.5816


Epoch 6: 100%|██████████| 1/1 [00:01<00:00,  1.86s/batch, loss=0.641]


Epoch: 6 	Training Loss: 0.640793
Epoch: 6 	Validation Loss: 0.864285
Epoch: 6 	Validation roc-auc: 0.5745


Epoch 7: 100%|██████████| 1/1 [00:01<00:00,  1.51s/batch, loss=0.64]


Epoch: 7 	Training Loss: 0.640439
Epoch: 7 	Validation Loss: 0.829402
Epoch: 7 	Validation roc-auc: 0.5816


Epoch 8: 100%|██████████| 1/1 [00:01<00:00,  1.74s/batch, loss=0.64]


Epoch: 8 	Training Loss: 0.640099
Epoch: 8 	Validation Loss: 0.805460
Epoch: 8 	Validation roc-auc: 0.6099


Epoch 9: 100%|██████████| 1/1 [00:01<00:00,  1.63s/batch, loss=0.64]


Epoch: 9 	Training Loss: 0.639802
Epoch: 9 	Validation Loss: 0.795091
Epoch: 9 	Validation roc-auc: 0.4894


Epoch 10: 100%|██████████| 1/1 [00:01<00:00,  1.39s/batch, loss=0.64]


Epoch: 10 	Training Loss: 0.639539
Epoch: 10 	Validation Loss: 0.793096
Epoch: 10 	Validation roc-auc: 0.3404
Early stopping triggered after 10 epochs.


In [6]:
trainer.test(test_data)

Start testing...
Test Loss: 0.800313
accuracy: 0.080000
precision: 0.041667
recall: 1.000000
f1: 0.080000
