In [None]:
# 写一个函数，入参是reactants SMILES
# 输出是predicted SMILES, mol_graph, diff graph


## Tool - NERF

In [None]:
# 把入参的SMILES处理成model可以接收的形式：（batch size = 1）的dataloader
# 同时，输出src的element, mask, bond, aroma, charge
from dataset import TransformerDataset
from preprocess import molecule
from torch.utils.data import DataLoader
from rdkit import Chem
from utils import result2mol
from model import *
from langchain.tools import BaseTool
from transformers import AutoTokenizer, T5ForConditionalGeneration

class NERF(BaseTool):
    name = 'NERF'
    description = ('Predict the product of a chemical reaction, input SMILES, output the change of bonds and the SMILES of predicted products')

    def map_atoms_in_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)

        if mol is None:
            return None

        atom_map = {}
        for atom in mol.GetAtoms():
            atom.SetAtomMapNum(atom.GetIdx() + 1)
            atom_map[atom.GetIdx()] = atom.GetIdx() + 1

        mapped_smiles = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
        return mapped_smiles

    def process_smiles(self, smiles, map_atoms=False):
        if map_atoms:
            smiles = self.map_atoms_in_smiles(smiles)
        reactant_mols = [Chem.MolFromSmiles(item) for item in smiles.split(".")]
        reactant_len = Chem.MolFromSmiles(smiles).GetNumAtoms()

        reactant_features = molecule(reactant_mols, reactant_len)

        element = reactant_features['element']
        mask = reactant_features['mask']
        bond = reactant_features['bond']
        aroma = reactant_features['aroma']
        charge = reactant_features['charge']

        input_data = {}
        for key in reactant_features:
            if key in ["element", "reactant"]:
                input_data[key] = reactant_features[key]
            else:
                input_data['src_'+key] = reactant_features[key]

        # print('reactant')
        # print(reactant_features['reactant'])
        # print('src mask')
        # print(reactant_features['mask'])

        data = [input_data]
        # print('data in ipynb', '**'*20)
        # print(type(data))
        # print(data)

        full_dataset = TransformerDataset(False, data)

        data_loader = DataLoader(full_dataset,
                                 batch_size=1,
                                 num_workers=4, collate_fn=TransformerDataset.collate_fn)

        return data_loader, element, mask, bond, aroma, charge

    def init_model(self, save_path, checkpoint):
        state_dict = {}
        map_location = {'cuda:%d' % 0: 'cuda:%d' % 0}
        checkpoint = torch.load(os.path.join(save_path, checkpoint), map_location=map_location)
        for key in checkpoint['model_state_dict']:
            if key in state_dict:
                state_dict[key] += checkpoint['model_state_dict'][key]
            else:
                state_dict[key] = checkpoint['model_state_dict'][key]

        model = MoleculeVAE(None, 100, 192, 6).to(0) # TODO
        # model.load_state_dict(state_dict)

        return model

    def predict(self, data_loader,
                save_path='./CKPT/no_reactant_mask/', checkpoint="epoch-2-loss-4.190030106406974", temperature=1):

        model = self.init_model(save_path, checkpoint)

        for data in data_loader: # 只有1个
            data_gpu = {}
            for key in data:
                data_gpu[key] = data[key].to(0)

            predicted_dict = model('sample', data_gpu, temperature)

            element = data['element']
            src_mask = data['src_mask']
            pred_bond = predicted_dict['bond'].cpu()
            pred_aroma, pred_charge = predicted_dict['aroma'].cpu(), predicted_dict['charge'].cpu()

            arg_list = [(element[j], src_mask[j], pred_bond[j], pred_aroma[j], pred_charge[j], None) for j in
                    range(1)]

            res = map(result2mol, arg_list)
            res = list(res)

            for item in res:
                mol, smile, valid, mol_graph = item[0], item[1], item[2], item[3]

            return mol, smile, valid, mol_graph, pred_bond, pred_aroma, pred_charge

    def pred_from_smiles(self, smiles):
        dl, element, src_mask, src_bond, src_aroma, src_charge = self.process_smiles(smiles, False)

        arg_list_src = [(element, src_mask, src_bond, src_aroma, src_charge, None)]
        result_src = map(result2mol, arg_list_src)
        result_src = list(result_src)
        for item in result_src:
            src_mol_adj = Chem.GetAdjacencyMatrix(item[0])

        pred_mol, pred_smile, pred_valid, pred_mol_graph, pred_bond, pred_aroma, pred_charge = self.predict(dl)
        pred_mol_adj = Chem.GetAdjacencyMatrix(pred_mol)

        diff_adj = pred_mol_adj - src_mol_adj

        return pred_smile, element, diff_adj

    def get_reaction_info(self, elements, diff):
        element_symbols = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Ni", "Co", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"]

        reaction_info = []

        for i in range(len(diff)):
            for j in range(i+1, len(diff)):
                if diff[i][j] == 1:
                    reaction_info.append(f"Formation of bond between atom {i+1} ({element_symbols[elements[i]]} element) and atom {j+1} ({element_symbols[elements[j]]} element)")
                elif diff[i][j] == -1:
                    reaction_info.append(f"Breaking of bond between atom {i+1} ({element_symbols[elements[i]]} element) and atom {j+1} ({element_symbols[elements[j]]} element)")

        return reaction_info

    def _run(self, reactants_smile: str) -> str:
        pred_smile, src_element, diff_ = self.pred_from_smiles(reactants_smile)
        explain = self.get_reaction_info(src_element, diff_)
        return '\n\n'.join([pred_smile, explain])

## Other Tools

In [None]:
from rdkit import Chem
from langchain.agents import tool

@tool
def check_no_extra_element_or_invalid(reactants_reagents_products_smiles: str):
    """This function should be used finally to check answer.
    You need to input a list of SMILES where the SMILES of reactants+reagents and the possible SMILES of products (answer) should be concated using '.'. The molecules are splitted by .\n
    This function will return true or false whether the products contain extra elements out of the reactants and reagents, or whether this products are invalid. \
    If the products contain extra elements, or products are invalid SMILES, return false, otherwise return true. True means the products are invalid and need to regenerate other products. \
    Overall, for the return value of this function, false is good, true is bad. \
    Use this function to check before output the final result.\n

    For example , if you predict the possible product of reactants+regents C/C=C/C=O.Nc1ccc are CC,
    you need to input C/C=C/C=O.Nc1ccc.CC to this function, to check whether CC is a valid answer to the reactants+regents C/C=C/C=O.Nc1ccc. If the return value is false, good; if the return value is true, bad.
    """

    if "." not in reactants_reagents_products_smiles:
        return "Invalid input. You need to input a list of SMILES in the format SMILE of reactants+reagents.SMILE of products concated using '.'. The molecules are splitted by ."
    # Split the input strings to obtain lists of reactants, reagents, and products
    reactants_reagents = reactants_reagents_products_smiles.split(".")[:-1]
    products = reactants_reagents_products_smiles.split(".")[-1]

    # Create sets to store the unique elements in reactants and products
    reactants_elements = set()
    products_elements = set()

    # Process each reactant/reagent and add its elements to the set
    for r in reactants_reagents:
        try:
            mol = Chem.MolFromSmiles(r)
            if mol:
                elements = set([atom.GetSymbol() for atom in mol.GetAtoms()])
                reactants_elements.update(elements)
        except:
            return True

    # Process each product and add its elements to the set
    for p in products:
        mol = Chem.MolFromSmiles(p)
        if mol:
            elements = set([atom.GetSymbol() for atom in mol.GetAtoms()])
            products_elements.update(elements)

    # Check if there are any elements in products that are not in reactants/reagents
    extra_elements = products_elements.difference(reactants_elements)

    # Return True if there are extra elements, otherwise return False
    return bool(extra_elements)



In [None]:
class FuncGroups(BaseTool):
    name = "FunctionalGroups"
    description = "Input SMILES, return list of functional groups in the molecule. Analysis the observation of the functional groups to predict the reaction type."
    dict_fgs: dict = None

    def __init__(
        self,
    ):
        super(FuncGroups, self).__init__()

        # List obtained from https://github.com/rdkit/rdkit/blob/master/Data/FunctionalGroups.txt
        self.dict_fgs = {
            "furan": "o1cccc1",
            "aldehydes": " [CX3H1](=O)[#6]",
            "esters": " [#6][CX3](=O)[OX2H0][#6]",
            "ketones": " [#6][CX3](=O)[#6]",
            "amides": " C(=O)-N",
            "thiol groups": " [SH]",
            "alcohol groups": " [OH]",
            "methylamide": "*-[N;D2]-[C;D3](=O)-[C;D1;H3]",
            "carboxylic acids": "*-C(=O)[O;D1]",
            "carbonyl methylester": "*-C(=O)[O;D2]-[C;D1;H3]",
            "terminal aldehyde": "*-C(=O)-[C;D1]",
            "amide": "*-C(=O)-[N;D1]",
            "carbonyl methyl": "*-C(=O)-[C;D1;H3]",
            "isocyanate": "*-[N;D2]=[C;D2]=[O;D1]",
            "isothiocyanate": "*-[N;D2]=[C;D2]=[S;D1]",
            "nitro": "*-[N;D3](=[O;D1])[O;D1]",
            "nitroso": "*-[N;R0]=[O;D1]",
            "oximes": "*=[N;R0]-[O;D1]",
            "Imines": "*-[N;R0]=[C;D1;H2]",
            "terminal azo": "*-[N;D2]=[N;D2]-[C;D1;H3]",
            "hydrazines": "*-[N;D2]=[N;D1]",
            "diazo": "*-[N;D2]#[N;D1]",
            "cyano": "*-[C;D2]#[N;D1]",
            "primary sulfonamide": "*-[S;D4](=[O;D1])(=[O;D1])-[N;D1]",
            "methyl sulfonamide": "*-[N;D2]-[S;D4](=[O;D1])(=[O;D1])-[C;D1;H3]",
            "sulfonic acid": "*-[S;D4](=O)(=O)-[O;D1]",
            "methyl ester sulfonyl": "*-[S;D4](=O)(=O)-[O;D2]-[C;D1;H3]",
            "methyl sulfonyl": "*-[S;D4](=O)(=O)-[C;D1;H3]",
            "sulfonyl chloride": "*-[S;D4](=O)(=O)-[Cl]",
            "methyl sulfinyl": "*-[S;D3](=O)-[C;D1]",
            "methyl thio": "*-[S;D2]-[C;D1;H3]",
            "thiols": "*-[S;D1]",
            "thio carbonyls": "*=[S;D1]",
            "halogens": "*-[#9,#17,#35,#53]",
            "t-butyl": "*-[C;D4]([C;D1])([C;D1])-[C;D1]",
            "tri fluoromethyl": "*-[C;D4](F)(F)F",
            "acetylenes": "*-[C;D2]#[C;D1;H]",
            "cyclopropyl": "*-[C;D3]1-[C;D2]-[C;D2]1",
            "ethoxy": "*-[O;D2]-[C;D2]-[C;D1;H3]",
            "methoxy": "*-[O;D2]-[C;D1;H3]",
            "side-chain hydroxyls": "*-[O;D1]",
            "ketones": "*=[O;D1]",
            "primary amines": "*-[N;D1]",
            "nitriles": "*#[N;D1]",
        }

    def _is_fg_in_mol(self, mol, fg):
        fgmol = Chem.MolFromSmarts(fg)
        mol = Chem.MolFromSmiles(mol.strip())
        return len(Chem.Mol.GetSubstructMatches(mol, fgmol, uniquify=True)) > 0

    def _run(self, smiles: str) -> str:
        """
        Input a molecule SMILES or name.
        Returns a list of functional groups identified by their common name (in natural language), which is helpful to predict reaction type.
        """
        try:
            fgs_in_molec = [
                name
                for name, fg in self.dict_fgs.items()
                if self._is_fg_in_mol(smiles, fg)
            ]
            if len(fgs_in_molec) > 1:
                return f"This molecule contains {', '.join(fgs_in_molec[:-1])}, and {fgs_in_molec[-1]}."
            else:
                return f"This molecule contains {fgs_in_molec[0]}."
        except:
            return "Wrong argument. Please input a valid molecular SMILES."

    async def _arun(self, smiles: str) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError()

In [None]:
from langchain.tools import BaseTool
from transformers import AutoTokenizer, T5ForConditionalGeneration

# To put ML models for reaction prediction
tokenizer = AutoTokenizer.from_pretrained('sagawa/ReactionT5-product-prediction')
model = T5ForConditionalGeneration.from_pretrained('sagawa/ReactionT5-product-prediction')

class ReactionT5(BaseTool):
    name = 'reaction_t5'
    description = ('Predict the product of a chemical reaction')
    global tokenizer, model
    def __init__(self, model_name="t5-large", verbose=False):
        super().__init__(verbose=verbose)


    def _run(self, input_smiles: str) -> str:
        inp = tokenizer(f'REACTANT:{input_smiles}REAGENT:', return_tensors='pt')
        output = model.generate(**inp, min_length=6, max_length=109, num_beams=1, num_return_sequences=1, return_dict_in_generate=True, output_scores=True)
        output = tokenizer.decode(output['sequences'][0], skip_special_tokens=True).replace(' ', '').rstrip('.')
        return output

    async def _arun(self, smiles_pair: str) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError()

# Agent

In [None]:
from langchain.tools import WikipediaQueryRun
from langchain.chat_models import ChatOpenAI
# 引入库

import openai
import os
from langchain.llms import OpenAI
from langchain.agents import load_tools
from langchain.agents import initialize_agent
CUDA_LAUNCH_BLOCKING=1
# from dotenv import load_dotenv
# load_dotenv()

torch.backends.cudnn.enable =True
torch.backends.cudnn.benchmark = True

# 载入 API keys; 如果没有，你需要先获取。
os.environ["OPENAI_API_KEY"] = "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 "
os.environ["SERPER_API_KEY"] = "aa8785a775f630cc3e542817a6e6d6757bb9fcea"

# llm = ChatOpenAI(temperature=0)
llm = OpenAI(model_name="text-davinci-003" ,temperature=0.2)

tools = load_tools(['arxiv'], llm = llm)
agent = initialize_agent(tools+[check_no_extra_element_or_invalid]+ [ReactionT5(), FuncGroups(), NERF()], llm, agent="zero-shot-react-description",
                         verbose=True, handle_parsing_errors=True)

In [None]:
agent.run("What is the products of the chemical reaction with reactants: COC(=O)c1cc(CCCc2cc3c(=O)[nH]c(N)nc3[nH]2)cs1? \
    Explain the possible reaction mechanisms")