# Testing Local Llama3-8B Agent with LangChain for Molecule Generation and Neighbor Calculation

This notebook aims to be a very simple implementation of Recursion's LOWE. It uses LangChain to empower an LLM agent to use a molecular generation tool and a neighborhood calculation tool.

- The molecular generation tool uses the trained Phi-SAFE model.
- The neighborhood calculator measures Tanimoto similarity from a public database of SMILES molecules  (https://www.kaggle.com/datasets/yanmaksi/big-molecules-smiles-dataset)

The notebook tests two models as agents:

1. Llama3-8B-Q4

2. Cohere's Command-R

Of course, Command-R is much more precise and accurate in it's tool usage.

---


## Llama3-8B

In [252]:
from langchain_community.llms import LlamaCpp

n_gpu_layers = -1  # Metal set to 1 is enough.
n_batch = 512  # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.

# Make sure the model path is correct for your system!
llm = LlamaCpp(
    model_path="./Llama-3-8B-Instruct-64k.Q4_K_M.gguf",
    n_gpu_layers=n_gpu_layers,
    n_batch=n_batch,
    n_ctx=2048,
    max_tokens =128,
    temperature =0.0,
    f16_kv=True,  # MUST set to True, otherwise you will run into problem after a couple of calls
    verbose=False,
    stop=["<|eot_id|>"]
)

In [253]:
from langchain_core.tools import tool
from transformers import AutoModelForCausalLM
import safe as sf
from safe.tokenizer import SAFETokenizer
import pandas as pd
from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from langchain.tools.render import render_text_description
from langchain_core.output_parsers import JsonOutputParser
import os
import json
from operator import itemgetter
from langchain_core.messages import AIMessage
from langchain_core.runnables import (
    Runnable,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

@tool
def molecular_generation(n_samples: int) -> list:
    """Generate n_samples molecules using the trained SAFE-Phi model."""

    checkpoint_path = ".saved_model/phi1_5-safmol_0528/checkpoint-29600"
    model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
    model.eval()
    tokenizer = SAFETokenizer().load(checkpoint_path+"/tokenizer.json")

    designer = sf.SAFEDesign(model=model, tokenizer=tokenizer)
    generated_smiles = designer.de_novo_generation(sanitize=True, n_samples_per_trial=n_samples)

    return generated_smiles

@tool
def retrieve_closest_neighbors(query_smiles: str, n: int) -> list:
    """Find the n closest neighbors to the query_smiles in the SMILES_Big_Data_Set.csv dataset."""

    # Read the CSV file
    df = pd.read_csv("SMILES_Big_Data_Set.csv")
    
    # Convert the query SMILES to a Morgan fingerprint
    query_mol = Chem.MolFromSmiles(query_smiles)
    query_fp = AllChem.GetMorganFingerprintAsBitVect(query_mol, 2, nBits=2048)
    
    # Calculate the similarity between the query and each SMILES in the dataset
    similarities = []
    for smiles in df['SMILES']:
        try:
            mol = Chem.MolFromSmiles(smiles)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            similarity = DataStructs.TanimotoSimilarity(query_fp, fp)
            similarities.append(similarity)

        except:
            similarities.append(0.0)
    
    # Add the similarities as a new column in the DataFrame
    df['Similarity'] = similarities
    
    # Sort the DataFrame by similarity in descending order
    df_sorted = df.sort_values('Similarity', ascending=False)
    
    # Get the top N closest neighbors
    closest_neighbors = df_sorted.head(n)['SMILES'].tolist()
    
    return closest_neighbors

tools = [molecular_generation, retrieve_closest_neighbors]

def is_valid_json(json_string):
    try:
        json.loads(json_string)
        return True
    except ValueError:
        return False
def tool_chain(model_output):
    # Function to check if the output is a valid JSON string
    def is_valid_json(json_string):
        try:
            json_object = json.loads(json_string)
            return True, json_object
        except ValueError:
            return False, None

    # Check if the model_output is a valid JSON string
    is_json, parsed_output = is_valid_json(model_output)
    
    # If it is not a JSON string, return the original output
    if not is_json:
        return model_output
    
    # If it is a JSON string, proceed with the tool chain logic
    tool_map = {tool.name: tool for tool in tools}
    chosen_tool = tool_map[parsed_output["name"]]
    return JsonOutputParser() | itemgetter("arguments") | chosen_tool

rendered_tools = render_text_description(tools)

In [254]:
from langchain_core.prompts import PromptTemplate

prompt_template = PromptTemplate.from_template(
""""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
)

system_prompt = f"""You are an assistant that has access to the following set of tools. Here are the names and descriptions for each tool:

{rendered_tools}

Given the user question, return the name and input of the tool to use. If you decide to use a tool, return your response as a JSON blob with 'name' and 'arguments' keys. Nothing else, just the JSON.
Here are a few examples of correct responses:

Example 1:
user: Generate 10 new molecules.
assistant: {{ "name": "molecular_generation", "arguments": {{"n_samples": 10}} }}

Example 2:
user: Find the 5 closest neighbors to the SMILES molecule 'CCO':
assistant: {{ "name": "retrieve_closest_neighbors", "arguments": {{ "query_smiles": "CCO", "n": 5 }} }}

If you want to use a tool, RETURN ONLY THE JSON. DONT WRITE ANY NORMAL TEXT!!
"""

In [255]:

chain = prompt_template | llm | tool_chain
molecules = chain.invoke({"system_prompt":system_prompt,"prompt": "Generate 10 molecules"})
print(molecules)

[32m2024-06-04 16:51:41.671[0m | [1mINFO    [0m | [36msafe.sample[0m:[36mde_novo_generation[0m:[36m577[0m - [1mAfter sanitization, 2 / 10 (20.00 %) generated molecules are valid ![0m


['C1COCCN1.CC(C)CN[CH]CCN=C(O)CNC(C)(C)C.CCC(C)C', 'CCCOc1ccccc1C(C)=O']


In [257]:
neighbors = chain.invoke({"system_prompt":system_prompt,"prompt": "Find the 10 closest neighbors of the SMILES molecule '{}' ".format(molecules[1])})
print(neighbors)

['CC(=O)c1ccccc1OCC(=O)O', 'COc1ccccc1C(C)=O', 'CCCOC(=O)c1ccccc1C(=O)OCCC', 'CC(=O)c1ccccc1OC(=O)N(C)C', 'CCNC(=O)c1ccccc1OCC', 'CCOc1ccccc1C(=O)N(CC)CC', 'CCCC(=O)OCOc1ccccc1C(N)=O', 'CC(=O)OCOc1ccccc1C(N)=O', 'CCCOC(=O)c1ccccc1O', 'CCCOc1cc(N)ccc1C(=O)O']


In [258]:
random_chat = chain.invoke({"system_prompt":system_prompt,"prompt": "What is AI for drug discovery?"})
print(random_chat)

Artificial Intelligence (AI) has revolutionized the field of drug discovery by providing new tools and techniques to accelerate the process.

Here are some ways AI is being used in drug discovery:

1. **Virtual screening**: AI algorithms can quickly scan large databases of potential compounds against a target protein or receptor, identifying those that bind with high affinity.
2. **Structure prediction**: AI can predict the three-dimensional structure of proteins and other molecules from their amino acid sequence.
3.assistant


---

## Cohere's Command-R

In [259]:
os.environ["COHERE_API_KEY"] = getpass.getpass()
from langchain_cohere import ChatCohere

llm = ChatCohere(model="command-r")

llm_with_tools = llm.bind_tools(tools)
tool_map = {tool.name: tool for tool in tools}


def call_tools(msg: AIMessage) -> Runnable:
    """Simple sequential tool calling helper."""
    tool_map = {tool.name: tool for tool in tools}
    tool_calls = msg.tool_calls.copy()
    for tool_call in tool_calls:
        tool_call["output"] = tool_map[tool_call["name"]].invoke(tool_call["args"])
    return tool_calls


chain = llm_with_tools | call_tools

In [260]:
molecules = chain.invoke("Generate 10 molecules")
print(molecules[0]['output'])

[32m2024-06-04 16:53:27.481[0m | [1mINFO    [0m | [36msafe.sample[0m:[36mde_novo_generation[0m:[36m577[0m - [1mAfter sanitization, 3 / 10 (30.00 %) generated molecules are valid ![0m


['NC(=O)N1CCC(C2CCCCC2)CC1', 'CC=CC=C(C)C.Cc1nc(C)c(C)c(C(C)(C)NCC(C)C)n1', 'CC(=O)Oc1ccccc1Cc1ccccc1F.N#CCO']


In [262]:
neighbors = chain.invoke("Find the 10 closest neighbors of the SMILES molecule '{}' ".format(molecules[0]['output'][0]))
print(neighbors[0]['output'])

['O=C(CC1CCCCC1)N1CCCCC1', 'O=C(O)C1CCCCC1', 'NC(=S)N1CCCCC1', 'NNC(=O)NC1CCCCC1', 'O=C1CCC2CCCCC2C1', 'NC(=O)CN1CCCC1=O', 'C1CCC(C2CO2)CC1', 'NC1CCCCC1', 'OC1CCCCC1', 'C=CC(=O)OC1CCCCC1']
