In [19]:
import pandas as pd

plausibility = pd.read_csv('./newformat_curated_human_ratings.csv')
plausibility['Typicality'] = plausibility['Typicality'].map({"AT":"Atypical (1)", "T":"Typical (7)"})
plausibility


Unnamed: 0,Sentence,Typicality,Rating,Item
0,The actor won the battle.,Atypical (1),2.60,1
1,The actor won the award.,Typical (7),5.80,1
2,The anchorman told the parable.,Atypical (1),3.00,2
3,The anchorman told the news.,Typical (7),6.75,2
4,The animal found the map.,Atypical (1),2.00,3
...,...,...,...,...
789,The woman carried the bag.,Typical (7),6.25,395
790,The woman opened the manhole.,Atypical (1),2.40,396
791,The woman opened the bag.,Typical (7),6.45,396
792,The woman painted the sign.,Atypical (1),3.55,397


# Setup Common Functions

In [65]:
from langchain_core.output_parsers import StrOutputParser
from langchain.output_parsers.enum import EnumOutputParser
from langchain.output_parsers.fix import OutputFixingParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain.llms import Ollama
from tqdm import tqdm 
import time 
from enum import Enum
# from langchain.globals import set_verbose
# set_verbose(True)
    

class ChainManager:
    def __init__(self):
        self.prompt = PromptTemplate.from_template("Tell me a short joke about {input}")
        self.output_parser = StrOutputParser()
        self.df = plausibility
#         self.model_list = ["llama2"]
        self.model_list = ["llama2","mistral", "orca-mini:7b", "qwen:7b"]
        self.logs = []
        
    def run_single_query(self, inputs, model_name, verbose, output_file_name=""):
        full_prompt = self.prompt.format(Sentence=inputs["Sentence"])
        if verbose:
            print("----------------------------------------------------------------------")
            print(f"model_name: {model_name}")
            print(f"prompt: {full_prompt}")
        else:
            self.write_string_to_buffer("----------------------------------------------------------------------")
            self.write_string_to_buffer(f"model_name: {model_name}")
            self.write_string_to_buffer(f"prompt: {full_prompt}")
                
        chain = (
            self.prompt
            | Ollama(model=model_name)
            | self.output_parser)

        chain_of_thought = chain.invoke(inputs)
        classifier_output = self.run_retry_classifier(chain_of_thought, 3, verbose)
        final_answer = classifier_output[0]
        num_retries = classifier_output[1]
        
        if verbose:
            print(f"Classifier finished...\n")
            print(f"chain_of_thought: {chain_of_thought}\n")
            print(f"final_answer: {final_answer}")
            print(f"correct_answer: {inputs['Typicality']}")
        else:
            self.write_string_to_buffer(f"classifier finished....\n")
            self.write_string_to_buffer(f"chain_of_thought: {chain_of_thought}\n")
            self.write_string_to_buffer(f"final_answer: {final_answer}")
            self.write_string_to_buffer(f"correct_answer: {inputs['Typicality']}")
            self.write_buffer_to_file(output_file_name)
        return classifier_output
    
    def run_batch_query(self, verbose, batch_size, output_file_name=""):
        self.df = self.df.iloc[2:2+batch_size].copy()
        input_list = self.df.to_dict('records')
        if output_file_name != "":
            self.clear_existing_file(output_file_name)
        for model_name in self.model_list:
            results = []
            num_retries = []
            for item in tqdm(input_list, desc="Processing queries"):
#                     result = self.run_single_query(item, model_name, verbose, output_file_name) 
                    result = self.run_single_self_review(item, model_name, verbose, output_file_name) 
                    results.append(result[0])
                    num_retries.append(result[1])
                    
            self.df[model_name] = results 
            self.df[model_name+"_retries"] = num_retries 

    def evaluate_order(self):
        for model_name in self.model_list:
            binary_results = self.df[model_name].str.strip().str[0]
            correct_predictions = (self.df['label'] == binary_results).sum()
            total_predictions = len(self.df)
            accuracy = correct_predictions / total_predictions
            print(f"{model_name}: {accuracy}")
        
    
    def verify_output(self, mcq_choice):
        try:
            choice_int = int(mcq_choice)
            return 1 <= choice_int <= 7
        except ValueError:
            return False
        
    def run_classifier(self, initial_answer):
        
        classifier_prompt = PromptTemplate.from_template("""
        You are recieving an explanation from a language model about its Likert scale rating.
        You are a numerical classifier, designed to reply numerically from 1-7. 
        You are not to explain or mention anything other than provide the numerical choice!
        If there is not enough information, select 4. 
        Please provide only the numerical rating as your response!

        Initial answer: {initial_answer}
        Numerical answer:
        """)
        
        chain = (
            classifier_prompt
            | Ollama(model="mistral")
            | self.output_parser
            
        )
        output = chain.invoke({"initial_answer": initial_answer})
        return output
    
    def run_retry_classifier(self, initial_answer, max_tries, verbose, output_file_name=""):
        if verbose:
            print("Running classifier....")
        else:
            self.write_string_to_buffer("Running classifier....")

        mcq_choice = 0
        for i in range(max_tries):
            classifier_output = self.run_classifier(initial_answer)
            if verbose:
                print(f"retry_classifier_{i+1}: {classifier_output}")
            else:
                self.write_string_to_buffer(f"retry_classifier_{i+1}: {classifier_output}")
            mcq_choice = classifier_output.strip()[0].upper()
            if self.verify_output(mcq_choice):
                return (mcq_choice, i+1)
        return (mcq_choice, max_tries)
    
    def run_self_review(self, question, answer, model):
        info_retrieval_prompt = PromptTemplate.from_template("""
        Question: {question}
        
        Previous answer: {answer}
        
        Review your previous answer for any problems, and improve it based on your critique.
        """)
    
        chain = (
            info_retrieval_prompt
            | model
            | self.output_parser
            
        )
        output = chain.invoke({"question": question, "answer": answer})
        return output
        
       
    def run_single_self_review(self, inputs, model_name, verbose, output_file_name=""):
        
        llm_model = Ollama(model=model_name)
        
        full_prompt = self.prompt.format(Sentence=inputs["Sentence"])
        if verbose:
            print("----------------------------------------------------------------------")
            print(f"model_name: {model_name}")
            print(f"prompt: {full_prompt}")
        else:
            self.write_string_to_buffer("----------------------------------------------------------------------")
            self.write_string_to_buffer(f"model_name: {model_name}")
            self.write_string_to_buffer(f"prompt: {full_prompt}")
                        
        chain = (
            self.prompt
            | llm_model
            | self.output_parser)

        chain_of_thought = chain.invoke(inputs)
        
        critique = self.run_self_review(full_prompt, chain_of_thought, llm_model)
        
        classifier_output = self.run_retry_classifier(critique, 3, verbose)
        final_answer = classifier_output[0]
        num_retries = classifier_output[1]
        
        if verbose:
            print(f"Classifier finished...\n")
            print(f"chain_of_thought: {chain_of_thought}\n")
            print(f"critique: {critique}\n")
#             print(f"improved_answer: {improved_answer}")
            print(f"final_answer: {final_answer}")
            print(f"correct_answer: {inputs['Typicality']}")
        else:
            self.write_string_to_buffer(f"classifier finished....\n")
            self.write_string_to_buffer(f"chain_of_thought: {chain_of_thought}")
            self.write_string_to_buffer(f"critique: {critique}\n")
#             self.write_string_to_buffer(f"improved_answer: {improved_answer}\n")
            self.write_string_to_buffer(f"final_answer: {final_answer}")
            self.write_string_to_buffer(f"correct_answer: {inputs['Typicality']}")
            self.write_buffer_to_file(output_file_name)
        return  classifier_output
    
    def write_string_to_buffer(self, input_string):
        self.logs.append(input_string)
    
    def write_buffer_to_file(self, filename):
        with open(filename, 'a') as file:
            for log in self.logs:
                file.write("\n"+log)
            self.logs = []
        
    def clear_existing_file(self, filename):
         with open(filename, 'w') as file:
            file.write("")

### Single Prompt + classifier

In [44]:
simple_QA_chain = ChainManager()

simple_QA_chain.prompt = PromptTemplate.from_template("""
You're tasked with evaluating the typicality of a given sentence using a Likert scale. 
1 (the sentence is very atypical or uncommon), 7 (the sentence is very typical or common). 

For example:
Sentence: The actor won the battle
Typicality rating: 3

Sentence: The actor won the award
Typicality rating: 6

Please evaluate the following:
Sentence: {Sentence}
Typicality rating:""")

simple_QA_chain.run_batch_query(False, 350, "/Users/kohjunkai/Desktop/plausibility_simple_QA.txt")
simple_QA_chain.df.to_csv('/Users/kohjunkai/Desktop/plausibility_simple_QA.csv', index=False)

Processing queries: 100%|███████████████████| 350/350 [1:13:10<00:00, 12.54s/it]
Processing queries: 100%|█████████████████████| 350/350 [49:20<00:00,  8.46s/it]
Processing queries: 100%|█████████████████████| 350/350 [57:53<00:00,  9.92s/it]
Processing queries: 100%|███████████████████| 350/350 [1:17:56<00:00, 13.36s/it]


In [None]:
simple_QA_chain.evaluate_order()

### Chain of thought  + classifier

In [45]:
chain = ChainManager()

chain.prompt = PromptTemplate.from_template("""
You're tasked with evaluating the typicality of a given sentence using a Likert scale. 
1 (the sentence is very atypical or uncommon), 7 (the sentence is very typical or common). 
You are to explain your chain of thought before coming up with a typicality rating.

For example:
Sentence: The actor won the battle
Chain of thought: The phrase implies a situation where an actor, typically known for performing in films or theater, is involved in a "battle," which is less common in the context of acting. The term "battle" might metaphorically refer to overcoming personal challenges or competition in the industry, but it's less typical than winning awards or recognition for acting. Hence, a rating of 3 indicates that it's somewhat atypical, considering the unconventional use of "battle" in relation to an actor's professional achievements.
Typicality rating: 3

Sentence: The actor won the award
Chain of thought: This statement aligns closely with common scenarios within the entertainment industry, where actors are frequently recognized for their performances through various awards. Winning an award is a typical outcome for actors who have delivered exceptional performances in their roles. Therefore, a rating of 6 is justified as it reflects a highly typical event in the context of an actor's career.
Typicality rating: 6

Please evaluate the following:
Sentence: {Sentence}
Chain of thought:
Typicality rating:""")

chain.run_batch_query(False, 350, "/Users/kohjunkai/Desktop/plausibility_COT.txt")
chain.df.to_csv('/Users/kohjunkai/Desktop/plausibility_COT.csv', index=False)

Processing queries: 100%|███████████████████| 350/350 [1:24:41<00:00, 14.52s/it]
Processing queries: 100%|███████████████████| 350/350 [1:12:25<00:00, 12.41s/it]
Processing queries: 100%|███████████████████| 350/350 [1:19:24<00:00, 13.61s/it]
Processing queries: 100%|███████████████████| 350/350 [1:22:02<00:00, 14.06s/it]


In [None]:
chain.df.to_csv('order_results.csv', index=False)

### Self Critique + classifier

In [66]:
critique_chain = ChainManager()

critique_chain.prompt = PromptTemplate.from_template("""
You're tasked with evaluating the typicality of a given sentence using a Likert scale. 
1 (the sentence is very atypical or uncommon), 7 (the sentence is very typical or common). 
You are to explain your chain of thought before coming up with a typicality rating.

For example:
Sentence: The actor won the battle
Chain of thought: The phrase implies a situation where an actor, typically known for performing in films or theater, is involved in a "battle," which is less common in the context of acting. The term "battle" might metaphorically refer to overcoming personal challenges or competition in the industry, but it's less typical than winning awards or recognition for acting. Hence, a rating of 3 indicates that it's somewhat atypical, considering the unconventional use of "battle" in relation to an actor's professional achievements.
Typicality rating: 3

Sentence: The actor won the award
Chain of thought: This statement aligns closely with common scenarios within the entertainment industry, where actors are frequently recognized for their performances through various awards. Winning an award is a typical outcome for actors who have delivered exceptional performances in their roles. Therefore, a rating of 6 is justified as it reflects a highly typical event in the context of an actor's career.
Typicality rating: 6

Please evaluate the following:
Sentence: {Sentence}
Chain of thought:
Typicality rating:
""")

critique_chain.run_batch_query(False, 200, "/Users/kohjunkai/Desktop/plausibility_self_critique.txt")
critique_chain.df.to_csv('/Users/kohjunkai/Desktop/plausibility_self_critique.csv', index=False)

Processing queries: 100%|███████████████████| 200/200 [1:37:24<00:00, 29.22s/it]
Processing queries: 100%|█████████████████████| 200/200 [54:08<00:00, 16.24s/it]
Processing queries: 100%|███████████████████| 200/200 [1:10:44<00:00, 21.22s/it]
Processing queries: 100%|███████████████████| 200/200 [1:17:36<00:00, 23.28s/it]
