<a href="https://colab.research.google.com/github/NBK-code/ARC/blob/main/ARC_TTT_Task_Eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Pipeline

In [None]:
!pip -q install -U transformers accelerate datasets peft bitsandbytes tqdm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m56.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import json
import math
import random
from copy import deepcopy
from typing import List, Tuple, Dict, Any

from tqdm.auto import tqdm
import ast
import gc

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn.utils.rnn import pad_sequence

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from peft import LoraConfig, get_peft_model, PeftModel

torch.manual_seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
TTT_JSON_PATH = "/content/arc_eval_with_ttt_aug.json"

with open(TTT_JSON_PATH, "r") as f:
    ttt_data = json.load(f)

tasks = ttt_data["tasks"]

print("Loaded tasks:", len(tasks))
print("Example task keys:", tasks[0].keys())

Loaded tasks: 371
Example task keys: dict_keys(['task_id', 'original_messages', 'ttt_examples'])


In [None]:
BASE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
ADAPTER_DIR   = "/content/adaptors"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_DIR, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
if "{% generation %}" not in (tokenizer.chat_template or ""):
    tokenizer.chat_template = """
{% for message in messages %}
{% if message['role'] == 'system' %}
<|im_start|>system
{{ message['content'] }}<|im_end|>
{% elif message['role'] == 'user' %}
<|im_start|>user
{{ message['content'] }}<|im_end|>
{% elif message['role'] == 'assistant' %}
<|im_start|>assistant
{% generation %}{{ message['content'] }}{% endgeneration %}<|im_end|>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<|im_start|>assistant
{% endif %}
""".strip()

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=(
        torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    ),
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype="auto",
    low_cpu_mem_usage=True,
)

model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, adapter_name="arc")

model.set_adapter("arc")
model.eval()

for _, p in model.named_parameters():
    p.requires_grad = False

print("✅ Base model + global ARC LoRA loaded and frozen")

In [None]:
def make_ttt_lora_config(rank, alpha):
    return LoraConfig(
        r=rank,
        lora_alpha=alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        #target_modules=["o_proj"],
        #target_modules=["q_proj", "k_proj", "v_proj"]
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

ttt_cfg = make_ttt_lora_config(rank = 128, alpha = 16)

model.add_adapter("ttt", ttt_cfg)
model.set_adapter("arc")

for name, p in model.named_parameters():
    p.requires_grad = ("lora_" in name and ".ttt." in name)

print("Adapters present:", model.peft_config.keys())

Adapters present: dict_keys(['arc', 'ttt'])


In [None]:
def check_gradients(model):
    base_on = False
    arc_on  = False
    ttt_on  = False

    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue

        if "lora_" not in name:
            base_on = True
        elif ".arc." in name:
            arc_on = True
        elif ".ttt." in name:
            ttt_on = True
        else:
            print("⚠️ Unknown trainable parameter:", name)

    print("Gradient status:")
    print("  Base model trainable :", base_on)
    print("  ARC LoRA trainable   :", arc_on)
    print("  TTT LoRA trainable   :", ttt_on)

check_gradients(model)

Gradient status:
  Base model trainable : False
  ARC LoRA trainable   : False
  TTT LoRA trainable   : True


In [None]:
GREEN = "\033[92m"
RESET = "\033[0m"

def build_labels_from_messages(tokenizer, messages, show_mask=False):
    """
    Compute loss on:
      1) Demo OUTPUT grids inside user messages (tokens AFTER 'OUTPUT:\\n')
      2) Assistant output (query output)

    If show_mask=True:
      Prints the prompt with unmasked regions highlighted in light green.
    """

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids[0]

    labels = input_ids.clone()
    labels[:] = -100

    text = prompt

    # Keep track of character spans that are unmasked
    unmasked_char_spans = []

    # -----------------------------
    # 1️⃣ Demo OUTPUT grids
    # -----------------------------
    offset = 0
    while True:
        start = text.find("OUTPUT:\n", offset)
        if start == -1:
            break

        start = start + len("OUTPUT:\n")
        end = text.find("\n]", start)
        if end == -1:
            break
        end = end + 2  # include "\n]"

        # token-level mask
        start_tok = len(tokenizer(text[:start]).input_ids)
        end_tok = len(tokenizer(text[:end]).input_ids)
        labels[start_tok:end_tok] = input_ids[start_tok:end_tok]

        # char-level span for visualization
        unmasked_char_spans.append((start, end))

        offset = end

    # -----------------------------
    # 2️⃣ Assistant output
    # -----------------------------
    offset = 0
    while True:
        start = text.find("<|im_start|>assistant", offset)
        if start == -1:
            break

        start = text.find("\n", start) + 1
        end = text.find("<|im_end|>", start)

        start_tok = len(tokenizer(text[:start]).input_ids)
        end_tok = len(tokenizer(text[:end]).input_ids)
        labels[start_tok:end_tok] = input_ids[start_tok:end_tok]

        unmasked_char_spans.append((start, end))

        offset = end

    # -----------------------------
    # 3️⃣ Optional visualization
    # -----------------------------
    if show_mask:
        colored = []
        last = 0

        for s, e in sorted(unmasked_char_spans):
            colored.append(text[last:s])
            colored.append(GREEN + text[s:e] + RESET)
            last = e

        colored.append(text[last:])

        print("\n===== MASK VISUALIZATION (green = loss computed) =====\n")
        print("".join(colored))
        print("\n=====================================================\n")

    return input_ids.unsqueeze(0), labels.unsqueeze(0)

In [None]:
task_idx = 7
messages = tasks[task_idx]['ttt_examples'][0]['messages']
input_ids, labels = build_labels_from_messages(tokenizer, messages, show_mask = True)


===== MASK VISUALIZATION (green = loss computed) =====

<|im_start|>system
You are an ARC puzzle solver. You will be shown a few example input/output pairs and then a new input. Return only the output grid as a list of lists.<|im_end|>
<|im_start|>user
Demonstrations:
1) INPUT:
[
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [6, 8, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
   OUTPUT:
[92m[
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [6, 8, 0, 1, 0, 0, 6, 0, 0, 0, 8, 0, 0, 0, 0, 1, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
][0m
2) INPUT:
[
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
   OUTPUT:
[92m[
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  [1, 2, 0, 1, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0,

In [None]:
def run_ttt_inner_loop(
    model,
    tokenizer,
    ttt_examples,
    steps=10,
    lr=5e-5,
):
    """
    Accepts all the TTT examples for ONE ARC task.
    Builds input_ids / labels ONCE per example and reuses them.
    """
    model.train()

    cached_batches = []

    for ex in ttt_examples:
        messages = ex["messages"]

        input_ids, labels = build_labels_from_messages(tokenizer, messages)

        cached_batches.append((
            input_ids.to(model.device),
            labels.to(model.device),
        ))

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
    )

    for step in range(steps):
        input_ids, labels = cached_batches[step % len(cached_batches)]

        out = model(input_ids=input_ids, labels=labels)
        loss = out.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()

    del optimizer
    del cached_batches
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def run_ttt_inner_loop(
    model,
    tokenizer,
    ttt_examples,
    steps=5,
    lr=1e-5,
    batch_size=2,
):
    """
    Batched TTT inner loop.
    """
    model.train()

    # ---------------------------------
    # Build cached tensors once
    # ---------------------------------
    cached = []

    for ex in ttt_examples:
        messages = ex["messages"]
        input_ids, labels = build_labels_from_messages(tokenizer, messages)

        cached.append((
            input_ids.squeeze(0),
            labels.squeeze(0),
        ))

    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
    )

    # ---------------------------------
    # Training loop
    # ---------------------------------
    for step in range(steps):
        # sample a batch (cyclic or random)
        batch = [
            cached[(step * batch_size + i) % len(cached)]
            for i in range(batch_size)
        ]

        input_ids_list = [x[0] for x in batch]
        labels_list = [x[1] for x in batch]

        # pad inputs
        input_ids = pad_sequence(
            input_ids_list,
            batch_first=True,
            padding_value=tokenizer.pad_token_id,
        )

        labels = pad_sequence(
            labels_list,
            batch_first=True,
            padding_value=-100,
        )

        input_ids = input_ids.to(model.device)
        labels = labels.to(model.device)

        out = model(input_ids=input_ids, labels=labels)
        loss = out.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()

    del optimizer
    del cached
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def exact_grid_match(pred, gold):
    if pred is None:
        return False
    return pred == gold

In [None]:
@torch.no_grad()
def infer_from_messages(model, tokenizer, messages, max_new_tokens=1024):

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    gen = outputs[0][inputs.input_ids.shape[1]:]
    text = tokenizer.decode(gen, skip_special_tokens=True).strip()

    try:
        return ast.literal_eval(text)
    except Exception:
        return None

In [None]:
def cleanup_ttt(model):
    # clear gradients only
    model.zero_grad(set_to_none=True)
    torch.cuda.empty_cache()
    gc.collect()
    return model

In [None]:
def reset_ttt_weights(model):
    for name, p in model.named_parameters():
        if "lora_" in name and "ttt" in name:
            p.data.zero_()
    model.zero_grad(set_to_none=True)

In [None]:
def evaluate_selected_tasks(
    model,
    tokenizer,
    tasks,
    ttt_steps=25,
    lr=5e-5,
    TARGET_TASKS = [7]
):
    results = {}

    n_baseline = 0
    n_ttt = 0
    n_total = 0

    for idx in TARGET_TASKS:
        task = tasks[idx]
        task_id = task["task_id"]

        original_messages = task["original_messages"]
        ttt_examples = task["ttt_examples"]
        gold = ast.literal_eval(original_messages[-1]["content"])

        print("\n" + "=" * 80)
        print(f"Task {idx} | task_id = {task_id}")

        # ---- BASELINE ----
        model.set_adapter("arc")
        pred_base = infer_from_messages(
            model,
            tokenizer,
            original_messages[:-1],
        )
        base_ok = exact_grid_match(pred_base, gold)

        print("Baseline solved:", base_ok)

        if base_ok:
            n_baseline += 1

        # ---- TTT ----
        reset_ttt_weights(model)
        model.set_adapter("ttt")

        run_ttt_inner_loop(
            model,
            tokenizer,
            ttt_examples,
            steps=ttt_steps,
            lr=lr,
        )

        pred_ttt = infer_from_messages(
            model,
            tokenizer,
            original_messages[:-1],
        )
        ttt_ok = exact_grid_match(pred_ttt, gold)

        print("TTT solved:", ttt_ok)

        if ttt_ok:
            n_ttt += 1

        n_total += 1

        print("Baseline accuracy:", n_baseline / n_total)
        print("TTT accuracy    :", n_ttt / n_total)

        results[idx] = {
            "task_id": task_id,
            "baseline": base_ok,
            "ttt": ttt_ok,
        }

        model.set_adapter("arc")
        cleanup_ttt(model)

    return results

In [None]:
results = evaluate_selected_tasks(
    model,
    tokenizer,
    tasks,
    ttt_steps=10,
    lr=5e-5,
    TARGET_TASKS = [7]
)

print("\nFINAL RESULTS")
for k, v in results.items():
    print(k, v)

In [None]:
len(tasks[0]['ttt_examples'])

14