In [1]:
from typing import List

import json
from langchain_google_genai import GoogleGenerativeAI
from langchain.prompts import PromptTemplate, MessagesPlaceholder, ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel, Field, validator
import pandas as pd
from tqdm.notebook import tqdm

import util

In [3]:
from getpass import getpass

api_key = getpass()

In [9]:
data_config = util.get_data_config()

In [10]:
llm_config = {
    'temperature': 0.,
}


In [11]:
llm = GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key=api_key, **llm_config)

In [12]:
llm

GoogleGenerativeAI(model='gemini-1.0-pro', google_api_key=SecretStr('**********'), temperature=0.0, client=genai.GenerativeModel(
    model_name='models/gemini-1.0-pro',
    generation_config={},
    safety_settings={},
    tools=None,
))

# Load Data

In [13]:
summaries = pd.read_parquet(data_config['summaries_file_name'])
summaries.shape

(51, 2)

In [14]:
summaries.head()

Unnamed: 0,nct_id,brief_summary
0,NCT00037648,The purpose of this study is to determine the ...
1,NCT00048542,"This is a multicenter, Phase 3 randomized, pla..."
2,NCT00071487,The purpose of this study is to evaluate the s...
3,NCT00071812,The purpose of this study is to evaluate the s...
4,NCT00072839,The purpose of the study is to determine wheth...


In [15]:
synonyms_df = pd.read_parquet(data_config['processed_synonyms_file_name'])
synonyms_df.head()

Unnamed: 0_level_0,preferred_name,synonyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1
4.0,levobupivacaine,"[chirocain, levobupivacaine, levobupivacaine H..."
5.0,(S)-nicardipine,"[(-)-Nicardipine, (-)-nicardipine, (S)-nicardi..."
6.0,(S)-nitrendipine,"[(-)-Nitrendipine, (-)-nitrendipine, (S)-nitre..."
13.0,levdobutamine,"[LY206243, levdobutamine, levdobutamine lactob..."
21.0,aminopterin,"[4-aminofolic acid, Aminofolic acid, 4-, amino..."


In [16]:
with open(data_config['ground_truth_raw_file_name'], 'r') as fin:
    ground_truth = json.loads(fin.read())
    
with open(data_config['ground_truth_cleaned_file_name'], 'r') as fin:
    ground_truth_cleaned = json.loads(fin.read())

# Prompt without examples

In [17]:
class Drug(BaseModel):
    """Information about a drug. Placebo is not a drug."""
    name: str = Field(..., description='The name of the drug')

class Drugs(BaseModel):
    """Identifying information about all drugs in a text."""
    drugs: List[Drug]

class Disease(BaseModel):
    """Information about a disease."""
    name: str = Field(..., description='The name of the disease')

class Data(BaseModel):
    """Identifying information about all drugs and diseases in a text."""
    drugs: List[Drug]
    diseases: List[Disease]

def get_prompt_messages(with_examples: bool = False):
    messages = [
        (
            'system',
            'You are an expert extraction algorithm. '
            'Only extract relevant information from the text. '
            ' The text is a brief summary text of a clinical trial. '
            'If you do not know the value of an attribute asked to extract, '
            "return null for the attribute's value. Wrap the output in `json` tags\n{format_instructions}",
        ),
    ]
    if with_examples:
        messages.append(MessagesPlaceholder('examples'))
    
    messages.append(('human', '{text}'))
    return messages


def score(chain, example_messages=None):
    all_responses = []
    extracted = []
    for _, row in tqdm(summaries.iterrows(), total=summaries.shape[0]):
        text = row['brief_summary']
        
        chain_input = {'text': text}
        if example_messages is not None:
            chain_input['examples'] = example_messages
        response = chain.invoke(chain_input)
        all_responses.append(response)
        drugs = response.drugs
        drug_names = [drug.name for drug in drugs]
        extracted.append(drug_names)
        print(drug_names)
    return all_responses, extracted


def evaluate(summaries_df: pd.DataFrame, extracted: list[list[str]], preferred_name_by_term: dict[str, str]) -> pd.DataFrame:
    df = summaries_df.copy()

    # Clean up extracted text to match synonyms.
    df['extracted_terms_raw'] = extracted
    df['extracted_terms_cleaned'] = df['extracted_terms_raw'].map(lambda terms: [util.clean_up_synonym_term(term) for term in terms])
    
    extractive_terms_preferred_names = []
    for terms in df['extracted_terms_cleaned']:
        terms_preferred_names = []
        for term in terms:
            if term in preferred_name_by_term:
                terms_preferred_names.append(preferred_name_by_term[term])
        extractive_terms_preferred_names.append(terms_preferred_names)
    
    df['extracted_terms'] = extractive_terms_preferred_names
    
    # df['extracted_terms'] = extracted
    df['gt_cleaned'] = df.nct_id.map(lambda x: ground_truth_cleaned[x])
    df['match_count'] = df.apply(lambda row: len(set(row['extracted_terms']).intersection(set(row['gt_cleaned']))), axis=1)
    df['extracted_terms_count'] = df['extracted_terms'].map(len)
    df['gt_cleaned_count'] = df['gt_cleaned'].map(len)
    
    true_positive_count = df['match_count'].sum()
    precision = true_positive_count / df['extracted_terms_count'].sum()
    recall = true_positive_count / df['gt_cleaned_count'].sum()
    print(f'Precision: {precision: .2%}')
    print(f'Recall: {recall: .2%}')
    return df


In [18]:
# Set up a parser and prompt
parser = PydanticOutputParser(pydantic_object=Drugs)

prompt = ChatPromptTemplate.from_messages(
  get_prompt_messages(with_examples=False)
).partial(format_instructions=parser.get_format_instructions())

In [19]:
text = summaries['brief_summary'].iloc[0] 

In [20]:
print(prompt.format_prompt(text=text).to_string())

System: You are an expert extraction algorithm. Only extract relevant information from the text.  The text is a brief summary text of a clinical trial. If you do not know the value of an attribute asked to extract, return null for the attribute's value. Wrap the output in `json` tags
The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.

Here is the output schema:
```
{"description": "Identifying information about all drugs in a text.", "properties": {"drugs": {"title": "Drugs", "type": "array", "items": {"$ref": "#/definitions/Drug"}}}, "required": ["drugs"], "definitions": {"Drug": {"title": "Drug", "description": "Informatio

In [21]:
chain = prompt | llm | parser
response = chain.invoke({'text': text})
response

Drugs(drugs=[Drug(name='anakinra')])

In [22]:
response.drugs

[Drug(name='anakinra')]

In [23]:
all_responses, extracted = score(chain)


  0%|          | 0/51 [00:00<?, ?it/s]

['anakinra']
['adalimumab', 'methotrexate']
['belimumab']
['belimumab']
['ALX-0600']
['rituximab', 'methotrexate', 'placebo']
['etanercept']
['etanercept']
['Omalizumab']
[]
[]
['BMS-188667']
['mepolizumab']
['tocilizumab', 'methotrexate', 'placebo']
['tocilizumab', 'methotrexate', 'placebo']
['tocilizumab', 'methotrexate', 'placebo']
['tocilizumab', 'methotrexate', 'placebo']
['AMN107', 'imatinib']
['anakinra', 'placebo']
['etanercept']
['Abatacept', 'prednisone']
['pimecrolimus cream 1%', 'topical corticosteroids (TCS)']
['nitazoxanide', 'placebo']
['rituximab', 'placebo']
['etanercept']
['prednisone MR', 'prednisone IR']
['imatinib mesylate', 'prednisone', 'hydroxyurea', 'interferon-alpha']
['Certolizumab Pegol']
['adalimumab', 'methotrexate']
['adalimumab', 'placebo', 'methotrexate']
['Leukine']
['sargramostim']
['sargramostim']
['Leukine']
['Golimumab (CNTO 148)']
['CNTO 148 (golimumab)']
['UVADEX']
['Adalimumab', 'Methotrexate', 'Placebo']
['voclosporin', 'placebo']
[]
['CDP870',

In [24]:
# all_responses = []
# for _, row in tqdm(summaries.iterrows(), total=summaries.shape[0]):
#     text = row['brief_summary']
#     response = chain.invoke({'text': text})
#     all_responses.append(response)

In [25]:
# extracted = []
# for response in all_responses:
#     drugs = response.drugs
#     drug_names = [drug.name for drug in drugs]
#     extracted.append(drug_names)


In [26]:
synonym_maps = util.get_synonym_maps(synonyms_df)
preferred_name_by_term = synonym_maps['preferred_name_by_term']

In [27]:
test_df = evaluate(
    summaries_df=summaries,
    extracted=extracted,
    preferred_name_by_term=preferred_name_by_term
)

Precision:  100.00%
Recall:  91.23%


In [28]:
test_df[test_df['match_count'] != test_df['gt_cleaned_count']]

Unnamed: 0,nct_id,brief_summary,extracted_terms_raw,extracted_terms_cleaned,extracted_terms,gt_cleaned,match_count,extracted_terms_count,gt_cleaned_count
21,NCT00120523,The primary purpose of this study is to invest...,"[pimecrolimus cream 1%, topical corticosteroid...","[pimecrolimuscream1%, topicalcorticosteroids(t...",[],[pimecrolimus],0,0,1
25,NCT00146640,The objective of this study is to investigate ...,"[prednisone MR, prednisone IR]","[prednisonemr, prednisoneir]",[],[prednisone],0,0,1
34,NCT00207714,"Multicenter, randomized, double-blind, placebo...",[Golimumab (CNTO 148)],[golimumab(cnto148)],[],[golimumab],0,0,1
35,NCT00207740,The purpose of this study is to evaluate the e...,[CNTO 148 (golimumab)],[cnto148(golimumab)],[],[golimumab],0,0,1
47,NCT00267956,The purpose of this study is to evaluate the e...,[CNTO 1275 (ustekinumab)],[cnto1275(ustekinumab)],[],[ustekinumab],0,0,1


The issue lies on the extracted data cleanup phase. The LLM recognizes all the drugs in the (small) dataset, but it extracts a bit more text, ex. `CNTO 1275 (ustekinumab)`. The nltk approach considers this two instances of the drug name.

Several false positives have been removed by the cleanup phase above.

A Retrieval step on a datastore built on top of the synonyms could help clean up the data in a better fashion.

# Add examples 

In [29]:
examples = [
    (
        'The purpose of this study is to assess the safety and tolerability of CNTO 5825 following a single intravenous (IV) or subcutaneous (SC) dose administration in healthy volunteers.',
        '```json\n{\n  "drugs": [\n    {\n      "name": "CNTO 5825"\n    }\n  ]\n}\n```',
    ),
    (
        'The purpose of Part 1 of this study is to assess the safety and tolerability of 2 dose levels (1.4 and 2.8 mg/kg) of CHO-derived CNTO 328 and Sp2/0-derived CNTO 328. The purpose of Part 2 of this study is to access the pharmacokinetics (what the body does to the study medication) comparability of the 1.4 mg/kg dose of CHO-derived CNTO 328 and Sp2/0-derived CNTO 328.',
        '```json\n{\n  "drugs": [\n    {\n      "name": "CNTO 328 (CHO-derived)"\n    },\n    {\n      "name": "CNTO 328 (Sp2/0-derived)"\n    }\n  ]\n}\n```'
    ),
    (
        '''The purpose of this study is to:

Evaluate the efficacy of Adapalene gel 0.3% compared to Tretinoin Emollient cream 0.05%, reducing signs of cutaneous photoageing, measured trough photonumeric scale evaluation, investigator evaluation of global response to treatment and subject's evaluation of improvement.
Evaluate the safety and tolerability of Adapalene Gel 0.3%, compared to Tretinoin Emollient cream 0.05% during 24 weeks of treatment.
The study has the clinical hypothesis that Adapalene Gel 0.3% is as effective as Tretinoin Emollient cream 0.05% in the treatment of cutaneous photoaging.''',
        '```json\n{\n  "drugs": [\n    {\n      "name": "Adapalene"\n    },\n    {\n      "name": "Tretinoin Emollient cream 0.05%"\n    }\n  ]\n}\n```'
    )
]

In [30]:
def example_to_messages(example):
    messages: List[BaseMessage] = [HumanMessage(content=example[0])]
    messages.append(AIMessage(content=example[1]))
    return messages
    

In [31]:
example_messages = []
for example in examples:
    example_messages.extend(example_to_messages(example))

prompt_with_examples = ChatPromptTemplate.from_messages(
    get_prompt_messages(with_examples=True),
).partial(
    format_instructions=parser.get_format_instructions(),
)
print(prompt_with_examples.format_prompt(text=text, examples=example_messages).to_string())

System: You are an expert extraction algorithm. Only extract relevant information from the text.  The text is a brief summary text of a clinical trial. If you do not know the value of an attribute asked to extract, return null for the attribute's value. Wrap the output in `json` tags
The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.

Here is the output schema:
```
{"description": "Identifying information about all drugs in a text.", "properties": {"drugs": {"title": "Drugs", "type": "array", "items": {"$ref": "#/definitions/Drug"}}}, "required": ["drugs"], "definitions": {"Drug": {"title": "Drug", "description": "Informatio

In [32]:
chain_with_examples = prompt_with_examples | llm | parser
response = chain_with_examples.invoke({'text': text, 'examples': example_messages})
response

Drugs(drugs=[Drug(name='anakinra')])

In [33]:
all_responses_with_examples, extracted_with_examples = score(chain_with_examples, example_messages=example_messages)

  0%|          | 0/51 [00:00<?, ?it/s]

['anakinra']
['adalimumab', 'methotrexate (MTX)']
['belimumab']
['belimumab']
['ALX-0600']
['Rituximab', 'Methotrexate']
['etanercept']
['etanercept']
['Omalizumab']
[]
[]
['BMS-188667 (Abatacept)']
['mepolizumab']
['tocilizumab', 'methotrexate (MTX)']
['tocilizumab', 'methotrexate']
['tocilizumab', 'methotrexate (MTX)']
['tocilizumab', 'methotrexate']
['AMN107', 'Imatinib']
['anakinra']
['etanercept']
['Abatacept', 'Prednisone']
['pimecrolimus cream 1%', 'topical corticosteroids (TCS)']
['nitazoxanide', 'placebo']
['rituximab']
['etanercept']
['prednisone MR formulation', 'prednisone IR']
['imatinib mesylate']
['Certolizumab Pegol (CZP)']
['adalimumab', 'methotrexate']
['adalimumab', 'methotrexate']
['Leukine']
['sargramostim']
['sargramostim']
['Leukine']
['Golimumab (CNTO 148)', 'Methotrexate (MTX)']
['CNTO 148 (golimumab)']
['UVADEX']
['Adalimumab', 'Methotrexate']
['voclosporin', 'placebo']
[]
['CDP870', 'placebo']
['Abatacept']
['golimumab', 'methotrexate']
['golimumab', 'methotr

In [34]:
test_with_examples_df = evaluate(
    summaries_df=summaries,
    extracted=extracted_with_examples,
    preferred_name_by_term=preferred_name_by_term
)

Precision:  100.00%
Recall:  75.44%


# [incomplete] Cleaning up responses with RAG

Well, the retrieval part, not generation.


Adapted from [langchain's RAG cookbook](https://python.langchain.com/docs/expression_language/cookbook/retrieval)

In [39]:
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import math

In [36]:
embeddings_llm = GoogleGenerativeAIEmbeddings(model='models/embedding-001', google_api_key=api_key)


In [37]:
embeddings_llm.embed_query(text)

[0.023084601,
 -0.025059354,
 -0.07961606,
 0.008955923,
 0.04864266,
 0.007839439,
 -0.0045868554,
 -0.025806641,
 0.014736501,
 0.048185743,
 0.004881629,
 0.03529224,
 0.026073176,
 -0.003535608,
 0.014341052,
 -0.015081712,
 0.022542626,
 0.007935388,
 -0.016002657,
 0.019658072,
 0.014963381,
 0.034099385,
 -0.019226383,
 -0.026446614,
 -0.040287122,
 -0.031387016,
 0.03283562,
 -0.0710565,
 -0.014579742,
 0.042835716,
 -0.03685568,
 0.027751768,
 -0.0043345797,
 -0.0010353795,
 -0.059268977,
 -0.084676735,
 -0.018681535,
 -0.030509226,
 -0.012053325,
 0.03248099,
 -0.0020808352,
 -0.020201702,
 -0.052452028,
 0.0051273117,
 0.02359672,
 0.016149832,
 -0.012028812,
 0.04197049,
 0.041959573,
 -0.063283674,
 0.04432846,
 0.036107812,
 0.054297153,
 -0.012818668,
 0.019672962,
 -0.03972611,
 0.05740015,
 -0.020961974,
 -0.0196247,
 0.046559334,
 -0.044270094,
 0.070880875,
 -0.03571292,
 0.042567857,
 -0.0047877873,
 -0.05406059,
 0.012660801,
 -0.049162626,
 0.06778261,
 0.01251893

In [168]:
# limit = 100
limit = math.inf
texts = []
metadata = []
for index, (term, preferred_name) in enumerate(preferred_name_by_term.items()):
    texts.append(term)
    metadata.append({'preferred_name': preferred_name})
    # if index + 1  > limit:
    if index > limit - 1:
        break


In [169]:
len(texts)

23140

In [170]:
data_store = FAISS.from_texts(
    texts=texts,
    metadatas=metadata,
    embedding=embeddings_llm,
)

In [174]:
data_store.save_local('data/faiss_index')

In [173]:
# data_store.save_to_file('data/FAISS_data_store')

In [185]:
# CNTO 1275 (ustekinumab)
# cnto1275(ustekinumab)

res = data_store.search(
    # query='cnto1275(ustekinumab)',
    query='cnto1275',
    # query='CNTO 1275 (ustekinumab)',
    search_type='similarity',
    # search_type='mmr',
    k=20,
)

In [186]:
res

[Document(page_content='cnto1275', metadata={'preferred_name': 'ustekinumab'}),
 Document(page_content='cnto-1275', metadata={'preferred_name': 'ustekinumab'}),
 Document(page_content='cnto4424', metadata={'preferred_name': 'amivantamab'}),
 Document(page_content='cnto148', metadata={'preferred_name': 'golimumab'}),
 Document(page_content='nn1250', metadata={'preferred_name': 'insulin degludec'}),
 Document(page_content='cnto328', metadata={'preferred_name': 'siltuximab'}),
 Document(page_content='cm6912', metadata={'preferred_name': 'ethyl loflazepate'}),
 Document(page_content='cns7056', metadata={'preferred_name': 'remimazolam'}),
 Document(page_content='ono1206', metadata={'preferred_name': 'limaprost'}),
 Document(page_content='cnto-328', metadata={'preferred_name': 'siltuximab'}),
 Document(page_content='cp65703', metadata={'preferred_name': 'ampiroxicam'}),
 Document(page_content='cm-8282', metadata={'preferred_name': 'omoconazole'}),
 Document(page_content='ono1078', metadata={

In [184]:
'cnto1275' in preferred_name_by_term

True

In [158]:
res

[Document(page_content='acamprosate', metadata={'preferred_name': 'acamprosate'}),
 Document(page_content='5-fluorouracil', metadata={'preferred_name': 'fluorouracil'}),
 Document(page_content='fluorouracil', metadata={'preferred_name': 'fluorouracil'}),
 Document(page_content='methoxsalen', metadata={'preferred_name': 'methoxsalen'}),
 Document(page_content='ladakamycin', metadata={'preferred_name': 'azacitidine'}),
 Document(page_content='ledakamycin', metadata={'preferred_name': 'azacitidine'}),
 Document(page_content='azacitidine', metadata={'preferred_name': 'azacitidine'}),
 Document(page_content='abacavirsulfate', metadata={'preferred_name': 'abacavir'}),
 Document(page_content='abacavirsuccinate', metadata={'preferred_name': 'abacavir'}),
 Document(page_content='methoxsalene', metadata={'preferred_name': 'methoxsalen'})]

In [155]:
raw_synonym_maps = util.get_synonym_maps(synonyms_df, clean_up_terms=False)
raw_terms_by_preferred_name = raw_synonym_maps['terms_by_preferred_name']

In [156]:
raw_terms_by_preferred_name['ustekinumab']


array(['CNTO 1275', 'CNTO-1275', 'cnto 1275', 'cnto-1275', 'stelara',
       'ustekinumab'], dtype=object)

In [159]:
raw_terms_by_preferred_name['acamprosate']

array(['N-acetylhomotaurine', 'acamprosaic acid', 'acamprosate',
       'acamprosate calcium', 'calcium acetylhomotaurinate',
       'calcium acetylhomotaurine', 'campral', 'n-acetylhomotaurine',
       'zulex'], dtype=object)