Tutorial5: Building a Model Using a Custom Featurizer and Encoder Layer

In [2]:
import os
os.chdir('../')

from FlexMol.dataset.loader import load_DAVIS
from FlexMol.encoder import FlexMol
from FlexMol.task import BinaryTrainer

from FlexMol.encoder.enc_layer import EncoderLayer
from FlexMol.encoder.featurizer import Featurizer

In [7]:

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# Custom Featurizer: MyMorganFeaturizer
# This class converts SMILES strings to Morgan fingerprints (bit vectors).
class MyMorganFeaturizer(Featurizer):
    def transform(self, s):
        mol = Chem.MolFromSmiles(s)
        features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
        features = np.zeros((1,))
        DataStructs.ConvertToNumpyArray(features_vec, features)
        return features

# Custom Encoder Layer: MyMLP
# This class defines a simple multi-layer perceptron (MLP) with configurable hidden layers.
class MyMLP(EncoderLayer):
    def __init__(self, input_dim=1024, output_dim=128, hidden_dims_lst=[1024, 256, 64]):
        super(MyMLP, self).__init__()
        layer_size = len(hidden_dims_lst) + 1
        dims = [input_dim] + hidden_dims_lst + [output_dim]
        self.output_shape = output_dim
        self.predictor = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(layer_size)])

    def get_output_shape(self):
        return self.output_shape

    def forward(self, v):
        for i, l in enumerate(self.predictor):
            v = F.relu(l(v))
        return v


In [5]:
train = load_DAVIS("data/DAVIS/train.txt").head(20)
val = load_DAVIS("data/DAVIS/val.txt").head(20)
test = load_DAVIS("data/DAVIS/test.txt").head(20)

# Initialize FlexMol instance
FM = FlexMol()

# Register a custom method with FlexMol
# This custom method uses MyMLP as the encoder layer and MyMorganFeaturizer as the featurizer for drug data
FM.register_method("drug", "my_method", MyMLP, MyMorganFeaturizer)

# Initialize drug and protein encoders
# Use the custom method "my_method" for the drug encoder and the default "AAC" method for the protein encoder
drug_encoder = FM.init_drug_encoder("my_method")
protein_encoder = FM.init_prot_encoder("AAC")

# Concatenate the outputs of the drug and protein encoders
combined_output = FM.cat([drug_encoder, protein_encoder])
output = FM.apply_mlp(combined_output, head=1)
FM.build_model()

In [8]:
trainer = BinaryTrainer(
    FM,
    task="DTI",
    test_metrics=["accuracy", "precision", "recall", "f1"],
    device="cpu",
    early_stopping="roc-auc",
    epochs=30,
    patience=10,
    lr=0.0001,
    batch_size=128
)

# Prepare the datasets for training, validation, and testing
train_data, val_data, test_data = trainer.prepare_datasets(train_df=train, val_df=val, test_df=test)

# Train the model
trainer.train(train_data, val_data)

# Test the model
trainer.test(test_data)

Start training...


Epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.06s/batch, loss=0.757]


Epoch: 0 	Training Loss: 0.757339
Epoch: 0 	Validation Loss: 0.661793
Epoch: 0 	Validation roc-auc: 0.5556


Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.43batch/s, loss=0.791]


Epoch: 1 	Training Loss: 0.791020
Epoch: 1 	Validation Loss: 0.662584
Epoch: 1 	Validation roc-auc: 0.4167


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  1.34batch/s, loss=0.737]


Epoch: 2 	Training Loss: 0.737460
Epoch: 2 	Validation Loss: 0.663926
Epoch: 2 	Validation roc-auc: 0.3611


Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  1.35batch/s, loss=0.728]


Epoch: 3 	Training Loss: 0.728215
Epoch: 3 	Validation Loss: 0.666695
Epoch: 3 	Validation roc-auc: 0.2778


Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  1.30batch/s, loss=0.723]


Epoch: 4 	Training Loss: 0.722957
Epoch: 4 	Validation Loss: 0.669520
Epoch: 4 	Validation roc-auc: 0.3056


Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  1.12batch/s, loss=0.723]


Epoch: 5 	Training Loss: 0.722743
Epoch: 5 	Validation Loss: 0.672304
Epoch: 5 	Validation roc-auc: 0.4167


Epoch 6: 100%|██████████| 1/1 [00:00<00:00,  1.21batch/s, loss=0.722]


Epoch: 6 	Training Loss: 0.722018
Epoch: 6 	Validation Loss: 0.675070
Epoch: 6 	Validation roc-auc: 0.4167


Epoch 7: 100%|██████████| 1/1 [00:00<00:00,  1.33batch/s, loss=0.721]


Epoch: 7 	Training Loss: 0.721421
Epoch: 7 	Validation Loss: 0.677907
Epoch: 7 	Validation roc-auc: 0.4444


Epoch 8: 100%|██████████| 1/1 [00:00<00:00,  1.66batch/s, loss=0.721]


Epoch: 8 	Training Loss: 0.721201
Epoch: 8 	Validation Loss: 0.680776
Epoch: 8 	Validation roc-auc: 0.4167


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  1.43batch/s, loss=0.721]


Epoch: 9 	Training Loss: 0.720855
Epoch: 9 	Validation Loss: 0.683647
Epoch: 9 	Validation roc-auc: 0.4167


Epoch 10: 100%|██████████| 1/1 [00:00<00:00,  1.66batch/s, loss=0.72]


Epoch: 10 	Training Loss: 0.720391
Epoch: 10 	Validation Loss: 0.686617
Epoch: 10 	Validation roc-auc: 0.4167
Early stopping triggered after 10 epochs.
Start testing...
Test Loss: 0.686417
accuracy: 0.950000
precision: 0.000000
recall: 0.000000
f1: 0.000000


  _warn_prf(average, modifier, msg_start, len(result))
