In [2]:
from transformers import AutoTokenizer
model_id = "m42-health/med42-70b"
# model_id = "meta-llama/Meta-Llama-3-70B-Instruct"

from pydantic import BaseModel, Field
from typing import List
from langchain.output_parsers import PydanticOutputParser

import pandas as pd
import huggingface_hub
import requests
from huggingface_hub import InferenceClient

from fuzzywuzzy import fuzz
from fuzzywuzzy import process

In [16]:
!pip install --upgrade huggingface_h

Collecting huggingface_hub
  Downloading huggingface_hub-0.23.0-py3-none-any.whl.metadata (12 kB)
Downloading huggingface_hub-0.23.0-py3-none-any.whl (401 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m401.2/401.2 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.20.2
    Uninstalling huggingface-hub-0.20.2:
      Successfully uninstalled huggingface-hub-0.20.2
Successfully installed huggingface_hub-0.23.0


In [8]:
def is_updated(old_memory, new_memory, threshold):
    old_str = "\n".join(old_memory)
    new_str = "\n".join(new_memory)
    if fuzz.ratio(old_str, new_str) >= threshold : 
        return True # update memory
    else:
        return False

def plot_in_box(lines):
    max_length = max(len(line) for line in lines if isinstance(line, str))
    print('-' * (max_length + 4))
    for line in lines:
        if "\n" in line:
            parts = line.split("\n")
            for part in parts:
                print(f"| {part.ljust(max_length)} |")
        else:
            print(f"| {line.ljust(max_length)} |")
    print('-' * (max_length + 4))


class Response(BaseModel):
    predictedStage: str 
    reasoning: str
    rules: List[str]

client = InferenceClient(model="http://127.0.0.1:8081/")

In [9]:
brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
sample_reports = brca_report.sample(n=50, random_state=123)
Response.schema()

{'properties': {'predictedStage': {'title': 'Predictedstage',
   'type': 'string'},
  'reasoning': {'title': 'Reasoning', 'type': 'string'},
  'rules': {'items': {'type': 'string'}, 'title': 'Rules', 'type': 'array'}},
 'required': ['predictedStage', 'reasoning', 'rules'],
 'title': 'Response',
 'type': 'object'}

In [11]:
system_instruction = "You are an expert at interpreting pathology reports for cancer staging."

initial_predict_prompt = """You are provided with a pathology report for a cancer patient.
Please review this report and determine the pathologic stage of the patient's cancer.

Here is the report:
```
{report}
```

What is the T stage from this report? Ignore any substaging information. Please select from the following four options:  T1, T2, T3, T4.
What is your reasoning to support your stage prediction?
Please induce a list of rules as knowledge that help you predict the next report. Make sure every rule does not contain any report-specific information. Instead, list general guidelines that apply to the specific cancer type and the AJCC staging system.

Please use the following schema: {schema}
"""

subsequent_predict_prompt = """You are provided with a pathology report for a cancer patient.
Here is a list of rules you leanred to correctly predict the cancer stage information:
```
{memory}
```

Please review this report and determine the pathologic stage of the patient's cancer.

Here is the report:
```
{report}
```

What is the T stage from this report? Ignore any substaging information. Please select from the following four options:  T1, T2, T3, T4.
What is your reasoning to support your stage prediction?
What is your updated list of rules that help you predict the next report? You can either modify the original rules or add new rules. Make sure every rule does not contain any report-specific information. Instead, list general guidelines that apply to the specific cancer type and the AJCC staging system.

Please use the following schema: {schema}
"""


prompt_template='''
<|system|>:{system_instruction}
<|prompter|>:{prompt}
<|assistant|>:
'''

In [12]:
for threshold in range(100, -1, -25):
    memory = "" # a list of strings
    correct_count = 0
    incorrect_count = 0

    for idx, row in sample_reports.iterrows():
        report = row["text"]
        label = row["t"]
        
        if memory == "":
            prompt = initial_predict_prompt.format(report=row["text"], schema=Response.schema())
            prompt = prompt_template.format(system_instruction=system_instruction, prompt=prompt)

            response = client.text_generation(
                prompt=prompt, do_sample=False, max_new_tokens=1024,
                grammar={"type": "json", "value": Response.schema()}
                )
            memory = response["rules"]
            print(f"Initial memory: {memory}\n")

        else:
             prompt = subsequent_predict_prompt.format(report=row["text"], schema=Response.schema())
             prompt = prompt_template.format(system_instruction=system_instruction, prompt=prompt)
             
             response = client.text_generation(
                prompt=prompt, do_sample=False, max_new_tokens=1024,
                grammar={"type": "json", "value": Response.schema()}
                )
             if is_updated(memory,response["rules"], threshold) and (f"T{label+1}" == response['predictedStage']):
                print(f"At {idx}, memory is updated")
                memory = response["rules"]
                print(f"New memory: {memory}")

        if f"T{label+1}" == response['predictedStage']:
                result = "Correct prediction"
                correct_count += 1
        else:
            result = f"Wrong prediction\nReasoning: {response['reasoning']}"
            incorrect_count += 1
            lines = [
                f"Report Index: {idx}",
                f"Label: T{label+1}",
                f"Prediction: {response['predictedStage']}",
                result
            ]
            plot_in_box(lines)

    print(f"when threshold is {threshold}")
    print(f"correct: {correct_count}, incorrect: {incorrect_count}")

GenerationError: Request failed during generation: Server error: CANCELLED

In [25]:
response

'{  \n  "predictedStage": "T2",  \n"reasoning": "The tumor is staged as T2 according to the AJCC staging system, to reflect the presence of tumor in the axillary sentinel node and the absence of metastatic disease in the lymph node. The tumor is also characterized by infiltrating ductal carcinoma with a well-differentiated grade and an overall grade of  Elston SBR grade 1."  \n   \n \n\n    \n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n    \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n    \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n    \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n    \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n    \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   \n\n\n\n   

In [24]:
print(response)

{  
  "predictedStage": "T2",  
"reasoning": "The tumor is staged as T2 according to the AJCC staging system, to reflect the presence of tumor in the axillary sentinel node and the absence of metastatic disease in the lymph node. The tumor is also characterized by infiltrating ductal carcinoma with a well-differentiated grade and an overall grade of  Elston SBR grade 1."  
   
 

    

   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



    



   



   



   



   



   



   



    



   



   



   



   



    



   



   



   



   



   



   



    



   



   



   



   



   



   



   



    



   



   



   



   



   



   



   



    



   



   



   



   



   



   



   



   



   



   



   



    



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



   



  

In [13]:
import ast
ast.literal_eval(response)

SyntaxError: '{' was never closed (<unknown>, line 1)

In [13]:
correct_count, incorrect_count

(30, 5, 15)

In [15]:
for idx, report in sample_reports.iterrows():
    print(idx)

134
13
965
779
962
98
291
528
426
1019
138
492
381
978
491
1029
345
235
246
203
909
896
161
85
318
977
145
538
43
379
521
710
626
338
50
171
114
95
988
868
624
182
147
328
378
943
831
929
852
595
