In [1]:
#Notebook to evaluate BASF's prediction script.
#Import statements
import argparse
from pathlib import Path
from typing import Any, List, Optional, Tuple
import os
import time
import random

import pandas as pd
#from ase import io
from ase.atoms import Atoms
from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm
import torch
from torch_geometric.data import Batch

from conf_solv.trainer import LitConfSolvModule
from conf_solv.dataloaders.loader import create_pairdata
from conf_solv.dataloaders.features import MolGraph
from conf_solv.model.model import ConfSolv

In [2]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
#Read the xyz file.
def read_xyz(fileobj):
    # This function reads first all atoms and then yields based on the index.
    # Perfomance could be improved, but this serves as a simple reference.
    # It'd require more code to estimate the total number of images
    # without reading through the whole file (note: the number of atoms
    # can differ for every image).
    lines = fileobj.readlines()
    images = []
    name_options = ["name", "id"]
    i = 0
    while len(lines) > 0:
        symbols = []
        positions = []
        natoms = int(lines.pop(0))
        comment_line = lines.pop(0).strip()  # Comment line; ignored
        comment_line_split = comment_line.split(' ')
        comment = None
        for item in comment_line_split:
            for name in name_options:
                if item.startswith(f"{name}="):
                    comment = item.replace(f"{name}=", "")
                    break
            if comment:
                break
        if not comment:
            comment = str(i)
        for _ in range(natoms):
            line = lines.pop(0)
            symbol, x, y, z = line.split()[:4]
            symbol = symbol.lower().capitalize()
            symbols.append(symbol)
            positions.append([float(x), float(y), float(z)])
        images.append((Atoms(symbols=symbols, positions=positions), comment))
        i += 1
    for atoms in images[:]:
        yield atoms


In [4]:
def divide_solute_mols(solute_mols, n_anchor_mols=10, n_threshold_mols=50):
    anchor_mols = solute_mols[:n_anchor_mols]
    other_mols = solute_mols[n_anchor_mols:]
    n_chunks = int(np.ceil(len(other_mols) / (n_threshold_mols - n_anchor_mols)))
    batch_ids = [(i*(n_threshold_mols - n_anchor_mols), (i+1)*(n_threshold_mols - n_anchor_mols)) for i in range(n_chunks)]
    batch_solute_mols = [anchor_mols + other_mols[a:b] for (a,b) in batch_ids]
    return batch_solute_mols

In [5]:
def ase_atoms_from_xyz(xyz_path):
    #ase_atoms_from_xyz = io.read(xyz_path, index=':')
    with open(xyz_path, 'r') as f:
        ase_atoms_from_xyz = list(read_xyz(f))
    return ase_atoms_from_xyz

In [6]:
def flatten(xss):
    return [x for xs in xss for x in xs]

In [7]:
def load_lightning_model(trained_model_dir,i):
    models = [LitConfSolvModule.load_from_checkpoint(os.path.join(trained_model_dir, f'ensemble_{i}','best_model.ckpt'))]
    return models

def load_lightning_model_parallel(trained_model_dir,ensemble_nos):
    models = Parallel(n_jobs=len(ensemble_nos))([delayed(load_lightning_model)(trained_model_dir,i) for i in ensemble_nos])
    models = flatten(models)
    return models

In [8]:
class ConfsolvPrediction:
    
    def __init__(self, n_threshold_mols: int = 50, n_anchor_mols: int = 0, num_cores: int = -1, silent=False) -> None:
        self.n_threshold_mols: int = n_threshold_mols
        self.n_anchor_mols: int = n_anchor_mols
        self.models: Optional[List[ConfSolv]] = None
        self.model_parameters: Optional[Any] = None
        self.solute_mols: Optional[List[Atoms]] = None
        self.solute_names: Optional[List[str]] = None
        self.solvent_df: Optional[pd.DataFrame] = None
        self.trained_model_dir = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        #self.device = 'cpu'
        print(f'Device being used is {self.device}')
        self.silent = silent
        print(f'Number of cores requested is {num_cores}')
        if num_cores != -1:
            torch.set_num_threads(num_cores)
        threads = torch.get_num_threads()
        print(f'Number of threads used by torch is {threads}')
        
        
    def load_models(self, trained_model_dir: str = 'confsolv_models/nonionic_solvents_scaffold') -> None:
        """
        Load pretrained models.

        Parameters
        ----------
        trained_model_dir : str, optional
            path to the trained model, by default 'confsolv_models/nonionic_solvents_scaffold'
            use 'confsolv_models/ionic_solvents_scaffold' for ionic solvents
        """
        self.trained_model_dir = Path(trained_model_dir)
        ensemble_nos = len([x for x in self.trained_model_dir.iterdir() if x.is_dir()])
        print(f"Loading model {self.trained_model_dir.name} with {ensemble_nos} ensembles...")
        self.models = load_lightning_model_parallel(trained_model_dir, range(ensemble_nos))
        
    def load_solutes(self, xyz_file: str) -> None:
        """
        Load solutes from xyz file.

        Parameters
        ----------
        xyz_file : str
            path to the xyz file containing all molecules to be evaluated

        """
        solute_mols_and_names = ase_atoms_from_xyz(xyz_file)
        self.solute_mols = [item[0] for item in solute_mols_and_names]
        self.solute_names = [item[1] for item in solute_mols_and_names]

    def load_solvents(self, solvent_file: str) -> None:
        """
        Load solvents as pandas dataframe from csv file. Headers must be SOLVENT_NAME and SMILES.

        Parameters
        ----------
        solvent_file : str
            path to solvent csv file (Headers must be SOLVENT_NAME and SMILES).
        """
        solvent_df = pd.read_csv(solvent_file)
        if "SOLVENT_NAME" not in solvent_df.columns or "SMILES" not in solvent_df.columns:
            raise Exception("Solvent file does not have necessary columns (SOLVENT_NAME, SMILES).")
        self.solvent_df = solvent_df[['SOLVENT_NAME', 'SMILES']]

    def predict_all_solutes_all_solvents(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        print("Running prediction...")
        result_df = pd.DataFrame()
        std_df = pd.DataFrame()
        conf_name_col = 'CONF_NAME'
        if self.n_anchor_mols == 0:
            result_df[conf_name_col] = self.solute_names
            std_df[conf_name_col] = self.solute_names
        solvent_data_iterator = self.solvent_df.index if self.silent else tqdm(self.solvent_df.index)
        for index in solvent_data_iterator:  
            solvent_smi = self.solvent_df.at[index, 'SMILES']
            solvent_name = self.solvent_df.at[index, 'SOLVENT_NAME']
            preds, stds = self._predict_all_solutes_single_solvent(solvent_smi=solvent_smi)
            result_df[solvent_name] = preds
            std_df[solvent_name] = stds
        print("Prediction completed.")
        return result_df, std_df

    def _predict_all_solutes_single_solvent(self, solvent_smi: str) -> Tuple[np.array, np.array]:
        # change solvent_smi for predictions in a different solvent
        solvent_molgraph = MolGraph(solvent_smi)
        # batch solute mols
        # split input into [n_anchor_mols + (n_threshold_mols-n_anchor_mols)] sections
        # should only trigger when n_solute_mols > n_threshold_mols
        # ideally, n_threshold_mols is as large as possible
        if len(self.solute_mols) > self.n_threshold_mols:
            batch_solute_mols = divide_solute_mols(self.solute_mols,
                                                n_anchor_mols=self.n_anchor_mols,
                                                n_threshold_mols=self.n_threshold_mols)
        else:
            batch_solute_mols = [self.solute_mols]
        n_atoms = batch_solute_mols[0][0].get_global_number_of_atoms()

        #out_final = torch.tensor([])
        out_final = torch.empty(0)
        for batch_idx, solute_mols in enumerate(batch_solute_mols):
            data = create_pairdata(solvent_molgraph, solute_mols, len(solute_mols))
            data.solute_confs_batch = torch.concat([torch.zeros([n_atoms],dtype=torch.int64) + i for i in range( len(solute_mols))])
            batch_data = Batch.from_data_list([data], follow_batch=['x_solvent', 'x_solute'])
            if self.device == 'cuda':
                #batch_data.to(self.device)
                batch_data.cuda()
                for model in self.models:
                    model.cuda()
                    #model.to(self.device)
            with torch.no_grad():
                for model in self.models:
                    model.eval()
                    print(f"Envelope_a {model.model.solute_model.rbf.envelope.a}")
                    print(f"Envelope_b {model.model.solute_model.rbf.envelope.b}")
                    print(f"Envelope_c {model.model.solute_model.rbf.envelope.c}")
                    print(f"Envelope_p {model.model.solute_model.rbf.envelope.p}")
                out = torch.stack([model(batch_data,len(solute_mols)) for model in self.models])
            out = out.cpu()
            out_final = torch.cat([out_final, out], dim=-1)
        out_scaled = out_final - out_final.min(dim=1, keepdim=True).values #Scale each prediction relative to lowest energy conformer.\n",
        stds = out_scaled.std(dim=0)
        preds = out_scaled.mean(dim=0)
        preds = preds - preds.min()
        return preds.cpu().numpy(), stds.cpu().numpy()

## Run Inference

In [9]:
#We'll define an inference function to help us.
def inference(xyz_file, solvent_file, out_path, out_path_uncertainty, n_threshold_mols=1000, n_anchor_mols=0, num_cores=-1, silent=False):
    model_path = f'../sample_trained_models/'
    prediction = ConfsolvPrediction(n_threshold_mols=n_threshold_mols, n_anchor_mols=n_anchor_mols, num_cores=num_cores, silent=silent)
    
    START = time.time()
    prediction.load_models(model_path)
    print(f'Elapsed time for task load models is {time.time() - START:.2f} seconds')
    
    START = time.time()
    prediction.load_solutes(xyz_file)
    print(f'Elapsed time for task load conformers is {time.time() - START:.2f} seconds')
    
    START = time.time()
    prediction.load_solvents(solvent_file)
    print(f'Elapsed time for task load solvents is {time.time() - START:.2f} seconds')
    
    START = time.time()
    result_df, std_df = prediction.predict_all_solutes_all_solvents()
    print(f'Elapsed time for task make predictions is {time.time() - START:.2f} seconds')
    
    result_df.to_csv(out_path)
    std_df.to_csv(out_path_uncertainty)

In [10]:
#get current directory -> for path and variable initialization
maindir = os.getcwd()
maindir

'/data1/groups/RMG/Projects/BASF/conf_solv/BASF_predict'

In [11]:
xyzfile = os.path.join(maindir,'example.xyz')
solventfile = os.path.join(maindir, 'solvents_example.act')
outfile = os.path.join(maindir, 'outpreds.csv')
uncertfile = os.path.join(maindir, 'outstd.csv')
thresholdmols = 100 



In [None]:
#Run inference
seed_everything(seed=10608)
inference(xyzfile, solventfile, outfile, uncertfile, thresholdmols, num_cores = 20)
