# Setup
Import necessary libraries and set up common functions that will be used throughout the analysis.

In [None]:
# Setting Up the Environment

# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from sklearn.metrics import balanced_accuracy_score
from ast import literal_eval

# Set random seed for reproducibility
np.random.seed(0)

# Define common functions
def leaf_classification(margin, shape, texture):
    if margin == 'serrate': return 'Ocimum basilicum'
    elif margin == 'indented': return 'Jatropha curcas'
    elif margin == 'lobed': return 'Platanus orientalis'
    elif margin == 'serrulate': return "Citrus limon"
    elif margin == 'entire':
        if shape == 'ovate': return 'Pongamia Pinnata'
        elif shape == 'lanceolate': return 'Mangifera indica'
        elif shape == 'oblong': return 'Syzygium cumini'
        elif shape == 'obovate': return "Psidium guajava"
        else:
            if texture == 'leathery': return "Alstonia Scholaris"
            elif texture == 'rough': return "Terminalia Arjuna"
            elif texture == 'glossy': return "Citrus limon"
            else: return "Punica granatum"
    else:
        if shape == 'elliptical': return 'Terminalia Arjuna'
        elif shape == 'lanceolate': return "Mangifera indica"
        else: return 'Syzygium cumini'

def clutrr_function(facts, query):
    rules = {
        ("daughter", "daughter"): "granddaughter",
        ("daughter", "sister"): "daughter",
        ("daughter", "son"): "grandson",
        ("daughter", "aunt"): "sister",
        ("daughter", "father"): "husband",
        ("daughter", "husband"): "son-in-law",
        ("daughter", "brother"): "son",
        ("daughter", "mother"): "wife",
        ("daughter", "uncle"): "brother",
        ("daughter", "grandfather"): "father",
        ("daughter", "grandfather"): "father-in-law",
        ("daughter", "grandmother"): "mother",
        ("daughter", "grandmother"): "mother-in-law",
        ("sister", "daughter"): "niece",
        ("sister", "sister"): "sister",
        ("sister", "son"): "nephew",
        ("sister", "aunt"): "aunt",
        ("sister", "father"): "father",
        ("sister", "brother"): "brother",
        ("sister", "mother"): "mother",
        ("sister", "uncle"): "uncle",
        ("sister", "grandfather"): "grandfather",
        ("sister", "grandmother"): "grandmother",
        ("son", "daughter"): "granddaughter",
        ("son", "sister"): "daughter",
        ("son", "son"): "grandson",
        ("son", "aunt"): "sister",
        ("son", "father"): "husband",
        ("son", "brother"): "son",
        ("son", "mother"): "wife",
        ("son", "uncle"): "brother",
        ("son", "grandfather"): "father",
        ("son", "grandfather"): "father-in-law",
        ("son", "grandmother"): "mother",
        ("son", "grandmother"): "mother-in-law",
        ("aunt", "sister"): "aunt",
        ("aunt", "father"): "grandfather",
        ("aunt", "brother"): "uncle",
        ("aunt", "mother"): "grandmother",
        ("father", "daughter"): "sister",
        ("father", "sister"): "aunt",
        ("father", "son"): "brother",
        ("father", "father"): "grandfather",
        ("father", "brother"): "uncle",
        ("father", "mother"): "grandmother",
        ("father", "wife"): "mother",
        ("husband", "daughter"): "daughter",
        ("husband", "son"): "son",
        ("husband", "father"): "father-in-law",
        ("husband", "granddaughter"): "granddaughter",
        ("husband", "mother"): "mother-in-law",
        ("husband", "grandson"): "grandson",
        ("granddaughter", "sister"): "granddaughter",
        ("granddaughter", "brother"): "grandson",
        ("brother", "daughter"): "niece",
        ("brother", "sister"): "sister",
        ("brother", "son"): "nephew",
        ("brother", "aunt"): "aunt",
        ("brother", "father"): "father",
        ("brother", "brother"): "brother",
        ("brother", "mother"): "mother",
        ("brother", "uncle"): "uncle",
        ("brother", "grandfather"): "grandfather",
        ("brother", "grandmother"): "grandmother",
        ("nephew", "sister"): "niece",
        ("nephew", "brother"): "nephew",
        ("mother", "daughter"): "sister",
        ("mother", "sister"): "aunt",
        ("mother", "son"): "brother",
        ("mother", "father"): "grandfather",
        ("mother", "husband"): "father",
        ("mother", "brother"): "uncle",
        ("mother", "mother"): "grandmother",
        ("mother", "father"): "grandfather",
        ("mother", "mother"): "grandmother",
        ("uncle", "sister"): "aunt",
        ("uncle", "father"): "grandfather",
        ("uncle", "brother"): "uncle",
        ("uncle", "mother"): "grandmother",
        ("grandfather", "wife"): "grandmother",
        ("wife", "daughter"): "daughter",
        ("wife", "son"): "son",
        ("wife", "father"): "father-in-law",
        ("wife", "granddaughter"): "granddaughter",
        ("wife", "mother"): "mother-in-law",
        ("wife", "grandson"): "grandson",
        ("wife", "son-in-law"): "son-in-law",
        ("wife", "father-in-law"): "father",
        ("wife", "daughter-in-law"): "daughter-in-law",
        ("wife", "mother-in-law"): "mother",
        ("grandmother", "husband"): "grandfather",
        ("grandson", "sister"): "granddaughter",
        ("grandson", "brother"): "grandson",
    }

    last_facts = {}
    while query not in facts:
        added_facts = {}
        for fact1 in facts.items():
            for fact2 in facts.items():
                if fact1[0][0] != fact2[0][1] and fact1[0][1] == fact2[0][0] and (fact2[1], fact1[1]) in rules and (fact1[0][0], fact2[0][1]) not in facts:
                    new_fact = rules[(fact2[1], fact1[1])]
                    added_facts[(fact1[0][0], fact2[0][1])] = new_fact
        facts.update(added_facts)
        if last_facts == facts:
            break
        last_facts = facts.copy()

    if query in facts:
        return facts[query]
    else:
        return "Uncertain"

def clevr_function(data, program):
    """
    data: a dict where each key is attributes (color, shape, material, size) 
          and value is the object's bbox
    program: a list of instructions (dictionaries) with function name and inputs
    """
    if not data:
        return "-999"
    # Preprocess the input scene data into a list of objects
    scene_objs = []
    obj_id = 0
    for key, bbox in data.items():
        obj = {
            "color": key[0],
            "shape": key[1],
            "material": key[2],
            "size": key[3],
            "bbox": bbox,
            "id": obj_id
        }
        scene_objs.append(obj)
        obj_id += 1

    # Store intermediate results in memory
    memory = []

    # Helper Functions
    def scene_fn():
        return scene_objs

    def filter_color(objects, color):
        return [obj for obj in objects if obj["color"] == color]

    def filter_size(objects, size):
        return [obj for obj in objects if obj["size"] == size]

    def filter_material(objects, material):
        return [obj for obj in objects if obj["material"] == material]

    def filter_shape(objects, shape):
        return [obj for obj in objects if obj["shape"] == shape]

    def unique(objects):
        if len(objects) == 1:
            return objects[0]
        raise ValueError(f"unique() expected exactly one object, but got {len(objects)} objects.")

    def query_color(obj):
        return obj["color"]

    def query_shape(obj):
        return obj["shape"]

    def query_material(obj):
        return obj["material"]

    def query_size(obj):
        return obj["size"]

    def same_size(obj):
        return [o for o in scene_objs if o["size"] == obj["size"] and o["id"] != obj["id"]]

    def same_material(obj):
        return [o for o in scene_objs if o["material"] == obj["material"] and o["id"] != obj["id"]]

    def same_shape(obj):
        return [o for o in scene_objs if o["shape"] == obj["shape"] and o["id"] != obj["id"]]

    def same_color(obj):
        return [o for o in scene_objs if o["color"] == obj["color"] and o["id"] != obj["id"]]

    def relate(obj, relation):
        def center(o):
            bbox = o["bbox"]
            return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
        ref_center = center(obj)
        if relation == "left":
            return [o for o in scene_objs if center(o)[0] < ref_center[0]]
        elif relation == "right":
            return [o for o in scene_objs if center(o)[0] > ref_center[0]]
        elif relation == "front":
            return [o for o in scene_objs if center(o)[1] > ref_center[1]]
        elif relation == "behind":
            return [o for o in scene_objs if center(o)[1] < ref_center[1]]
        else:
            raise ValueError("Unknown relation: " + relation)

    def union(list1, list2):
        seen = set()
        result = []
        for obj in list1 + list2:
            if obj["id"] not in seen:
                seen.add(obj["id"])
                result.append(obj)
        return result

    def intersect(list1, list2):
        set1 = {obj["id"] for obj in list1}
        return [obj for obj in list2 if obj["id"] in set1]

    def count_fn(objects):
        return str(len(objects))

    # Map function names to helper functions
    function_map = {
        "scene": scene_fn,
        "filter_color": filter_color,
        "filter_size": filter_size,
        "filter_material": filter_material,
        "filter_shape": filter_shape,
        "unique": unique,
        "query_color": query_color,
        "query_shape": query_shape,
        "query_material": query_material,
        "query_size": query_size,
        "same_size": same_size,
        "same_material": same_material,
        "same_shape": same_shape,
        "same_color": same_color,
        "relate": relate,
        "union": union,
        "intersect": intersect,
        "count": count_fn
    }

    # Execute the program
    for instruction in program:
        inputs = [memory[i] for i in instruction.get("inputs", [])]
        func_name = instruction["function"]
        val_inputs = instruction.get("value_inputs", [])
        func = function_map.get(func_name)
        if func is None:
            raise ValueError("Unknown function: " + func_name)
        result = func(*inputs, *val_inputs)
        memory.append(result)

    # Return the final result
    return memory[-1]

# Leaf Dataset

In [None]:
# Analyzing Leaf Dataset Results
from src.dataset import LeafDataset
import torchvision
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from ast import literal_eval

margin = ['entire', 'indented', 'lobed', 'serrate', 'serrulate', 'undulate']
shape = ['elliptical', 'lanceolate', 'oblong', 'obovate', 'ovate']
texture = ['glossy', 'leathery', 'smooth', 'rough']

# Load Leaf Dataset
data = LeafDataset(root="../")
test_data_ids = list(range(min(200, len(data))))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]

# Load Gemini Predictions
preds = []
with open("../model_outputs/gemini-2.0-flash/leaf/llm_symbolic_fs.txt") as f:
    for line in f:
        preds.append(literal_eval(line))

# Load Scallop Predictions
scallop_intermediate = pickle.load(open("../model_outputs/scallop/leaf.pkl", "rb"))
intermediate = []
for m, t, s in scallop_intermediate:
    m = margin[torch.argmax(m[0]).item()]
    s = shape[torch.argmax(s[0]).item()]
    t = texture[torch.argmax(t[0]).item()]
    intermediate.append((m, s, t))

scallop_preds = [leaf_classification(*s) for s in intermediate]

# Shuffle predictions to match test data
intermediate = [intermediate[i] for i in shuf[:200]]
scallop_preds = [scallop_preds[i] for i in shuf[:200]]

# Calculate Accuracy
gemini_correct = sum([leaf_classification(*preds[i][1]) == gt[i] for i in range(len(preds))])
scallop_correct = sum([scallop_preds[i] == gt[i] for i in range(len(scallop_preds))])

print("Gemini Accuracy:", gemini_correct / len(preds))
print("Scallop Accuracy:", scallop_correct / len(scallop_preds))

# Get cases where scallop is correct and gemini is wrong
scallop_correct_gemini_wrong = [i for i in range(len(scallop_preds)) 
                              if scallop_preds[i] == gt[i] and leaf_classification(*preds[i][1]) != gt[i]]
print(f"Found {len(scallop_correct_gemini_wrong)} examples where Scallop is correct and Gemini is wrong")

# Get cases where gemini is correct and scallop is wrong
gemini_correct_scallop_wrong = [i for i in range(len(scallop_preds)) 
                              if scallop_preds[i] != gt[i] and leaf_classification(*preds[i][1]) == gt[i]]
print(f"Found {len(gemini_correct_scallop_wrong)} examples where Gemini is correct and Scallop is wrong")

# Create interactive widget
def plot_example(idx):
    plt.figure(figsize=(8, 8))
    plt.imshow(test_data[idx][0][0])
    plt.title(f"Ground Truth: {gt[idx]}\nGemini: {tuple(preds[idx][1])} -> {leaf_classification(*preds[idx][1])}\nScallop: {intermediate[idx]} -> {scallop_preds[idx]}")
    plt.axis('off')
    plt.show()
    
    print("Ground Truth:", gt[idx])
    print("Scallop symbols:", intermediate[idx])
    print("Scallop prediction:", scallop_preds[idx])
    print("Gemini symbols:", preds[idx][1])
    print("Gemini prediction:", leaf_classification(*preds[idx][1]))

def on_dropdown_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new'] == 'scallop_correct_gemini_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in scallop_correct_gemini_wrong]
        elif change['new'] == 'gemini_correct_scallop_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in gemini_correct_scallop_wrong]
        elif change['new'] == 'all':
            example_dropdown.options = [(f"Example #{i}", i) for i in range(len(test_data))]

def on_example_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with output:
            output.clear_output(wait=True)
            plot_example(change['new'])

# Create widgets
category_dropdown = widgets.Dropdown(
    options=[
        ('All examples', 'all'), 
        ('Scallop correct, Gemini wrong', 'scallop_correct_gemini_wrong'),
        ('Gemini correct, Scallop wrong', 'gemini_correct_scallop_wrong')
    ],
    description='Category:'
)

example_dropdown = widgets.Dropdown(
    options=[(f"Example #{i}", i) for i in range(len(test_data))],
    description='Example:'
)

output = widgets.Output()

# Set up event handlers
category_dropdown.observe(on_dropdown_change)
example_dropdown.observe(on_example_change)

# Display widgets
display(widgets.VBox([widgets.HBox([category_dropdown, example_dropdown]), output]))

# Initialize with first example
with output:
    plot_example(0)

Gemini Accuracy: 0.39090909090909093
Scallop Accuracy: 0.8818181818181818
Found 59 examples where Scallop is correct and Gemini is wrong
Found 5 examples where Gemini is correct and Scallop is wrong


VBox(children=(HBox(children=(Dropdown(description='Category:', options=(('All examples', 'all'), ('Scallop co…

# Sum5 Dataset

In [None]:
# Analyzing Sum5 Dataset Results

# Load Sum5 Dataset
from src.dataset import MNISTSumKOrigDataset
import torchvision
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from ast import literal_eval

mnist_img_transform = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

np.random.seed(0)
data = MNISTSumKOrigDataset(root="../data", train=False, download=True, k=5, noise=0.00)
test_data_ids = list(range(200))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]
gt_c = [test_data[i][2] for i in range(len(test_data))]

# Load Gemini Predictions
preds = []
with open("../model_outputs/gemini-2.0-flash/sum2/llm_symbolic_fs.txt") as f:
    for line in f:
        preds.append(literal_eval(line))

preds_ans = np.array([p[0] for p in preds])
print("Gemini Accuracy:", np.sum(preds_ans == np.array(gt)) / len(gt))

# Load Scallop Predictions
scallop_intermediate = pickle.load(open("../model_outputs/scallop/mnist.pkl", "rb"))
intermediate = []
scallop_preds = []
for i in scallop_intermediate:
    s = [torch.argmax(dig).item() for dig in i]
    intermediate.append(s)
    scallop_preds.append(sum(s))

print("Scallop Accuracy:", np.sum(np.array(scallop_preds) == np.array(gt)) / len(gt))

# Get cases where scallop is correct and gemini is wrong
scallop_correct_gemini_wrong = [i for i in range(len(scallop_preds)) if scallop_preds[i] == gt[i] and preds_ans[i] != gt[i]]
print(f"Found {len(scallop_correct_gemini_wrong)} examples where Scallop is correct and Gemini is wrong")

# Get cases where gemini is correct and scallop is wrong
gemini_correct_scallop_wrong = [i for i in range(len(scallop_preds)) if scallop_preds[i] != gt[i] and preds_ans[i] == gt[i]]
print(f"Found {len(gemini_correct_scallop_wrong)} examples where Gemini is correct and Scallop is wrong")

# Create interactive widget
def plot_example(idx):
    # concatenate the 5 images
    img = torch.cat([torchvision.transforms.functional.to_tensor(test_data[idx][0][i]) for i in range(5)], dim=2)
    plt.figure(figsize=(12, 3))
    plt.imshow(img.permute(1, 2, 0).numpy())
    plt.title(f"Ground Truth: {gt[idx]}\nGemini: {tuple(preds[idx][1])} -> {preds_ans[idx]}\nScallop: {intermediate[idx]} -> {scallop_preds[idx]}")
    plt.axis('off')
    plt.show()
    
    print("Ground Truth:", gt[idx])
    print("Scallop symbols:", intermediate[idx])
    print("Scallop prediction:", scallop_preds[idx])
    print("Gemini symbols:", preds[idx][1])
    print("Gemini prediction:", preds[idx][0])

def on_dropdown_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new'] == 'scallop_correct_gemini_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in scallop_correct_gemini_wrong]
        elif change['new'] == 'gemini_correct_scallop_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in gemini_correct_scallop_wrong]
        elif change['new'] == 'all':
            example_dropdown.options = [(f"Example #{i}", i) for i in range(len(test_data))]

def on_example_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with output:
            output.clear_output(wait=True)
            plot_example(change['new'])

# Create widgets
category_dropdown = widgets.Dropdown(
    options=[
        ('All examples', 'all'), 
        ('Scallop correct, Gemini wrong', 'scallop_correct_gemini_wrong'),
        ('Gemini correct, Scallop wrong', 'gemini_correct_scallop_wrong')
    ],
    description='Category:'
)

example_dropdown = widgets.Dropdown(
    options=[(f"Example #{i}", i) for i in range(len(test_data))],
    description='Example:'
)

output = widgets.Output()

# Set up event handlers
category_dropdown.observe(on_dropdown_change)
example_dropdown.observe(on_example_change)

# Display widgets
display(widgets.VBox([widgets.HBox([category_dropdown, example_dropdown]), output]))

# Initialize with first example
with output:
    plot_example(0)

Gemini Accuracy: 0.815
Scallop Accuracy: 0.975
Found 33 examples where Scallop is correct and Gemini is wrong
Found 1 examples where Gemini is correct and Scallop is wrong


VBox(children=(HBox(children=(Dropdown(description='Category:', options=(('All examples', 'all'), ('Scallop co…

# HWF5 Dataset

In [None]:
# Analyzing HWF5 Dataset Results
from src.dataset import HWFDataset
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from ast import literal_eval

np.random.seed(0)
data = HWFDataset(root="../data", split="test", length=5)
test_data_ids = list(range(200))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]
gt_c = [test_data[i][2] for i in range(len(test_data))]

# Load Gemini Predictions
preds = []
with open("../model_outputs/gemini-2.0-flash/hwf/llm_symbolic_fs.txt") as f:
    for line in f:
        preds.append(literal_eval(line))

preds_ans = np.array([p[0] for p in preds])
print("Gemini Accuracy:", np.sum(preds_ans == np.array(gt)) / len(gt))

# Load Scallop Predictions
scallop_intermediate = torch.load(open("../model_outputs/scallop/hwf.pkl", "rb"))
intermediate = []
symbol_str = list(range(10)) + ["+", "-", "*", "/"]
for i in scallop_intermediate:
    s = [str(symbol_str[torch.argmax(dig).item()]) for dig in i]
    intermediate.append(s)

# Generate scallop predictions with proper error handling
scallop_preds = []
for i in range(len(intermediate)):
    try:
        expr = "".join(intermediate[i])
        res = eval(expr)
        scallop_preds.append(res)
    except:
        scallop_preds.append("error")

scallop_correct = sum([(scallop_preds[i] == gt[i]) for i in range(len(scallop_preds))])
print("Scallop Accuracy:", scallop_correct / len(scallop_preds))

# Get cases where scallop is correct and gemini is wrong
scallop_correct_gemini_wrong = [i for i in range(len(scallop_preds)) 
                             if scallop_preds[i] == gt[i] and preds_ans[i] != gt[i]]
print(f"Found {len(scallop_correct_gemini_wrong)} examples where Scallop is correct and Gemini is wrong")

# Get cases where gemini is correct and scallop is wrong
gemini_correct_scallop_wrong = [i for i in range(len(scallop_preds)) 
                             if (scallop_preds[i] != gt[i]) and preds_ans[i] == gt[i]]
print(f"Found {len(gemini_correct_scallop_wrong)} examples where Gemini is correct and Scallop is wrong")

# Create interactive widget
def plot_example(idx):
    img = np.concatenate([np.array(test_data[idx][0][i]) for i in range(5)], axis=1)
    plt.figure(figsize=(12, 3))
    plt.imshow(img)
    
    # Display scallop prediction safely
    try:
        scallop_pred = eval("".join(intermediate[idx])) if scallop_preds[idx] != "error" else "error"
    except:
        scallop_pred = "error"
        
    plt.title(f"Ground Truth: {gt[idx]}\nGemini: {tuple(preds[idx][1])} -> {preds_ans[idx]}\nScallop: {intermediate[idx]} -> {scallop_pred}")
    plt.axis('off')
    plt.show()
    
    print("Ground Truth:", gt[idx])
    print("Scallop symbols:", intermediate[idx])
    print("Scallop expression:", "".join(intermediate[idx]))
    print("Scallop prediction:", scallop_preds[idx])
    print("Gemini symbols:", preds[idx][1])
    print("Gemini prediction:", preds[idx][0])

def on_dropdown_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new'] == 'scallop_correct_gemini_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in scallop_correct_gemini_wrong]
        elif change['new'] == 'gemini_correct_scallop_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in gemini_correct_scallop_wrong]
        elif change['new'] == 'all':
            example_dropdown.options = [(f"Example #{i}", i) for i in range(len(test_data))]

def on_example_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with output:
            output.clear_output(wait=True)
            plot_example(change['new'])

# Create widgets
category_dropdown = widgets.Dropdown(
    options=[
        ('All examples', 'all'), 
        ('Scallop correct, Gemini wrong', 'scallop_correct_gemini_wrong'),
        ('Gemini correct, Scallop wrong', 'gemini_correct_scallop_wrong')
    ],
    description='Category:'
)

example_dropdown = widgets.Dropdown(
    options=[(f"Example #{i}", i) for i in range(len(test_data))],
    description='Example:'
)

output = widgets.Output()

# Set up event handlers
category_dropdown.observe(on_dropdown_change)
example_dropdown.observe(on_example_change)

# Display widgets
display(widgets.VBox([widgets.HBox([category_dropdown, example_dropdown]), output]))

# Initialize with first example
with output:
    plot_example(0)

Gemini Accuracy: 0.69
Scallop Accuracy: 0.96
Found 58 examples where Scallop is correct and Gemini is wrong
Found 4 examples where Gemini is correct and Scallop is wrong


  scallop_intermediate = torch.load(open("../baseline_outputs/scallop/hwf.pkl", "rb"))


VBox(children=(HBox(children=(Dropdown(description='Category:', options=(('All examples', 'all'), ('Scallop co…

# CLUTRR Dataset

In [None]:
# Analyzing CLUTRR Dataset Results
from src.dataset import ClutrrDataset
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from ast import literal_eval

np.random.seed(0)
data = ClutrrDataset(train=False, varied_complexity=False, root="../")
test_data_ids = list(range(min(200, len(data))))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]

# Load Gemini Predictions
gemini_preds = []
with open("../model_outputs/gemini-2.0-flash/clutrr/llm_symbolic_fs.txt") as f:
    for line in f:
        gemini_preds.append(literal_eval(line))

# Load Scallop Predictions
scallop_intermediate = torch.load(open("../model_outputs/scallop/clutrr.pkl", "rb"))
intermediate = []
predictions = []
relations = [
    'daughter', 'sister', 'son', 'aunt', 'father', 'husband', 'granddaughter', 'brother', 'nephew', 'mother', 'uncle', 
    'grandfather', 'wife', 'grandmother', 'niece', 'grandson', 'son-in-law', 'father-in-law', 'daughter-in-law', 
    'mother-in-law', 'nothing'
]
for idx, i in enumerate(scallop_intermediate):
    facts = {}
    for names, rel in i[0]:
        if relations[torch.argmax(rel).item()] != "nothing":
            facts[names[::-1]] = relations[torch.argmax(rel).item()]
    intermediate.append(facts)
    predictions.append(clutrr_function(facts, test_data[idx][0][1]))

# Calculate Accuracy
gemini_correct = sum([gemini_preds[i][0] == gt[i] for i in range(len(gemini_preds))])
scallop_correct = sum([predictions[i] == gt[i] for i in range(len(predictions))])

print("Gemini Accuracy:", gemini_correct / len(gemini_preds))
print("Scallop Accuracy:", scallop_correct / len(predictions))

# Get cases where scallop is correct and gemini is wrong
scallop_correct_gemini_wrong = [i for i in range(len(predictions)) 
                             if predictions[i] == gt[i] and gemini_preds[i][0] != gt[i]]
print(f"Found {len(scallop_correct_gemini_wrong)} examples where Scallop is correct and Gemini is wrong")

# Get cases where gemini is correct and scallop is wrong
gemini_correct_scallop_wrong = [i for i in range(len(predictions)) 
                             if predictions[i] != gt[i] and gemini_preds[i][0] == gt[i]]
print(f"Found {len(gemini_correct_scallop_wrong)} examples where Gemini is correct and Scallop is wrong")

# Helper functions for formatting
def scallop_symbols_print(symbols):
    out_str = ""
    for k, v in symbols.items():
        out_str += f"{k[0]} is {v} of {k[1]}\n"
    return out_str.strip()

def gemini_symbols_print(symbols):
    out_str = ""
    for k, v in symbols:
        out_str += f"{k[0]} is {v} of {k[1]}\n"
    return out_str.strip()

# Create interactive widget
def plot_example(idx):
    # Display the story/relations
    story = test_data[idx][0][0]
    query = test_data[idx][0][1]
    
    # print("Story:", story)
    # print story and wrap it to fit the screen
    print("Story:", ".\n".join(story.split(". ")))
    print(f"Query: What is the relation between {query[0]} and {query[1]}?")
    print("-" * 80)
    print("Ground Truth:", gt[idx])
    print("\nScallop inferred relations:")
    print(scallop_symbols_print(intermediate[idx]))
    print("Scallop prediction:", predictions[idx])
    print("\nGemini inferred relations:")
    print(gemini_symbols_print(gemini_preds[idx][1][0]))
    print("Gemini prediction:", gemini_preds[idx][0])

def on_dropdown_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new'] == 'scallop_correct_gemini_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in scallop_correct_gemini_wrong]
        elif change['new'] == 'gemini_correct_scallop_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in gemini_correct_scallop_wrong]
        elif change['new'] == 'all':
            example_dropdown.options = [(f"Example #{i}", i) for i in range(len(test_data))]

def on_example_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with output:
            output.clear_output(wait=True)
            plot_example(change['new'])

# Create widgets
category_dropdown = widgets.Dropdown(
    options=[
        ('All examples', 'all'), 
        ('Scallop correct, Gemini wrong', 'scallop_correct_gemini_wrong'),
        ('Gemini correct, Scallop wrong', 'gemini_correct_scallop_wrong')
    ],
    description='Category:'
)

example_dropdown = widgets.Dropdown(
    options=[(f"Example #{i}", i) for i in range(len(test_data))],
    description='Example:'
)

output = widgets.Output()

# Set up event handlers
category_dropdown.observe(on_dropdown_change)
example_dropdown.observe(on_example_change)

# Display widgets
display(widgets.VBox([widgets.HBox([category_dropdown, example_dropdown]), output]))

# Initialize with first example
with output:
    plot_example(0)

Number of samples: 190
Complexity histogram:
(array([100,   0,   0,   0,   0,   0,   0]), array([ 4,  5,  6,  7,  8,  9, 10, 11]))
Gemini Accuracy: 0.8
Scallop Accuracy: 0.38
Found 1 examples where Scallop is correct and Gemini is wrong
Found 43 examples where Gemini is correct and Scallop is wrong


  scallop_intermediate = torch.load(open("../baseline_outputs/scallop/clutrr.pkl", "rb"))


VBox(children=(HBox(children=(Dropdown(description='Category:', options=(('All examples', 'all'), ('Scallop co…

In [28]:
print(gemini_preds[0])

('0', [{('gray', 'cube', 'rubber', 'large'): (95, 350, 276, 659), ('brown', 'cylinder', 'rubber', 'small'): (279, 403, 354, 543), ('purple', 'sphere', 'metal', 'large'): (556, 275, 689, 475), ('blue', 'cylinder', 'metal', 'large'): (595, 165, 720, 384), ('yellow', 'sphere', 'rubber', 'small'): (362, 700, 456, 850)}, [{'inputs': [], 'function': 'scene', 'value_inputs': []}, {'inputs': [0], 'function': 'filter_color', 'value_inputs': ['yellow']}, {'inputs': [1], 'function': 'filter_shape', 'value_inputs': ['sphere']}, {'inputs': [2], 'function': 'unique', 'value_inputs': []}, {'inputs': [3], 'function': 'same_material', 'value_inputs': []}, {'inputs': [4], 'function': 'filter_size', 'value_inputs': ['small']}, {'inputs': [5], 'function': 'filter_shape', 'value_inputs': ['sphere']}, {'inputs': [6], 'function': 'count', 'value_inputs': []}]])


# CLEVR Dataset

In [None]:
# Analyzing CLEVR Dataset Results
from src.dataset import ClevrDataset
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pickle
from ast import literal_eval
from PIL import Image

np.random.seed(0)
data = ClevrDataset(
    questions_path="../data/CLEVR_v1.0/questions/CLEVR_val_questions.json",
    images_path="../data/CLEVR_v1.0/images/val/",
    scene_path="../data/CLEVR_v1.0/scenes/CLEVR_val_scenes.json",
    max_samples=500
)

test_data_ids = list(range(min(200, len(data))))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]

# Load Gemini Predictions
gemini_preds = []
with open("../model_outputs/gemini-2.0-flash/clevr/llm_symbolic_fs.txt") as f:
    for line in f:
        gemini_preds.append(literal_eval(line))

preds_ans = np.array([p[0] for p in gemini_preds])
print("Gemini Accuracy:", np.sum(preds_ans == np.array(gt)) / len(gt))

# Load Scallop Predictions
shapes = ["cube", "cylinder", "sphere"]
colors = ["gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow"]
sizes = ["large", "small"]
mats = ["metal", "rubber"]

scallop_intermediate = torch.load(open("../model_outputs/scallop/clevr.pkl", "rb"))
intermediate = []
for inter in scallop_intermediate:
    symbols = {}
    for o in range(len(inter["boxes"][0])):
        box = inter["boxes"][0][o]
        s = shapes[torch.argmax(inter["shape"][o]).item()]
        c = colors[torch.argmax(inter["color"][o]).item()]
        z = sizes[torch.argmax(inter["size"][o]).item()]
        m = mats[torch.argmax(inter["texture"][o]).item()]
        symbols[(c, s, m, z)] = box
    intermediate.append(symbols)

# Calculate Scallop predictions and accuracy
scallop_preds = []
for i in range(len(intermediate)):
    try:
        out = clevr_function(intermediate[i], test_data[i][0][1])
        scallop_preds.append(out)
    except:
        scallop_preds.append("error")

scallop_correct = sum([str(scallop_preds[i]) == str(gt[i]) for i in range(len(scallop_preds))])
print("Scallop Accuracy:", scallop_correct / len(scallop_preds))

# Get cases where scallop is correct and gemini is wrong
scallop_correct_gemini_wrong = [i for i in range(len(scallop_preds)) 
                             if str(scallop_preds[i]) == str(gt[i]) and str(preds_ans[i]) != str(gt[i])]
print(f"Found {len(scallop_correct_gemini_wrong)} examples where Scallop is correct and Gemini is wrong")

# Get cases where gemini is correct and scallop is wrong
gemini_correct_scallop_wrong = [i for i in range(len(scallop_preds)) 
                             if str(scallop_preds[i]) != str(gt[i]) and str(preds_ans[i]) == str(gt[i])]
print(f"Found {len(gemini_correct_scallop_wrong)} examples where Gemini is correct and Scallop is wrong")

# Helper function to extract gemini objects when available
def extract_gemini_objects(pred):
    if len(pred) > 1 and isinstance(pred[1], list) and len(pred[1]) > 0:
        objects = pred[1][0]
        return objects
    return []

# Create interactive widget
def plot_example(idx):
    # Get the image
    img = test_data[idx][0][0]
    question = test_data[idx][2]
    
    # Create side-by-side comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # Title for the entire figure
    fig.suptitle(f"Question: {question[0]}\nGround Truth: {gt[idx]}", fontsize=14)
    
    # Left subplot - Scallop detection
    ax1.imshow(img)
    
    # Add bounding boxes for Scallop objects
    for k, v in intermediate[idx].items():
        rect = patches.Rectangle((v[0]*480, v[1]*320), (v[2]-v[0])*480, (v[3]-v[1])*320, 
                                linewidth=2, edgecolor='r', facecolor='none')
        ax1.add_patch(rect)
        
        # Label the box with the object attributes
        ax1.text(v[0]*480, (v[1])*320-5, f"{k[0]}\n{k[1]}\n{k[2]}\n{k[3]}", color='r', fontsize=9,
                bbox=dict(facecolor='white', alpha=0.7))
    
    # Right subplot - Gemini detection
    ax2.imshow(img)
    ax2.set_title(f"Gemini Detection (Prediction: {preds_ans[idx]})")
    
    # Try to extract and visualize Gemini objects if available
    gemini_objects = extract_gemini_objects(gemini_preds[idx])
    print("Gemini objects:", gemini_objects)
    
    if gemini_objects:
        for i, (attrs, bbox) in enumerate(gemini_objects.items()):
            v = bbox
            rect = patches.Rectangle((v[0]/1000*480, v[1]/1000*320), (v[2]-v[0])/1000*480, (v[3]-v[1])/1000*320, linewidth=2, edgecolor='r', facecolor='none')
            ax2.add_patch(rect)
            
            # Compose attribute text
            attr_text = "\n".join(attrs)
            
            # Position text
            x_pos = v[0]/1000*480
            y_pos = v[1]/1000*320 - 5
            ax2.text(x_pos, y_pos, attr_text, color='blue', fontsize=9, 
                   bbox=dict(facecolor='white', alpha=0.7))
    
    ax1.axis('off')
    ax2.axis('off')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # Print results
    print(f"Ground Truth: {gt[idx]}")
    print(f"Gemini prediction: {preds_ans[idx]}")
    print(f"Scallop prediction: {scallop_preds[idx]}")
    
    print("\nScallop detected objects:")
    for k in intermediate[idx].keys():
        print(f"  {k[0]} {k[1]} {k[2]} {k[3]}")
    
    print("\nGemini detected objects:")
    gemini_objects = extract_gemini_objects(gemini_preds[idx])
    if gemini_objects:
        for attrs, _ in gemini_objects.items():
            print(f"  {attrs}")
    else:
        # Fall back to printing raw gemini inference data
        if len(gemini_preds[idx]) > 1:
            print(gemini_preds[idx][1])

def on_dropdown_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        if change['new'] == 'scallop_correct_gemini_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in scallop_correct_gemini_wrong]
        elif change['new'] == 'gemini_correct_scallop_wrong':
            example_dropdown.options = [(f"Example #{i}", i) for i in gemini_correct_scallop_wrong]
        elif change['new'] == 'all':
            example_dropdown.options = [(f"Example #{i}", i) for i in range(len(test_data))]

def on_example_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with output:
            output.clear_output(wait=True)
            plot_example(change['new'])

# Create widgets
category_dropdown = widgets.Dropdown(
    options=[
        ('All examples', 'all'), 
        ('Scallop correct, Gemini wrong', 'scallop_correct_gemini_wrong'),
        ('Gemini correct, Scallop wrong', 'gemini_correct_scallop_wrong')
    ],
    description='Category:'
)

example_dropdown = widgets.Dropdown(
    options=[(f"Example #{i}", i) for i in range(len(test_data))],
    description='Example:'
)

output = widgets.Output()

# Set up event handlers
category_dropdown.observe(on_dropdown_change)
example_dropdown.observe(on_example_change)

# Display widgets
display(widgets.VBox([widgets.HBox([category_dropdown, example_dropdown]), output]))

# Initialize with first example
with output:
    plot_example(0)

Gemini Accuracy: 0.755
Scallop Accuracy: 0.75
Found 28 examples where Scallop is correct and Gemini is wrong
Found 29 examples where Gemini is correct and Scallop is wrong


  scallop_intermediate = torch.load(open("../baseline_outputs/scallop/clevr.pkl", "rb"))


VBox(children=(HBox(children=(Dropdown(description='Category:', options=(('All examples', 'all'), ('Scallop co…