In [1]:
# Initialization
import os, re, json
from google.colab import userdata
import pandas as pd
import google.generativeai as genai

os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
genai.configure(api_key=userdata.get("GEMINI_API_KEY"))

# Install dependencies
!pip install --upgrade --quiet accelerate bitsandbytes huggingface_hub transformers

In [2]:
# Load prompt template
import json
from huggingface_hub import hf_hub_download

tdc_prompts_filepath = hf_hub_download(
    repo_id="google/txgemma-2b-predict",
    filename="tdc_prompts.json",
)

with open(tdc_prompts_filepath, "r") as f:
    tdc_prompts_json = json.load(f)

In [3]:
# Load model
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

PREDICT_VARIANT = "2b-predict"  # @param ["2b-predict", "9b-predict", "27b-predict"]
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]
USE_CHAT = True # @param {type: "boolean"}

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

predict_tokenizer = AutoTokenizer.from_pretrained(f"google/txgemma-{PREDICT_VARIANT}")
predict_model = AutoModelForCausalLM.from_pretrained(
    f"google/txgemma-{PREDICT_VARIANT}",
    device_map="auto",
    quantization_config=quantization_config,
)

if USE_CHAT:
    chat_tokenizer = AutoTokenizer.from_pretrained(f"google/txgemma-{CHAT_VARIANT}")
    chat_model = AutoModelForCausalLM.from_pretrained(
        f"google/txgemma-{CHAT_VARIANT}",
        device_map="auto",
        quantization_config=quantization_config,
    )

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
# Example task and input
task_name = "BBB_Martins"
input_type = "{Drug SMILES}"
drug_smiles = "CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21"
TDC_PROMPT = tdc_prompts_json[task_name].replace(input_type, drug_smiles)

def txgemma_predict(prompt):
    input_ids = predict_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = predict_model.generate(**input_ids, max_new_tokens=8)
    return predict_tokenizer.decode(outputs[0], skip_special_tokens=True)

def txgemma_chat(prompt):
    input_ids = chat_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = chat_model.generate(**input_ids, max_new_tokens=32)
    return chat_tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Prediction model response: {txgemma_predict(TDC_PROMPT)}")
if USE_CHAT: print(f"Chat model response: {txgemma_chat(TDC_PROMPT)}")



Prediction model response: Instructions: Answer the following question about drug properties.
Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.
Question: Given a drug SMILES string, predict whether it
(A) does not cross the BBB (B) crosses the BBB
Drug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21
Answer:(B)
Chat model response: Instructions: Answer the following question about drug properties.
Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervo

In [5]:
# This will allow us to extract content from inside of ticks
def extract_prompt(text, word):
    code_block_pattern = rf"```{word}(.*?)```"
    code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
    extracted_code = "\n".join(code_blocks).strip()
    return extracted_code

# This class will allow us to inferface with TxGemma
class TxGemmaChatTool:
    def __init__(self):
      self.tool_name = "Chat Tool"

    def use_tool(self, question):
        # Here, we are submitting a question to TxGemma
        response = txgemma_chat(question)
        return response

    def tool_is_used(self, query):
        # This just checks to see if the tool call was evoked
        return "```TxGemmaChat" in query

    def process_query(self, query):
        # Here, we clean to query to remove the tool call
        return extract_prompt(query, word="TxGemmaChat")

    def instructions(self):
        return (
            "=== TX-009 Task: Therapeutic Chat Tool Instructions ===\n"
            "### What This Tool Does\n"
            "The TxGemma Therapeutic Chat Tool allows the agent to ask domain-specific questions to a large language model "
            "fine-tuned on therapeutic and biomedical datasets. It is particularly configured for TX-009 to:\n"
            "- Interpret trial metadata\n"
            "- Extract and rank top 5 contributing factors\n"
            "- Estimate clinical success probabilities (0.0–1.0)\n"
            "- Generate scientific summaries for Markdown or JSON reports\n\n"

            "### How to Use It\n"
            "Wrap your query in triple backticks (```), starting with `TxGemmaChat`. Write your question on the next line.\n"
            "Do NOT include external instructions inside your prompt, only the direct biomedical content.\n\n"

            "### Required Format\n"
            "```TxGemmaChat\n"
            "[your question, such as: What are the top 5 factors influencing success of Bemdaneprocel?]\n"
            "```\n\n"

            "### Example:\n"
            "```TxGemmaChat\n"
            "Estimate clinical success probability for Bemdaneprocel based on trial metadata and literature.\n"
            "```\n"
        )

In [6]:
if USE_CHAT:
    chat_tool = TxGemmaChatTool()
    response = chat_tool.use_tool("Can Aspirin help with headaches? Yes or no?")
    print(response)

Can Aspirin help with headaches? Yes or no? 



In [7]:
# Load data
df = pd.read_csv("Single_Disease_Dataset_with_Index.csv")

row = df[
    df['Stem-Cell Modality'].str.contains("bemdaneprocel", case=False, na=False) &
    df['Single_Disease'].str.contains("Parkinson", case=False, na=False)
].iloc[0]

In [8]:
#PubMed Search

! pip install --upgrade --quiet biopython

from Bio import Medline, Entrez

class PubMedSearch:
    def __init__(self):
        self.tool_name = "PubMed Search"

    def tool_is_used(self, query: str):
        return "```PubMedSearch" in query

    def process_query(self, query: str):
        search_text = extract_prompt(query, word="PubMedSearch")
        return search_text.strip()

    def use_tool(self, search_text):
        handle = Entrez.esearch(db="pubmed", sort="relevance", term=search_text, retmax=3)
        record = Entrez.read(handle)
        pmids = record.get("IdList", [])
        handle.close()

        if not pmids:
            return f"No PubMed articles found for '{search_text}'. Please try a simpler search query."

        fetch_handle = Entrez.efetch(db="pubmed", id=",".join(pmids), rettype="medline", retmode="text")
        records = list(Medline.parse(fetch_handle))
        fetch_handle.close()

        result_str = f"=== PubMed Search Results for: '{search_text}' ===\n"
        for i, record in enumerate(records, start=1):
            pmid = record.get("PMID", "N/A")
            title = record.get("TI", "No title available")
            abstract = record.get("AB", "No abstract available")
            journal = record.get("JT", "No journal info")
            pub_date = record.get("DP", "No date info")
            authors = record.get("AU", [])
            authors_str = ", ".join(authors[:3])
            result_str += (
                f"\n--- Article #{i} ---\n"
                f"PMID: {pmid}\n"
                f"Title: {title}\n"
                f"Authors: {authors_str}\n"
                f"Journal: {journal}\n"
                f"Publication Date: {pub_date}\n"
                f"<abstract_start>{abstract}</abstract_finish>\n"
            )
        return f"Query: {search_text}\nResults: {result_str}"

    def instructions(self):
        return (
            f"{'@' * 10}\n@@@ PubMed Search Tool Instructions @@@\n\n"
            "### What This Tool Does\n"
            "The PubMed Search Tool queries the NCBI Entrez API (PubMed) for a given search phrase, "
            "and retrieves metadata for a few of the top articles (PMID, title, authors, journal, date, abstract).\n\n"
            "### When / Why You Should Use It\n"
            "- To find **scientific literature** on biomedical topics like stem cell therapies.\n"
            "- To get **abstracts** as evidence for TxGemma prediction.\n\n"
            "### Query Format\n"
            "Use triple backticks ``` and start with `PubMedSearch`. Example:\n"
            "```PubMedSearch\nBemdaneprocel Parkinson's Disease\n```\n"
        )

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.3/3.3 MB[0m [31m177.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m91.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [9]:
pubmed_tool = PubMedSearch()
search_results = pubmed_tool.use_tool("Bemdaneprocel Parkinson's disease stem cell therapy")
print(search_results)

            Email address is not specified.

            To make use of NCBI's E-utilities, NCBI requires you to specify your
            email address with each request.  As an example, if your email address
            is A.N.Other@example.com, you can specify it as follows:
               from Bio import Entrez
               Entrez.email = 'A.N.Other@example.com'
            In case of excessive usage of the E-utilities, NCBI will attempt to contact
            a user at the email address provided before blocking access to the
            E-utilities.


Query: Bemdaneprocel Parkinson's disease stem cell therapy
Results: === PubMed Search Results for: 'Bemdaneprocel Parkinson's disease stem cell therapy' ===

--- Article #1 ---
PMID: 40240592
Title: Phase I trial of hES cell-derived dopaminergic neurons for Parkinson's disease.
Authors: Tabar V, Sarva H, Lozano AM
Journal: Nature
Publication Date: 2025 May
<abstract_start>Parkinson's disease is a progressive neurodegenerative condition with a considerable health and economic burden(1). It is characterized by the loss of midbrain dopaminergic neurons and a diminished response to symptomatic medical or surgical therapy as the disease progresses(2). Cell therapy aims to replenish lost dopaminergic neurons and their striatal projections by intrastriatal grafting. Here, we report the results of an open-label phase I clinical trial (NCT04802733) of an investigational cryopreserved, off-the-shelf dopaminergic neuron progenitor cell product (bemdaneprocel) derived from human embryonic stem (hE

In [10]:
#Creating a tool manager
class ToolManager:
    def __init__(self, toolset):
        self.toolset = toolset

    def tool_prompt(self):
        # This will let the agent know what tools it has access to
        tool_names = ", ".join([tool.tool_name for tool in self.toolset])
        return f"You have access to the following tools: {tool_names}\n{self.tool_instructions()}. You can only use one tool at a time. These are the only tools you have access to nothing else."

    def tool_instructions(self):
        # This allows the agent to know how to use the tools
        tool_instr = "\n".join([tool.instructions() for tool in self.toolset])
        return f"The following is a set of instructions on how to use each tool.\n{tool_instr}"

    def use_tool(self, query):
        # This will iterate through all of the tools
        # and find the correct tool that the agent requested
        for tool in self.toolset:
            if tool.tool_is_used(query):
                # use the tool and return the output
                return tool.use_tool(tool.process_query(query))
        return f"No tool match for search: {query}"

if USE_CHAT:
    tools = ToolManager([TxGemmaChatTool(), PubMedSearch()])
else:
    tools = ToolManager([PubMedSearch()])

In [11]:
#Creating a Gemini inference tool

def inference_gemini(prompt, system_prompt, model_str):
  # Check to see that our model string matches
  if model_str == "gemini-2.5-pro":
    model = genai.GenerativeModel(model_name="gemini-2.5-pro-preview-03-25", system_instruction=system_prompt)
    response = model.generate_content(prompt)
    answer = response.text
  return answer

In [12]:
#Creating a therapeutics agent

class AgenticTx:
    def __init__(self, tool_manager, model_str, num_steps=5):
        self.curr_steps = 0
        self.num_steps = num_steps
        self.model_str = model_str
        self.tool_manager = tool_manager
        self.thoughts = list()
        self.actions  = list()
        self.observations = list()

    def reset(self):
        self.curr_steps = 0

    def system_prompt(self, use_tools=True):
        role_prompt = "You are an expert therapeutic agent. You answer accurately and thoroughly."
        prev_actions = f"You can perform a maximum of {self.num_steps} actions. You have performed {self.curr_steps} and have {self.num_steps - self.curr_steps - 1} left."
        if use_tools:
            tool_prompt = "You can use tools to solve problems and answer questions. " + self.tool_manager.tool_prompt()
        else:
            tool_prompt = "You cannot use any tools right now."
        return f"{role_prompt} {prev_actions} {tool_prompt}"

    def prior_information(self, query):
        info_txt = f"Question: {query}\n" if query is not None else ""
        for _i in range(self.curr_steps):
            info_txt += f"### Thought {_i + 1}: {self.thoughts[_i]}\n"
            info_txt += f"### Action {_i + 1}: {self.actions[_i]}\n"
            info_txt += f"### Observation {_i + 1}: {self.observations[_i]}\n\n"
            info_txt += "@"*20
        return info_txt

    def step(self, question):
        for _i in range(self.num_steps):
            if self.curr_steps == self.num_steps - 1:
                return inference_gemini(
                    model_str=self.model_str,
                    prompt=f"{self.prior_information(question)}\nYou must now provide an answer to this question {question}",
                    system_prompt=self.system_prompt(use_tools=False))
            else:
                thought = inference_gemini(
                    model_str=self.model_str,
                    prompt=f"{self.prior_information(question)}\nYou cannot currently use tools but you can think about the problem and what tools you want to use. This was the question, think about plans for how to use tools to answer this {question}. Let's think step by step (respond with only 1-2 sentences).\nThought: ",
                    system_prompt=self.system_prompt(use_tools=False))
                action = inference_gemini(
                    model_str=self.model_str,
                    prompt=f"{self.prior_information(question)}\n{thought}\nNow you must use tools to answer the following user query [{question}], closely following the tool instructions. Tool",
                    system_prompt=self.system_prompt(use_tools=True))
                obs = self.tool_manager.use_tool(action)

                print("Thought:", thought)
                print("Action:", action)
                print("Observation:", obs)

                self.thoughts.append(thought)
                self.actions.append(action)
                self.observations.append(obs)

                self.curr_steps += 1

In [13]:
agentictx = AgenticTx(tool_manager=tools, model_str="gemini-2.5-pro")

trial_metadata = f"""
<trial_metadata_start>
Company: {row['Company (Website)']}
HQ Country: {row['HQ Country']}
Stem-Cell Modality: {row['Stem-Cell Modality']}
Development Stage: {row['Development Stage']}
Experimental vs. Formal: {row['Experimental vs. Formal']}
Latest Funding: {row['Latest Funding (Date, Amount, Lead)']}
Public/Private: {row['Public/Private']}
Clinical Trials (NCT): {row['Clinical Trials (NCT)']}
IP / Technology: {row['IP (Intellectual Property) / Technology']}
Partnerships: {row['Partnerships/Collaborations']}
Lead Disease Areas: {row['Lead Disease Areas']}
Single Disease: {row['Single_Disease']}
<trial_metadata_end>
"""

final_prompt = f"""Estimate the clinical success probability (range: 0.0–1.0) of Bemdaneprocel (BRT‑DA01), list the top 5 contributing factors, and write natural language summary of prediction based on the following trial metadata.

Respond ONLY with:
- A numerical success score (0.0–1.0)
- A list of the top 5 contributing factors
- A concise natural language summary

Do not restate the prompt or metadata.

{trial_metadata}
"""

response = agentictx.step(final_prompt)
print("\nFinal Response:", response)


Thought: 0.50

Top 5 contributing factors:
1.  Positive 18-month Phase 1 clinical data demonstrating initial safety and potential efficacy signals.
2.  Advancement in development to "preparing Phase 3," indicating successful navigation of early clinical hurdles.
3.  Strong corporate backing and substantial financial/developmental resources from Bayer following its acquisition of BlueRock Therapeutics.
4.  The innovative therapeutic modality, allogeneic iPSC-derived dopaminergic neurons, which offers a potentially restorative approach by directly replacing lost cells in Parkinson's Disease.
5.  The focus on Parkinson's Disease, an area of high unmet medical need, which may facilitate regulatory pathways if strong efficacy is demonstrated.

Summary:
Bemdaneprocel (BRT-DA01) has a moderate probability of clinical success, estimated at 0.50. This is supported by promising 18-month data from its completed Phase 1 trial, its progression towards a Phase 3 trial, and the substantial backing fr

In [14]:
import datetime

f_md = open("bemdaneprocel_report.md", "w")
f_md.write(response)
f_md.close()

summary_json = {
    "task": "TX-009",
    "timestamp": str(datetime.datetime.now()),
    "response": response
}

f_json = open("bemdaneprocel_report.json", "w")
json.dump(summary_json, f_json, indent=2)
f_json.close()

from google.colab import files

files.download("bemdaneprocel_report.md")
files.download("bemdaneprocel_report.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>