# 0 Model for testing

```bash
CUDA_VISIBLE_DEVICES="4,5", python -m vllm.entrypoints.openai.api_server \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--disable-custom-all-reduce \
--enforce-eager \
--download-dir /secure/chiahsuan/hf_cache/ \
--tensor-parallel-size 2 \
--port 8085
```

```bash
CUDA_VISIBLE_DEVICES="0,1,2,3", python -m vllm.entrypoints.openai.api_server \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--download-dir /secure/chiahsuan/hf_cache/ \
--tensor-parallel-size 4 \
--disable-custom-all-reduce \
--enforce-eager
```
mistralai/Mixtral-8x7B-Instruct-v0.1: always port 8000

# 1 Load packages and define class
== The following two cells can be organized into a separate py file as a module.==

In [1]:
from src.prompt import system_instruction
from src.prompt import baseline_prompt_t14 as baseline_prompt
from src.prompt import initial_predict_prompt_t14 as initial_predict_prompt
from src.prompt import subsequent_predict_prompt_t14 as subsequent_predict_prompt
from src.prompt import testing_predict_prompt_t14 as testing_predict_prompt
from src.metrics import n03_performance_report

import openai
from openai import OpenAI
from sklearn.model_selection import KFold
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
import json
from pydantic import BaseModel, Field
from typing import List, Dict, Union
from fuzzywuzzy import fuzz
import re
import ast



## 1.1 Define ChoiceAgent

In [2]:
# define class
class ChoiceAgent:
    """ the simplest agent, which is appropriate for zero-shot prompting
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template: str, choices: dict, label: str) -> None:
        self.client = client
        self.model = model
        self.prompt_template = prompt_template
        self.choices = choices
        self.label = label

    def get_response_from_choices(self, messages: list, temperature:float) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            extra_body={"guided_choice":self.choices},
            temperature = temperature
        )
        return response.choices[0].message.content

    def run(self, dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=dataset.shape[0])
        for idx, row in dataset.iterrows():
            report = row["text"]
            prompt = self.prompt_template.format(report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]
            answer = self.get_response_from_choices(messages, temperature)
            dataset.loc[idx, f"zs_{self.label}_ans_str"] = answer
            pbar.update(1)
        pbar.close()

        return dataset

## 1.2 Define MemoryAgent

In [3]:
class MemoryAgent:
    """ the implementation of memory agent, which learn memory from training set and
    utilize memory as contexts for testing set.
    """
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict, label: str) -> None:
        self.client = client
        self.model = model
        self.prompt_template_dict = prompt_template_dict
        self.validate_prompt_template()
        self.schema_dict = schema_dict
        self.validate_schema()
        self.memory = ""
        self.label = label

    def validate_prompt_template(self) -> None:
        keys = self.prompt_template_dict.keys()
        initial_prompt_exist = "initialized_prompt" in keys
        learning_prompt_exist = "learning_prompt" in keys
        testing_prompt_exist = "testing_prompt" in keys
        assert True == initial_prompt_exist == learning_prompt_exist == testing_prompt_exist, \
        "You should provide a dict with initialized_prompt, learning_prompt, and testing_prompt as keys."

    def validate_schema(self) -> None:
        keys = self.schema_dict.keys()
        learning_schema_exist = "learning_schema" in keys
        testing_schema_exist = "testing_schema" in keys
        assert True == learning_schema_exist == testing_schema_exist, \
        "You should provide a dict with learning_schema and testing_schema as keys."

    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> Union[Dict, None]:
        num_attempt = 3
        for attempt in range(num_attempt):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    extra_body={"guided_json":schema},
                    temperature = temperature
                )
                return json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
            except json.JSONDecodeError:
                print("Error decoding JSON response")
                return None
            except openai.APITimeoutError:
                if attempt < (num_attempt -1):
                    wait_time = 2 * (attempt + 1)
                    print(f"Request timed out. Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    print("Max retries reached. Request faild.")
                    return None
        
    def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=training_dataset.shape[0])
        parsing_error = 0
        for idx, row in training_dataset.iterrows():

            report = row["text"]

            if self.memory == "":
                prompt = self.prompt_template_dict["initialized_prompt"].format(report=report)
            else:
                prompt = self.prompt_template_dict["learning_prompt"].format(memory=self.memory, report=report)
            
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["learning_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            self.memory = json_output['rules']
            training_dataset.loc[idx, f"mem_{self.label}_reasoning"] = json_output['reasoning']
            training_dataset.loc[idx, f"mem_{self.label}_ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        print(f"Number of parsing errors: {parsing_error}")
        return training_dataset
    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        pbar = tqdm(total=testing_dataset.shape[0])
        parsing_error = 0
        for idx, row in testing_dataset.iterrows():

            report = row["text"]

            prompt = self.prompt_template_dict["testing_prompt"].format(memory=self.memory, report=report)
            system_prompt = system_instruction+ "\n" + prompt
            messages = [{"role": "user", "content": system_prompt}]

            json_output = self.get_schema_followed_response(messages, self.schema_dict["testing_schema"], temperature)

            if not json_output:
                parsing_error += 1
                continue
            
            testing_dataset.loc[idx, f"mem_{self.label}_reasoning"] = json_output['reasoning']
            testing_dataset.loc[idx, f"mem_{self.label}_ans_str"] = json_output['predictedStage']
            
            pbar.update(1)
        pbar.close()
        print(f"Number of parsing errors: {parsing_error}")
        return testing_dataset

## 1.3 Define ConditionalMemoryAgent

In [4]:
class ConditionalMemoryAgent(MemoryAgent):
  
  def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict, label: str) -> None:
    # inherit all properties and methods from MemoryAgent
    super().__init__(client, model, prompt_template_dict, schema_dict, label)
  
  def is_updated(self, new_memory, threshold):
    old_str = "\n".join(self.memory)
    new_str = "\n".join(new_memory)
    if fuzz.ratio(old_str, new_str) >= threshold : 
        return True # update memory
    else:
        return False
    
  def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0, threshold: float = 80) -> pd.DataFrame:
    # only overide this function because the rest parts are the same
    pbar = tqdm(total=training_dataset.shape[0])
    parsing_error = 0
    num_update = 0
    for idx, row in training_dataset.iterrows():

        report = row["text"]

        if self.memory == "":
            prompt = self.prompt_template_dict["initialized_prompt"].format(report=report)
        else:
            prompt = self.prompt_template_dict["learning_prompt"].format(memory=self.memory, report=report)
        
        system_prompt = system_instruction+ "\n" + prompt
        messages = [{"role": "user", "content": system_prompt}]

        json_output = self.get_schema_followed_response(messages, self.schema_dict["learning_schema"], temperature)

        if not json_output:
            parsing_error += 1
            continue
        
        if self.memory == "":
           self.memory = json_output['rules']
        else:
          current_memory_str = "\n".join(self.memory)
          new_memory_str = "\n".join(json_output['rules'])
          if fuzz.ratio(current_memory_str, new_memory_str) >= threshold :
            self.memory = json_output['rules']
            num_update += 1

        training_dataset.loc[idx, f"cmem_{self.label}_reasoning"] = json_output['reasoning']
        training_dataset.loc[idx, f"cmem_{self.label}_ans_str"] = json_output['predictedStage']
        training_dataset.loc[idx, f"cmem_{self.label}_num_update"] = num_update
        
        pbar.update(1)
    pbar.close()
    print(f"Number of memory updates: {num_update}")
    print(f"Number of parsing errors: {parsing_error}")
    return training_dataset

  def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
    pbar = tqdm(total=testing_dataset.shape[0])
    parsing_error = 0
    for idx, row in testing_dataset.iterrows():

        report = row["text"]

        prompt = self.prompt_template_dict["testing_prompt"].format(memory=self.memory, report=report)
        system_prompt = system_instruction+ "\n" + prompt
        messages = [{"role": "user", "content": system_prompt}]

        json_output = self.get_schema_followed_response(messages, self.schema_dict["testing_schema"], temperature)

        if not json_output:
            parsing_error += 1
            continue
        
        testing_dataset.loc[idx, f"cmem_{self.label}_reasoning"] = json_output['reasoning']
        testing_dataset.loc[idx, f"cmem_{self.label}_ans_str"] = json_output['predictedStage']
        
        pbar.update(1)
    pbar.close()
    print(f"Number of parsing errors: {parsing_error}")
    return testing_dataset

## 1.4 To-be-Implemented

In [None]:
class MultiTaskMemoryAgent(MemoryAgent):
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
        # inherit all properties and methods from MemoryAgent
        super().__init__(client, model, prompt_template_dict, schema_dict)
  
    def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        # overide this function
        pass

    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        # overide this function
        pass

class MemoryAgentWithVerifier(MemoryAgent): # inherit from MemoryAgent or MultiTaskMemoryAgent
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict) -> None:
        # inherit all properties and methods from MemoryAgent
        super().__init__(client, model, prompt_template_dict, schema_dict)
  
    def train(self, training_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        # overide this function
        pass

# 2 Demonstration

In [4]:
openai_api_key = "Empty"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
print(len(brca_report))
brca_report = brca_report[brca_report["n"]!=-1]
print(len(brca_report))

df_training_samples = brca_report.iloc[:80, :]
df_testing_samples = brca_report.iloc[80:, :]

print(len(df_training_samples))
print(len(df_testing_samples))

1031
800
80
720


## 2.1 Initialize ChoiceAgent and use its instance

In [5]:
# zs_agent = ChoiceAgent(client=client, model="mistralai/Mistral-7B-Instruct-v0.2",
#                  prompt_template=baseline_prompt,
#                  choices=["T1", "T2", "T3", "T4"])
   
zs_agent = ChoiceAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                 prompt_template=baseline_prompt,
                 choices=["T1", "T2", "T3", "T4"], label = "t")

In [7]:
zs_agent.run(dataset=brca_report).to_csv("brca_df_zs_t14.csv", index=False)

100%|██████████| 800/800 [05:14<00:00,  2.54it/s]


In [8]:
brca_report.to_csv("result/brca_df_zs_n03.csv", index=False)

## 2.2 Initialize MemoryAgent and use its instance

In [None]:
class TrainingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
    rules: List[str] = Field(description="list of rules") 
training_schema = TrainingResponse.model_json_schema()

class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

In [None]:
memory_agent = MemoryAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema},
                                        label = "n")

In [None]:
memory_agent.train(df_training_samples, temperature=0.001)

In [None]:
memory_agent.memory

In [None]:
print(len(memory_agent.memory))

In [None]:
memory_agent.test(df_testing_samples)

## 2.3 Initialize ConditionalMemoryAgent and use its instance

In [None]:
conditional_memory_agent = ConditionalMemoryAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema},
                                        label = "n")

In [None]:
conditional_memory_agent.train(df_training_samples)

In [None]:
conditional_memory_agent.memory

In [None]:
len(conditional_memory_agent.memory)

In [None]:
conditional_memory_agent.test(df_testing_samples)

In [None]:
df_training_samples.to_csv("df_training.csv", index=False)
df_testing_samples.to_csv("df_testing.csv", index=False)

## 2.4 Memory Saturation Test

In [6]:
class TrainingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
    rules: List[str] = Field(description="list of rules") 
training_schema = TrainingResponse.model_json_schema()

class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

In [7]:
class FixedTestSizeCV:
    def __init__(self, num_test_points):
        self.num_test_points = num_test_points

    def split(self, X, y=None):
        n_samples = len(X)
        indices = np.arange(n_samples)
        np.random.shuffle(indices)
        test_indices = indices[:self.num_test_points]
        train_indices = indices[self.num_test_points:]
        yield train_indices, test_indices

sorted_df = brca_report.reset_index(drop=True)

for size in range(10, 101, 10):
    cv = FixedTestSizeCV(num_test_points=size)
    for test_idx, train_idx in cv.split(sorted_df):
        memory_agent = ConditionalMemoryAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema},
                                        label = "n")
        df_train, df_test = sorted_df.iloc[train_idx], sorted_df.iloc[test_idx]
        train_result = memory_agent.train(df_train)
        train_result.to_csv(f"saturation_train_result_{size}.csv", index=False)
        test_result = memory_agent.test(df_test)
        test_result.to_csv(f"saturation_test_result_{size}.csv", index=False)



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem_{self.label}_reasoning"] = json_output['reasoning']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem_{self.label}_ans_str"] = json_output['predictedStage']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem

Number of memory updates: 9
Number of parsing errors: 0


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  testing_dataset.loc[idx, f"cmem_{self.label}_reasoning"] = json_output['reasoning']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  testing_dataset.loc[idx, f"cmem_{self.label}_ans_str"] = json_output['predictedStage']
100%|██████████| 790/790 [54:14<00:00,  4.12s/it]  


Number of parsing errors: 0


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem_{self.label}_reasoning"] = json_output['reasoning']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem_{self.label}_ans_str"] = json_output['predictedStage']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_dataset.loc[idx, f"cmem

KeyboardInterrupt: 

In [None]:
for size in [10, 20, 30, 40]:
    result_df = pd.read_csv(f"saturation_test_result_{size}.csv")
    n03_performance_report(df=result_df, ans_col="mem_n_ans_str")

# 3 Multi-Turn Conversation

In [None]:
class MultiTurnMemoryAgent:
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template_dict: dict[str, str], schema_dict: dict, label: str) -> None:
        self.client = client
        self.model = model
        self.prompt_template_dict = prompt_template_dict
        self.schema_dict = schema_dict
        self.label = label
        self.memory = ""
    
    def is_updated(self, new_memory):
        old_str = "\n".join(self.memory)
        new_str = "\n".join(new_memory)
        if fuzz.ratio(old_str, new_str) >= 80 : 
            return True # update memory
        else:
            return False
        
    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> Union[Dict, None]:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            extra_body=schema,
            temperature = temperature
        )
        return response.choices[0].message.content.replace("\\", "\\\\")
        
    def train(self, training_dataset: pd.DataFrame, conditional_update: bool, temperature: float = 0.0) -> pd.DataFrame:
        parsing_error = 0
        num_update = 0
        update_condition = True
        for idx, row in training_dataset.iterrows():
            report = row["text"]
            if self.memory == "":
                prediction_prompt = self.prompt_template_dict["initialized_prompt"].format(report=report)
            else:
                prediction_prompt = self.prompt_template_dict["prediction_prompt"].format(memory=self.memory, report=report)

            messages = [{"role": "user", "content": prediction_prompt}]
            ans_format = {"guided_choice":self.schema_dict["predic_choice"]}
            pred = self.get_schema_followed_response(messages, ans_format, temperature)
            training_dataset.loc[idx, f"{self.label}_pred"] = pred
 
            reason_prompt = self.prompt_template_dict["reason_prompt"]
            messages.append({"role": "assistant", "content": pred}, 
                            {"role": "user", "content": reason_prompt})
            ans_format = {"guided_json":self.schema_dict["reason_schema"]}
            try:
                reason = json.loads(self.get_schema_followed_response(messages, ans_format, temperature))['reasoning']
                training_dataset.loc[idx, f"{self.label}_reason"] = reason  
            except json.JSONDecodeError:
                parsing_error += 1
                continue
    
            rule_prompt = self.prompt_template_dict["rule_prompt"]
            messages.append({"role": "assistant", "content": reason}, 
                            {"role": "user", "content": rule_prompt})
            ans_format = {"guided_json":self.schema_dict["rule_schema"]}
            try:
                rule = json.loads(self.get_schema_followed_response(messages, ans_format, temperature))['rules']
                training_dataset.loc[idx, f"{self.label}_rule"] = rule
            except json.JSONDecodeError:
                parsing_error += 1
                continue
                
            if conditional_update == True:
                update_condition = self.is_updated(rule)

            if self.memory == "" or update_condition:
                self.memory = rule
                num_update += 1
                
        print(f"Number of parsing errors: {parsing_error}")
        print(f"Number of memory updates: {num_update}")
        return training_dataset
    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0) -> pd.DataFrame:
        parsing_error = 0
        for idx, row in testing_dataset.iterrows():
            report = row["text"]
            prediction_prompt = self.prompt_template_dict["prediction_prompt"].format(memory=self.memory, report=report)
            messages = [{"role": "user", "content": prediction_prompt}]
            ans_format = {"guided_choice":self.schema_dict["predic_choice"]}
            pred = self.get_schema_followed_response(messages, ans_format, temperature)
            testing_dataset.loc[idx, f"{self.label}_pred"] = pred
            
            reason_prompt = self.prompt_template_dict["reason_prompt"]
            messages.append({"role": "assistant", "content": pred}, 
                            {"role": "user", "content": reason_prompt})
            ans_format = {"guided_json":self.schema_dict["reason_schema"]}
            try:
                reason = json.loads(self.get_schema_followed_response(messages, ans_format, temperature))['reasoning']
                testing_dataset.loc[idx, f"{self.label}_reason"] = reason  
            except json.JSONDecodeError:
                parsing_error += 1
                continue
       
        print(f"Number of parsing errors: {parsing_error}")
        return testing_dataset

In [None]:
pred_choices=["N0", "N1", "N2", "N3"]

class ReasonSchema(BaseModel):
    reasoning: str = Field(description="reasoning to support predicted cancer stage")     
reason_schema = ReasonSchema.model_json_schema()

class RuleSchema(BaseModel):
    rules: List[str] = Field(description="list of rules")
rule_schema = RuleSchema.model_json_schema()

In [None]:
multi_turn_agent = MultiTurMemoryAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "prediction_prompt":subsequent_predict_prompt,
                                                 "reason_prompt":reason_prompt,
                                                 "rule_prompt":rule_prompt},
                           schema_dict={"predic_choice":pred_choices,
                                        "reason_schema":reason_schema,
                                        "rule_schema":rule_schema},
                                        label = "n")

# 998 Experiemntal Plan

In [None]:
# run T task only
T_memory_agent = MemoryAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template_dict={"initialized_prompt":initial_predict_prompt,
                                                 "learning_prompt":subsequent_predict_prompt,
                                                 "testing_prompt":testing_predict_prompt},
                           schema_dict={"learning_schema":training_schema,
                                        "testing_schema":testing_schema})

# run N task only
N_memory_agent = MemoryAgent(client=client, model=...,
                           prompt_template_dict=...,
                           schema_dict=...)

# run T and N task simultaneously
TN_memory_agent = MultiTaskMemoryAgent(client=client, model=...,
                           prompt_template_dict=...,
                           schema_dict=...)

```python
""" the following pseudo codes are for checking number of instances are sufficient
"""
def stratified_sampling_function(pd.dataframe, target_column, size):
    pass

def eval():
    pass

def report():
    pass

performances = {}
for size in [5, 10, 20, 25, 30, 35, 40]:
    evaluated_scores = []
    # the process should be evaluated on K different splits and take the average performance
    for i, (train_indexes, test_indexes) in stratified_sampling_function(df): 
        
        # the memory is unique for each split
        N_memory_agent = MemoryAgent(client=client, model=...,
                           prompt_template_dict=...,
                           schema_dict=...)

        df_train = df.iloc[train_indexes,:]
        N_memory_agent.train(df_train)
        df_test = df.iloc[test_indexes,:]
        df_results = N_memory_agent.test(df_test)
        evaluated_scores.append(eval(df_results))

    performances[size] = report(evaluated_scores)
```

# 999 (temporary) Memory saturation test

In [None]:
def find_memory(file_path):
    with open(file_path, 'r') as file:
        content = file.read()
    matches = re.findall(r" memory: (\[.*?\])", content, re.DOTALL)
    memory_lst = [ast.literal_eval(match) for match in matches]
    return memory_lst

file_path = '/home/yl3427/cylab/selfCorrectionAgent/memory.txt'

memories = find_memory(file_path)

print(len(memories))
print(memories)

In [None]:
class TestAgent:
    def __init__(self, client: OpenAI, model: str, 
                 prompt_template: str, schema, memory_lst: List[str]) -> None:
        self.client = client
        self.model = model
        self.prompt_template = prompt_template
        self.schema = schema
        self.memory_lst = memory_lst


    def get_schema_followed_response(self, messages: list, schema:dict, temperature:float) -> Union[Dict, None]:
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                extra_body={"guided_json":schema},
                temperature = temperature
            )
            return json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
        except json.JSONDecodeError:
            return None

    
    def test(self, testing_dataset: pd.DataFrame, temperature: float = 0.0):
        for mem_idx, memory in enumerate(self.memory_lst):
            print(f"memory index: {mem_idx}")
            parsing_error = 0
            correct_count = 0
            incorrect_count = 0
            for report_idx, row in testing_dataset.iterrows():
                # print(report_idx)
                report = row["text"]
                label = row["t"]

                prompt = self.prompt_template.format(memory=memory, report=report)
                system_prompt = system_instruction+ "\n" + prompt
                messages = [{"role": "user", "content": system_prompt}]

                json_output = self.get_schema_followed_response(messages, self.schema, temperature)

                if not json_output:
                    parsing_error += 1
                    continue
                
                if f"T{label+1}" == json_output["predictedStage"]:
                    correct_count += 1
                else:
                    incorrect_count +=1

            print(f"\tcorrect: {correct_count}")
            print(f"\twrong: {incorrect_count}")
            print(f"\tparsing error: {parsing_error}")
            print()

        

In [None]:
openai_api_key = "Empty"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

brca_report = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/brca_df.csv")
df_testing = brca_report.iloc[500:600, :]

In [None]:
len(df_testing)

In [None]:
class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 
testing_schema = TestingResponse.model_json_schema()

memory_agent = TestAgent(client=client, model="mistralai/Mixtral-8x7B-Instruct-v0.1",
                           prompt_template = testing_predict_prompt,
                           schema = testing_schema,memory_lst = memories)
memory_agent.test(df_testing)