# Model for testing

```bash
CUDA_VISIBLE_DEVICES="4,5", python -m vllm.entrypoints.openai.api_server \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--download-dir /secure/chiahsuan/hf_cache/ \
--tensor-parallel-size 2
```

# 1 Load packages and define class
== The following two cells can be organized into a separate py file as a module.==

In [22]:
from openai import OpenAI
import pandas as pd
from tqdm import tqdm
from prompt import system_instruction, baseline_prompt, initial_predict_prompt, \
subsequent_predict_prompt, testing_predict_prompt
import json
from pydantic import BaseModel, Field
from typing import List, Dict, Union
import re
import ast

# define class
class ChoiceAgent:
    """ the simplest agent, which is appropriate for zero-shot prompting
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template: str, choices: dict) -> None:
        self.client = client
        self.model = model
        self.prompt_template = prompt_template
        self.choices = choices

    def get_response_from_choices(self, messages: list, temperature:float) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            extra_body={"guided_choice":self.choices},
            temperature = temperature
        )
        return response.choices[0].message.content

    def run(self, dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:

        pbar = tqdm(total=dataset.shape[0])
        for idx, row in dataset.iterrows():
            report = row["text"]
            prompt = self.prompt_template.format(report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]
            answer = self.get_response_from_choices(messages, temperature)
            dataset.loc[idx, "ans_str"] = answer
            pbar.update(1)
        pbar.close()

        return dataset

In [14]:
class MemoryAgent:
    """ the implementation of memory agent, which learn memory from training set and
    utilize memory as contexts for testing set.
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
        self.client = client
        self.model = model
        self.prompt_template_dict = prompt_template_dict
        self.validate_prompt_template()
        self.schema_dict = schema_dict
        self.validate_schema()
        self.memory = ""

    def validate_prompt_template(self) -> None:
        keys = self.prompt_template_dict.keys()
        initial_prompt_exist = "initialized_prompt" in keys
        learning_prompt_exist = "learning_prompt" in keys
        testing_prompt_exist = "testing_prompt" in keys
        assert True == initial_prompt_exist == learning_prompt_exist == testing_prompt_exist, \
        "You should provide a dict with initialized_prompt, learning_prompt, and testing_prompt as keys."

    def validate_schema(self) -> None:
        keys = self.schema_dict.keys()
        learning_schema_exist = "learning_schema" in keys
        testing_schema_exist = "testing_schema" in keys
        assert True == learning_schema_exist == testing_schema_exist, \
        "You should provide a dict with learning_schema and testing_schema as keys."

    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> Union[Dict, None]:
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                extra_body={"guided_json":schema},
                temperature = temperature
            )
            return json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
        except json.JSONDecodeError:
            return None
        
    def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=training_dataset.shape[0])
        parsing_error = 0
        for idx, row in training_dataset.iterrows():

            report = row["text"]

            if self.memory == "":
                prompt = self.prompt_template_dict["initialized_prompt"].format(report=report)
            else:
                prompt = self.prompt_template_dict["learning_prompt"].format(memory=self.memory, report=report)
            
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["learning_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            self.memory = json_output['rules']
            training_dataset.loc[idx, "reasoning"] = json_output['reasoning']
            training_dataset.loc[idx, "ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        return training_dataset
    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=testing_dataset.shape[0])
        parsing_error = 0
        for idx, row in testing_dataset.iterrows():

            report = row["text"]

            prompt = self.prompt_template_dict["testing_prompt"].format(memory=self.memory, report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["testing_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            testing_dataset.loc[idx, "reasoning"] = json_output['reasoning']
            testing_dataset.loc[idx, "ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        return testing_dataset

In [15]:
class ConditionalMemoryAgent(MemoryAgent):
  
  def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
    # inherit all properties and methods from MemoryAgent
    super().__init__(client, model, prompt_template_dict, schema_dict)
  
  def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
    # only overide this function because the rest parts are the same
    pass

# 2 Demonstration

In [16]:
openai_api_key = "Empty"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
df_training = brca_report.iloc[:500, :]
df_testing = brca_report.iloc[500:, :]

In [17]:
zs_agent = ChoiceAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
                 prompt_template=baseline_prompt,
                 choices=["T1", "T2", "T3", "T4"])

In [18]:
df_zs = zs_agent.run(df_testing)

  0%|          | 0/531 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset.loc[idx, "ans_str"] = answer
100%|██████████| 531/531 [02:17<00:00,  3.87it/s]


In [19]:
class TrainingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
    rules: List[str] = Field(description="list of rules") 
training_schema = TrainingResponse.model_json_schema()

class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

memory_agent = MemoryAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema})

In [20]:
_ = memory_agent.train(df_training)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, "reasoning"] = json_output['reasoning']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, "ans_str"] = json_output['predictedStage']
 26%|██▌       | 130/500 [21:52<1:45:16, 17.07s/it]

KeyboardInterrupt: 

In [11]:
memory_agent.memory

['T stage is determined by the size and extension of the primary tumor.',
 'T1 stage indicates a tumor size of less than 2 cm.',
 'T stage can also be determined by the presence or absence of certain clinical findings such as ulceration, fixation to the chest wall or skin, and distant metastasis.',
 'Invasion of the skeletal muscle is a clinical finding that contributes to a T4 stage.',
 'A tumor size of 2 cm or greater is indicative of a T2 or higher stage.',
 'Direct invasion into the dermis is a clinical finding that contributes to a higher T stage.',
 'Presence of vascular invasion and perineural invasion are clinical findings that contribute to a higher T stage.',
 'Presence of extracapsular extension in lymph nodes is a clinical finding that contributes to a higher T stage.']

In [None]:
print(len(memory_agent.memory))

In [10]:
df_memory = memory_agent.test(df_testing)

100%|██████████| 5/5 [00:11<00:00,  2.25s/it]


Unnamed: 0.1,Unnamed: 0,patient_filename,t,text,type,n,reasoning,ans_str
858,1303,TCGA-E9-A228.B0E8F3C1-E996-4E70-9630-AC95AF6E4EDC,1,"aterality: Right, lower inner quadrant. Path R...",BRCA,1,The report states that the tumor size is 2.5x2...,T2
291,736,TCGA-AC-A8OP.F4F5C477-30BB-41EE-B188-B20DE019F30A,0,Gender: F. Patient Location: Date of Service: ...,BRCA,-1,The report states that the invasive tumor size...,T1
845,1290,TCGA-E9-A1R6.8A865E33-082A-454F-A6DF-89E994206E65,1,OC ID: Gross Description: Lump with the tumor ...,BRCA,0,The report states that the tumor size is 2.2 x...,T2
671,1116,TCGA-D8-A1J9.D80E9389-AAD8-4EEB-9DE0-5DE57B5E5F6B,0,page 1 / 2. copy No. Examination: Histopatholo...,BRCA,0,The report states that the tumor size is 1.8 x...,T1
800,1245,TCGA-E2-A1IO.A9D36308-DAEF-48B9-A4EC-6B9B24EE3DC2,0,SPECIMENS: A. SLN #1 RIGHT AXILLA. B. RIGHT BR...,BRCA,0,The report states that the tumor size is 1.2 c...,T1


#### (temporary) Memory saturation test

In [29]:
def find_memory(file_path):
    with open(file_path, 'r') as file:
        content = file.read()
    matches = re.findall(r" memory: (\[.*?\])", content, re.DOTALL)
    memory_lst = [ast.literal_eval(match) for match in matches]
    return memory_lst

file_path = '/home/yl3427/cylab/selfCorrectionAgent/memory.txt'

memories = find_memory(file_path)

print(len(memories))
print(memories)

32
[['T stage is determined by the size of the primary tumor.', 'T1 tumors are 2 cm or less in size.', 'T2 tumors are more than 2 cm but not more than 5 cm in size.', 'T3 tumors are more than 5 cm in size.', 'T4 tumors involve the chest wall or skin.'], ['T stage is determined by the size of the primary tumor.', 'T1 tumors are 2 cm or less in size.', 'T2 tumors are more than 2 cm but not more than 5 cm in size.', 'T3 tumors are more than 5 cm in size.', 'T4 tumors involve the chest wall or skin.'], ['T stage is determined by the size of the primary tumor.', 'T1 tumors are 2 cm or less in size.', 'T2 tumors are more than 2 cm but not more than 5 cm in size.', 'T3 tumors are more than 5 cm in size.', 'T4 tumors involve the chest wall or skin.'], ['T stage is determined by the size of the primary tumor.', 'T1 tumors are 2 cm or less in size.', 'T2 tumors are more than 2 cm but not more than 5 cm in size.', 'T3 tumors are more than 5 cm in size.', 'T4 tumors involve the chest wall or skin.

In [30]:
class TestAgent:
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template: str, schema, memory_lst: List[str]) -> None:
        self.client = client
        self.model = model
        self.prompt_template = prompt_template
        self.schema = schema
        self.memory_lst = memory_lst


    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> Union[Dict, None]:
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                extra_body={"guided_json":schema},
                temperature = temperature
            )
            return json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
        except json.JSONDecodeError:
            return None

    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0):
        for mem_idx, memory in enumerate(self.memory_lst):
            print(f"memory index: {mem_idx}")
            parsing_error = 0
            correct_count = 0
            incorrect_count = 0
            for report_idx, row in testing_dataset.iterrows():
                # print(report_idx)
                report = row["text"]
                label = row["t"]

                prompt = self.prompt_template.format(memory=memory, report=report)
                system_prompt = system_instruction+ "\n" + prompt
                messages = [{"role": "user", "content": system_prompt}]

                json_output = self.get_schema_followed_response(messages, self.schema, temperature)

                if not json_output:
                    parsing_error += 1
                    continue
                
                if f"T{label+1}" == json_output["predictedStage"]:
                    correct_count += 1
                else:
                    incorrect_count +=1

            print(f"\tcorrect: {correct_count}")
            print(f"\twrong: {incorrect_count}")
            print(f"\tparsing error: {parsing_error}")
            print()

        

In [33]:
openai_api_key = "Empty"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
df_testing = brca_report.iloc[500:600, :]

In [35]:
len(df_testing)

100

In [34]:
class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

memory_agent = TestAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
                           prompt_template = testing_predict_prompt,
                           schema = testing_schema,memory_lst = memories)
memory_agent.test(df_testing)

memory index: 0
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 1
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 2
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 3
	correct: 80
	wrong: 20
	parsing error: 0

memory index: 4
	correct: 80
	wrong: 20
	parsing error: 0

memory index: 5
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 6
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 7
	correct: 78
	wrong: 22
	parsing error: 0

memory index: 8
	correct: 78
	wrong: 22
	parsing error: 0

memory index: 9
	correct: 78
	wrong: 22
	parsing error: 0

memory index: 10
	correct: 77
	wrong: 23
	parsing error: 0

memory index: 11
	correct: 76
	wrong: 24
	parsing error: 0

memory index: 12
	correct: 78
	wrong: 22
	parsing error: 0

memory index: 13
	correct: 77
	wrong: 23
	parsing error: 0

memory index: 14
	correct: 78
	wrong: 22
	parsing error: 0

memory index: 15
	correct: 79
	wrong: 21
	parsing error: 0

memory index: 16
	correct: 79
	wrong: 21
	parsing 