In [None]:
!pip install -q -U bitsandbytes transformers accelerate torch
!pip install -q -U safetensors xformers langchain
!pip install gdown==v4.6.3

In [None]:
import os
import json
import pandas as pd
import numpy as np

def gdrive_download(file_id, file_name):
    !gdown $file_id --output $file_name

In [None]:
if not os.path.exists('SE2024'):
    !mkdir SE2024
if not os.path.exists('SE2024/test_split'):
    gdrive_download('1JcpBjTXv2OfaG6uYcIJO-Yk69nT9uN8i','./SE2024/test_split')

# Load LLM

In [None]:
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_4bit = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto",quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
pipeline = pipeline(
        "text-generation",
        model=model_4bit,
        tokenizer=tokenizer,
        use_cache=True,
        device_map="auto",
        max_length=2000,
        do_sample=True,
        top_k=5,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
)

In [None]:
from langchain import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
llm = HuggingFacePipeline(pipeline=pipeline)

# LLM's Prompts

In [None]:
all_at_once_prompt = """
You are given a riddle and four options to choose the answer amongst them. A riddle is a question or statement intentionally phrased so as to require ingenuity in ascertaining its answer or meaning, typically presented as a game. 
Different ideas can be used in these riddles:
    1. Riddles often employ misdirection, leading you away from the actual solution.
    2. They include elements with double meanings, requiring a keen eye for words with dual interpretations.
    3. Metaphorical wordplay adds another layer, urging you to decipher figurative language.
    4. Look out for exaggeration, as riddles may present overly dramatic details to divert your attention.
    5. Common phrases and sayings may hide within the puzzle, demanding familiarity.
    6. Associations and irony play a crucial role, introducing unexpected connections.
    7. Numerical puzzles can also be part of the mystery, requiring you to decode their significance.
    8. Elemental imagery, drawn from nature, might hold key descriptors.
    9. Rhyming and sound clues can add a poetic dimension.
    10. Also, it is important to note you should decode the upcoming riddle using everyday logic and creativity. 
Although a clever solution is required, avoid supernatural solutions and keep your answer within the limits of realistic imagination. For example, having superhuman abilities or unusual events or things are mostly a not preferred choice unless that is a better solution. Now which of the following options is the answer to the following riddle:

Riddle: ```
{RIDDLE}
```
Options:
[option 1]: ```{OPTION_1}```
[option 2]: ```{OPTION_2}```
[option 3]: ```{OPTION_3}```
[option 4]: ```None of the [{OPTION_1}, {OPTION_2}, {OPTION_3}] are correct```


Let's think step by step about each option, then at the end, choose the best and the most logical option:
"""

def all_at_once_prompt_processor(ds):
    prompt = all_at_once_prompt.format(
            RIDDLE=ds['QUESTION'],
            OPTION_1=ds['OPTION 1'],
            OPTION_2=ds['OPTION 2'],
            OPTION_3=ds['OPTION 3'],
    )
    return [prompt]
detailed_prompt = """
You are given a riddle and four options to choose the answer amongst them. A riddle is a question or statement intentionally phrased so as to require ingenuity in ascertaining its answer or meaning, typically presented as a game. 
Different ideas can be used in these riddles:
    1. Riddles often employ misdirection, leading you away from the actual solution.
    2. They include elements with double meanings, requiring a keen eye for words with dual interpretations.
    3. Metaphorical wordplay adds another layer, urging you to decipher figurative language.
    4. Look out for exaggeration, as riddles may present overly dramatic details to divert your attention.
    5. Common phrases and sayings may hide within the puzzle, demanding familiarity.
    6. Associations and irony play a crucial role, introducing unexpected connections.
    7. Numerical puzzles can also be part of the mystery, requiring you to decode their significance.
    8. Elemental imagery, drawn from nature, might hold key descriptors.
    9. Rhyming and sound clues can add a poetic dimension.
    10. Also, it is important to note you should decode the upcoming riddle using everyday logic and creativity. 
Although a clever solution is required, avoid supernatural solutions and keep your answer within the limits of realistic imagination. For example, having superhuman abilities or unusual events or things are mostly a not preferred choice unless that is a better solution. Now consider reddle below and tell me is the provided option could be the answer of the roddle:

Riddle: ```
{RIDDLE}
```
Option:```
{OPTION}
```

Let's think step by step and keep your answer as short as you can:
"""

def detailed_prompt_processor(ds):
    prompts = []
    for option in ['OPTION 1', 'OPTION 2', 'OPTION 3']:
        prompt = detailed_prompt.format(
            RIDDLE=ds['QUESTION'],
            OPTION=ds[option],
        )
        prompts.append(prompt)
    prompt = detailed_prompt.format(
            RIDDLE=ds['QUESTION'],
            OPTION=f'None of the [{ds["OPTION 1"]}, {ds["OPTION 2"]}, {ds["OPTION 3"]}] are correct',
        )
    prompts.append(prompt)
    return prompts

entail_prompt = """
You are given a riddle and four options to choose the answer amongst them. I would provide you a context about each option and how they are related to the riddle. Then, you should choose the best option that is related to the riddle. Now, consider the riddle below and the context provided for you and tell me which option is the best answer to the riddle due to the context.
Your response should be in the format below:
ANSWER: [<answer option such as 1,2,3,4>] <option content>
It is better to choose an option instead of rejecting all options. If you entailed none of the options are correct, the choose the option that says none of the options are correct.

Riddle: ```
{RIDDLE}
```
Options:
[option 1]: ```{OPTION_1}```
[option 2]: ```{OPTION_2}```
[option 3]: ```{OPTION_3}```
[option 4]: ```{OPTION_4}```

Context:```
{CONTEXT}
```

ANSWER: 
"""

def entail_prompt_processor(ds, context):
    prompt = entail_prompt.format(
        RIDDLE=ds['QUESTION'],
        OPTION_1=ds['OPTION 1'],
        OPTION_2=ds['OPTION 2'],
        OPTION_3=ds['OPTION 3'],
        OPTION_4=ds['OPTION 4'],
        CONTEXT=context[0],
    )
    return prompt

conclusion_prompt = """
You are given a riddle and four options to choose the answer amongst them. I would provide you a context about each option and how they are related to the riddle. Then, you should choose the best option that is related to the riddle. Now, consider the riddle below and the context provided for you and tell me which option is the best answer to the riddle due to the context.
Your response should be in the format below:
ANSWER: [<answer option such as 1,2,3,4>] <option content>

Riddle: ```
{RIDDLE}
```
Options:
[option 1]: ```{OPTION_1}```
[option 2]: ```{OPTION_2}```
[option 3]: ```{OPTION_3}```
[option 4]: ```{OPTION_4}```
Contexts:
[context about option 1]: ```{THESIS_1}```
[context about option 2]: ```{THESIS_2}```
[context about option 3]: ```{THESIS_3}```
[context about option 4]: ```{THESIS_4}```

ANSWER: 
"""

def conclusion_prompt_processor(ds, theses):
    prompt = conclusion_prompt.format(
        RIDDLE=ds['QUESTION'],
        OPTION_1=ds['OPTION 1'],
        OPTION_2=ds['OPTION 2'],
        OPTION_3=ds['OPTION 3'],
        OPTION_4=ds['OPTION 4'],
        THESIS_1=theses[0],
        THESIS_2=theses[1],
        THESIS_3=theses[2],
        THESIS_4=theses[3],
    )
    return prompt


In [None]:
def entail_result(ds, entail_response: str):
    # find where 'ANSWER:' pattern starts
    entail_response = entail_response.lower()
    start_idx = max(entail_response.find('answer:'), 0)
    
    # extract if [option 1] or [1] is in the entail_response
    option_1 = entail_response.find(ds['OPTION 1'].lower()) >= start_idx or entail_response.find('1') >= start_idx
    option_2 = entail_response.find(ds['OPTION 2'].lower()) >= start_idx or entail_response.find('2') >= start_idx
    option_3 = entail_response.find(ds['OPTION 3'].lower()) >= start_idx or entail_response.find('3') >= start_idx
    option_4 = entail_response.find(ds['OPTION 4'].lower()) >= start_idx or entail_response.find('4') >= start_idx
    
    # if only one option is in the entail_response, return that option
    if option_1 + option_2 + option_3 + option_4 == 1:
        if option_1:
            return ds['OPTION 1']
        elif option_2:
            return ds['OPTION 2']
        elif option_3:
            return ds['OPTION 3']
        elif option_4:
            return ds['OPTION 4']
    
    # check if option content exists
    option_1 = entail_response.find(ds['OPTION 1'].lower()) >= start_idx
    option_2 = entail_response.find(ds['OPTION 2'].lower()) >= start_idx
    option_3 = entail_response.find(ds['OPTION 3'].lower()) >= start_idx
    option_4 = entail_response.find(ds['OPTION 4'].lower()) >= start_idx
    
    # if only one option is in the entail_response, return that option
    if option_1 + option_2 + option_3 + option_4 == 1:
        if option_1:
            return ds['OPTION 1']
        elif option_2:
            return ds['OPTION 2']
        elif option_3:
            return ds['OPTION 3']
        elif option_4:
            return ds['OPTION 4']
        
    return "Unknown"

In [None]:
def write_log(qid, rid, content):
  qid, rid = str(qid), str(rid)
  path = 'inference'
  if not os.path.exists(path):
    os.mkdir(path)
  path = os.path.join(path, qid)
  if not os.path.exists(path):
    os.mkdir(path)
  with open(os.path.join(path, rid)+".txt", "w", encoding='utf-8') as fp:
    fp.write(content)


In [None]:
from tqdm.notebook import tqdm

In [None]:
def inference_dataset(llm, data, prompt_processor, entailment_prompt_processor, entail_result, result_col, appr):
    data[result_col] = None
    itr = tqdm(data.index, total=len(data), desc="(Inference)")
    for idx in itr:
        ds = data.loc[idx]
        prompts = prompt_processor(ds)
        responses = []
        for prompt in tqdm(prompts, desc=f"(Inference-Prompt)", leave=False):
            response = llm.invoke(prompt)
            responses.append(response)
            write_log(appr, f"{idx}_{len(responses)}", response)
        entail_prompt = entailment_prompt_processor(ds, responses)
        entail_response = llm.invoke(entail_prompt)
        
        result = entail_result(ds, entail_response)
        write_log(appr, f"{idx}_result", entail_response+"\n\n"+result)

        data.loc[idx, result_col]=result
        
    return data


# Inference

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

## CoT

In [None]:
data = pd.read_csv('SE2024/test_split')

In [None]:
data_cot_result = inference_dataset(
    llm=llm,
    data=data,
    prompt_processor=all_at_once_prompt_processor,
    entailment_prompt_processor=entail_prompt_processor,
    entail_result=entail_result,
    result_col='CoT',
    appr='CoT',
)

In [None]:
data_cot_result.to_csv('SE2024/test_split_cot_result.csv', index=True)

## Chain of Thesis

In [None]:
data_cot_result2 = inference_dataset(
    llm=llm,
    data=data_cot_result,
    prompt_processor=detailed_prompt_processor,
    entailment_prompt_processor=conclusion_prompt_processor,
    entail_result=entail_result,
    result_col='Thesis',
    appr='Thesis',
)

In [None]:
data_cot_result2.to_csv('SE2024/test_split_cot_result2.csv', index=True)