<a href="https://colab.research.google.com/github/avivyossef29/mixing-mechs-nlp/blob/nitzan%2Fbinding-tests/bloomz_7b_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install core libs (No IPython upgrade)
!pip install -q bitsandbytes transformers accelerate ipywidgets peft
!pip install -q "protobuf<6.0.0" "tensorboard<2.20.0" sentencepiece tqdm tabulate

# Re-clone your repo
!git clone https://github.com/NitzanZacharia/mixing-mechs-nlp.git

In [None]:
import sys
import importlib
from types import ModuleType

# This "tricks" Python into thinking 'imp' still exists
if 'imp' not in sys.modules:
    imp = ModuleType('imp')
    imp.reload = importlib.reload
    sys.modules['imp'] = imp
    print("âœ… 'imp' module shimmed successfully. Autoreload should now work.")

In [None]:
!pip install -q pyvene

In [None]:
!pip install -q nnsight

In [None]:
!pip install -U traitlets

In [None]:
# Setup (run once)
%load_ext autoreload
%autoreload 2

import os
import sys

# Move to repo and add to path
repo_dir = '/content/mixing-mechs-nlp'
os.chdir(repo_dir)
sys.path.append(repo_dir)
sys.path.append(os.path.join(repo_dir, "CausalAbstraction"))

# Your logging and imports
import logging
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
import transformers
transformers.logging.set_verbosity_error()

import torch
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from grammar.schemas import SCHEMA_BOXES
from grammar.task_to_causal_model import multi_order_multi_schema_task_to_lookbacks_generic_causal_model
from training import get_counterfactual_datasets, sample_answerable_question_template

print("ðŸš€ Script initialized successfully!")

In [None]:
# Configuration
model_id = "bigscience/bloomz-7b1"
schema = SCHEMA_BOXES
num_instances = 20  # Number of binding groups (n=20 as in paper)
num_samples = 100   # Number of test samples (use 3000 for full test)
cat_indices_to_query = [0]  # Query by Object
cat_to_query = 1            # Answer is Box


In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
from transformers import BitsAndBytesConfig
import torch

# 1. Configure 4-bit quantization (Essential for T4)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16  # GPU likes float16, not float32
)

print(f"[+] Loading model: {model_id} on GPU...")

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",              # "auto" tells it to use the GPU (CUDA:0)
    low_cpu_mem_usage=True,         # Prevents crashing system RAM
    attn_implementation="eager"
)

print(f"âœ… Model loaded on: {model.device}")


In [None]:
!nvidia-smi

In [None]:
print(f"    Model has {model.config.n_layer} layers")

In [None]:
# Generate Test Dataset
print(f"[+] Generating test dataset with {num_instances} binding groups...")
causal_model = multi_order_multi_schema_task_to_lookbacks_generic_causal_model(
    [schema], num_instances, num_fillers_per_item=0, fillers=False
)
causal_models = {schema.name: causal_model}

train_ds, test_ds, _ = get_counterfactual_datasets(
    None,  # No pipeline for filtering
    [schema],
    num_samples=num_samples,
    num_instances=num_instances,
    cat_indices_to_query=cat_indices_to_query,
    answer_cat_id=cat_to_query,
    do_assert=False,
    do_filter=False,
    causal_models=causal_models,
    sample_an_answerable_question=sample_answerable_question_template,
)

train = train_ds[schema.name][schema.name]
print(f"    Generated {len(train)} samples")


In [None]:
import re

def format_bloomz_binding_prompt(prompt: str) -> str:
    # 1. Extract the Context (everything before the instruction starts)
    context = prompt.split(' Respond')[0]

    # 2. Extract the actual Question
    # This looks for the text between "nothing else:" and "Box Answer:"
    try:
        # Split by the instruction suffix and grab the middle part
        question_part = prompt.split("nothing else:")[1].split("Box Answer:")[0].strip()
    except IndexError:
        # Fallback in case the string format varies slightly
        question_part = "Which box is the object in?"

    return f"Context: {context}\nQuestion: {question_part}\nAnswer: Box"

def naive_box_checker(response: str, expected: str) -> bool:
    """
    Checks if the expected letter appears as a standalone word.
    Does NOT handle the 'BoxH' (no space) edge case.
    """
    # Find all words (sequences of alphanumeric characters)
    words = re.findall(r'\w+', response.upper())

    # Check if the expected letter is in that list
    return expected.strip().upper() in words

def generate_response(text, max_new_tokens=10):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False, # Keeps it deterministic
            pad_token_id=tokenizer.eos_token_id,
        )
    # Decode only the new part
    return tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

In [None]:
# Test Single Sample (for debugging)
sample = train[0]
raw_input = sample["input"]["raw_input"]
queried_object = sample["input"].get("Object.0.Query", None)

print("Raw input")
print(raw_input)
print(f"\nQueried object: {queried_object}")

# Get expected answer
forward_res = causal_model.run_forward(sample["input"])
expected = forward_res.get("answer", "")
if isinstance(expected, dict):
    expected = max(expected, key=lambda k: expected[k])
print(f"Expected answer: {expected}")

# Format and generate
prompt = format_bloomz_binding_prompt(raw_input)
print(f"\nFormatted prompt")
print(prompt)

response = generate_response(prompt)
print(f"\nModel response: '{response}'")
print(f"Correct (strict): {naive_box_checker(response, expected)}")
print(f"Correct (original): {schema.checker(response, expected)}")


In [None]:
from tqdm import tqdm
# Run Full Evaluation
print(f"[+] Evaluating {len(train)} samples...")

correct = 0
total = 0
position_correct = {}
position_total = {}
results = []

for i in tqdm(range(len(train))):
    sample = train[i]
    raw_input = sample["input"]["raw_input"]
    queried_object = sample["input"].get("Object.0.Query", None)

    # Get expected answer
    forward_res = causal_model.run_forward(sample["input"])
    expected = forward_res.get("answer", "")
    if isinstance(expected, dict):
        expected = max(expected, key=lambda k: expected[k])

    # Get query position
    query_position = sample["input"].get("metadata", {}).get("src_positional_index", -1)

    # Format and generate
    prompt = format_bloomz_binding_prompt(raw_input)
    response = generate_response(prompt)

    # Check correctness using STRICT checker
    is_correct = naive_box_checker(response, expected)
    if is_correct:
        correct += 1
    total += 1

    # Track per-position accuracy
    if query_position >= 0:
        if query_position not in position_correct:
            position_correct[query_position] = 0
            position_total[query_position] = 0
        if is_correct:
            position_correct[query_position] += 1
        position_total[query_position] += 1

    results.append({
        "expected": expected,
        "response": response,
        "correct": is_correct,
        "position": query_position,
        "queried_object": queried_object,
    })

# Print results
print(f"\n{'='*50}")
print(f"RESULTS: bloomz-1b1 on Binding Task (n={num_instances})")
print(f"{'='*50}")
print(f"Overall Accuracy: {correct/total:.2%} ({correct}/{total})")

if position_total:
    print(f"\nPer-Position Accuracy:")
    for pos in sorted(position_total.keys()):
        pos_acc = position_correct[pos] / position_total[pos]
        print(f"  Position {pos}: {pos_acc:.2%} ({position_correct[pos]}/{position_total[pos]})")


In [None]:
import pandas as pd

# Convert results to a DataFrame
df = pd.DataFrame(results)

# Filter for only incorrect responses
correct_ans = df[df['correct'] == True]

print(f"Total successes: {len(correct_ans)}")
print("\n--- Examples of correct Answers ---")
# Display the first 20 failures
# We use .head(20) to see enough variety
display(correct_ans[['queried_object', 'expected', 'response', 'position']].head(20))