In [1]:
from conf_solv.model.model import ConfSolv
from conf_solv.trainer import LitConfSolvModule
from conf_solv.dataloaders.loader import create_pairdata
from conf_solv.dataloaders.features import MolGraph
from conf_solv.dataloaders.collate import Collater

import os
from joblib import Parallel, delayed
from ase import Atoms, io
import numpy as np

import torch
from torch_geometric.data import Batch

In [2]:
def divide_solute_mols(solute_mols, n_anchor_mols=10, n_threshold_mols=50):

    anchor_mols = solute_mols[:n_anchor_mols]
    other_mols = solute_mols[n_anchor_mols:]
    n_chunks = int(np.ceil(len(other_mols) / (n_threshold_mols - n_anchor_mols)))
    batch_ids = [(i*(n_threshold_mols - n_anchor_mols), (i+1)*(n_threshold_mols - n_anchor_mols)) for i in range(n_chunks)]
    batch_solute_mols = [anchor_mols + other_mols[a:b] for (a,b) in batch_ids]
    
    return batch_solute_mols


def load_lightning_model(trained_model_path):
    models = [LitConfSolvModule.load_from_checkpoint(trained_model_path)]
    return models


def load_lightning_model_parallel(parent_model_dir, n_models=1):
    
    model_paths = []
    for root, dirs, files in os.walk(parent_model_dir):
        for file in files:
            if file.endswith("best_model.ckpt"):
                 model_paths.append(os.path.join(root, file))
                    
    model_paths = sorted(model_paths)[:n_models]
    
    models = Parallel(n_jobs=-1)(delayed(load_lightning_model)(p) for p in model_paths)
    models = [m for ms in models for m in ms]
    return models


# for batch creation
collater = Collater(follow_batch=["x_solvent", "x_solute"], exclude_keys=None)

In [3]:
# load data

data_dir = os.path.abspath('../data/')
all_solute_mols = io.read(os.path.join(data_dir, 'Sorbitol.xyz'), index=':')

In [4]:
# load models

parent_model_dir = '../exps/2022_07_28/random/'
models = load_lightning_model_parallel(parent_model_dir, n_models=4)

In [5]:
# change solvent_smi for predictions in a different solvent
solvent_smi = 'O'
solvent_molgraph = MolGraph(solvent_smi)

n_threshold_mols = 50  # change this based on available memory
n_anchor_mols = 5

In [6]:
# split input into [n_anchor_mols + (n_threshold_mols-n_anchor_mols)] sections
# should only trigger when n_solute_mols > n_threshold_mols
# ideally, n_threshold_mols is as large as possible

if len(all_solute_mols) > n_threshold_mols:
    batch_solute_mols = divide_solute_mols(all_solute_mols,
                                           n_anchor_mols=n_anchor_mols,
                                           n_threshold_mols=n_threshold_mols)
else:
    batch_solute_mols = [all_solute_mols]


out_final = torch.tensor([])
for batch_idx, solute_mols in enumerate(batch_solute_mols):
    data = create_pairdata(solvent_molgraph, solute_mols, len(solute_mols))
    batch_data = collater([data])

    if batch_idx == 0:
        out = torch.stack([model(batch_data, len(solute_mols)) for model in models])
    else:
        out = torch.stack([model(batch_data, len(solute_mols)) for model in models])[:, n_anchor_mols:]
    out_final = torch.cat([out_final, out], dim=-1)


out_scaled = out_final - out_final.min(dim=1, keepdim=True).values  # scale each prediction relative to lowest energy conformer.\n",
stds = out_scaled.std(dim=0)
preds = out_scaled.mean(dim=0)
preds = preds - preds.min()

print(preds*0.238846)
print(stds*0.238846)

tensor([10.8977, 12.4077, 12.3912, 10.5550, 12.5749, 11.6368, 11.5799, 10.8865,
        14.1697, 14.0723, 12.2350, 11.6469, 15.4874, 13.3913, 15.0867, 12.6538,
        12.6738, 12.8915, 12.9548, 12.8177, 14.5499, 12.9141, 10.9946, 12.0461,
        13.3579, 13.1725, 12.8395, 14.9351, 13.6939, 17.2657, 14.3632, 13.0742,
        12.8118, 15.9643, 16.0562, 12.7125, 15.2607, 13.4480, 14.8114, 12.2154,
        16.5650, 15.8668, 17.2567, 14.2698, 17.1249, 15.7755, 14.3689, 17.1320,
        13.5963, 14.5115, 14.3937, 16.1479, 17.7627, 15.9031, 13.5128, 15.1298,
        15.3484, 19.0909, 15.5702, 16.9397, 14.6593, 19.0787, 16.4714, 16.3603,
        14.0503, 17.6173, 17.2691, 20.2740, 17.6452, 15.8369, 19.3415, 18.4590,
        15.6575, 21.2221, 17.1644, 12.7323, 10.4717, 11.8764, 12.1912,  0.0000,
        12.0534, 15.3450,  7.6801,  6.3522, 14.9360,  7.1314,  7.9346, 14.4323,
        13.7662, 11.5420, 11.9538, 11.5690, 11.6476,  5.7250, 12.1647, 12.6293,
        11.5747, 12.1295, 14.0970, 13.24