# Basic Setting

In [1]:
import os
import pandas as pd
from tqdm.auto import tqdm
import pickle
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pytz import timezone

# OpenAI + Langchain
from openai import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, LLMChain

import warnings
warnings.filterwarnings('ignore')

# API Key Configuration
OPENAI_API_KEY = "your_openai_api_key_here"
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

# Default data path
DATA_PATH = "/path/to/data/directory"
DB_DIRECTORY = "/path/to/database/directory"
MODEL_NAME = "gpt-4o"  # or 'gpt-3.5-turbo', etc.


class DataRepository:
    def __init__(self, default_path):
        self.default_path = default_path
    
    def load_data(self, filename):
        file_path = os.path.join(self.default_path, filename)
        with open(file_path, 'rb') as file:
            return pickle.load(file)

    def save_data(self, data, filename):
        file_path = os.path.join(self.default_path, filename)
        with open(file_path, 'wb') as file:
            pickle.dump(data, file)

## RAGAssistant

In [2]:
class RAGAssistant:
    def __init__(self, llm, retriever, checkpoint_path):
        self.llm = llm
        self.retriever = retriever
        self.checkpoint_path = checkpoint_path

    def inference(self, df, prompt, iteration, checkpoint=False, init=False):
        df_temp = df.copy()
        result_col = 'base_answer' if init else f'rag_answer_{iteration}'
        source_col = f'rag_source_{iteration}'
        
        rag_template = prompt + """    
        Retrieved Context: {context}
        Question: {question}
        Answer:"""

        QA_CHAIN_PROMPT = PromptTemplate(template=rag_template, input_variables=['context', 'question'])
        qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm, 
            retriever=self.retriever, 
            return_source_documents = True,
            chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
        )
        
        for question in tqdm(df_temp['question']):
            response = qa_chain(question)
            df_temp.loc[df_temp['question'] == question, result_col] = response['result']
            df_temp.loc[df_temp['question'] == question, source_col] = response['source_documents'][0].metadata['source']
            
            if checkpoint:
                df_temp.to_csv(self.checkpoint_path, index=False)
                
        return df_temp

## RAGEvaluator

In [3]:
class RAGEvaluator:
    def __init__(self, llm, retriever, checkpoint_path):
        self.llm = llm
        self.retriever = retriever
        self.criteria = [
            'Scientific Accuracy and Depth', 'Alignment with Research Objectives',
            'Source Clarity and Reliability', 'Professional Engagement', 'Information Fidelity'
        ]
        self.checkpoint_path = checkpoint_path

    def eval_answer(self, llm_question, llm_ans):
        eval_prompt =  f"""
        As a leading researcher in plant physiology with a focus on photosynthesis, you are considering using a language model to refine research ideas and guide projects. To ensure responses are of the highest quality and relevance, and to avoid score inflation, assess each response according to the following criteria:
        
        **Scoring Metrics**:
        1. **Scientific Accuracy and Depth** : Evaluate how accurately the response reflects current scientific understanding and its depth concerning photosynthesis. Score from 1 to 10, where 10 indicates perfect accuracy and depth, and 1 indicates major inaccuracies or superficial information.
        2. **Alignment with Research Objectives** : Determine how well the response aligns with specific research objectives within plant physiology and photosynthesis. Score from 1 to 10, with 10 being perfectly aligned and 1 being irrelevant.
        3. **Source Clarity and Reliability** : Assess the clarity and accuracy of the source citations used in the response, such as specific studies, data sets, or scientific theories. Rate from 1 to 10, where 10 signifies sources are cited with exceptional clarity and reliability, and 1 where citations are poor or missing.
        4. **Professional Engagement** : Evaluate the professional tone of the response and its suitability for academic discussions among experts in photosynthesis. A score of 10 suggests an expert-level dialogue, while 1 suggests a non-professional or irrelevant tone.
        5. **Information Fidelity** : Check the fidelity of the information provided by evaluating its accuracy and the absence of fabricated content not supported by source material. Score from 1 to 10, where 10 indicates high fidelity and 1 indicates significant errors or fabrications.
        
        **Evaluation Instructions**:
        1. Focus on accuracy, clarity, depth, fidelity, and professional engagement without considering response length.
        2. A score above 5 is considered good; a score of 10 indicates a flawless response.

        **Feedback Guidance**:
        After evaluating each response, provide specific suggestions for improvements based on the metrics. Focus exclusively on areas that need enhancement to help achieve higher scores in future evaluations.
        Avoid reiterating strengths or commendable aspects of the response. Concentrate on identifying and suggesting practical steps for improvement in each scoring area.
        Ensure your feedback aims to enhance the quality, clarity, and depth of future interactions with the subject to mitigate any bias and to drive meaningful advancement in the field of photosynthesis research.
            
        Question toward LLM : {llm_question}
        """
        
        eval_template = eval_prompt + """
        Retrieved Context : {context}
        LLM's Answer : {question}
        
        The format of answer will be as below: You must follow this format strictly
        ``
        Scientific Accuracy and Depth : []
        Alignment with Research Objectives : []
        Source Clarity and Reliability : []
        Professional Engagement : []
        Information Fidelity : []
        Feedback : []
        ``
        `"""

        
        QA_CHAIN_PROMPT = PromptTemplate(
            template = eval_template,
            input_variables = [
                'context',
                'question'
            ])
        
        qa_chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type_kwargs = {"prompt" : QA_CHAIN_PROMPT})
    
        response = qa_chain(llm_ans)
    
        return response['result']


    def score_parsing(self, response, criterion):
        try:
            score = int(response.split(criterion+' : [')[1][0])
        except:
            score = int(response.split(criterion+': [')[1][0])
        return score

    
    def feedback_parsing(self, response):
        try:
            feedback = response[response.index('Feedback : ')+len('Feedback : '):].strip()
        except:
            try:
                feedback = response[response.index('Feedback: ')+len('Feedback: '):].strip()
            except:
                feedback = 'none'
        return feedback


    def eval_process(self, df, iteration, check_point = False, init=False):
        df_temp = df.copy()
        col = 'base_score' if init else f'rag_score_{iteration}'
        ans_col = 'base_answer' if init else f'rag_answer_{iteration}'
      
        df_temp[col] = 'n'
        df_temp[col] = df_temp[col].astype(object)
        
        indices = list(df_temp[df_temp[col] == 'n'].index)
        
        while indices:
            for i in tqdm(indices):
                eval_result = self.eval_answer(df_temp.iloc[i]['question'], df_temp.iloc[i][ans_col])
                score_dict = {}
                try:
                    for crit in self.criteria:
                        score_dict[crit] = self.score_parsing(eval_result, crit)
                    score_dict['feedback'] = self.feedback_parsing(eval_result)
                    df_temp.at[i, col] = score_dict
                except Exception as err:
                    print(err)
                    
            indices = list(df_temp[df_temp[col] == 'n'].index)

        if check_point:
            df_temp.to_csv(self.checkpoint_path, index = False)
        else:
            pass
        
        return df_temp

    def summarize_scores(self, df, iteration, init = False):
        col = 'base_score' if init else f'rag_score_{iteration}'
        feedbacks = ''
    
        summaries = {crit: pd.Series([df[col].iloc[i][crit] for i in range(len(df))]).describe() for crit in self.criteria}
        summaries['Average Score'] = pd.Series(sum([df[col].map(lambda x: x[crit]) for crit in self.criteria])/len(self.criteria)).describe()
    
        # Filter feedback where the score is less than or equal to Q1 (which is 1 or below)
        crit_q1s = {crit : summaries[crit]['25%'] for crit in self.criteria}
        under_q1 = {crit: df[df[col].apply(lambda x: x.get(crit, 0) < crit_q1s[crit])].index.tolist() for crit in self.criteria}
        feedback_indices = set(num for sublist in under_q1.values() for num in sublist)

        print('N of Fail :',len(feedback_indices))
        for i in feedback_indices:
            feedbacks += df[col].iloc[i]['feedback']
            

        return summaries, feedbacks

    def check_iter_scores(self, df, iteration, init = False):
        col = 'base_score' if init else f'rag_score_{iteration}'

        iter_score = {crit: [df.iloc[i][col][crit] for i in range(len(df))] for crit in self.criteria}
        iter_score['Average Score'] = sum([df[col].map(lambda x: x[crit]) for crit in self.criteria])/len(self.criteria)

        print(f'Score Distribution ..................')
        print(pd.DataFrame(iter_score).describe())
        print()

## PromptRevisor

In [4]:
class PromptRevisionManager:
    def __init__(self, llm):
        self.llm = llm
        self.history = []

    def revise_prompt(self, prev_prompt, prev_scores, current_scores, feedbacks):
        comparison_lines = []
        
        for criterion in prev_scores.keys():
            line = f"[{criterion}] \nPrev (mean: {prev_scores[criterion]['mean']}, std: {prev_scores[criterion]['std']}, min: {prev_scores[criterion]['min']}, q2: {prev_scores[criterion]['50%']}, max: {prev_scores[criterion]['max']}) \n-> Current (mean: {current_scores[criterion]['mean']}, std: {current_scores[criterion]['std']}, min: {current_scores[criterion]['min']}, q2: {current_scores[criterion]['50%']}, max: {current_scores[criterion]['max']})"
            comparison_lines.append(line)
        
        score_info = "Previous Guideline: " + prev_prompt + "\n" + "\n".join(comparison_lines) + "\nFeedbacks: " + feedbacks
        #print(score_info)
        
        revise_template = """
        Evaluation Results : {question}
        
        You have received feedback and performance scores from an AI research assistant's responses on a photosynthesis study. Your task is to develop guidelines that will enhance the assistant's performance for future interactions. 
        Examine areas of weakness closely and propose specific, actionable steps for overall improvement. Detailed guidance should be provided for any aspects that are particularly lacking.

        Final Format:
        'Overall Guideline:'
        """
        
        CHAIN_PROMPT = PromptTemplate(template=revise_template, input_variables=["question"])
        llm_chain = LLMChain(llm = self.llm, prompt = CHAIN_PROMPT)
        response = llm_chain(score_info)
    
        try:
            new_prompt = response['text'][len('Overall Guideline: '):].strip()
        except:
            new_prompt = response['text'][len('Overall Guideline : '):].strip()

        self.history.append(new_prompt)
        
        return new_prompt

    def get_history(self):
        return self.history

## Prompt Optimization

In [7]:
class PromptOptimization:
    def __init__(self, check_point = False):
        self.data_repo = DataRepository(DATA_PATH)
        
        llm = ChatOpenAI(model_name=MODEL_NAME, temperature=0)
        embeddings = OpenAIEmbeddings()
        vectordb = Chroma(persist_directory=DB_DIRECTORY, embedding_function=embeddings)
        retriever = vectordb.as_retriever()

        self.Assistant = RAGAssistant(llm, retriever, checkpoint_path=f"{DATA_PATH}/temp_checkpoint.csv")
        self.Evaluator = RAGEvaluator(llm, retriever, checkpoint_path=f"{DATA_PATH}/temp_checkpoint.csv")
        self.PromptRevisor = PromptRevisionManager(llm)
        
        self.current_prompt = ''
        self.current_scores = {}
        self.prompt_history = [self.current_prompt, ]  # Initialize the history with the initial prompt

        self.check_point = check_point


    def Initialize_BaseQA(self, df):
        df_temp = df.copy()
        
        print(f"########################\n  Initialize : Base QA\n########################")

        #Evaluation
        print('Evaluation ..................')
        df_temp = self.Evaluator.eval_process(df_temp, '', self.check_point, True)
        self.current_scores, _ = self.Evaluator.summarize_scores(df_temp, '', True)
        self.Evaluator.check_iter_scores(df_temp, '', True)

        df_train = df_temp[df_temp['train_test'] == 'train'].reset_index(drop = True)
        df_test = df_temp[df_temp['train_test'] == 'test'].reset_index(drop = True)

        return df_train, df_test


    def run_iterations(self, df, num_iterations):
        df_temp = df.copy()
        
        for iteration in range(1, num_iterations + 1):
            print(f"########################\n       Iteration {iteration}\n########################")
            print(f"Prompt : \n{self.current_prompt}\n")
            
            prev_scores = self.current_scores

            # Inference
            print('Inference ..................')
            df_temp = self.Assistant.inference(df_temp, self.current_prompt, iteration, self.check_point)

            # Evaluation
            print('Evaluation ..................')
            df_temp = self.Evaluator.eval_process(df_temp, iteration, self.check_point)
            self.current_scores, feedbacks = self.Evaluator.summarize_scores(df_temp, iteration)
            self.Evaluator.check_iter_scores(df_temp, iteration)

            # Prompt Revision and History Update
            if iteration < num_iterations:
                new_prompt = self.PromptRevisor.revise_prompt(self.current_prompt, prev_scores, self.current_scores, feedbacks)
                self.current_prompt = new_prompt
                self.prompt_history.append(new_prompt)  # Save each new prompt to history
                
            # Update previous scores for the next iteration
            prev_scores = self.current_scores

        return df_temp


    def final_evaluation(self, df_test):
        df_temp = df_test.copy()
        print(f"########################\n    Final Evaluation\n########################")
        print(f"Prompt : \n{self.current_prompt}\n")

        # Inference
        print('Inference ..................')
        df_temp = self.Assistant.inference(df_temp, self.current_prompt, 'eval', self.check_point)

        # Evaluation
        print('Evaluation ..................')
        df_temp = self.Evaluator.eval_process(df_temp, 'eval')
        current_scores, feedbacks = self.Evaluator.summarize_scores(df_temp, 'eval')
        self.Evaluator.check_iter_scores(df_temp, 'eval')

        return df_temp

    def get_current_prompt(self):
        return self.current_prompt

    def get_prompt_history(self):
        return self.prompt_history
        
    def run(self, df, num_iterations, save_data = False):
        df_train, df_test = self.Initialize_BaseQA(df)
        final_train_df = self.run_iterations(df_train, num_iterations)
        final_test_df = self.final_evaluation(df_test)
        prompt_history = self.get_prompt_history()

        runtime = datetime.now(timezone('Asia/Seoul')).strftime('%m%d_%H%M')
        #print(runtime)
        if save_data:
            self.data_repo.save_data(final_train_df, filename=f'photosyn_qa_dataset_train_final_{runtime}.pickle')
            self.data_repo.save_data(final_test_df, filename=f'photosyn_qa_dataset_test_final_{runtime}.pickle')
            self.data_repo.save_data(prompt_history, filename=f'photosyn_qa_prompt_history.pickle')

        return final_train_df, final_test_df, prompt_history

# Main

In [10]:
if __name__ == "__main__":
    # Initialize data repository and load initial dataset
    data_repo = DataRepository(DATA_PATH)
    df = data_repo.load_data(file_name = 'photosyn_qa_dataset.pickle')

    # Initialize prompt optimization and run the pipeline
    prompt_optimizer = PromptOptimization()
    final_train_df, final_test_df, prompt_history = prompt_optimizer.run(df, num_iterations=10, save_data = True)

########################
  Initialize : Base QA
########################
Evaluation ..................


  0%|          | 0/150 [00:00<?, ?it/s]

N of Fail : 37
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     150.000000                          150.000000   
mean                        7.953333                            8.286667   
std                         0.423102                            1.031950   
min                         4.000000                            3.000000   
25%                         8.000000                            7.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      150.000000               150.000000   
mean                         6.993333                 7.786667   
std                          0.915905                 0.457402  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

list index out of range


  0%|          | 0/1 [00:00<?, ?it/s]

N of Fail : 25
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.141667                            8.150000   
std                         0.350170                            1.585915   
min                         8.000000                            1.000000   
25%                         8.000000                            7.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.433333                 7.950000   
std                          1.157608                 0.313559  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 28
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.216667                            8.083333   
std                         0.433538                            1.799082   
min                         7.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                       120.00000               120.000000   
mean                          7.42500                 7.950000   
std                           1.12767                 0.285504  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 23
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                           120.00000   
mean                        8.141667                             8.37500   
std                         0.373398                             1.39665   
min                         7.000000                             1.00000   
25%                         8.000000                             8.00000   
50%                         8.000000                             9.00000   
75%                         8.000000                             9.00000   
max                         9.000000                             9.00000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.341667                 7.983333   
std                          1.096180                 0.258741  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 20
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.175000                            8.408333   
std                         0.402983                            1.212649   
min                         7.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.475000                 7.966667   
std                          0.969818                 0.257112  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 28
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.200000                            8.225000   
std                         0.401677                            1.416959   
min                         8.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.483333                 7.983333   
std                          1.173828                 0.258741  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 25
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.166667                            8.291667   
std                         0.396059                            1.546999   
min                         7.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.325000                 7.966667   
std                          0.918187                 0.287946  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 26
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.141667                            8.308333   
std                         0.373398                            1.413000   
min                         7.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.508333                 7.983333   
std                          0.952624                 0.289402  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 22
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.158333                            8.525000   
std                         0.388832                            0.755512   
min                         7.000000                            7.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.433333                 7.966667   
std                          0.932453                 0.257112  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 30
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                      120.00000                          120.000000   
mean                         8.17500                            8.258333   
std                          0.38156                            1.558365   
min                          8.00000                            1.000000   
25%                          8.00000                            8.000000   
50%                          8.00000                            9.000000   
75%                          8.00000                            9.000000   
max                          9.00000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.500000                 7.983333   
std                          1.130182                 0.258741  

  0%|          | 0/120 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/120 [00:00<?, ?it/s]

N of Fail : 28
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                     120.000000                          120.000000   
mean                        8.166667                            8.283333   
std                         0.396059                            1.415303   
min                         7.000000                            1.000000   
25%                         8.000000                            8.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                      120.000000               120.000000   
mean                         7.441667                 7.975000   
std                          0.941979                 0.329566  

  0%|          | 0/30 [00:00<?, ?it/s]

Evaluation ..................


  0%|          | 0/30 [00:00<?, ?it/s]

list index out of range


  0%|          | 0/1 [00:00<?, ?it/s]

N of Fail : 6
Score Distribution ..................
       Scientific Accuracy and Depth  Alignment with Research Objectives  \
count                      30.000000                           30.000000   
mean                        8.166667                            8.133333   
std                         0.379049                            2.431593   
min                         8.000000                            1.000000   
25%                         8.000000                            9.000000   
50%                         8.000000                            9.000000   
75%                         8.000000                            9.000000   
max                         9.000000                            9.000000   

       Source Clarity and Reliability  Professional Engagement  \
count                       30.000000                30.000000   
mean                         7.333333                 8.066667   
std                          0.660895                 0.365148   