In [37]:
prot = 'BRD4'
n_models = 1
ratio = 5

### Import Data

In [38]:
import duckdb
import pandas as pd
import numpy as np

In [39]:
train_path = './data/train.parquet'
con = duckdb.connect()
df = con.query(f"""
                    (SELECT *
                    FROM parquet_scan('{train_path}')
                    WHERE protein_name = '{prot}'
                    ORDER BY random()
                    )""").df()
con.close()

In [40]:
df

### Load Data to ChemProp

In [41]:
from lightning import pytorch as pl

from chemprop import data, featurizers, nn
from chemprop.models import multi
from chemprop.models.utils import save_model, load_model

In [42]:
def loadData(df_input, smiles_cols, w):
    
    target_cols = ['binds']

    smiss = df_input.loc[:, smiles_cols].values
    ys = df_input.loc[:,target_cols].values

    all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y, weight=(1-w)*y+w) for smis, y in zip(smiss, ys)]]
    all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_cols))]

    train_idx, val_idx, test_idx = data.make_split_indices(all_data[0], seed=42)
    train_data, val_data, test_data = data.split_data_by_indices(all_data, train_idx, val_idx, test_idx)

    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_cols))]
    val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_cols))]
    test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_cols))]

    train_mcdset = data.MulticomponentDataset(train_datasets)
    val_mcdset = data.MulticomponentDataset(val_datasets)
    test_mcdset = data.MulticomponentDataset(test_datasets)

    nworkers = 4
    train_loader = data.build_dataloader(train_mcdset, num_workers=nworkers)
    val_loader = data.build_dataloader(val_mcdset, num_workers=nworkers, shuffle=False)
    test_loader = data.build_dataloader(test_mcdset, num_workers=nworkers, shuffle=False)
    
    return [train_loader, val_loader, test_loader]

In [43]:
def loadModel(smiles_cols):
    mcmp = nn.MulticomponentMessagePassing(
        blocks=[nn.BondMessagePassing(depth=4, d_h=600) for _ in range(len(smiles_cols))],
        n_components=len(smiles_cols),
    )
    agg = nn.NormAggregation()
    ffn = nn.BinaryClassificationFFN(input_dim=mcmp.output_dim, n_layers=4, dropout=0.2)
    metric_list = [nn.metrics.BCEMetric(), nn.metrics.BinaryAUPRCMetric(), ] 
    # Only the first metric is used for training and early stopping

    mcmpnn = multi.MulticomponentMPNN(mcmp, agg, ffn, batch_norm=True, metrics=metric_list)
    return mcmpnn

In [44]:
def downSampler(df, r=1):
    n = df["binds"].value_counts()[1]
    return pd.concat([df.loc[df["binds"]==0, :].sample(n=int(n*r)), df.loc[df["binds"]==1, :]])

### Train the Model

In [45]:
import torch
torch.set_float32_matmul_precision('medium')

smiles_cols = ['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles', 'molecule_smiles']
#smiles_cols = ['molecule_smiles']

In [46]:
def trainModel(df, model_name, smiles_cols, r=1):
    loader = loadData(downSampler(df, r=r), smiles_cols, w=1/r)
    mcmpnn = loadModel(smiles_cols)

    trainer = pl.Trainer(
        logger=False,
        enable_checkpointing=True,
        enable_progress_bar=True,
        accelerator="auto",
        devices=4,
        max_epochs=4,
    )
    trainer.fit(mcmpnn, loader[0], loader[1])
    results = trainer.test(mcmpnn, loader[2])

    model = mcmpnn
    save_model(model_name, model)
    model_dict = {"hyper_parameters": model.hparams, "state_dict": model.state_dict()}
    torch.save(model_dict, model_name)

In [47]:
for ii in range(n_models):
    trainModel(df, prot+"_blocks_1v"+str(ratio)+"-"+str(ii)+'.pt', smiles_cols, r=ratio)

### Test the Model

In [None]:
test_path = './data/test.parquet'
con = duckdb.connect()
df_test = con.query(f"""
                    (SELECT *
                    FROM parquet_scan('{test_path}')
                    WHERE protein_name = '{prot}'
                    )""").df()
con.close()

In [None]:
def loadTest(df_input, smiles_cols):
    
    target_cols = ['binds']

    smiss = df_input.loc[:, smiles_cols].values

    test_data = [[data.MoleculeDatapoint.from_smi(smis[0]) for smis in smiss]]
    test_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_cols))]

    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_cols))]
    test_mcdset = data.MulticomponentDataset(test_datasets)
    test_loader = data.build_dataloader(test_mcdset, num_workers=4, shuffle=False)
    
    return test_loader

In [None]:
preds = []
for ii in range(n_models):
    model = multi.MulticomponentMPNN.load_from_file(prot+"_blocks_1v"+str(ratio)+"-"+str(ii)+'.pt')

    with torch.inference_mode():
        trainer = pl.Trainer(
            logger=None,
            enable_progress_bar=True,
            accelerator="auto",
            devices=1
        )
        test_preds = trainer.predict(model, loadTest(df_test, smiles_cols))
    preds += [np.concatenate(test_preds, axis=0)]

pred = np.mean(np.array(preds), axis=0)
df_test['binds'] = pred
df_test[['id', 'binds']].to_csv("pred_"+prot+".csv", index=False)