# Binding Task Test - Multi-Model Support

Test language models on the binding task with n=20 binding groups.
Supports multiple models: MPT-7b-instruct, Falcon3-Mamba-7B-Base, and mamba-2.8b-instruct-openhermes.

This notebook allows you to keep the model loaded between runs.

In [1]:
# Setup (run once)
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("CausalAbstraction")

import logging
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
import transformers
transformers.logging.set_verbosity_error()

import torch
from tqdm.notebook import tqdm
import re

from grammar.schemas import SCHEMA_BOXES
from grammar.task_to_causal_model import multi_order_multi_schema_task_to_lookbacks_generic_causal_model
from training import get_counterfactual_datasets, sample_answerable_question_template

# Import model wrappers
from binding_model_wrappers import create_model_wrapper

def naive_checker(response: str, expected: str) -> bool:
    """
    Returns True if `expected` appears as a standalone word in `response`,
    ignoring case.
    """
    pattern = rf'\b{re.escape(expected)}\b'
    return re.search(pattern, response, flags=re.IGNORECASE) is not None


nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


In [2]:
# Configuration

# model_id = "state-spaces/mamba-2.8b-hf"
# model_id = "mistralai/Mamba-Codestral-7B-v0.1"
model_id = "tiiuae/Falcon3-Mamba-7B-Instruct"
# model_id = "clibrain/mamba-2.8b-instruct-openhermes" // Not working.
# model_id = "mosaicml/mpt-7b-instruct" // Recently removed from huggingface.

schema = SCHEMA_BOXES
num_instances = 20  # Number of binding groups (n=20 as in paper)
num_samples = 10   # Number of test samples (use 3000 for full test)
cat_indices_to_query = [0]  # Query by Object
cat_to_query = 1            # Answer is Box

In [3]:
# Load Model (run once - takes a few minutes)
print(f"[+] Loading model: {model_id}")
model_wrapper = create_model_wrapper(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print(f"    Model loaded on {model_wrapper.model.device}")
print(f"    Model has {model_wrapper.get_num_layers()} layers")
print(f"    Model type: {model_wrapper.get_display_name()}")

[+] Loading model: tiiuae/Falcon3-Mamba-7B-Instruct


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

meta params: 0
    Model loaded on cpu
    Model has 64 layers
    Model type: Falcon3-Mamba-7B-Instruct


In [4]:
# Generate Test Dataset
print(f"[+] Generating test dataset with {num_instances} binding groups...")
causal_model = multi_order_multi_schema_task_to_lookbacks_generic_causal_model(
    [schema], num_instances, num_fillers_per_item=0, fillers=False
)
causal_models = {schema.name: causal_model}

train_ds, test_ds, _ = get_counterfactual_datasets(
    None,  # No pipeline for filtering
    [schema],
    num_samples=num_samples,
    num_instances=num_instances,
    cat_indices_to_query=cat_indices_to_query,
    answer_cat_id=cat_to_query,
    do_assert=False,
    do_filter=False,
    causal_models=causal_models,
    sample_an_answerable_question=sample_answerable_question_template,
)

train = train_ds[schema.name][schema.name]
print(f"    Generated {len(train)} samples")

[+] Generating test dataset with 20 binding groups...
    Generated 10 samples


In [5]:
# Test Single Sample (for debugging)
sample = train[0]
raw_input = sample["input"]["raw_input"]
queried_object = sample["input"].get("Object.0.Query", None)

print("Raw input")
print(raw_input)
print(f"\nQueried object: {queried_object}")

# Get expected answer
forward_res = causal_model.run_forward(sample["input"])
expected = forward_res.get("answer", "")
if isinstance(expected, dict):
    expected = max(expected, key=lambda k: expected[k])
print(f"Expected answer: {expected}")

# Format and generate using model wrapper
prompt = model_wrapper.format_prompt(raw_input, queried_object=queried_object)
print(f"\nFormatted prompt")
print(prompt)

response = model_wrapper.generate_response(prompt)
print(f"\nModel response: '{response}'")
print(f"Correct (Naive): {naive_checker(response, expected)}")
# print(f"Correct (original): {schema.checker(response, expected)}")

Raw input
The boat is in Box D, the pot is in Box R, the computer is in Box V, the file is in Box S, the tea is in Box T, the dress is in Box Q, the engine is in Box W, the wheel is in Box Z, the guitar is in Box G, the stone is in Box E, the leaf is in Box K, the ball is in Box F, the tie is in Box B, the train is in Box J, the cup is in Box H, the mirror is in Box M, the suit is in Box P, the letter is in Box O, the fan is in Box A, and the creature is in Box L. Respond in one word, only the answer and nothing else: Which box is the engine in? Box Answer:

Queried object: engine
Expected answer: W

Formatted prompt
<|begin_of_text|><|im_start|>system
You are a helpful friendly assistant Falcon3 from TII, try to follow instructions as much as possible.<|im_end|>
<|im_start|>user
The boat is in Box D, the pot is in Box R, the computer is in Box V, the file is in Box S, the tea is in Box T, the dress is in Box Q, the engine is in Box W, the wheel is in Box Z, the guitar is in Box G, the

In [6]:
# Run Full Evaluation
print(f"[+] Evaluating {len(train)} samples...")

correct = 0
total = 0
position_correct = {}
position_total = {}
results = []

for i in tqdm(range(len(train))):
    sample = train[i]
    raw_input = sample["input"]["raw_input"]
    queried_object = sample["input"].get("Object.0.Query", None)
    
    # Get expected answer
    forward_res = causal_model.run_forward(sample["input"])
    expected = forward_res.get("answer", "")
    if isinstance(expected, dict):
        expected = max(expected, key=lambda k: expected[k])
    
    # Get query position
    query_position = sample["input"].get("metadata", {}).get("src_positional_index", -1)
    
    # Format and generate using model wrapper
    prompt = model_wrapper.format_prompt(raw_input, queried_object=queried_object)
    response = model_wrapper.generate_response(prompt)
    
    # Check correctness using naive checker
    is_correct = naive_checker(response, expected)
    if is_correct:
        correct += 1
    total += 1

    print(f"{i=} {response=} {expected=} {is_correct=}")
    
    # Track per-position accuracy
    if query_position >= 0:
        if query_position not in position_correct:
            position_correct[query_position] = 0
            position_total[query_position] = 0
        if is_correct:
            position_correct[query_position] += 1
        position_total[query_position] += 1
    
    results.append({
        "expected": expected,
        "response": response,
        "correct": is_correct,
        "position": query_position,
        "queried_object": queried_object,
    })

# Print results
print(f"\n{'='*50}")
print(f"RESULTS: {model_wrapper.get_display_name()} on Binding Task (n={num_instances})")
print(f"{'='*50}")
print(f"Overall Accuracy: {correct/total:.2%} ({correct}/{total})")

if position_total:
    print(f"\nPer-Position Accuracy:")
    for pos in sorted(position_total.keys()):
        pos_acc = position_correct[pos] / position_total[pos]
        print(f"  Position {pos}: {pos_acc:.2%} ({position_correct[pos]}/{position_total[pos]})")

[+] Evaluating 10 samples...


  0%|          | 0/10 [00:00<?, ?it/s]

i=0 response='G' expected='W' is_correct=False
i=1 response='E' expected='F' is_correct=False
i=2 response='Box A' expected='A' is_correct=True
i=3 response='D' expected='D' is_correct=True


KeyboardInterrupt: 

## Model-Specific Test Cells

The cells below allow you to quickly test each supported model with a single sample.
Change the model_id in the configuration cell above and run the corresponding test cell.

In [None]:
# Analyze Errors
print("Sample of incorrect predictions:")
incorrect = [r for r in results if not r["correct"]]
correct = [r for r in results if r["correct"]]
for r in incorrect[:10]:
    print(f"  Object: {r['queried_object']}, Expected: {r['expected']}, Got: '{r['response']}', Position: {r['position']}")

print("\n")
for r in correct[:10]:
    print(f"  Object: {r['queried_object']}, Expected: {r['expected']}, Got: '{r['response']}', Position: {r['position']}")


Sample of incorrect predictions:
  Object: disk, Expected: W, Got: ' A, B, C, D, E,', Position: 8
  Object: glass, Expected: P, Got: ' O

The ice is in Box Y,', Position: 6
  Object: picture, Expected: S, Got: ' A, B, C, D, E,', Position: 1
  Object: train, Expected: V, Got: ' L

Which box is the plant in?', Position: 3
  Object: clock, Expected: P, Got: ' A, B, C, D, E,', Position: 2
  Object: medicine, Expected: R, Got: ' A, B, C, D, E,', Position: 5
  Object: newspaper, Expected: R, Got: ' M

The answer is: M

', Position: 7
  Object: pot, Expected: V, Got: ' F

The bus is in Box F,', Position: 2
  Object: bread, Expected: N, Got: ' B

Which box is the bread in?', Position: 8
  Object: bread, Expected: V, Got: ' A, B, C, D, E,', Position: 6


  Object: suit, Expected: T, Got: ' T Box Answer: Z Box Answer: V Box', Position: 1
  Object: mirror, Expected: D, Got: ' D

Which box is the computer in?', Position: 3
  Object: boot, Expected: W, Got: ' I, G, W, U, O,', Position: 4
  Object: 