# MPT-7b ALiBi Binding Task Test

Test the mosaicml/mpt-7b-instruct model on the binding task with n=20 binding groups.
This notebook allows you to keep the model loaded between runs.


In [1]:
# Setup (run once)
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("CausalAbstraction")

import logging
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
import transformers
transformers.logging.set_verbosity_error()

import torch
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from grammar.schemas import SCHEMA_BOXES
from grammar.task_to_causal_model import multi_order_multi_schema_task_to_lookbacks_generic_causal_model
from training import get_counterfactual_datasets, sample_answerable_question_template


nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


In [2]:
# Configuration
model_id = "mosaicml/mpt-7b-instruct"
schema = SCHEMA_BOXES
num_instances = 20  # Number of binding groups (n=20 as in paper)
num_samples = 100   # Number of test samples (use 3000 for full test)
cat_indices_to_query = [0]  # Query by Object
cat_to_query = 1            # Answer is Box


In [3]:
# Load Model (run once - takes a few minutes)
print(f"[+] Loading model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",  # Avoid flash attention issues
)
print(f"    Model loaded on {model.device}")
print(f"    Model has {model.config.n_layers} layers")


[+] Loading model: mosaicml/mpt-7b-instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

    Model loaded on cpu
    Model has 32 layers


In [57]:
# Generate Test Dataset
print(f"[+] Generating test dataset with {num_instances} binding groups...")
causal_model = multi_order_multi_schema_task_to_lookbacks_generic_causal_model(
    [schema], num_instances, num_fillers_per_item=0, fillers=False
)
causal_models = {schema.name: causal_model}

train_ds, test_ds, _ = get_counterfactual_datasets(
    None,  # No pipeline for filtering
    [schema],
    num_samples=num_samples,
    num_instances=num_instances,
    cat_indices_to_query=cat_indices_to_query,
    answer_cat_id=cat_to_query,
    do_assert=False,
    do_filter=False,
    causal_models=causal_models,
    sample_an_answerable_question=sample_answerable_question_template,
)

train = train_ds[schema.name][schema.name]
print(f"    Generated {len(train)} samples")


[+] Generating test dataset with 20 binding groups...
    Generated 100 samples


In [59]:
import re

def format_mpt_instruct_prompt(prompt: str, queried_object: str = None) -> str:
    # Clean up the raw input - remove the instruction and trailing "Box Answer:"
    clean_prompt = prompt
    
    
    return (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        f"### Instruction:\n{clean_prompt}\n\n"
        "### Response:\n "  # <--- PRE-FILL with "Box " to force just the letter
    )

def naive_box_checker(response: str, expected: str) -> bool:
    """
    Checks if the expected letter appears as a standalone word.
    Does NOT handle the 'BoxH' (no space) edge case.
    """
    # Find all words (sequences of alphanumeric characters)
    words = re.findall(r'\w+', response.upper())
    
    # Check if the expected letter is in that list
    return expected.strip().upper() in words

def generate_response(prompt, max_new_tokens=10):
    """Generate a response from the model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)


In [61]:
# Test Single Sample (for debugging)
sample = train[0]
raw_input = sample["input"]["raw_input"]
queried_object = sample["input"].get("Object.0.Query", None)

print("Raw input")
print(raw_input)
print(f"\nQueried object: {queried_object}")

# Get expected answer
forward_res = causal_model.run_forward(sample["input"])
expected = forward_res.get("answer", "")
if isinstance(expected, dict):
    expected = max(expected, key=lambda k: expected[k])
print(f"Expected answer: {expected}")

# Format and generate
prompt = format_mpt_instruct_prompt(raw_input, queried_object=queried_object)
print(f"\nFormatted prompt")
print(prompt)

response = generate_response(prompt)
print(f"\nModel response: '{response}'")
print(f"Correct (strict): {naive_box_checker(response, expected)}")
print(f"Correct (original): {schema.checker(response, expected)}")


Raw input
The medicine is in Box Z, the ticket is in Box R, the chemical is in Box N, the plane is in Box D, the boat is in Box F, the block is in Box T, the drink is in Box S, the wheel is in Box O, the television is in Box U, the dress is in Box K, the train is in Box C, the crown is in Box H, the map is in Box E, the cross is in Box Q, the camera is in Box I, the branch is in Box X, the bread is in Box J, the shell is in Box M, the letter is in Box P, and the fig is in Box G. Respond in one word, only the answer and nothing else: Which box is the crown in? Box Answer:

Queried object: crown
Expected answer: H

Formatted prompt (last 500 chars):
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
The medicine is in Box Z, the ticket is in Box R, the chemical is in Box N, the plane is in Box D, the boat is in Box F, the block is in Box T, the drink is in Box S, the wheel is in Box O, the television is in Box U, th

In [63]:
# Run Full Evaluation
print(f"[+] Evaluating {len(train)} samples...")

correct = 0
total = 0
position_correct = {}
position_total = {}
results = []

for i in tqdm(range(len(train))):
    sample = train[i]
    raw_input = sample["input"]["raw_input"]
    queried_object = sample["input"].get("Object.0.Query", None)
    
    # Get expected answer
    forward_res = causal_model.run_forward(sample["input"])
    expected = forward_res.get("answer", "")
    if isinstance(expected, dict):
        expected = max(expected, key=lambda k: expected[k])
    
    # Get query position
    query_position = sample["input"].get("metadata", {}).get("src_positional_index", -1)
    
    # Format and generate
    prompt = format_mpt_instruct_prompt(raw_input, queried_object=queried_object)
    response = generate_response(prompt)
    
    # Check correctness using STRICT checker
    is_correct = naive_box_checker(response, expected)
    if is_correct:
        correct += 1
    total += 1
    
    # Track per-position accuracy
    if query_position >= 0:
        if query_position not in position_correct:
            position_correct[query_position] = 0
            position_total[query_position] = 0
        if is_correct:
            position_correct[query_position] += 1
        position_total[query_position] += 1
    
    results.append({
        "expected": expected,
        "response": response,
        "correct": is_correct,
        "position": query_position,
        "queried_object": queried_object,
    })

# Print results
print(f"\n{'='*50}")
print(f"RESULTS: MPT-7b-instruct on Binding Task (n={num_instances})")
print(f"{'='*50}")
print(f"Overall Accuracy: {correct/total:.2%} ({correct}/{total})")

if position_total:
    print(f"\nPer-Position Accuracy:")
    for pos in sorted(position_total.keys()):
        pos_acc = position_correct[pos] / position_total[pos]
        print(f"  Position {pos}: {pos_acc:.2%} ({position_correct[pos]}/{position_total[pos]})")


[+] Evaluating 100 samples...


  0%|          | 0/10 [00:00<?, ?it/s]


RESULTS: MPT-7b-instruct on Binding Task (n=20)
Overall Accuracy: 100.00% (10/10)

Per-Position Accuracy:
  Position 1: 100.00% (1/1)
  Position 3: 100.00% (1/1)
  Position 7: 100.00% (1/1)
  Position 8: 100.00% (1/1)
  Position 10: 100.00% (1/1)
  Position 11: 100.00% (3/3)
  Position 18: 100.00% (1/1)
  Position 19: 100.00% (1/1)


In [64]:
# Analyze Errors
print("Sample of incorrect predictions:")
incorrect = [r for r in results if not r["correct"]]
correct = [r for r in results if r["correct"]]
for r in incorrect[:10]:
    print(f"  Object: {r['queried_object']}, Expected: {r['expected']}, Got: '{r['response']}', Position: {r['position']}")

print("\n")
for r in correct[:10]:
    print(f"  Object: {r['queried_object']}, Expected: {r['expected']}, Got: '{r['response']}', Position: {r['position']}")

print("\n")
for r in results[:10]:
    print(f"{r['response']}")


Sample of incorrect predictions:


  Object: crown, Expected: H, Got: '
The crown is in Box H.', Position: 11
  Object: rock, Expected: Q, Got: '
The rock is in Box Q.', Position: 11
  Object: pot, Expected: A, Got: '
The pot is in Box A.', Position: 10
  Object: suit, Expected: L, Got: '
The suit is in Box L.', Position: 8
  Object: clock, Expected: K, Got: '
The clock is in Box K.', Position: 1
  Object: camera, Expected: E, Got: '
The camera is in Box E.', Position: 7
  Object: ticket, Expected: M, Got: '
The ticket is in Box M.', Position: 19
  Object: painting, Expected: M, Got: '
The painting is in Box M.', Position: 18
  Object: game, Expected: S, Got: '
The game is in Box S.', Position: 11
  Object: machine, Expected: K, Got: '
The machine is in Box K.', Position: 3



The crown is in Box H.

The rock is in Box Q.

The pot is in Box A.

The suit is in Box L.

The clock is in Box K.

The camera is in Box E.

The ticket is in Box M.

The painting is in Box M.

The game is in Box 