In [None]:
"""
Set the configuration here
"""
import ipywidgets as widgets
from IPython.display import display

# input your API key here

# openAI - for gpt
API_KEY = "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"

# groq - for llama3
#API_KEY = "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"


'''
Candidate generator
'''
preprocess_mode = widgets.Dropdown(
    options=['None', 'CDGP', 'BERT'], # in ISSR: CDGP
    value='CDGP',
    description='Preprocess Mode: Select your candidate generator model',
)
candidate_generator_top_k = widgets.IntSlider(
    value=2500, # in ISSR: 2500
    min=100,
    max=4000,
    step=10,
    description='candidate_generator_top_k: How much candidates would be generated by candidate generator (this amount is BEFORE filter, make sure to be big)',
)
# if this is set to true, BERT model will not be load as candidate generator in runtime (mocked), instead will load contents in "candidate_set_cache" as candidate set, not affecting performance
use_cache_result = widgets.Checkbox(
    value=False,
    description='Use cache result: Use cached candidate set generated by PLM (same stem+target word tends to generate same result)'
)
# candidate set cache location, used for "use_cache_result"
# CAUTION: Make sure to select the correct file for caching: *BERT_response_cache* contains results generated by BERT, whereas *CDGP_response_cache* contains results generated by CDGP.
candidate_set_cache_path = "./dataset/BERT_response_cache.json"


cheat = widgets.Checkbox(
    value=False, # in ISSR: false
    description='''Enable Cheat: Randomly replace some generated candidates to ground truth
(to ensure ground truth exists in the candidate set)'''
)

candidate_set_size = widgets.IntSlider(
    value=50, # in ISSR: 50
    min=30,
    max=300,
    step=10,
    description='Candidate Set Size: Size of candidate set'
)


'''
Distractor selector
'''
zero_shot = widgets.Checkbox(
    value=False, # in ISSR: false
    description='Zero-Shot Mode: Controls distractor selector zero-shot or few shot (true = zeroshot, false=fewshot)'
)

chain_of_thought = widgets.Checkbox(
    value=False,
    description='Chain of Thought: Whether use CoT on distractor selector or not (set to false)' 
)


pick_distractors_per_round = widgets.IntSlider(
    value=3, # in ISSR: 3
    min=3,
    max=30,
    step=5,
    description='Distractors per Round: Control distractors that picked by distractor selector per round'
)


'''
Self-review
'''

self_review = widgets.Checkbox(
    value=True, # in ISSR: true
    description='Self Answer: Using self-review or not'
)

error_report = widgets.Checkbox(
    value=False,
    description='Error Report: Abandoned (set to false)'
)


'''
Overall
'''

LLM = widgets.Dropdown(
    options=['gpt', 'gemma2', 'llama3_8b', 'llama3_70b'],
    value='gpt', # in ISSR: gpt
    description='''LLM Model: The LLM model used for distractor selector and self-reviewer
groq-API / OpenAI api required for models.'''
)

model_name = widgets.Dropdown(
    options=['gpt-4-turbo-2024-04-09', 'gpt-3.5-turbo-0125', 'llama3-70b-8192',
             'llama3-8b-8192', 'gpt-4o-mini-2024-07-18'],
    value='gpt-3.5-turbo-0125', # in ISSR: gpt-3.5-turbo-0125
    description='LLM Model: Specify LLM.'
)

generate_count = widgets.IntSlider(
    value=30, # in ISSR: 30
    min=10,
    max=100,
    step=10,
    description='Generate Count: Total required distractors from ISSR'
)

device = widgets.Dropdown(
    options=['cuda', 'cpu'],
    value='cuda',
    description='Device: your device'
)


record_bad_distractor = widgets.Checkbox(
    value=True, # for recording bad distractors, not affecting performance
    description='Record bad distractor: Record select history of distractor selector'
)

# location of rule-based reference datas
# CEEC word list
ref_vocabulary_path = "../Dataset/高中英文參考詞彙表v2.xlsx"
# GSAT questions (or questions required to generate distractors)
dataset_path = "../Dataset/processed_gsat_data.json"
# english dictionary list, which records all english vocabulary (used to test whether the generated candidate is a vocabulary or not)
english_dictionary_list_path = "../Dataset/words_alpha.txt"


print("'''HERE IS YOUR CONFIG - MODIFY IT IN CODE, NOT GUI'''")
display(preprocess_mode, cheat, zero_shot, chain_of_thought, candidate_set_size,
        pick_distractors_per_round, generate_count, self_review, error_report, LLM, device)


def get_config():
    config = {
        "preprocess_function": {
            "mode": preprocess_mode.value,
            "reason": "gpt", # this is abandoned, let it be
            "cheat": cheat.value
        },
        "distractor_generation_function": {
            "zero-shot": zero_shot.value,
            "chain_of_thought": chain_of_thought.value,
            "candidate_set_size": candidate_set_size.value,
            "pick_distractors_per_round": pick_distractors_per_round.value,
            "generate_count": generate_count.value,
        },
        "post_processing_function": {
            "self-answer": self_review.value,
            "error-report": error_report.value,
            "generate_count": generate_count.value
        },
        "api_key": API_KEY,
        "use_cache_result": use_cache_result.value,
        "LLM": LLM.value,
        "model_name": model_name.value,
        "device": device.value,
        "ref_vocabulary_path": ref_vocabulary_path,
        "dataset_path": dataset_path,
        "english_dictionary_list_path": english_dictionary_list_path,
        "candidate_set_cache_path": candidate_set_cache_path,
        "candidate_generator_top_k": candidate_generator_top_k.value,
        "record_bad_distractor": True
    }
    return config

In [9]:
import pandas as pd
import json
from models import openAIModel, gemma2, llama3_8b, llama3_70b
from tqdm import tqdm
import os
from transformers import BertTokenizer, BertForMaskedLM, pipeline
import numpy as np
import fasttext
import nltk
from nltk.tokenize import word_tokenize
from utils import *
import nltk
import spacy
from nltk.corpus import wordnet
from nltk import word_tokenize, pos_tag
from nltk.stem import WordNetLemmatizer
from postprocess import self_answer, self_answer_correctness, self_answer_same_meaning
import random
from openai import OpenAI

class DistractorGenerationModel:
    def __init__(
            self, 
            config,
            preprocess_function,
            distractor_generation_function,
            post_processing_function,
        ):
        with open(config['dataset_path'], "r") as f:
            self.dataset = json.load(f)
            # skip the fewshoted question
            self.dataset = self.dataset[2:]
        self.config = config
        self.ref_word = pd.read_excel(config['ref_vocabulary_path'])
        # Load words dictionary for distractor generation filter
        file_path = config['english_dictionary_list_path']  # Replace "your_file.txt" with the path to your text file
        self.word_list = []
        with open(file_path, 'r') as file:
            # Read each line and append it to the list
            for line in file:
                self.word_list.append(line.strip())  # Strip any leading/trailing whitespace or newline characters
        

        
        # Choose the LLM to inference
        if config['LLM'] == "gpt":
            # Model name is used for LLM templates
            self.model = openAIModel(config['api_key'])
            self.model_name = "gpt"
        elif config['LLM'] == "llama3_8b":
            self.model_name = "llama3_8b"
            self.model = llama3_8b(config['api_key'])
        elif config['LLM'] == "llama3_70b":
            self.model_name = "llama3_70b"
            self.model = llama3_70b(config['api_key'])
        elif config['LLM'] == "gemma2":
            self.model_name = "gemma2"
            self.model = gemma2(config['api_key'])

        
        self.lemma_model = spacy.load('en_core_web_sm')
        # Ready BERT
        if (config['preprocess_function']['mode'] == "BERT" or config['preprocess_function']['mode'] == "both") and config['use_cache_result'] == False:
            print("Loading BERT as candidate generator...")
            self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
            self.bert_csg_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
            self.unmasker = pipeline('fill-mask', tokenizer=self.bert_tokenizer, model=self.bert_csg_model, top_k=config['candidate_generator_top_k'])
        elif (config['preprocess_function']['mode'] == "CDGP" or config['preprocess_function']['mode'] == "both"):
            print("Loading cdgp-csg-bert as candidate generator...")
            self.bert_tokenizer = BertTokenizer.from_pretrained("AndyChiang/cdgp-csg-bert-cloth")
            self.bert_csg_model = BertForMaskedLM.from_pretrained("AndyChiang/cdgp-csg-bert-cloth")
            self.unmasker = pipeline('fill-mask', tokenizer=self.bert_tokenizer, model=self.bert_csg_model, top_k=config['candidate_generator_top_k'])
        self.preprocess_function = preprocess_function
        self.distractor_generation_function = distractor_generation_function
        self.post_processing_function = post_processing_function

    # filter list of candidates given predefined rules
    def _filter_good_cand(self, cs, question):
        filtered_cs = list()
        for c in cs:
            if self._has_same_postag(c, question) and self._has_sim_length(c, question) and self._has_sim_difficulty(c, question):
                filtered_cs.append(c)
        return filtered_cs


    def _has_sim_length(self, gen_distractor, question):
        ans_len = len(question['answer'])
        if(ans_len - len(gen_distractor) > 2):
            return False
        return True

    def _has_same_postag(self, gen_distractor, question):
        
        sentence = question['sentence']
        answer = question['answer']
        ans_pos_tag = get_pos_tag_of_word(sentence, answer)
        if(get_pos_tag_of_word(sentence, gen_distractor) != ans_pos_tag):
            return False
        return True


    def _has_sim_difficulty(self, gen_distractor, question):
        ref_word = self.ref_word.copy()
        ref_word.set_index('單字', inplace=True)
        answer = question['answer']
        answer = lemmatization(answer, self.lemma_model)
        result = ref_word[ref_word.index == answer]
        if result.empty:
            # print("WARNING, answer not in ref word list")
            # print(f"question: {question['sentence']}")
            return True
        else:
            ans_dif = result['難度'].values[0]

        gen_distractor = lemmatization(gen_distractor, self.lemma_model)
        result = ref_word[ref_word.index == gen_distractor]
        if result.empty:
            return False
        else:
            if(abs(result['難度'].values[0]-ans_dif) > 1):
                return False
            return True


    def extract_response(self, response, cand_pool = None):
        # remove ' " in response
        response = response.replace("'", "").replace('"', "").strip()
        if self.model_name == 'zephyr':
          pattern = re.compile("\d+\. ")
          response = [x.lower().replace("\"", "") for x in pattern.sub("", response).strip().split("\n") if x != ""]
        elif self.model_name == 'vicuna-1.5-original':
          pattern = re.compile("\d+\. ")
          response = [x.lower().replace("\"", "") for x in pattern.sub("", response).strip().split("\n") if x != ""]
        elif 'gpt' in self.model_name:
          pattern = re.compile("\d+\. ")
          response = [x.lower().replace("\"", "") for x in pattern.sub("", response).strip().split("\n") if x != ""]
        else:
          pattern = re.compile("\d+\. ")
          response = [x.lower().replace("\"", "") for x in pattern.sub("", response).strip().split("\n") if x != ""]
        response = list(set(response))
        # Further process output by checking whether it exist in english dictionary or not
        new_response = []
        for i in response:
            # check if the response vocabulary exists in english dictionary(dataset/words_alpha)
            if i.strip().lower() not in self.word_list:
                pass
            else:
                new_response.append(i.strip().lower())
                
        # Further process output by checking wheher it exist in candidate set
        # if cand_pool is not None:
        #     cand_pool = [x.lower().strip() for x in cand_pool]
        #     final_response=[]
        #     for i in new_response:
        #         if i.strip().lower() in cand_pool:
        #             final_response.append(i.strip())
        #         else:
        #             pass
        #     new_response = final_response
        return new_response
    
    def recall_rate_of_top_k(self, k):
        recall_total = list()
        for question in self.dataset:
            match = 0
            pool = self.preprocess_function(self, question)['cand_pool'][:k]
            if len(pool) < k:
                print(f"Warning: Valid distractor set size is lower than {k}")
            for d in question['distractors']:
                if d in pool:
                    match+=1
                recall_total.append(match)
        return (sum(recall_total)/len(recall_total))/3

    def run_framework(self):
        self.result = []
        self.overall_generate_history = []
        self.bert_result = list()


        if config['preprocess_function']['mode'] == "BERT" and config['use_cache_result'] == True:
            with open(config['candidate_set_cache_path'], "r") as f:
                fastdata = json.load(f)
        if config['preprocess_function']['mode'] == "CDGP" and config['use_cache_result'] == True:
            with open(config['candidate_set_cache_path'], "r") as f:
                fastdata = json.load(f)
        
        for ind, question in enumerate(self.dataset):
            print(ind)
            # The passed distractor
            self.good_distractor = []
            # The bad distractor in this question
            self.persist_bad_distractor = []
            # The bad distractor in this round of picking
            self.bad_distractor = []
            # Do the preprocess part (generation of candidate set using BERT)
            if (config['preprocess_function']['mode'] == "BERT" or config['preprocess_function']['mode'] == "CDGP") and config['use_cache_result'] == True:
                prompt_pool = {
                    "reason": None,
                    "cand_pool": fastdata[ind]
                }
                candidate_size = self.config["distractor_generation_function"]["candidate_set_size"]
                prompt_pool['cand_pool'] = prompt_pool['cand_pool'][:candidate_size]
            else:
                prompt_pool = self.preprocess_function(self, question)
                candidate_size = self.config["distractor_generation_function"]["candidate_set_size"]
                if(prompt_pool.get('cand_pool') is not None):
                    prompt_pool['cand_pool'] = prompt_pool['cand_pool'][:candidate_size]


            
            # cheat: decrease the size of cand_pool to K which contains the ground truth
            if config['preprocess_function']['cheat'] == True:
                # Cut the size of cand_pool into K
                prompt_pool['cand_pool'] = prompt_pool['cand_pool'][:10]
                # randomly replace three distractors candidate to ground truth
                numbers_range = list(range(0, min(len(prompt_pool['cand_pool']), 30)))
                unique_numbers = random.sample(numbers_range, 3)
                ground_truth = question['distractors']
                for uni, dis in zip(unique_numbers, ground_truth):
                    prompt_pool['cand_pool'][uni] = dis
            # Do the distractor generation part
            if ind == 0:
                # Print prompt
                distractor_pool = self.distractor_generation_function(self, question, prompt_pool, sample=True)
            else:
                distractor_pool = self.distractor_generation_function(self, question, prompt_pool, sample=False)
            
            
            if self.config["post_processing_function"]['self-answer'] == False:
                question['generated'] = distractor_pool
                continue
            
            # Self-answer
            self.post_processing_function(self, question, distractor_pool)
            generate_history = []
            t = []
            for g in self.good_distractor:
                t.append(g)
            for b in self.bad_distractor:
                t.append(b)
            generate_history.append(t)
            # Retry to generate distractors for at most 'tries' time
            tries = 0
            while len(self.good_distractor) < self.config['post_processing_function']['generate_count'] and tries < 2:
                tries+=1
                if(self.config["post_processing_function"]["self-answer"] == True):
                    for d in self.bad_distractor:
                        if self.config["preprocess_function"]['mode'] != "None" and d in prompt_pool['cand_pool']:
                            prompt_pool['cand_pool'].remove(d)
                    for d in self.good_distractor:
                        if self.config["preprocess_function"]['mode'] != "None" and d in prompt_pool['cand_pool']:
                            prompt_pool['cand_pool'].remove(d)
                for d in self.bad_distractor:
                    self.persist_bad_distractor.append(d)
                    
                t = []
                for g in self.good_distractor:
                    t.append(g)
                for b in self.bad_distractor:
                    t.append(b)
                generate_history.append(t)
            
                self.bad_distractor = []
                distractor_pool = self.distractor_generation_function(self, question, prompt_pool)
                self.post_processing_function(self, question, distractor_pool)

            # if good distractor is less then 3, append previous bad distractors to the result
            self.persist_bad_distractor = list(set(self.persist_bad_distractor))
            for d in self.persist_bad_distractor:
                if(len(self.good_distractor) < self.config['post_processing_function']['generate_count']):
                    self.good_distractor.append(d)
                else:
                    break
            if len(self.good_distractor) < self.config['post_processing_function']['generate_count']:
                # The generated distractor is still less then 3
                question['generated'] = self.good_distractor
            else:
                question['generated'] = self.good_distractor[:self.config['post_processing_function']['generate_count']]
            self.overall_generate_history.append(generate_history)
            generate_history = []
        return
        #return self.result

In [None]:
import json
from preprocess import *
from distractor_generation import *
from utils import *

if __name__ == "__main__":
    config = get_config()
    print(config)
    if(config['preprocess_function']['mode'] == "BERT") or config['preprocess_function']['mode'] == "CDGP":
        preprocess_function = pool_generation
    elif(config['preprocess_function']['mode'] == "reason"):
        preprocess_function = reason_generation
    elif(config['preprocess_function']['mode'] == "None"):
        preprocess_function = none
    else:
        preprocess_function = pool_generation


    nltk.download('punkt')
    nltk.download('averaged_perceptron_tagger')
    nltk.download('wordnet')
    nltk.download('omw-1.4')


    
    # dis_model = DistractorGenerationModel(config, preprocess_function, few_shot, self_answer)
    dis_model = DistractorGenerationModel(config, preprocess_function, few_shot, self_answer)
    result = dis_model.run_framework()

In [None]:
'''
Format: {distractor_selector_model}_{distractors picked per round}_{candidate generator}_{fewshot or zeroshot}_{self-review or not}.json
'''
result_path = f"""{config['LLM']}_pickRate{config['distractor_generation_function']["pick_distractors_per_round"]}_{config['preprocess_function']['mode']}_{'fewshot' if config['distractor_generation_function']['zero-shot'] == False else 'zeroshot'}_{'selfanswer' if config['post_processing_function']['self-answer'] == True else 'none'}.json"""
result_path

In [None]:
result_path = f"""./result/{config['LLM']}_pickRate{config['distractor_generation_function']["pick_distractors_per_round"]}_{config['preprocess_function']['mode']}_{'fewshot' if config['distractor_generation_function']['zero-shot'] == False else 'zeroshot'}_{'selfanswer' if config['post_processing_function']['self-answer'] == True else 'none'}.json"""

with open(result_path, 'w') as f:
    print(f"writing to {result_path}")
    json.dump(dis_model.dataset, f, indent=2)
