# Gemini

In [54]:
from PIL import Image
import os
import json
from google import genai
from google.genai import types
import numpy as np
# load API key
GEMINI_API_KEY = "AIzaSyCxuedQMQWQkYSu67h3L5PMXDg3cDmeeBQ"
client = genai.Client(api_key=GEMINI_API_KEY)

from PIL import Image

def gemini_call_single(query, model):
    # import cv2
    import requests
    import time
    import json
    temp=0
    # print(query)

    while True:
        try:
            
            content = client.models.generate_content(
                model=model,
                contents=query,
                # config = types.GenerateContentConfig(
                #     temperature=temp,
                # )
            )
            

        except Exception as e_msg:
            content = '[ERROR] ' + str(e_msg)
 
        if isinstance(content, str):
            content = '[ERROR] ' + content.lower()
            if 'exceeded call rate limit' in content or 'exhausted' in content:
                # retry for unacceptable response
                print('\n(retry later in 5 seconds...) ->', content)
                if "thinking" in model:
                    time.sleep(10)
                else:
                    time.sleep(5)
                continue
            else:
                print('\n(retry later...) ->', content)
        elif content.text is None:
            temp += 0.1
            continue
        else:
            break

    
    ########################################
    
    # print(responseJson["choices"][0]["message"]["content"])
    return content.text


def extract_answer_from_model_response(model_response):
    import re
    match = re.search(r'\\boxed\{.*?\b([A-D]|yes|no)\b.*?\}', model_response)
    
    return match.group(1) if match else "Z"


# convert PIL image to base64
def pil_to_base64(pil_image):
    import io
    import base64
    img_byte_arr = io.BytesIO()
    pil_image.save(img_byte_arr, format='PNG')
    img_encoded_str = base64.b64encode(img_byte_arr.getvalue()).decode('ascii')
    return img_encoded_str



def test_gemini_on_VisSim_va(dataset_name, output_dir, model, index, max_tokens=2048, debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))

    
    output_path = f"./{model}_{dataset_name}_{index}.json"
    from collections import defaultdict
    acc_by_type = defaultdict(float)
    acc_by_difficulty = defaultdict(float)
    counts_by_difficulty = defaultdict(int)
    counts_by_type = defaultdict(int)

    # create output dir
    answer_dict= {}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        difficulty_level = example['difficulty_level']
        variant = " ".join(qid.split("_")[1:-1])
        # 'Observe the transformation pattern of Shape A through steps 0 to 1. <question_image> Apply the same transformation sequence to Shape B and determine the final shape at step 3. <image_for_B> For reference, the black dots in each panel of the figures indicate the origin. Select the correct answer choice that matches the expected transformation result. <answer_choices>'
        A_image = example['A_image']
        B_image = example['B_image']
        question_info = json.loads(example['question_info'])
        question = question_info['question']
        choice_image = example['choices']
        query = []
        prefix, question = question.strip().split("<question_image>")
        query.append(prefix)
        query.append(A_image)
        prefix, question = question.split("<image_for_B>")
        query.append(prefix)
        query.append(B_image)
        prefix, question = question.split("<answer_choices>")   
        query.append(prefix)
        query.append(choice_image)
        if len(question) > 0:
            query.append(question)
        query.append("Please first solve the problem step by step, then put your final answer or a single letter (if it is a multiple choice question) in one \"\\boxed{}\"")
        # print(query)
        response = gemini_call_single(query, model=model)

        pred = extract_answer_from_model_response(response)

        gt_ans = example['answer']

        answer_dict[qid] ={
            "pred": pred,
            "gt_ans": gt_ans.lower(),
            "response": response
        }
            
        if "ans" not in variant:
            acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
            counts_by_difficulty[difficulty_level] += 1
        acc_by_type[variant] += pred.lower() == gt_ans.lower()
        counts_by_type[variant] += 1

        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)

    # print accuracy
    print(dataset, index)
    print("Accuracy by difficulty level:")
    for k, v in acc_by_difficulty.items():
        print(f"{k}: {v/counts_by_difficulty[k]}")

    print("Accuracy by variants:")
    for k, v in acc_by_type.items():
        print(f"{k}: {v/counts_by_type[k]}")
    
    overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
    print(f"Overall accuracy: {overall_acc}")
    np.save(output_path, answer_dict)

def test_gemini_on_VisSim_text_inst(dataset_name, output_dir, model, index, max_tokens=2048,debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))
   


    from collections import defaultdict
    acc_by_type = defaultdict(float)
    acc_by_difficulty = defaultdict(float)
    counts_by_difficulty = defaultdict(int)
    counts_by_type = defaultdict(int)

    output_path  = f"./{model}_{dataset_name}_{index}.json"
    if os.path.exists(output_path+".npy"):
        answer_dict = np.load(output_path+".npy", allow_pickle=True).item()
    else:
        answer_dict={}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        if qid in answer_dict:
            continue
        difficulty_level = example['difficulty_level']
        variant = " ".join(qid.split("_")[1:-1])
        
        images = example['images'][:-1]
        # question_info = json.loads(example['question_info'])
        question = example['question']
        choice_image = example['choices']

        # use regex to parse the question and place the images in the right spots
        query = []
        for i, image in enumerate(images):
            if i == 0:
                prefix, question = question.strip().split("<shapeB_image>")
            else:
                prefix, question = question.split(f"<shapeB_step_{i-1}>")
            query.append(prefix)
            query.append(image)
        
        # replace the remaining <shapeB_image> with "" using regex
        import re
        # using wildcards to match the <shapeB_step_{i}> and replace it with ""
        query.append(re.sub(r'<shapeB_step_\d+>', '', question))

        query.append(choice_image)
        
        query.append("Please first solve the problem step by step, then put your final answer or a single letter (if it is a multiple choice question) in one \"\\boxed{}\"")
        # print(query)
        response = gemini_call_single(query, model=model)

        pred = extract_answer_from_model_response(response)

        gt_ans = example['answer']

        answer_dict[qid]={
            "pred": pred,
            "gt_ans": gt_ans.lower(),
            "response": response
        }
        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)
        if "ans" not in variant:
            acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
            counts_by_difficulty[difficulty_level] += 1
        acc_by_type[variant] += pred.lower() == gt_ans.lower()
        counts_by_type[variant] += 1

    # print accuracy
    np.save(output_path, answer_dict)
    print("Accuracy by difficulty level:")
    for k, v in acc_by_difficulty.items():
        print(f"{k}: {v/counts_by_difficulty[k]}")

    print("Accuracy by variants:")
    for k, v in acc_by_type.items():
        print(f"{k}: {v/counts_by_type[k]}")
    
    overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
    print(f"Overall accuracy: {overall_acc}")



def test_gemini_on_folding_nets(dataset_name, output_dir, model, index, max_tokens=2048,debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))


    from collections import defaultdict
    pred_by_type = defaultdict(list)
    pred_by_difficulty = defaultdict(list)

    gt_by_type = defaultdict(list)
    gt_by_difficulty = defaultdict(list)

    
    output_path  = f"./{model}_{dataset_name}_{index}.json"
    if os.path.exists(output_path+".npy"):
        answer_dict = np.load(output_path+".npy", allow_pickle=True).item()
    else:
        answer_dict={}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        if qid in answer_dict:
            continue
        variant = example['type']
  
        images = example['images']
        # question_info = json.loads(example['question_info'])
        question = example['question']

        # use regex to parse the question and place the images in the right spots
        query = []
        for i, image in enumerate(images):
            prefix, question = question.split(f"<image_{i}>")
            query.append(prefix)
            query.append(image)
        if len(question) > 0:
            query.append(question + "Think step-by-step, and then put your final answer in \"\\boxed{}\".")
        else:
            query.append("Think step-by-step, and then put your final answer in \"\\boxed{}\".")
        
        # check if the query has at least 1 image after parsing

        # print(query)
        response = gemini_call_single(query, model=model)

        pred = extract_answer_from_model_response(response.lower())

        gt_choice = example['answer'].lower()
        answer_choices = example['choices']

        gt_ans = answer_choices[int(ord(gt_choice) - ord('a'))]

        answer_dict[ qid]={
            "pred": pred,
            "gt_choice": gt_choice,
            "gt_ans": gt_ans.lower(),
            "response": response,
            "question": [q  if isinstance(q, str) else '<image>' for q in query]
        }
        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)
        correct = pred.lower() == gt_ans.lower() or pred.lower() == gt_choice.lower()
        pred_by_type[variant].append(pred.lower())
        gt_by_type[variant].append(gt_ans.lower())

    np.save(output_path, answer_dict)
    print("F1 by variants:")
    for k in pred_by_type.keys():
        from sklearn.metrics import f1_score
        print(f"{k}: {f1_score(gt_by_type[k], pred_by_type[k], average='weighted')}")

    print("Random Chance F1 by variants:")
    for k in pred_by_type.keys():
        from sklearn.metrics import f1_score
        import random
        random_pred = [random.choice(["yes", "no"]) for _ in range(len(gt_by_type[k]))]
        print(f"{k}: {f1_score(gt_by_type[k],random_pred, average='weighted')}")


model = 'gemini-2.0-flash-thinking-exp-01-21'


def test_gemini_flash_thinking():
    model = 'gemini-2.0-flash-thinking-exp-01-21'
    # test_gemini_on_folding_nets('folding_nets_test', 'output_dir/gemini_response', model=model, debug=False)
    # test_gemini_on_folding_nets('tangram_puzzle_test', 'output_dir/gemini_response', model=model, debug=False)
    # test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=0, debug=False)
    test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=1, debug=False)
    test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=2, debug=False)
    # test_gemini_on_VisSim_text_inst('2d_text_instruct_test', 'output_dir/gemini_response', model=model, debug=False)
    # test_gemini_on_folding_nets('folding_nets_vissim_test', 'output_dir/gemini_response', model=model, debug=debug)
    # test_gemini_on_folding_nets('tangram_puzzle_vissim_test', 'output_dir/gemini_response', model=model, debug=debug)
    # test_gemini_on_VisSim_va('2d_va_vissim_test', 'output_dir/gemini_response', model=model, debug=debug)
    # test_gemini_on_VisSim_text_inst('2d_text_instruct_vissim_test', 'output_dir/gemini_response', model=model, debug=debug)


In [None]:
test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=0, debug=False)
test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=1, debug=False)
test_gemini_on_VisSim_va('2d_va_test', '.', model=model, index=2, debug=False)

In [23]:
model = 'gemini-2.0-flash-thinking-exp-01-21'
dataset_name = '2d_va_test'
index= 0
import numpy as np
answer_dict = np.load("gemini-2.0-flash-thinking-exp-01-21_2d_va_test_2.json.npy", allow_pickle=True).item() 
from datasets import load_dataset
dataset = load_dataset(f"VisSim/{dataset_name}")
dataset = dataset['train']
d_index = list(range(len(dataset)))


output_path = f"./{model}_{dataset_name}_{index}.json"
from collections import defaultdict
acc_by_type = defaultdict(float)
acc_by_difficulty = defaultdict(float)
counts_by_difficulty = defaultdict(int)
counts_by_type = defaultdict(int)

from tqdm import tqdm
for i in tqdm(d_index, total=len(d_index)):
    example = dataset[i]
    qid = example['qid']
    difficulty_level = example['difficulty_level']
    gt = example['answer']
    pred = answer_dict[qid]['pred']
    variant = " ".join(qid.split("_")[1:-1])
    gt_ans = example['answer']
    if pred=="Z":
        print(answer_dict[qid]['response'])
    if "ans" not in variant:
        acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
        counts_by_difficulty[difficulty_level] += 1
    acc_by_type[variant] += pred.lower() == gt_ans.lower()
    counts_by_type[variant] += 1

    # print accuracy
print(dataset, index)
print("Accuracy by difficulty level:")
for k, v in acc_by_difficulty.items():
    print(f"{k}: {v/counts_by_difficulty[k]}")

print("Accuracy by variants:")
for k, v in acc_by_type.items():
    print(f"{k}: {v/counts_by_type[k]}")

overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
print(f"Overall accuracy: {overall_acc}")


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 306/306 [00:09<00:00, 30.60it/s]

Dataset({
    features: ['qid', 'A_image', 'B_image', 'choices', 'answer', 'transformations', 'difficulty_level', 'question_info', 'answer_info'],
    num_rows: 306
}) 0
Accuracy by difficulty level:
easy: 0.6274509803921569
medium: 0.5392156862745098
hard: 0.46078431372549017
Accuracy by variants:
easy no: 0.6274509803921569
medium no: 0.5392156862745098
hard no: 0.46078431372549017
Overall accuracy: 0.5424836601307189





In [53]:
test_gemini_on_folding_nets('folding_nets_test', '.', model=model, index=0, debug=False)
test_gemini_on_folding_nets('folding_nets_test', '.', model=model, index=1, debug=False)
test_gemini_on_folding_nets('folding_nets_test', '.', model=model, index=2, debug=False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 193/193 [00:00<00:00, 221.09it/s]


F1 by variants:
Random Chance F1 by variants:


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 193/193 [58:39<00:00, 18.24s/it]


F1 by variants:
q_only: 0.5279663665028548
q+steps: 0.5031020439234194
Random Chance F1 by variants:
q_only: 0.4840389962341182
q+steps: 0.5893628865913032


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 193/193 [1:03:04<00:00, 19.61s/it]

F1 by variants:
q_only: 0.4334574942769565
q+steps: 0.4535728811161606
Random Chance F1 by variants:
q_only: 0.5332706974687741
q+steps: 0.38185920030006676





In [59]:
model = 'gemini-2.0-flash-thinking-exp-01-21'
dataset_name = 'folding_nets_test'
index= 0
import numpy as np
answer_dict = np.load("gemini-2.0-flash-thinking-exp-01-21_folding_nets_test_2.json.npy", allow_pickle=True).item() 
print(len(answer_dict))
from datasets import load_dataset
dataset = load_dataset(f"VisSim/{dataset_name}")
dataset = dataset['train']
d_index = list(range(len(dataset)))

from collections import defaultdict
pred_by_type = defaultdict(list)
pred_by_difficulty = defaultdict(list)

gt_by_type = defaultdict(list)
gt_by_difficulty = defaultdict(list)

output_path = output_path = f"./{model}_{dataset_name}_{index}.json"
from tqdm import tqdm
for i in tqdm(d_index, total=len(d_index)):
    example = dataset[i]
    qid = example['qid']
    variant = example['type']

    
    pred = answer_dict[qid]['pred']
    if pred == "Z":
        print(qid, pred)

    gt_choice = example['answer'].lower()
    answer_choices = example['choices']

    gt_ans = answer_choices[int(ord(gt_choice) - ord('a'))]

    ect = pred.lower() == gt_ans.lower() or pred.lower() == gt_choice.lower()
    pred_by_type[variant].append(pred.lower())
    gt_by_type[variant].append(gt_ans.lower())

print("F1 by variants:")
for k in pred_by_type.keys():
    from sklearn.metrics import f1_score
    print(f"{k}: {f1_score(gt_by_type[k], pred_by_type[k], average='weighted')}")

print("Random Chance F1 by variants:")
for k in pred_by_type.keys():
    from sklearn.metrics import f1_score
    import random
    random_pred = [random.choice(["yes", "no"]) for _ in range(len(gt_by_type[k]))]
    print(f"{k}: {f1_score(gt_by_type[k],random_pred, average='weighted')}")





193


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 193/193 [00:00<00:00, 222.44it/s]

F1 by variants:
q_only: 0.4334574942769565
q+steps: 0.4535728811161606
Random Chance F1 by variants:
q_only: 0.4583333333333333
q+steps: 0.5953698795220393





In [None]:
test_gemini_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=0, debug=False)
test_gemini_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=1, debug=False)
test_gemini_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=2, debug=False)

  4%|████▉                                                                                                           | 14/318 [02:48<40:33,  8.01s/it]


(retry later in 5 seconds...) -> [ERROR] [error] 429 resource_exhausted. {'error': {'code': 429, 'message': 'resource has been exhausted (e.g. check quota).', 'status': 'resource_exhausted'}}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [35:19<00:00,  6.67s/it]


Accuracy by difficulty level:
easy: 0.6964285714285714
medium: 0.7758620689655172
hard: 0.7169811320754716
Accuracy by variants:
no: 0.7305389221556886
Overall accuracy: 0.7305389221556886


 31%|██████████████████████████████████▌                                                                             | 98/318 [19:08<27:16,  7.44s/it]


(retry later in 5 seconds...) -> [ERROR] [error] 429 resource_exhausted. {'error': {'code': 429, 'message': 'resource has been exhausted (e.g. check quota).', 'status': 'resource_exhausted'}}

(retry later in 5 seconds...) -> [ERROR] [error] 429 resource_exhausted. {'error': {'code': 429, 'message': 'resource has been exhausted (e.g. check quota).', 'status': 'resource_exhausted'}}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [36:31<00:00,  6.89s/it]


Accuracy by difficulty level:
easy: 0.75
medium: 0.7758620689655172
hard: 0.5660377358490566
Accuracy by variants:
no: 0.7005988023952096
Overall accuracy: 0.7005988023952096


 19%|█████████████████████▍                                                                                          | 61/318 [11:01<55:09, 12.88s/it]

In [76]:

model = 'gemini-2.0-flash-thinking-exp-01-21'
dataset_name = '2d_text_instruct_test'
import numpy as np
answer_dict = np.load(f"gemini-2.0-flash-thinking-exp-01-21_{dataset_name}_2.json.npy", allow_pickle=True).item() 

from datasets import load_dataset
dataset = load_dataset(f"VisSim/{dataset_name}")
dataset = dataset['train']
d_index = list(range(len(dataset)))



from collections import defaultdict
acc_by_type = defaultdict(float)
acc_by_difficulty = defaultdict(float)
counts_by_difficulty = defaultdict(int)
counts_by_type = defaultdict(int)


from tqdm import tqdm
for i in tqdm(d_index, total=len(d_index)):
    example = dataset[i]
    qid = example['qid']
    
    difficulty_level = example['difficulty_level']
    variant = " ".join(qid.split("_")[1:-1])
    
    gt_ans = example['answer']

    pred = answer_dict[qid]['pred']
    if pred == "Z":
        print(qid, pred, answer_dict[qid]['respone'])
        
    
 
    if "ans" not in variant:
        acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
        counts_by_difficulty[difficulty_level] += 1
    acc_by_type[variant] += pred.lower() == gt_ans.lower()
    counts_by_type[variant] += 1

# print accuracy
np.save(output_path, answer_dict)
print("Accuracy by difficulty level:")
for k, v in acc_by_difficulty.items():
    print(f"{k}: {v/counts_by_difficulty[k]}")

print("Accuracy by variants:")
for k, v in acc_by_type.items():
    print(f"{k}: {v/counts_by_type[k]}")

overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
print(f"Overall accuracy: {overall_acc}")



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 318/318 [00:05<00:00, 60.22it/s]


Accuracy by difficulty level:
easy: 0.4298245614035088
medium: 0.5277777777777778
hard: 0.4166666666666667
Accuracy by variants:
no: 0.4591194968553459
Overall accuracy: 0.4591194968553459


In [None]:
test_gemini_on_folding_nets('tangram_puzzle_test', '.', model=model, index=0, debug=False)
test_gemini_on_folding_nets('tangram_puzzle_test', '.', model=model, index=1, debug=False)
test_gemini_on_folding_nets('tangram_puzzle_test', '.', model=model, index=2, debug=False)

In [73]:
model = 'gemini-2.0-flash-thinking-exp-01-21'
dataset_name = 'tangram_puzzle_test'
index= 0
import numpy as np
answer_dict = np.load(f"gemini-2.0-flash-thinking-exp-01-21_{dataset_name}_0.json.npy", allow_pickle=True).item() 
print(len(answer_dict))
from datasets import load_dataset
dataset = load_dataset(f"VisSim/{dataset_name}")
dataset = dataset['train']
d_index = list(range(len(dataset)))

from collections import defaultdict
pred_by_type = defaultdict(list)
pred_by_difficulty = defaultdict(list)

gt_by_type = defaultdict(list)
gt_by_difficulty = defaultdict(list)

output_path = output_path = f"./{model}_{dataset_name}_{index}.json"
from tqdm import tqdm
for i in tqdm(d_index, total=len(d_index)):
    example = dataset[i]
    qid = example['qid']
    variant = example['type']

    
    pred = answer_dict[qid]['pred']
    if pred == "Z":
        print(qid, pred, answer_dict[qid]['response'])

    gt_choice = example['answer'].lower()
    answer_choices = example['choices']

    gt_ans = answer_choices[int(ord(gt_choice) - ord('a'))]

    ect = pred.lower() == gt_ans.lower() or pred.lower() == gt_choice.lower()
    pred_by_type[variant].append(pred.lower())
    gt_by_type[variant].append(gt_ans.lower())

print("F1 by variants:")
for k in pred_by_type.keys():
    from sklearn.metrics import f1_score
    print(f"{k}: {f1_score(gt_by_type[k], pred_by_type[k], average='weighted')}")

print("Random Chance F1 by variants:")
for k in pred_by_type.keys():
    from sklearn.metrics import f1_score
    import random
    random_pred = [random.choice(["yes", "no"]) for _ in range(len(gt_by_type[k]))]
    print(f"{k}: {f1_score(gt_by_type[k],random_pred, average='weighted')}")





376


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 376/376 [00:06<00:00, 54.89it/s]

F1 by variants:
q+steps: 0.5319283599776713
q_only: 0.660537790245632
Random Chance F1 by variants:
q+steps: 0.5054152815698955
q_only: 0.510852133005453





# Claude

In [6]:

import anthropic
import json
import matplotlib.pyplot as plt
client = anthropic.Anthropic()
import base64
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
def pil_to_base64(pil_image):
    img_byte_arr = io.BytesIO()
    pil_image.save(img_byte_arr, format='PNG')
    img_encoded_str = base64.b64encode(img_byte_arr.getvalue()).decode('ascii')
    return img_encoded_str


def claude_call_single(query, model):
    message = client.messages.create(
            model= model,
            max_tokens=1024,
            
            messages=[
                {
                    "role": "user",
                    "content": query
                   
                }
            ],
        )
    return message.content[0].text


def extract_answer_from_model_response(model_response):
    import re
    match = re.search(r'\\boxed\{.*?\b([A-D]|yes|no)\b.*?\}', model_response)
    
    return match.group(1) if match else "Z"


# convert PIL image to base64
def pil_to_base64(pil_image):
    import io
    import base64
    img_byte_arr = io.BytesIO()
    pil_image.save(img_byte_arr, format='PNG')
    img_encoded_str = base64.b64encode(img_byte_arr.getvalue()).decode('ascii')
    return img_encoded_str



def test_claude_on_VisSim_va(dataset_name, output_dir, model, index, max_tokens=2048, debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))

    
    output_path = f"./{model}_{dataset_name}_{index}.json"
    from collections import defaultdict
    acc_by_type = defaultdict(float)
    acc_by_difficulty = defaultdict(float)
    counts_by_difficulty = defaultdict(int)
    counts_by_type = defaultdict(int)

    # create output dir
    answer_dict= {}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        difficulty_level = example['difficulty_level']
        variant = " ".join(qid.split("_")[1:-1])
        # 'Observe the transformation pattern of Shape A through steps 0 to 1. <question_image> Apply the same transformation sequence to Shape B and determine the final shape at step 3. <image_for_B> For reference, the black dots in each panel of the figures indicate the origin. Select the correct answer choice that matches the expected transformation result. <answer_choices>'
        A_image = example['A_image']
        B_image = example['B_image']
        question_info = json.loads(example['question_info'])
        question = question_info['question']
        choice_image = example['choices']
        query = []
        prefix, question = question.strip().split("<question_image>")
        query.append({
                            "type": "text",
                            "text": prefix
                        })
        query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(A_image),
                            }})
        prefix, question = question.split("<image_for_B>")
        query.append({
                            "type": "text",
                            "text": prefix
                        })
        query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(B_image),
                            }})
        prefix, question = question.split("<answer_choices>")   
        query.append({
                            "type": "text",
                            "text": prefix
                        })
        query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(choice_image),
                            }})
        if len(question) > 0:
            query.append({
                            "type": "text",
                            "text": question
                        })
        query.append({
                            "type": "text",
                            "text": "Please first solve the problem step by step, then put your final answer or a single letter (if it is a multiple choice question) in one \"\\boxed{}\""
                        })
        
        # print(query)
        response = claude_call_single(query, model=model)

        pred = extract_answer_from_model_response(response)

        gt_ans = example['answer']

        answer_dict[qid] ={
            "pred": pred,
            "gt_ans": gt_ans.lower(),
            "response": response
        }
        if "ans" not in variant:
            acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
            counts_by_difficulty[difficulty_level] += 1
        acc_by_type[variant] += pred.lower() == gt_ans.lower()
        counts_by_type[variant] += 1

        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)

    # print accuracy
    print(dataset, index)
    print("Accuracy by difficulty level:")
    for k, v in acc_by_difficulty.items():
        print(f"{k}: {v/counts_by_difficulty[k]}")

    print("Accuracy by variants:")
    for k, v in acc_by_type.items():
        print(f"{k}: {v/counts_by_type[k]}")
    
    overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
    print(f"Overall accuracy: {overall_acc}")
    np.save(output_path, answer_dict)

def test_claude_on_VisSim_text_inst(dataset_name, output_dir, model, index, max_tokens=2048,debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))
   


    from collections import defaultdict
    acc_by_type = defaultdict(float)
    acc_by_difficulty = defaultdict(float)
    counts_by_difficulty = defaultdict(int)
    counts_by_type = defaultdict(int)

    output_path  = f"./{model}_{dataset_name}_{index}.json"
    if os.path.exists(output_path+".npy"):
        answer_dict = np.load(output_path+".npy", allow_pickle=True).item()
    else:
        answer_dict={}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        if qid in answer_dict:
            continue
        difficulty_level = example['difficulty_level']
        variant = " ".join(qid.split("_")[1:-1])
        
        images = example['images'][:-1]
        # question_info = json.loads(example['question_info'])
        question = example['question']
        choice_image = example['choices']

        # use regex to parse the question and place the images in the right spots
        query = []
        for i, image in enumerate(images):
            if i == 0:
                prefix, question = question.strip().split("<shapeB_image>")
            else:
                prefix, question = question.split(f"<shapeB_step_{i-1}>")
            query.append({
                            "type": "text",
                            "text": prefix
                        })
            query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(image),
                            }})
        
        # replace the remaining <shapeB_image> with "" using regex
        import re
        # using wildcards to match the <shapeB_step_{i}> and replace it with ""
        uery.append({
                            "type": "text",
                            "text": re.sub(r'<shapeB_step_\d+>', '', question)
                        })
        query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(choice_image),
                            }})
        
        query.append({
                            "type": "text",
                            "text": "Please first solve the problem step by step, then put your final answer or a single letter (if it is a multiple choice question) in one \"\\boxed{}\""
                        })
        response = claude_call_single(query, model=model)

        pred = extract_answer_from_model_response(response)

        gt_ans = example['answer']

        answer_dict[qid]={
            "pred": pred,
            "gt_ans": gt_ans.lower(),
            "response": response
        }
        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)
        if "ans" not in variant:
            acc_by_difficulty[difficulty_level]+= pred.lower() == gt_ans.lower()
            counts_by_difficulty[difficulty_level] += 1
        acc_by_type[variant] += pred.lower() == gt_ans.lower()
        counts_by_type[variant] += 1

    # print accuracy
    np.save(output_path, answer_dict)
    print("Accuracy by difficulty level:")
    for k, v in acc_by_difficulty.items():
        print(f"{k}: {v/counts_by_difficulty[k]}")

    print("Accuracy by variants:")
    for k, v in acc_by_type.items():
        print(f"{k}: {v/counts_by_type[k]}")
    
    overall_acc = sum(acc_by_difficulty.values())/sum(counts_by_difficulty.values())
    print(f"Overall accuracy: {overall_acc}")



def test_claude_on_folding_nets(dataset_name, output_dir, model, index, max_tokens=2048,debug=False):

    from datasets import load_dataset
    dataset = load_dataset(f"VisSim/{dataset_name}")
    dataset = dataset['train']
    d_index = list(range(len(dataset)))


    from collections import defaultdict
    pred_by_type = defaultdict(list)
    pred_by_difficulty = defaultdict(list)

    gt_by_type = defaultdict(list)
    gt_by_difficulty = defaultdict(list)

    
    output_path  = f"./{model}_{dataset_name}_{index}.json"
    if os.path.exists(output_path+".npy"):
        answer_dict = np.load(output_path+".npy", allow_pickle=True).item()
    else:
        answer_dict={}
    from tqdm import tqdm
    for i in tqdm(d_index, total=len(d_index)):
        example = dataset[i]
        qid = example['qid']
        if qid in answer_dict:
            continue
        variant = example['type']
  
        images = example['images']
        # question_info = json.loads(example['question_info'])
        question = example['question']

        # use regex to parse the question and place the images in the right spots
        query = []
        for i, image in enumerate(images):
            prefix, question = question.split(f"<image_{i}>")
            query.append({
                            "type": "text",
                            "text": prefix
                        })
            query.append({
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": pil_to_base64(image),
                            }})
        if len(question) > 0:
            query.append({
                            "type": "text",
                            "text": question + "Think step-by-step, and then put your final answer in \"\\boxed{}\"."
                        })
        else:
            query.append({
                            "type": "text",
                            "text": "Think step-by-step, and then put your final answer in \"\\boxed{}\"."
                        })
        
        
        response = claude_call_single(query, model=model)

        pred = extract_answer_from_model_response(response.lower())

        gt_choice = example['answer'].lower()
        answer_choices = example['choices']

        gt_ans = answer_choices[int(ord(gt_choice) - ord('a'))]

        answer_dict[ qid]={
            "pred": pred,
            "gt_choice": gt_choice,
            "gt_ans": gt_ans.lower(),
            "response": response,
            "question": [q  if isinstance(q, str) else '<image>' for q in query]
        }
        if len(answer_dict)%10==0:
            np.save(output_path, answer_dict)
        correct = pred.lower() == gt_ans.lower() or pred.lower() == gt_choice.lower()
        pred_by_type[variant].append(pred.lower())
        gt_by_type[variant].append(gt_ans.lower())

    np.save(output_path, answer_dict)
    print("F1 by variants:")
    for k in pred_by_type.keys():
        from sklearn.metrics import f1_score
        print(f"{k}: {f1_score(gt_by_type[k], pred_by_type[k], average='weighted')}")

    print("Random Chance F1 by variants:")
    for k in pred_by_type.keys():
        from sklearn.metrics import f1_score
        import random
        random_pred = [random.choice(["yes", "no"]) for _ in range(len(gt_by_type[k]))]
        print(f"{k}: {f1_score(gt_by_type[k],random_pred, average='weighted')}")


model = "claude-3-5-sonnet-20241022"
import os


In [7]:
test_claude_on_VisSim_va('2d_va_test', '.', model=model, index=0, debug=False)
test_claude_on_VisSim_va('2d_va_test', '.', model=model, index=1, debug=False)
test_claude_on_VisSim_va('2d_va_test', '.', model=model, index=2, debug=False)

100%|███████████████████████████████████████████████████| 306/306 [39:00<00:00,  7.65s/it]


Dataset({
    features: ['qid', 'A_image', 'B_image', 'choices', 'answer', 'transformations', 'difficulty_level', 'question_info', 'answer_info'],
    num_rows: 306
}) 0
Accuracy by difficulty level:
easy: 0.7549019607843137
medium: 0.6666666666666666
hard: 0.6274509803921569
Accuracy by variants:
easy no: 0.7549019607843137
medium no: 0.6666666666666666
hard no: 0.6274509803921569
Overall accuracy: 0.6830065359477124


100%|███████████████████████████████████████████████████| 306/306 [40:11<00:00,  7.88s/it]


Dataset({
    features: ['qid', 'A_image', 'B_image', 'choices', 'answer', 'transformations', 'difficulty_level', 'question_info', 'answer_info'],
    num_rows: 306
}) 1
Accuracy by difficulty level:
easy: 0.803921568627451
medium: 0.6862745098039216
hard: 0.6176470588235294
Accuracy by variants:
easy no: 0.803921568627451
medium no: 0.6862745098039216
hard no: 0.6176470588235294
Overall accuracy: 0.7026143790849673


100%|███████████████████████████████████████████████████| 306/306 [38:17<00:00,  7.51s/it]

Dataset({
    features: ['qid', 'A_image', 'B_image', 'choices', 'answer', 'transformations', 'difficulty_level', 'question_info', 'answer_info'],
    num_rows: 306
}) 2
Accuracy by difficulty level:
easy: 0.7352941176470589
medium: 0.6470588235294118
hard: 0.6176470588235294
Accuracy by variants:
easy no: 0.7352941176470589
medium no: 0.6470588235294118
hard no: 0.6176470588235294
Overall accuracy: 0.6666666666666666





In [None]:
test_claude_on_folding_nets('tangram_puzzle_test', '.', model=model, index=0, debug=False)
test_claude_on_folding_nets('tangram_puzzle_test', '.', model=model, index=1, debug=False)
test_claude_on_folding_nets('tangram_puzzle_test', '.', model=model, index=2, debug=False)

100%|███████████████████████████████████████████████████| 376/376 [03:46<00:00,  1.66it/s]


F1 by variants:
q_only: 0.7605042016806722
q+steps: 0.4155844155844156
Random Chance F1 by variants:
q_only: 0.6666666666666666
q+steps: 0.253968253968254


100%|███████████████████████████████████████████████████| 376/376 [48:40<00:00,  7.77s/it]


F1 by variants:
q+steps: 0.41326203462125793
q_only: 0.7231833910034601
Random Chance F1 by variants:
q+steps: 0.5091154606588453
q_only: 0.4995475113122172


 96%|█████████████████████████████████████████████████  | 362/376 [47:18<01:55,  8.22s/it]

In [None]:
test_claude_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=0, debug=False)
test_claude_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=1, debug=False)
test_claude_on_VisSim_text_inst('2d_text_instruct_test',  '.', model=model,index=2, debug=False)

In [None]:
test_claude_on_folding_nets('folding_nets_test', '.', model=model, index=0, debug=False)
test_claude_on_folding_nets('folding_nets_test', '.', model=model, index=1, debug=False)
test_claude_on_folding_nets('folding_nets_test', '.', model=model, index=2, debug=False)