In [1]:
from src.tasks.execution.crux_execution import CruxEnv
from pydantic import BaseModel
from tqdm import tqdm
from py2cfg import CFGBuilder
import base64
import os
from openai import AzureOpenAI
import anthropic

# IternVL2
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import time

class config(BaseModel):
    limit: int = 250
    subset: str = "input"


In [2]:
client = anthropic.Anthropic(api_key="")  
def run(prompt):
    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        system="You are a helpful assistant.",
        messages=[
            {"role": "user", "content": prompt}
        ],
        max_tokens=2048
    )

    return response.content[0].text

In [3]:
def code_to_image(code: str, path: str):
    with open(path, 'rb') as f:
        image_data = f.read()
    return base64.b64encode(image_data).decode('utf-8')

def make_visual_cot_input_prompt(s):
    code, output = s
    return f"""You will be given a function `f`, its Control Flow Graph (CFG), and an output in the form `f(??) == output`. Your task is to find any input such that executing `f` on the input leads to the given output. There may be multiple answers, but only output one. First, analyze the function code and use the CFG to guide your reasoning about possible execution paths. You MUST surround the answer with [ANSWER] and [/ANSWER] tags. Express your answer as a passing assertion containing the input and the given output.

[PYTHON]
def f(x):
    while x > 10:
        if x % 2 == 0:
            x -= 2
        else:
            x -= 4
    return x
    
assert f(??) == 3
[/PYTHON]

[THOUGHT]
To determine the input `??` such that `f(??) == 17`, we can use both the plain code and the CFG.

1. **Code Analysis**: The function `f(x)` has a conditional statement that checks whether `x > 10`. If `x` is greater than 10, the function returns `x + 1`. Otherwise, it returns `x - 1`.
  
2. **CFG Insights**: The Control Flow Graph (CFG) illustrates two possible paths:
   - **True Branch (`T`)**: If `x > 10`, the path leads to the operation `return x + 1`.
   - **False Branch (`F`)**: If `x <= 10`, the path leads to the operation `return x - 1`.

3. **Path Consideration**:
   - **True Branch Analysis**: For the condition `x > 10`, the function returns `x + 1`. To satisfy `f(??) == 3`, we need `x + 1 = 17`. Solving for `x`, we get `x = 16`.
   - **False Branch Analysis**: For the condition `x <= 10`, the function returns `x - 1`. To satisfy `f(??) == 3`, we would need `x - 1 = 17`, which gives `x = 18`. However, this contradicts the branch condition because 18 is not less than or equal to 10. Therefore, this branch cannot produce the desired output.

4. **Conclusion**: Based on the CFG and the code analysis, the only valid input is `x = 16`, which lies on the True Branch.
[/THOUGHT]
[ANSWER]
assert f(16) == 17
[/ANSWER]

[PYTHON]
{code}
assert f(??) == {output}
[/PYTHON]

[THOUGHT]
"""

def visual_run(prompt, image_data):
    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        system="You are a helpful assistant.",
        messages=[
        { "role": "user", "content": [  
            { 
                "type": "text", 
                "text": "You are given a control flow graph image of a code snippet, utilize them in code execution reasoning process. " + prompt, 
            },
            {
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": image_data
                }
            },
        ] }
    ],
        max_tokens=2048 
    )

    return response.content[0].text

In [None]:
confi = config()
env = CruxEnv(confi)
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    print(f'Index: {idx}')
    print(env.problem['code'])
    print("_______________________")

In [None]:
confi = config()
env = CruxEnv(confi)
fail_case = {}
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    prompt = env.get_problem_statement()
    try:
        solution = run(prompt) 
    except Exception as e:
        print(e)
        error = True
        solution = ""
        time.sleep(20)
        pass
    correct = env.check_solution(solution)
    if not correct:
        print(idx)
        fail_case[idx] = solution
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

In [4]:
from __future__ import annotations
import ast
from cfg import * 
confi = config()
env = CruxEnv(confi)
false_case = {}
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    prompt = env.get_problem_statement()
    code = env.get_code()
    # save code to file
    filename = "code.py"
    with open(filename, "w") as f:
        f.write(code)
    try:
        source = open(filename, 'r').read()
        compile(source, filename, 'exec')
    except:
        print('Error in source code')
        exit(1)

    parser = PyParser(source)
    parser.removeCommentsAndDocstrings()
    parser.formatCode()
    cfg = CFGVisitor().build(filename, ast.parse(parser.script))
    cfg.clean()
    cfg.show()
    path = "cfg.png"

    try:
        solution = visual_run(prompt, code_to_image(code, path))
    except Exception as e:
        solution = ""
        error = True
        print(e)
        pass
    correct = env.check_solution(solution)
    if not correct:
        false_case[idx] = solution
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

  9%|▉         | 22/250 [01:54<18:46,  4.94s/it]

Error code: 500 - {'type': 'error', 'error': {'type': 'api_error', 'message': 'Internal server error'}}


 92%|█████████▏| 229/250 [20:51<01:50,  5.25s/it]

Error code: 500 - {'type': 'error', 'error': {'type': 'api_error', 'message': 'Internal server error'}}


100%|██████████| 250/250 [22:43<00:00,  5.45s/it]

{'correct': 196, 'total': 250, 'error': 2}





In [None]:
false_case

In [6]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


In [11]:
path = 'OpenGVLab/InternVL2-8B'
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=1024, do_sample=False)

ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`

In [3]:

def run_model(prompt):
    response, history = model.chat(tokenizer, None, prompt, generation_config, history=None, return_history=True)
    return response

def code_to_image(code: str, name="ControlFlowGraph"):
    #TODO: quick fix for now, since although the cfg image is saved, the CFGBuilder is not able to finished
    cfg = CFGBuilder().build_from_src(name, code)
    render = cfg.build_visual(name, 'jpeg', show=False)
    return f"{name}.jpeg"

# def visual_run_model(prompt, image_path):
#     pixel_values = load_image(image_path, max_num=12).to(torch.bfloat16).cuda()
#     prompt = "You are given a control flow graph image of a code snippet, utilize them in code execution reasoning process: <image>. Describe the image"
#     return model.chat(tokenizer, pixel_values, prompt, generation_config)

def visual_run_model(prompt, image_path):
    pixel_values = load_image(image_path, max_num=12).to(torch.bfloat16).cuda()
    prompt = f"CFG image: <image> {prompt}"
    return model.chat(tokenizer, pixel_values, prompt, generation_config)

In [None]:

cfg = config()
env = CruxEnv(cfg)
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    prompt = env.get_problem_statement()
    try:
        solution = run_model(prompt)
    except:
        error = True
        solution = ""
        pass
    correct = env.check_solution(solution)
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

In [None]:
def make_visual_cot_output_prompt(s):
    code, input = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Execute the program step by step before arriving at an answer, and provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[PYTHON]
def f(L, m, start, step): 
    L.insert(start, m) 
    for x in range(start-1, 0, -step): 
        start -= 1
        L.insert(start, L.pop(L.index(m)-1)) 
    return L
assert f(thigh_o_two[:], 3, 3, 2) == ??
[/PYTHON]
[THOUGHT]
Let's execute code step by step:
1. Initial State:
	L = [1, 2, 7, 9]
	m = 3
	start = 3
	step = 2
2. First Operation (L.insert(start, m)):
	This is shown in the control flow graph as the first action after the function begins.
	Insert m (which is 3) at index start (which is 3).
	L = [1, 2, 7, 3, 9]
3. For Loop Initialization (for x in range(start - 1, 0, -step)):
	range(start - 1, 0, -step) becomes range(2, 0, -2) because start is 3.
	The loop will iterate with x taking values 2.
	The control flow graph indicates this loop.
4. First Loop Iteration (x = 2):
	Decrement start by 1: start = start - 1 = 2.
	L.pop(L.index(m) - 1):
	L.index(m) finds the index of m (which is 3) in L. The index of 3 is 3.
	L.index(m) - 1 is 3 - 1 = 2.
	L.pop(2) removes and returns the element at index 2, which is 7.
	L.insert(start, 7):
	Insert 7 at index start (which is 2).
	L becomes [1, 2, 7, 3, 9] after removing 7 and inserting it back at the same position. (No visible change)
5. End of Loop:
	The range range(2, 0, -2) has no more values after x = 2, so the loop ends.

After following the control flow of the function and given input parameters, the final output is: [1, 2, 7, 3, 9]
[/THOUGHT]
[ANSWER]
f(thigh_o_two[:], 3, 3, 2) == [1, 2, 7, 3, 9]
[/ANSWER]

[PYTHON]
{code}
assert f({input}) == ??
[/PYTHON]
[THOUGHT]
"""

def make_visual_cot_input_prompt(s):
    code, output = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Execute the program step by step before arriving at an answer, and provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[[PYTHON]
def f(num):
    for i in range(10):
        if num % 2 == 0:
            num -= 2*i
        else:
            num += 2*i
    if num % 3 == 0:
        num -= 3
    elif num % 3 == 1:
        num += 6
    else:
        num += 3
    return num

assert f(??) == 103
[/PYTHON]

[THOUGHT]
Let's execute the code step by step:

1. The function `f` is defined, which takes a single argument `num`.
2. The function is called with an initial value for `num`, which we need to determine.
3. Inside the function, there is a loop that runs 10 times (`for i in range(10)`). During each iteration:
   - If `num` is even, it is decreased by `2 * i`.
   - If `num` is odd, it is increased by `2 * i`.
4. After the loop, there are three conditional checks based on the value of `num % 3`:
   - If `num % 3 == 0`, 3 is subtracted from `num`.
   - If `num % 3 == 1`, 6 is added to `num`.
   - Otherwise, 3 is added to `num`.
5. The goal is to find the initial value of `num` such that the function returns 103 after all operations.

Let's think backward using the CFG:

1. If the function returns 103, we follow the path in the CFG. We notice that the final operation performed was adding 6 to `num` (since `num % 3 == 1`).
   - Therefore, the value of `num` before the final `if` statement must have been `103 - 6 = 97`.
   
2. The loop operations involve adding or subtracting `2 * i` based on whether `num` is even or odd. Importantly, these operations do not affect whether `num % 3` equals 1 (the remainder remains consistent under the transformations).

3. Similarly, the loop preserves the property of whether `num` is even or odd due to the alternating addition and subtraction by even amounts. Thus, we need to determine a starting value of `num` such that it reaches 97 after the loop operations.

4. By analyzing the sequence of operations, we determine that starting with `num = 7` and following through all iterations of the loop, the value transitions correctly to meet all conditions and ultimately results in 103.

Therefore, the correct input that produces 103 is 7.
[/THOUGHT]
[ANSWER]
assert f(7) == 103
[/ANSWER]

[PYTHON]
{code}
assert f(??) == {output}
[/PYTHON]
[THOUGHT]
"""

cfg = config()
env = CruxEnv(cfg)
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    # prompt = env.get_problem_statement()
    prompt = env.get_problem_statement(make_visual_cot_output_prompt)
    code = env.get_code()
    try:
        solution = visual_run_model(prompt, code_to_image(code))
    except Exception as e:
        solution = ""
        error = True
        print(e)
        pass
    correct = env.check_solution(solution)
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

In [30]:
print(env.result)

{'correct': 104, 'total': 250, 'error': 2}


In [3]:
# Gemini 
import google.generativeai as genai 
import PIL.Image 
import os  
genai.configure(api_key="AIzaSyBUShLYTXK4agg-o-03R3iSrmq9BGcm_34") 
img = PIL.Image.open('ControlFlowGraph.jpeg')  
client = genai.GenerativeModel(model_name="gemini-1.5-flash") 

one_shot_code = """def f(L, m, start, step): 
    L.insert(start, m) 
    for x in range(start-1, 0, -step): 
        start -= 1
        L.insert(start, L.pop(L.index(m)-1)) 
    return L"""

def run_gemini(prompt):
    return client.generate_content([prompt]).text

def visual_run_gemini(prompt, image_path):
    img = PIL.Image.open(image_path)
    img1 = PIL.Image.open('cfg_example.png')
    prompt = "You are given a control flow graph image of a code snippet, utilize them in code execution reasoning process. " + prompt
    return client.generate_content([img1, img, prompt]).text


In [None]:
import time 

cfg = config()
env = CruxEnv(cfg)
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    prompt = env.get_problem_statement()
    
    # Initialize variables
    attempts = 0
    max_attempts = 5
    while attempts < max_attempts:
        try:
            solution = run_gemini(prompt)
            break 
        except Exception as e:
            print(e)
            error = True
            attempts += 1  # Increment the attempts counter
            if attempts < max_attempts:
                time.sleep(60)  
            else:
                solution = ""  
                print("Failed after 4 attempts")
                break 
    correct = env.check_solution(solution)
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

In [6]:
print(env.result)

{'correct': 146, 'total': 250, 'error': 12}


In [4]:
from __future__ import annotations
import time 
import ast
from cfg import * 

def make_visual_cot_output_prompt(s):
    code, input = s
    return f"""You are given a Python function and an assertion containing an input to the function. Complete the assertion with a literal (no unsimplified expressions, no function calls) containing the output when executing the provided code on the given input, even if the function is incorrect or incomplete. Do NOT output any extra information. Execute the program step by step before arriving at an answer, and provide the full assertion with the correct output in [ANSWER] and [/ANSWER] tags, following the examples.

[PYTHON]
def f(L, m, start, step): 
    L.insert(start, m) 
    for x in range(start-1, 0, -step): 
        start -= 1
        L.insert(start, L.pop(L.index(m)-1)) 
    return L
assert f(thigh_o_two[:], 3, 3, 2) == ??
[/PYTHON]
[THOUGHT]
Let's execute code step by step:
1. Initial State:
	L = [1, 2, 7, 9]
	m = 3
	start = 3
	step = 2
2. First Operation (L.insert(start, m)):
	This is shown in the control flow graph as the first action after the function begins.
	Insert m (which is 3) at index start (which is 3).
	L = [1, 2, 7, 3, 9]
3. For Loop Initialization (for x in range(start - 1, 0, -step)):
	range(start - 1, 0, -step) becomes range(2, 0, -2) because start is 3.
	The loop will iterate with x taking values 2.
	The control flow graph indicates this loop.
4. First Loop Iteration (x = 2):
	Decrement start by 1: start = start - 1 = 2.
	L.pop(L.index(m) - 1):
	L.index(m) finds the index of m (which is 3) in L. The index of 3 is 3.
	L.index(m) - 1 is 3 - 1 = 2.
	L.pop(2) removes and returns the element at index 2, which is 7.
	L.insert(start, 7):
	Insert 7 at index start (which is 2).
	L becomes [1, 2, 7, 3, 9] after removing 7 and inserting it back at the same position. (No visible change)
5. End of Loop:
	The range range(2, 0, -2) has no more values after x = 2, so the loop ends.

After following the control flow of the function and given input parameters, the final output is: [1, 2, 7, 3, 9]
[/THOUGHT]
[ANSWER]
f(thigh_o_two[:], 3, 3, 2) == [1, 2, 7, 3, 9]
[/ANSWER]

[PYTHON]
{code}
assert f({input}) == ??
[/PYTHON]
[THOUGHT]
"""

confi = config()
env = CruxEnv(confi)
for idx in tqdm(range(250)):
    error = False
    env.set_problem(idx)
    prompt = env.get_problem_statement(make_visual_cot_output_prompt)
    code = env.get_code()
    
    # Initialize variables
    attempts = 0
    max_attempts = 5
    filename = "code.py"
    with open(filename, "w") as f:
        f.write(code)
    try:
        source = open(filename, 'r').read()
        compile(source, filename, 'exec')
    except:
        print('Error in source code')
        exit(1)

    parser = PyParser(source)
    parser.removeCommentsAndDocstrings()
    parser.formatCode()
    cfg = CFGVisitor().build(filename, ast.parse(parser.script))
    cfg.clean()
    cfg.show()
    path = "cfg.png"
    
    while attempts < max_attempts:
        try:
            solution = visual_run_gemini(prompt, path)
            break 
        except Exception as e:
            print(e)
            error = True
            attempts += 1  # Increment the attempts counter
            if attempts < max_attempts:
                time.sleep(60)  
            else:
                solution = ""  
                print("Failed after 4 attempts")
                break 
    correct = env.check_solution(solution)
    env.accumulate_result({"is_correct": correct, "solution": solution, "error": error})
env.finalize()
print(env.result)

 12%|█▏        | 31/250 [01:47<10:46,  2.95s/it]

429 Resource has been exhausted (e.g. check quota).


 20%|██        | 50/250 [03:52<09:11,  2.76s/it]  

429 Resource has been exhausted (e.g. check quota).


 34%|███▍      | 85/250 [06:59<09:33,  3.48s/it]  

429 Resource has been exhausted (e.g. check quota).


 40%|████      | 101/250 [08:46<06:07,  2.46s/it] 

429 Resource has been exhausted (e.g. check quota).


100%|█████████▉| 249/250 [22:57<00:02,  2.64s/it]  

429 Resource has been exhausted (e.g. check quota).


100%|██████████| 250/250 [24:05<00:00,  5.78s/it]

{'correct': 58, 'total': 250, 'error': 5}





In [26]:
print(env.result)

{'correct': 171, 'total': 250, 'error': 5}
