In [None]:
from utils import *

In [None]:
import os
import sys
import time
import jsonlines
from pprint import pprint
from tqdm import tqdm

from collections import defaultdict
import copy
from copy import deepcopy
import random

<h2 data-lake-id="Zv3rs" id="Zv3rs"><span data-lake-id="u80758df7" id="u80758df7"> Config</span></h2>

In [None]:
from argparse import Namespace

config = Namespace()

config.dataset_name = 'GSM8K'
config.split = 'test'
config.model_name = 'LLAMA3-8B'
config.start_index = 0
config.end_index = 768

config

<h2 data-lake-id="L2NZ7" id="L2NZ7"><span data-lake-id="ud4d32cb7" id="ud4d32cb7"> Dataset</span></h2>

In [None]:
dataset_frn = f"data/{config.dataset_name}/{config.split}.jsonl"
dataset = load_data(dataset_frn)

print(f'Dataset: {config.dataset_name}, Length: {len(dataset)}')

In [None]:
# Path of the Initial Responses
initial_pred_directory=f'/Initial-Generation-List/{config.model_name}/{config.dataset_name}'
initial_pred_path=os.path.join(initial_pred_directory, 'output.jsonl')

# Read and Processing the Initial Responses
initial_generation_list = read_jsonl_as_list(initial_pred_path)
recording_list = [{'id': zip_data[0]['id'], 'question': zip_data[0]['question'], 'response':zip_data[1]['completion'], 'response-answer':zip_data[1]['answer']} for zip_data in zip(dataset, initial_generation_list)]

# Use recording_list to keep track of all the intermediate results.
# Now the keys in recording_list: 'id', 'question', 'response' (initial response)
print(f'size of initial prediction: {len(recording_list)}')

<h2 data-lake-id="BHdZ7" id="BHdZ7"><span data-lake-id="ua587df90" id="ua587df90"> LLM Configuration</span></h2>

In [None]:
from model import Model

# Model Initialization
llm = Model(config, cur_stage='Prepare-Model')

model_id = '/root/Meta-Llama-3-8B-Instruct'
llm.prepare_model(model_id)

<h2 data-lake-id="PAj9K" id="PAj9K"><span data-lake-id="ub677a3bb" id="ub677a3bb"> Main Function</span></h2>

In [None]:
from tqdm import tqdm

def single_run(llm, stage, recording, config, round):
    # Initialization of LLM Wrapper
    llm.refresh_stage(cur_stage = stage, cur_round = round)
    
    # Current experiment name
    if stage in ['Contrast-Responses-Merge-Memory', 'Regeneration-w-Suggestion']:
        exp_name = f'{round}-{stage}'
    else:
        exp_name = stage
        
    for sample in tqdm(recording):
        if exp_name in sample.keys():
            # print(f'{exp_name} already done for the {sample["id"]}-th sample')
            continue

        try:
            completion = llm.predict(sample)
            for k,v in completion.items():
                sample[k] = v
        except Exception as e:
            sample[exp_name] = str(e)
            print(f'Error at {sample["id"]}-th sample: {str(e)}', file=sys.stderr)

    # Save current recording-List
    recording_path = f'/ossfs/workspace/Faithful-COT-Logic/recording/{config.model_name}/{config.dataset_name}'
    if not os.path.exists(recording_path):
        os.makedirs(recording_path)
    with open(os.path.join(recording_path, f'{exp_name}-{config.start_index}-{config.end_index}.json'), 'w') as f:
        json.dump(recording, f, indent=4)

def complete_run(llm, recording, config, total_EM_rounds):
    try:
        single_run(llm=llm, stage='Initial-Regeneration', recording=recording, config=config, round=0)
        get_cur_major_vote(weight_method='average', recording=recording, config=config)
        for EM_step_id in range(1, total_EM_rounds+1):
            single_run(llm=llm, stage='Contrast-Responses-Merge-Memory', recording=recording, config=config, round=2*EM_step_id-1)
            single_run(llm=llm, stage='Regeneration-w-Suggestion', recording=recording, config=config, round=2*EM_step_id)
            get_cur_major_vote(weight_method='average', recording=recording, config=config)
    except Exception as e:
        print(f'Error: {str(e)}', file=sys.stderr)


<h2 data-lake-id="HrWpH" id="HrWpH"><span data-lake-id="u7d31b026" id="u7d31b026"> Formal Running</span></h2>

In [None]:
recording = deepcopy(recording_list)[config.start_index:config.end_index]

with open('/ossfs/workspace/Faithful-COT-Logic/recording/LLAMA3-8B/GSM8K/17-Contrast-Responses-Merge-Memory-0-768.json', 'r') as f:
    recording = json.load(f)


print(f'size of current run: {len(recording)}')

In [None]:
# The main running function

complete_run(llm, recording, config, total_EM_rounds=9)

In [None]:
# initial_list = []

# for r in recording:
#     initial_list.append({'id':r['id'], 'completion': r['Initial-Regeneration'], 'answer': r['Initial-Regeneration-answer']})

# initial_pred_directory=f'/ossfs/workspace/Faithful-COT-Logic/Initial-Generation-List/{config.model_name}/{config.dataset_name}'

# if not os.path.exists(initial_pred_directory):
#     os.makedirs(initial_pred_directory)

# path = os.path.join(initial_pred_directory, 'output.jsonl')

# dump_list_as_jsonl(path, initial_list)

<h1 data-lake-id="qz0Nz" id="qz0Nz"><span data-lake-id="u2e00292b" id="u2e00292b"> Evaluating</span></h1>

In [None]:
recording_backup = deepcopy(recording)

In [None]:
for sample in recording_backup:
    for k,v in sample.copy().items():
        if k.endswith('answer'):
            pred_answer = extract_pred_answer(config.dataset_name, v)

            if isinstance(pred_answer, str):
                print(sample)
            sample[f'{k}-extracted'] = pred_answer
            

In [None]:
def evaluate_acc(dataset, predictions, dataset_name, non_empty_only=False, valid_only=False, key4check = 'Initial-Regeneration-answer-extracted'):
	correct_count, total_count = 0, 0
        
	for example, prediction in zip(dataset, predictions):
		gold_id = int(example["id"])
		if prediction == {}:
			continue
		pred_id = int(prediction["id"])

		try:
			assert gold_id == pred_id
		except:
			raise AssertionError(f"Gold id {gold_id} doesn't match pred id {pred_id}.")

		try:
			gold_answer = extract_gold_answer(dataset_name, example["answer"])
		except SyntaxError as e:
			print("Error: ", e)
			print(gold_id)
			exit(-1)
            
		if key4check not in prediction:
			continue
		pred_answer = extract_pred_answer(dataset_name, prediction[key4check])

		if non_empty_only and pred_answer == "":
			continue

		if valid_only:
			if type(pred_answer) == str and ("invalid" in pred_answer or "error" in pred_answer):
				print(pred_answer, flush=True)
				continue

		total_count += 1
		try:
			correct = is_correct(dataset_name, gold_answer, pred_answer)
		except Exception as e:
			print("Error: ", e)
			print("Example: ", gold_id)
			print("Question: ", example["question"])
			print("Gold answer: ", gold_answer, type(gold_answer))
			print("Pred answer: ", pred_answer, type(pred_answer))
			print("Completion: ", prediction["completion"])
			print("\n")
			exit(-1)

		if correct:
			correct_count += 1      

		prediction[key4check+'-correct']=correct
	print(f'correct_count: {correct_count}, total_count: {total_count}')
	acc=round(correct_count/ total_count * 100, 1)
	return acc

In [None]:
steps_name_list = ['Initial-Regeneration-answer']+[f'{rnd}-Regeneration-w-Suggestion-answer' for rnd in [2,4,6]]

acc_list = []
for step_name in steps_name_list:
    acc = evaluate_acc(dataset=dataset,
                       predictions=recording_backup,
                       dataset_name=config.dataset_name,
                       non_empty_only=True,
                       valid_only=True,
                       key4check=step_name)
    acc_list.append(acc)

In [None]:
acc_list

In [None]:
evaluate_acc(dataset=dataset,
               predictions=recording_backup,
               dataset_name=config.dataset_name,
               non_empty_only=True,
               valid_only=True,
               key4check='Initial-Regeneration-answer')