# Model for testing

```bash
CUDA_VISIBLE_DEVICES="4,5", python -m vllm.entrypoints.openai.api_server \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--download-dir /secure/chiahsuan/hf_cache/ \
--tensor-parallel-size 2
```

# 1 Load packages and define class
== The following two cells can be organized into a separate py file as a module.==

In [1]:
from openai import OpenAI
import pandas as pd
from tqdm import tqdm
from prompt import system_instruction, baseline_prompt, initial_predict_prompt, \
subsequent_predict_prompt, testing_predict_prompt
import json
from pydantic import BaseModel, Field
from typing import List

# define class
class ChoiceAgent:
    """ the simplest agent, which is appropriate for zero-shot prompting
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template: str, choices: dict) -> None:
        self.client = client
        self.model = model
        self.prompt_template = prompt_template
        self.choices = choices

    def get_response_from_choices(self, messages: list, temperature:float) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            extra_body={"guided_choice":self.choices},
            temperature = temperature
        )
        return response.choices[0].message.content

    def run(self, dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:

        pbar = tqdm(total=dataset.shape[0])
        for idx, row in dataset.iterrows():
            report = row["text"]
            prompt = self.prompt_template.format(report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]
            answer = self.get_response_from_choices(messages, temperature)
            dataset.loc[idx, "ans_str"] = answer
            pbar.update(1)
        pbar.close()

        return dataset

In [2]:
class MemoryAgent:
    """ the implementation of memory agent, which learn memory from training set and
    utilize memory as contexts for testing set.
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
        self.client = client
        self.model = model
        self.prompt_template_dict = prompt_template_dict
        self.validate_prompt_template()
        self.schema_dict = schema_dict
        self.validate_schema()
        self.memory = ""

    def validate_prompt_template(self) -> None:
        keys = self.prompt_template_dict.keys()
        initial_prompt_exist = "initialized_prompt" in keys
        learning_prompt_exist = "learning_prompt" in keys
        testing_prompt_exist = "testing_prompt" in keys
        assert True == initial_prompt_exist == learning_prompt_exist == testing_prompt_exist, \
        "You should provide a dict with initialized_prompt, learning_prompt, and testing_prompt as keys."

    def validate_schema(self) -> None:
        keys = self.schema_dict.keys()
        learning_schema_exist = "learning_schema" in keys
        testing_schema_exist = "testing_schema" in keys
        assert True == learning_schema_exist == testing_schema_exist, \
        "You should provide a dict with learning_schema and testing_schema as keys."

    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> dict | None:
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                extra_body={"guided_json":schema},
                temperature = temperature
            )
            return json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
        except json.JSONDecodeError:
            return None
        
    def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=training_dataset.shape[0])
        parsing_error = 0
        for idx, row in training_dataset.iterrows():

            report = row["text"]

            if self.memory == "":
                prompt = self.prompt_template_dict["initialized_prompt"].format(report=report)
            else:
                prompt = self.prompt_template_dict["learning_prompt"].format(memory=self.memory, report=report)
            
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["learning_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            self.memory = json_output['rules']
            training_dataset.loc[idx, "reasoning"] = json_output['reasoning']
            training_dataset.loc[idx, "ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        return training_dataset
    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=testing_dataset.shape[0])
        parsing_error = 0
        for idx, row in testing_dataset.iterrows():

            report = row["text"]

            prompt = self.prompt_template_dict["testing_prompt"].format(memory=self.memory, report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["testing_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            testing_dataset.loc[idx, "reasoning"] = json_output['reasoning']
            testing_dataset.loc[idx, "ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        return testing_dataset

In [None]:
class ConditionalMemoryAgent(MemoryAgent):
  
  def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
    # inherit all properties and methods from MemoryAgent
    super().__init__(client, model, prompt_template_dict, schema_dict)
  
  def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
    # only overide this function because the rest parts are the same
    pass

# 2 Demonstration

In [4]:
openai_api_key = "Empty"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
df_samples = brca_report.sample(10)
df_testing_samples = brca_report.sample(5)

In [5]:
zs_agent = ChoiceAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
                 prompt_template=baseline_prompt,
                 choices=["T1", "T2", "T3", "T4"])

In [12]:
zs_agent.run(df_testing_samples)

100%|██████████| 5/5 [00:00<00:00,  5.04it/s]


Unnamed: 0.1,Unnamed: 0,patient_filename,t,text,type,n,reasoning,ans_str
858,1303,TCGA-E9-A228.B0E8F3C1-E996-4E70-9630-AC95AF6E4EDC,1,"aterality: Right, lower inner quadrant. Path R...",BRCA,1,The report states that the tumor size is 2.5x2...,T2
291,736,TCGA-AC-A8OP.F4F5C477-30BB-41EE-B188-B20DE019F30A,0,Gender: F. Patient Location: Date of Service: ...,BRCA,-1,The report states that the invasive tumor size...,T1
845,1290,TCGA-E9-A1R6.8A865E33-082A-454F-A6DF-89E994206E65,1,OC ID: Gross Description: Lump with the tumor ...,BRCA,0,The report states that the tumor size is 2.2 x...,T2
671,1116,TCGA-D8-A1J9.D80E9389-AAD8-4EEB-9DE0-5DE57B5E5F6B,0,page 1 / 2. copy No. Examination: Histopatholo...,BRCA,0,The report states that the tumor size is 1.8 x...,T1
800,1245,TCGA-E2-A1IO.A9D36308-DAEF-48B9-A4EC-6B9B24EE3DC2,0,SPECIMENS: A. SLN #1 RIGHT AXILLA. B. RIGHT BR...,BRCA,0,The report states that the tumor size is 1.2 c...,T1


In [7]:
class TrainingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
    rules: List[str] = Field(description="list of rules") 
training_schema = TrainingResponse.model_json_schema()

class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

memory_agent = MemoryAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema})

In [8]:
memory_agent.train(df_samples)

100%|██████████| 10/10 [00:53<00:00,  5.34s/it]


Unnamed: 0.1,Unnamed: 0,patient_filename,t,text,type,n,ans_str,reasoning
197,642,TCGA-A8-A086.D4E86E20-B75E-4CDA-83F2-22C481D4B9E8,0,Diagnosis: 1. Moderately differentiated invasi...,BRCA,1,T1,The T stage in this report indicates the size ...
125,570,TCGA-A7-A0CH.F70DA7E5-2AB3-487F-8B2C-72EEE91E58E0,1,SPECIMEN. A. Sentinel node left breast. B. Lef...,BRCA,-1,T2,The T stage is determined by the size of the p...
596,1041,TCGA-BH-A0HY.E57B4989-102F-44FD-8FB2-0627143FA904,0,FINAL DIAGNOSIS: PART 1: LEFT AXILLARY SENTINE...,BRCA,0,T1,The report states that the maximum dimension o...
311,756,TCGA-AN-A0FF.9B4BA15D-A071-4213-AD2C-075A04ED48CD,0,Sex: Female. Diagnosis: Breast Cancer Histolog...,BRCA,0,T1,The pathology report states that the T stage i...
1023,1509,TCGA-UU-A93S.821ED144-DF12-4E49-ADC7-27FA5E422B83,3,Sex: F. Account: Date Collected: Date Received...,BRCA,3,T4,The report indicates that the tumor size is 19...
378,823,TCGA-AO-A1KR.FF188295-E139-4AEE-8EE8-364536F2BBE8,1,Clinical Diagnosis & History: y/o female with ...,BRCA,0,T2,The T stage is determined by the size and exte...
560,1005,TCGA-BH-A0DS.E4CBDC29-63E2-4CF0-B02F-16CA9C69F26E,1,P.15/33. GSIS;. TAXILLARY SENTINEL LYMPH NODE ...,BRCA,2,T2,The report states that the main tumor mass mea...
250,695,TCGA-AC-A23G.810663DD-1718-4FD6-B9C2-77B4C091B3F1,0,Acct/Res. DIAGNOSIS. DIAGNOSIS. A. Right axili...,BRCA,1,T1,The report states that the tumor size is up to...
911,1356,TCGA-EW-A1PC.4A0524B4-2A0B-4442-91B6-0E2FEA90CBC6,2,Pathologic Interpretation: A. Sentinel node #1...,BRCA,-1,T3,The T stage is determined by the size and exte...
646,1091,TCGA-BH-A208.4F943D12-E769-45F3-86BE-75193786DD4E,1,DIABNOSIS: 16 H&E STAINED SLIDES. PREVIOUS REP...,BRCA,1,T3,"The report describes the tumor size as 6.0 cm,..."


In [11]:
memory_agent.memory

['T stage is determined by the size and extension of the primary tumor.',
 'T1 stage indicates a tumor size of less than 2 cm.',
 'T stage can also be determined by the presence or absence of certain clinical findings such as ulceration, fixation to the chest wall or skin, and distant metastasis.',
 'Invasion of the skeletal muscle is a clinical finding that contributes to a T4 stage.',
 'A tumor size of 2 cm or greater is indicative of a T2 or higher stage.',
 'Direct invasion into the dermis is a clinical finding that contributes to a higher T stage.',
 'Presence of vascular invasion and perineural invasion are clinical findings that contribute to a higher T stage.',
 'Presence of extracapsular extension in lymph nodes is a clinical finding that contributes to a higher T stage.']

In [10]:
memory_agent.test(df_testing_samples)

100%|██████████| 5/5 [00:11<00:00,  2.25s/it]


Unnamed: 0.1,Unnamed: 0,patient_filename,t,text,type,n,reasoning,ans_str
858,1303,TCGA-E9-A228.B0E8F3C1-E996-4E70-9630-AC95AF6E4EDC,1,"aterality: Right, lower inner quadrant. Path R...",BRCA,1,The report states that the tumor size is 2.5x2...,T2
291,736,TCGA-AC-A8OP.F4F5C477-30BB-41EE-B188-B20DE019F30A,0,Gender: F. Patient Location: Date of Service: ...,BRCA,-1,The report states that the invasive tumor size...,T1
845,1290,TCGA-E9-A1R6.8A865E33-082A-454F-A6DF-89E994206E65,1,OC ID: Gross Description: Lump with the tumor ...,BRCA,0,The report states that the tumor size is 2.2 x...,T2
671,1116,TCGA-D8-A1J9.D80E9389-AAD8-4EEB-9DE0-5DE57B5E5F6B,0,page 1 / 2. copy No. Examination: Histopatholo...,BRCA,0,The report states that the tumor size is 1.8 x...,T1
800,1245,TCGA-E2-A1IO.A9D36308-DAEF-48B9-A4EC-6B9B24EE3DC2,0,SPECIMENS: A. SLN #1 RIGHT AXILLA. B. RIGHT BR...,BRCA,0,The report states that the tumor size is 1.2 c...,T1
