Tutorial6: Building a more Complex Model with Self-Attention Interaction

In [3]:
import os
os.chdir('../')

from FlexMol.dataset.loader import load_DAVIS
from FlexMol.encoder import FlexMol
from FlexMol.task import BinaryTrainer

In [4]:
# Load the DAVIS dataset
# We are using a subset of the data (first 20 rows) for demonstration purposes
train_df = load_DAVIS("data/DAVIS/train.txt").head(20)
val_df = load_DAVIS("data/DAVIS/val.txt").head(20)
test_df = load_DAVIS("data/DAVIS/test.txt").head(20)

In [6]:
# Initialize FlexMol instance
FM = FlexMol()

# Initialize drug encoder: GCN
# Setting output features to 128
drug_encoder = FM.init_drug_encoder("GCN", output_feats=128)

# Initialize protein pocket encoder: PocketDC
# Disable pooling  and the output shape will be 30 * 128
pocket_encoder = FM.init_prot_encoder("PocketDC", pdb=True, data_dir="data/DAVIS/pdb/", num_pockets=30, output_feats=128, pooling=False)

# Initialize protein encoder: GCN_ESM
# Setting output features to 128
protein_encoder = FM.init_prot_encoder("GCN_ESM", pdb=True, hidden_feats=[128, 128, 128], data_dir="data/DAVIS/pdb/", output_feats=128)

# Set up self-attention interaction layer
# Stack the encoders and apply self-attention interaction
# output shape will be 32 * 128
interaction_output = FM.set_interaction(FM.stack([drug_encoder, pocket_encoder, protein_encoder]), "self_attention")

# Select and flatten the drug and protein outputs that encapsulate information about the pockets
drug_final = FM.flatten(FM.select(interaction_output, index_start = 0))
protein_final = FM.flatten(FM.select(interaction_output, index_start = 31))

# Concatenate the final drug and protein outputs and apply MLP
final_output = FM.apply_mlp(FM.cat([drug_final, protein_final]), hidden_layers=[512, 512, 256], head=1)

# Build the model
FM.build_model()

In [None]:
trainer = BinaryTrainer(
    FM, 
    task="DTI", 
    early_stopping="roc-auc", 
    test_metrics=["roc-auc", "pr-auc"], 
    device="cpu", 
    epochs=25, 
    patience=6, 
    lr=0.0001, 
    batch_size=32, 
)

# Prepare the datasets for training, validation, and testing
train_data, val_data, test_data = trainer.prepare_datasets(train_df=train_df, val_df=val_df, test_df=test_df)

# Train the model
trainer.train(train_data, val_data)

# Test the model
trainer.test(test_data)