Tutorial 2: Building a Simple Model Using a 2-Encoder + MLP Structure

In [1]:
import os
os.chdir('../')

from FlexMol.dataset.loader import load_DAVIS
from FlexMol.encoder import FlexMol
from FlexMol.task import BinaryTrainer

In [2]:
# Load the DAVIS dataset
# We are using a subset of the data (first 50 rows) for demonstration purposes
train = load_DAVIS("data/DAVIS/train.txt").head(50)
val = load_DAVIS("data/DAVIS/val.txt").head(50)
test = load_DAVIS("data/DAVIS/test.txt").head(50)

In [3]:
# Initialize FlexMol instance
FM = FlexMol()
# Initialize drug and protein encoders
drug_encoder = FM.init_drug_encoder("GCN")
protein_encoder = FM.init_prot_encoder("AAC")
# Concatenate the encoders' outputs
combined_output = FM.cat([drug_encoder, protein_encoder])
# Apply a Multi-Layer Perceptron (MLP) to the concatenated outputs
output = FM.apply_mlp(combined_output, head=1)
# Build the model
FM.build_model()

In [4]:
# Set up the trainer with specified parameters and metrics
trainer = BinaryTrainer(
    FM,
    task = "DTI",
    test_metrics=["accuracy", "precision", "recall", "f1"],
    device="cpu",
    early_stopping="roc-auc",
    epochs=30,
    patience=10,
    lr=0.0001,
    batch_size=128
)
# Prepare the datasets for training, validation, and testing
train_data, val_data, test_data = trainer.prepare_datasets(train_df=train, val_df=val, test_df=test)

In [5]:
# Train the model
trainer.train(train_data, val_data)

Start training...


Epoch 0:   0%|          | 0/1 [00:00<?, ?batch/s]

Epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.11s/batch, loss=0.664]


Epoch: 0 	Training Loss: 0.663608
Epoch: 0 	Validation Loss: 0.653195
Epoch: 0 	Validation roc-auc: 0.5957


Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 11.10batch/s, loss=0.658]


Epoch: 1 	Training Loss: 0.658224
Epoch: 1 	Validation Loss: 0.624172
Epoch: 1 	Validation roc-auc: 0.4965


Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 11.93batch/s, loss=0.641]


Epoch: 2 	Training Loss: 0.640507
Epoch: 2 	Validation Loss: 0.613853
Epoch: 2 	Validation roc-auc: 0.3759


Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 10.70batch/s, loss=0.639]


Epoch: 3 	Training Loss: 0.638680
Epoch: 3 	Validation Loss: 0.617383
Epoch: 3 	Validation roc-auc: 0.3546


Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 12.06batch/s, loss=0.638]


Epoch: 4 	Training Loss: 0.638139
Epoch: 4 	Validation Loss: 0.625347
Epoch: 4 	Validation roc-auc: 0.3191


Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 12.16batch/s, loss=0.638]


Epoch: 5 	Training Loss: 0.637545
Epoch: 5 	Validation Loss: 0.641414
Epoch: 5 	Validation roc-auc: 0.3333


Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 11.63batch/s, loss=0.636]


Epoch: 6 	Training Loss: 0.636284
Epoch: 6 	Validation Loss: 0.669903
Epoch: 6 	Validation roc-auc: 0.3546


Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 11.79batch/s, loss=0.636]


Epoch: 7 	Training Loss: 0.635997
Epoch: 7 	Validation Loss: 0.713974
Epoch: 7 	Validation roc-auc: 0.3546


Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 11.87batch/s, loss=0.636]


Epoch: 8 	Training Loss: 0.635855
Epoch: 8 	Validation Loss: 0.764850
Epoch: 8 	Validation roc-auc: 0.3759


Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 12.12batch/s, loss=0.635]


Epoch: 9 	Training Loss: 0.635422
Epoch: 9 	Validation Loss: 0.826206
Epoch: 9 	Validation roc-auc: 0.3617


Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 12.23batch/s, loss=0.635]

Epoch: 10 	Training Loss: 0.634938
Epoch: 10 	Validation Loss: 0.897588
Epoch: 10 	Validation roc-auc: 0.3546
Early stopping triggered after 10 epochs.





In [8]:
trainer.test(test_data)

Start testing...
Test Loss: 0.978239
accuracy: 0.060000
precision: 0.040816
recall: 1.000000
f1: 0.078431


In [10]:
# Perform inference on the test data using the trained model
# This returns the total loss, all true labels, and all model predictions
# Note: You can now compute custom metrics using 'all_labels' and 'all_predictions'
total_loss, all_labels, all_predictions = trainer.inference(trainer.create_loader(test_data))


In [13]:
print("Ground Truth Lables:")
print(all_labels)
print("All_Predictions:")
print(all_predictions)

Ground Truth Lables:
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
All_Predictions:
[0.6218882, 0.6287671, 0.6197413, 0.5495148, 0.643523, 0.62342805, 0.5295856, 0.4806529, 0.5538527, 0.63042957, 0.5538419, 0.57894975, 0.72568244, 0.6414241, 0.6840686, 0.56393874, 0.73658043, 0.64223135, 0.61184186, 0.5978523, 0.5393544, 0.74992156, 0.62433743, 0.6433823, 0.52405965, 0.6406623, 0.5763875, 0.57649946, 0.59156555, 0.5964469, 0.57781327, 0.5427243, 0.6857456, 0.60090387, 0.5776537, 0.59034, 0.70655006, 0.7228107, 0.6243686, 0.5771154, 0.70688933, 0.57815194, 0.6854482, 0.70658773, 0.54950184, 0.72565895, 0.59068215, 0.6951169, 0.71074164, 0.7069677]
