In [22]:
(113 * 60 + 9) / 1297 * (1280+128-1297) / 60

9.68361603700848

# MCQ

In [45]:
import regex
import jsonlines
from collections import defaultdict
from string import punctuation

def load_jsonl(fname):
    with jsonlines.open(fname, mode='r') as reader:
        data = [l for l in reader]
    return data

def _extract_answer(gen):
    raw_gen = gen
    format_flag = False
    if 'QUESTION:' in gen:
        gen = gen.split('QUESTION:')[0]
    if ' answer is' in gen:
        gen = gen.strip().split(' answer is')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'Overall, ' in gen:
        gen = gen.strip().split('Overall, ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'Answer: ' in gen:
        gen = gen.strip().split('Answer: ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'Therefore, ' in gen:
        gen = gen.strip().split('Therefore, ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'Thus, ' in gen:
        gen = gen.strip().split('Thus, ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'So, ' in gen:
        gen = gen.strip().split('So, ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif 'So ' in gen:
        gen = gen.strip().split('So ')[-1].strip().strip(punctuation).strip()
        format_flag = True
    elif '\n' in gen:
        gen = gen.strip().split('\n')[-1].strip().strip(punctuation).strip()
    
    options = regex.findall(r'\([A-Z1-9]\)|[A-Z1-9]\)', gen)
    options_backup = regex.findall(r'[A-Z1-9]', gen)
    if options:
        options = [x for i, x in enumerate(options) if x not in options[i+1:]]
        prediction = options[-1].strip(punctuation)
    elif options_backup and format_flag:
        options = [x for i, x in enumerate(options_backup) if x not in options_backup[i+1:]]
        prediction = options[-1].strip(punctuation)
    else:
        options = regex.findall(r'\([A-Z1-9]\)|[A-Z1-9]\)', raw_gen)
        if options:
            options = [x for i, x in enumerate(options) if x not in options[i+1:]]
            prediction = options[-1].strip(punctuation)
        else:
            prediction = None
    return prediction

def eval_accu(prediction, option, answer):    
    if isinstance(prediction, str):
        return _extract_answer(prediction) == option.strip(punctuation)
    else:
        counter = defaultdict(int)
        for g in prediction:
            g = _extract_answer(g)
            if g is not None:
                counter[g] += 1
        try:
            return max(counter.items(), key=lambda x: x[1])[0] == option.strip(punctuation)
        except:
            return False

def load_raw_labels(dtype):
    raw_labels = load_jsonl(f'/home/users/nus/e0672129/scratch/csr/mcq_{dtype}_fulltest.jsonl')
    # raw_labels = load_jsonl(f'/home/users/nus/e0672129/scratch/csr/mcq_{dtype}_test.jsonl')
    raw_labels_dict = {}
    for i, dt in enumerate(raw_labels):
        qu = (dt['question'].strip(), dt['answer'])
        raw_labels_dict[qu] = dt['label']
    return raw_labels_dict


In [40]:
from mcts_rl.configs.constants import COT_INSTRUCTIONS, PROMPT_BEGIN, PROMPT_ASSISTANT, PROMPT_USER, SQA_PROMPT

SQA_PROMPT = SQA_PROMPT.replace('</s>\n\n', ' ').replace('</s>', '').strip()

def extract_pred_result(raw_pred, dtype='default'):
    predictions, lens = {}, []
    for dt in raw_pred:
        prompt = dt['prompt'][0].strip().replace(SQA_PROMPT, '').strip().replace(PROMPT_BEGIN, '').replace(PROMPT_USER, '').split(PROMPT_ASSISTANT)[0].strip()
        if dtype == 'mcts':
            generated = dt['generated'][-1][-1] if len(dt['generated']) else None
        else:
            generated = dt['generated'][0] if len(dt['generated']) == 1 else dt['generated']
        lens.append(len(dt['generated']))
        gt_answer = (dt['answer'], dt['answer_content'],)
        if prompt in predictions: continue
        predictions[prompt] = {'pred': generated, 'gt_answer': gt_answer}
    return predictions

def visualize_pred_result(predictions, N=int(1e5), dtype='csr', show_split=False):
    raw_labels_dict = load_raw_labels(dtype)
    accu = []
    if dtype == 'sqa':
        tsk_accu = {x:[] for x in ['openbook', 'arc_easy', 'arc_hard', 'ai2s_ele', 'ai2s_mid']}
    elif dtype == 'csr':
        tsk_accu = {x:[] for x in ['csqa', 'siqa', 'piqa']}
    for prompt, gens in predictions.items():
        sft_gen = gens['pred']
        _eval = eval_accu(sft_gen, gens['gt_answer'][0], gens['gt_answer'][1])
        tsk = prompt.replace('QUESTION:', '').strip().split(PROMPT_ASSISTANT)[0].strip()
        tsk = (tsk, gens['gt_answer'][0])
        if tsk not in raw_labels_dict:
            continue
        if raw_labels_dict[tsk] not in tsk_accu:
            continue
        accu.append(_eval)
        tsk_accu[raw_labels_dict[tsk]].append(_eval)
        if len(accu) >= N:
            break

    print('* all', sum(accu)/max(1, len(accu)), '({})'.format(len(accu)))
    if not show_split:
        return
    for k, v in tsk_accu.items():
        print(k, sum(v)/max(1, len(v)), '({})'.format(len(v)))


## SQA

In [15]:
N = 33120
dtype = 'sqa'
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/mcq/sqa-noptx/predictions/mcts-mistral-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/mcq/sqa-noptx/predictions/mcts-mistral-s3072.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)

* all 0.833810888252149 (1396)
* all 0.8223495702005731 (1396)


In [17]:
N = 33120
dtype = 'sqa'
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/sc20-arithmo.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)


* all 0.7927461139896373 (1737)


In [18]:
N = 40580
dtype = 'sqa'
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/few-shot-arithmo.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)


* all 0.805087782156933 (2791)


In [50]:
N = 17370
dtype = 'sqa'
print('=== baseline ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/mcq/sqa-noptx-instance/predictions/s2816.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=True,
)
print('=== online sc ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-sc-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
print('=== online mcts ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s1024.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    # show_split=True,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s2560.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s3072.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s3456.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/mistral-online-mcts-s3840.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    # show_split=True,
)
print('=== online mcts (random) ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/rd-mistral-online-mcts-s1024.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    # show_split=True,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/rd-mistral-online-mcts-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
print('======')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/mcq/sqa-noptx/predictions/mcts-mistral-s2048-fulltest.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=True,
)


=== baseline ===
* all 0.8265135199808566 (4179)
openbook 0.766 (500)
arc_easy 0.8769462581617278 (1991)
arc_hard 0.733201581027668 (1012)
ai2s_ele 0.8576779026217228 (267)
ai2s_mid 0.8655256723716381 (409)
=== online sc ===
* all 0.7829624312036373 (4179)
=== online mcts ===
* all 0.7793730557549653 (4179)
* all 0.797559224694903 (4179)
* all 0.7939698492462312 (4179)
* all 0.7032782962431203 (4179)
* all 0.7080314009661836 (3312)
* all 0.7429768358797437 (4058)
=== online mcts (random) ===
* all 0.8013878918401531 (4179)
* all 0.7881559942224362 (4154)
* all 0.8368030629337162 (4179)
openbook 0.762 (500)
arc_easy 0.880462079357107 (1991)
arc_hard 0.7490118577075099 (1012)
ai2s_ele 0.9176029962546817 (267)
ai2s_mid 0.8801955990220048 (409)


## CSR

In [4]:
N = 15570
dtype = 'csr'
show_split = True
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/mcq/csr-noptx/predictions/mcts-mistral-s5120-fulltest.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)


* all 0.748195669607057 (4988)
csqa 0.733005733005733 (1221)
siqa 0.6870466321243524 (1930)
piqa 0.8225367446924333 (1837)


In [5]:
N = 15570
dtype = 'csr'
show_split = False
print('=== online mcts ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-mcts-s1024.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-mcts-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-mcts-s3072.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-mcts-s3968.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=False,
)
print('=== online mcts (random) ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/rd-mistral-online-mcts-s1024.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/rd-mistral-online-mcts-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
print('=== offline mcts ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-offline-mcts-s1024.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-offline-mcts-s1792.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
print('=== online sc ===')
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-sc-s1536.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-sc-s2048.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)
visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/csr/predictions/mistral-online-sc-s2816.jsonl'),
        dtype=dtype,
    ),
    dtype=dtype,
    N=N,
    show_split=show_split,
)


=== online mcts ===
* all 0.6988773055332799 (4988)
* all 0.6864474739374499 (4988)
* all 0.6858460304731355 (4988)
* all 0.7139133921411387 (4988)
=== online mcts (random) ===
* all 0.714514835605453 (4988)
* all 0.720729751403368 (4988)
=== offline mcts ===
* all 0.6840417000801925 (4988)
* all 0.6955550591914228 (4477)
=== online sc ===
* all 0.6742181234963913 (4988)
* all 0.6807901517320355 (3493)
* all 0.6926201760324983 (4431)


# ARITHMO

In [25]:
import jsonlines
from string import punctuation

def load_jsonl(fname):
    with jsonlines.open(fname, mode='r') as reader:
        data = [l for l in reader]
    return data

In [26]:
from collections import defaultdict
from mcts_rl.utils import extract_answer, math_equal
from mcts_rl.configs.constants import COT_INSTRUCTIONS, PROMPT_BEGIN, PROMPT_ASSISTANT, PROMPT_USER

def extract_pred_result(raw_pred, dtype='default'):
    predictions, lens = {}, []
    for dt in raw_pred:
        prompt = dt['prompt'][0].strip().replace(PROMPT_BEGIN, '').replace(PROMPT_USER, '').replace(PROMPT_ASSISTANT, '').strip()
        if dtype == 'mcts':
            generated = dt['generated'][-1][-1] if len(dt['generated']) else None
        else:
            generated = dt['generated'][0] if len(dt['generated']) == 1 else dt['generated']
        lens.append(len(dt['generated']))
        gt_answer = (dt['answer'], dt['answer_content'],)
        if prompt in predictions: continue
        predictions[prompt] = {'pred': generated, 'gt_answer': gt_answer}

    return predictions

def extract_sc_answer(gens, use_code=False):
    preds = [extract_answer(g, use_code=use_code) for g in gens]
    counter = defaultdict(int)
    for p in preds:
        counter[p] += 1
    return max(counter.items(), key=lambda x: x[1])[0]

def visualize_pred_result(predictions, N=int(1e5), use_code=False):
    accu, errors, idx = [], [], -1
    for prompt, gens in predictions.items():
        idx += 1
        sft_gen = gens['pred']
        pred = extract_answer(sft_gen, use_code=use_code) if isinstance(sft_gen, str) else extract_sc_answer(sft_gen, use_code=use_code)
        correct = math_equal(pred, gens['gt_answer'][0])
        if '\nAnswer Choices: ' in prompt:
            options = prompt.split('\nAnswer Choices: ')[-1].replace('Write a Python program to solve this.', '').strip()
            option = options.split(f"({gens['gt_answer'][0]})")[-1].split(' (')[0].strip()
            correct = correct or (math_equal(pred, option))
            gens['option_content'] = option
        gens['index'] = idx
        gens['prompt'] = prompt
        gens['pred_rst'] = pred
        gens['correct'] = correct
        errors.append(gens)
        accu.append(correct)

    print('all', sum(accu[:N])/max(1, len(accu[:N])), '({})'.format(len(accu[:N])))
    return errors


## MATH

In [37]:
N = 1028
print('=== baseline ===')
base_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/math-sft.jsonl'),
    ),
    N=N,
)
print('=== online mcts ===')
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/math-wide-mistral-online-mcts-s1024.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/math-mistral-online-mcts-s1792.jsonl'),
    ),
    N=N,
)


=== baseline ===
all 0.316147859922179 (1028)
=== online mcts ===
all 0.3151750972762646 (1028)
all 0.3559322033898305 (177)


## AQuA

In [37]:
N = 1650
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/aqua-sft.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/aqua-mistral-online-sc-s1024.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/aqua-wide-mistral-online-mcts-s1024.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/sqa/predictions/aqua-rd-mistral-online-mcts-s1024.jsonl'),
    ),
    N=N,
)

all 0.48031496062992124 (254)
all 0.468503937007874 (254)
all 0.46062992125984253 (254)
all 0.41338582677165353 (254)


In [6]:
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/aqua-sft.jsonl'),
    ),
    N=N,
    use_code=True,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/aqua-mistral-online-mcts-s2560.jsonl'),
    ),
    N=N,
    use_code=True,
)


all 0.2637795275590551 (254)
all 0.2637795275590551 (254)


## GSM8K

In [4]:
N = 5240
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/sc-gsm8k-sft.jsonl'),
    ),
    N=N,
)

all 0.8582259287338894 (1319)


In [28]:
N = 5240
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/sft/diymistral-arithmo-lowerlr/predictions/gsm-cot.jsonl'),
    ),
    N=N,
)
print('=== online mcts ===')
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-online-mcts-s1024.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-online-mcts-s1536.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/wide-mistral-online-mcts-s1408.jsonl'),
    ),
    N=N,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-online-mcts-s1792.jsonl'),
    ),
    N=N,
)
print('=== online sc ===')
sc_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-online-sc-s1024.jsonl'),
    ),
    N=N,
)
sc_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-online-sc-s1536.jsonl'),
    ),
    N=N,
)
print('=== offline mcts ===')
off_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-offline-mcts-s1024.jsonl'),
    ),
    N=N,
)
off_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/mathqa/predictions/mistral-offline-mcts-s1280.jsonl'),
    ),
    N=N,
)


all 0.7589082638362395 (1319)
=== online mcts ===
all 0.7626990144048522 (1319)
all 0.7702805155420773 (1319)
all 0.756633813495072 (1319)
all 0.7687642153146323 (1319)
=== online sc ===
all 0.7672479150871873 (1319)
all 0.7647963105303612 (1301)
=== offline mcts ===
all 0.7634571645185747 (1319)
all 0.759666413949962 (1319)


In [19]:
N = 5630
print('=== baseline ===')
base_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/sft/diymistral-arithmo-lowerlr/predictions/gsm-pot.jsonl'),
    ),
    N=N,
    use_code=True,
)
print('=== online mcts ===')
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-mcts-s1024.jsonl'),
    ),
    N=N,
    use_code=True,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-mcts-s1536.jsonl'),
    ),
    N=N,
    use_code=True,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-mcts-s2048.jsonl'),
    ),
    N=N,
    use_code=True,
)
errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-mcts-s2560.jsonl'),
    ),
    N=N,
    use_code=True,
)
print('=== online sc ===')
sc_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-sc-s1024.jsonl'),
    ),
    N=N,
    use_code=True,
)
sc_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-online-sc-s1536.jsonl'),
    ),
    N=N,
    use_code=True,
)
print('=== offline mcts ===')
off_errors = visualize_pred_result(
    extract_pred_result(
        load_jsonl('/home/users/nus/e0672129/scratch/MCTS-DPO/outputs/experiments/code/predictions/mistral-offline-mcts-s1024.jsonl'),
    ),
    N=N,
    use_code=True,
)


=== baseline ===
all 0.7740712661106899 (1319)
=== online mcts ===
all 0.7672479150871873 (1319)
all 0.7687642153146323 (1319)
all 0.7558756633813495 (1319)
all 0.7498104624715694 (1319)
=== online sc ===
all 0.7748294162244125 (1319)
all 0.7793783169067475 (1319)
=== offline mcts ===
all 0.7604245640636846 (1319)
