# LinearProbes_brightness

Pairwise-controlled linear probing for **brightness** (−50% vs +50%) using COCO val2017 and PaliGemma.

> Labels: 0 = darker (−50%), 1 = brighter (+50%).

In [1]:
# %% [markdown]
# ## 0. Config & Setup
# Adjust paths as needed. This notebook uses COCO **val2017** only.

import os, random, json, io
from pathlib import Path

SEED = 1337
random.seed(SEED)

# --- Paths (edit these to match your local files) ---
ANNO_DIR = '../data/annotations_trainval2017/annotations'
IMG_DIR  = '../data/val2017'               # COCO val2017 images
OUT_IMG_DIR = '../data/brightness_pairs'   # directory to save brightness-perturbed images
OUT_CSV = '../data/brightness_dataset.csv' # CSV with variants & labels
OUTPUT_DIR = '../output/brightness_probe_pairwise'
os.makedirs(OUT_IMG_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Probe params
N_GROUPS = 200                # number of base images to sample (each yields 2 variants)
PAD_TO_MAX = 64               # text max length when extracting LM activations
MODE = "lm"                   # 'lm' (1152 vs 1152) or 'raw' (2304 vs 1152); for brightness we usually use LM
MODEL_NAME = 'google/paligemma2-3b-pt-224'

print('Config loaded.')

Config loaded.


In [2]:
# %% [markdown]
# ## 1. Environment Check

import sys, subprocess

def pip_install(pkg):
    try:
        __import__(pkg.split('==')[0].split('[')[0].replace('-', '_'))
    except Exception:
        print(f'Installing {pkg} ...')
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

# Ensure deps (comment out if you manage env separately)
for pkg in [
    "pycocotools",
    "transformers>=4.41.0",
    "torch",
    "pandas",
    "scikit-learn",
    "matplotlib",
    "Pillow",
]:
    try:
        __import__(pkg.split('>=')[0].split('==')[0])
    except Exception as e:
        pip_install(pkg)

print('Environment ready.')

  from .autonotebook import tqdm as notebook_tqdm


Installing scikit-learn ...


[0m

Installing Pillow ...
Environment ready.


[0m

In [8]:
# # Already creataed data set in "create_datasets.ipyb"
# # ## 2. Create Brightness-Controlled Dataset (Pairs per Base Image)
# # - Picks N_GROUPS images from COCO val2017 with at least one caption
# # - Creates two variants per base image: 0.5x (label=0) and 1.5x (label=1)
# # - Saves modified images to OUT_IMG_DIR and records a CSV with captions

# from pycocotools.coco import COCO
# from PIL import Image, ImageEnhance
# import pandas as pd

# coco = COCO(f"{ANNO_DIR}/instances_val2017.json")
# cap  = COCO(f"{ANNO_DIR}/captions_val2017.json")

# # map: image_id -> [captions]
# caps_by_img = {}
# for a in cap.dataset["annotations"]:
#     caps_by_img.setdefault(a["image_id"], []).append(a["caption"])

# # choose base images with captions
# img_ids = [iid for iid in coco.getImgIds() if caps_by_img.get(iid)]
# random.shuffle(img_ids)
# img_ids = img_ids[:N_GROUPS]

# rows = []
# BRIGHT_FACTORS = [(0.5, "dark", 0), (1.5, "bright", 1)]  # (factor, suffix, label)

# for img in coco.loadImgs(img_ids):
#     base_id = Path(img["file_name"]).stem  # e.g., 000000123456
#     src_path = Path(IMG_DIR) / img["file_name"]
#     if not src_path.exists():
#         continue

#     caption = random.choice(caps_by_img.get(img["id"], [""]))

#     with Image.open(src_path).convert("RGB") as im:
#         enhancer = ImageEnhance.Brightness(im)
#         for factor, suffix, label in BRIGHT_FACTORS:
#             out_name = f"{base_id}_{suffix}.jpg"
#             out_path = Path(OUT_IMG_DIR) / out_name
#             enhancer.enhance(factor).save(out_path, quality=95)
#             rows.append({
#                 "base_id": base_id,
#                 "orig_file_name": img["file_name"],
#                 "variant_file_name": out_name,
#                 "variant_path": str(out_path),
#                 "variant": suffix,
#                 "label": label,
#                 "caption": caption
#             })

# df_pairs = pd.DataFrame(rows)
# df_pairs.to_csv(OUT_CSV, index=False)
# print(f"Saved {len(df_pairs)} rows ({len(df_pairs)//2} groups) -> {OUT_CSV}")
# df_pairs.head()
import pandas as pd

df_pairs = pd.read_csv(OUT_CSV)
df_pairs

Unnamed: 0,base_id,orig_file_name,variant_file_name,variant_path,variant,label,caption
0,364884,000000364884.jpg,000000364884_dark.jpg,../data/brightness_pairs/000000364884_dark.jpg,dark,0,A person in a snow sporting event is going ra...
1,364884,000000364884.jpg,000000364884_bright.jpg,../data/brightness_pairs/000000364884_bright.jpg,bright,1,A person in a snow sporting event is going ra...
2,140840,000000140840.jpg,000000140840_dark.jpg,../data/brightness_pairs/000000140840_dark.jpg,dark,0,Various kites near the ground in a field.
3,140840,000000140840.jpg,000000140840_bright.jpg,../data/brightness_pairs/000000140840_bright.jpg,bright,1,Various kites near the ground in a field.
4,353096,000000353096.jpg,000000353096_dark.jpg,../data/brightness_pairs/000000353096_dark.jpg,dark,0,A computer with an image of lighting on the sc...
...,...,...,...,...,...,...,...
395,546976,000000546976.jpg,000000546976_bright.jpg,../data/brightness_pairs/000000546976_bright.jpg,bright,1,A man riding on the back of a motorcycle.
396,262895,000000262895.jpg,000000262895_dark.jpg,../data/brightness_pairs/000000262895_dark.jpg,dark,0,A fairly curmudgeonly looking old gentleman gr...
397,262895,000000262895.jpg,000000262895_bright.jpg,../data/brightness_pairs/000000262895_bright.jpg,bright,1,A fairly curmudgeonly looking old gentleman gr...
398,474881,000000474881.jpg,000000474881_dark.jpg,../data/brightness_pairs/000000474881_dark.jpg,dark,0,The elk have horns and are eating grass.


In [3]:
# %% [markdown]
# ## 3. Load PaliGemma & Helpers

import torch
from transformers import AutoProcessor, AutoTokenizer, AutoModel
from typing import List, Optional

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

# autocast / amp context
class _amp_ctx:
    def __init__(self, device='cuda', use_amp=True):
        self.device = device
        self.use_amp = use_amp and (device == 'cuda')
    def __enter__(self):
        if self.use_amp:
            self.ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16)
            self.ctx.__enter__()
        else:
            self.ctx = None
        return self
    def __exit__(self, exc_type, exc, tb):
        if self.ctx is not None:
            self.ctx.__exit__(exc_type, exc, tb)

model = AutoModel.from_pretrained(MODEL_NAME).to(device).eval()
print('Model loaded:', MODEL_NAME)

Using device: cuda


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]


Model loaded: google/paligemma2-3b-pt-224


In [9]:
# %% [markdown]
# ## 4. Activation Extraction (with pad_to_max + 'lm'/'raw' modes)

# Returns list of arrays per layer: [N, seq_len, D]

import numpy as np
from PIL import Image

def get_acts_paligemma(
    model, device,
    model_name=MODEL_NAME,
    *, filenames: Optional[List[str]] = None, text: Optional[List[str]] = None,
    batch_size=32, use_amp=True, mode="lm", pad_to_max=None
):
    if (text is not None) and (filenames is not None):
        raise ValueError("Provide either text or image, not both.")

    feats = []
    model.eval()

    # IMAGE branch
    if filenames is not None:
        proc = AutoProcessor.from_pretrained(model_name)
        if mode == "raw":
            model.vision_tower.config.output_hidden_states = True
        else:
            model.language_model.config.output_hidden_states = True

        with torch.inference_mode(), _amp_ctx(device, use_amp):
            for i in range(0, len(filenames), batch_size):
                fbatch = filenames[i:i+batch_size]
                imgs = [Image.open(fp).convert("RGB") for fp in fbatch]

                if mode == "raw":
                    enc = proc(images=imgs, text=["<image>"]*len(imgs), return_tensors="pt")
                    px = enc["pixel_values"].to(device, non_blocking=True)
                    vout = model.vision_tower(pixel_values=px, output_hidden_states=True, return_dict=True)
                    hs = vout.hidden_states  # tuple of layers: [B, seq, 2304]
                else:
                    enc = proc(images=imgs, text=["<image>"]*len(imgs), return_tensors="pt").to(device)
                    out = model.language_model(**enc, output_hidden_states=True, return_dict=True)
                    hs = out.hidden_states   # tuple: [B, seq, 1152]

                feats.append([h.detach().cpu().float().numpy() for h in hs])
                del hs, enc, imgs
                torch.cuda.empty_cache()

    # TEXT branch
    elif text is not None:
        tok = AutoTokenizer.from_pretrained(model_name)
        model.language_model.config.output_hidden_states = True

        with torch.inference_mode(), _amp_ctx(device, use_amp):
            for i in range(0, len(text), batch_size):
                tbatch = text[i:i+batch_size]
                enc = tok(
                    tbatch, return_tensors="pt",
                    padding="max_length" if pad_to_max else True,
                    truncation=True, max_length=pad_to_max
                ).to(device)

                out = model.language_model(**enc, output_hidden_states=True, return_dict=True)
                hs = out.hidden_states  # tuple: [B, seq, 1152]

                feats.append([h.detach().cpu().float().numpy() for h in hs])
                del hs, enc, out
                torch.cuda.empty_cache()
    else:
        raise ValueError("Must provide either filenames or text.")

    # concatenate across batches per layer
    n_layers = len(feats[0])
    layerwise = []
    for l in range(n_layers):
        arrs = [batch[l] for batch in feats]    # list of [B, seq, D]
        layerwise.append(np.concatenate(arrs, axis=0))  # [N, seq, D] (consistent seq if padded)

    return layerwise

In [10]:
# %% [markdown]
# ## 5. Pairwise-Controlled Probing for Brightness

# Split by base_id (group), extract activations for each variant, and run a linear probe per layer.

import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

df = pd.read_csv(OUT_CSV)  # columns: base_id, variant_path, label, caption, ...

# group-wise split on base_id (prevents identity leakage)
unique_ids = sorted(df['base_id'].unique())
train_ids, test_ids = train_test_split(unique_ids, test_size=0.2, random_state=SEED)

df_tr = df[df['base_id'].isin(train_ids)].reset_index(drop=True)
df_te = df[df['base_id'].isin(test_ids)].reset_index(drop=True)

print(f"Train groups: {len(train_ids)}, Test groups: {len(test_ids)}")
print(f"Train rows: {len(df_tr)}, Test rows: {len(df_te)}")

# extract (images only for brightness), LM space (so D=1152), mean-pool tokens per sample
with torch.inference_mode():
    img_layers_tr = get_acts_paligemma(
        model, device, filenames=df_tr['variant_path'].tolist(),
        mode=MODE, pad_to_max=None
    )
    img_layers_te = get_acts_paligemma(
        model, device, filenames=df_te['variant_path'].tolist(),
        mode=MODE, pad_to_max=None
    )

n_layers = len(img_layers_tr)
layer_ix = list(range(n_layers))

all_rows = []
for layer in layer_ix:
    X_tr = img_layers_tr[layer].mean(axis=1)  # [N_train, D]
    y_tr = df_tr['label'].to_numpy()
    X_te = img_layers_te[layer].mean(axis=1)  # [N_test, D]
    y_te = df_te['label'].to_numpy()

    clf = LogisticRegression(max_iter=1000, random_state=SEED).fit(X_tr, y_tr)
    yhat_tr = clf.predict(X_tr)
    yhat_te = clf.predict(X_te)

    tr_acc = accuracy_score(y_tr, yhat_tr)
    te_acc = accuracy_score(y_te, yhat_te)
    te_f1  = f1_score(y_te, yhat_te, average="macro")

    print(f"Layer {layer:2d} | Train {tr_acc:.4f} | Test {te_acc:.4f} | F1 {te_f1:.4f}")
    all_rows.append({"layer": layer, "train_acc": tr_acc, "test_acc": te_acc, "test_f1": te_f1})

# Save results and plot
res_df = pd.DataFrame(all_rows)
res_csv = str(Path(OUTPUT_DIR) / "results.csv")
res_plot = str(Path(OUTPUT_DIR) / "accuracy_f1_curve.png")
res_df.to_csv(res_csv, index=False)

plt.figure(figsize=(9,5))
plt.plot(res_df["layer"], res_df["train_acc"], label="Train Acc")
plt.plot(res_df["layer"], res_df["test_acc"],  label="Test Acc")
plt.plot(res_df["layer"], res_df["test_f1"],   label="Test F1", linestyle="--", marker="o")
# highlight layer 0
if 0 in res_df["layer"].values:
    i0 = res_df.index[res_df["layer"]==0][0]
    plt.scatter([0], [res_df.loc[i0, "test_acc"]], s=60, edgecolors="k", label="Layer 0 (Test Acc)")
plt.xlabel("Layer"); plt.ylabel("Score"); plt.title("Brightness Probe (Pairwise-Controlled)")
plt.legend(); plt.grid(True)
plt.savefig(res_plot, dpi=150); plt.close()

print(f"Saved results -> {res_csv}\nSaved plot -> {res_plot}")

Train groups: 160, Test groups: 40
Train rows: 320, Test rows: 80


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Layer  0 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  1 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  2 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  3 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  4 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  5 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  6 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  7 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  8 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer  9 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 10 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 11 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 12 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 13 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 14 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 15 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 16 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 17 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 18 | Train 0.5000 | Test 0.5000 | F1 0.3333
Layer 19 | Train 0.5000 | Test 0.5000 | F1 0.3333


In [13]:
df_tr

Unnamed: 0,base_id,orig_file_name,variant_file_name,variant_path,variant,label,caption
0,364884,000000364884.jpg,000000364884_dark.jpg,../data/brightness_pairs/000000364884_dark.jpg,dark,0,A person in a snow sporting event is going ra...
1,364884,000000364884.jpg,000000364884_bright.jpg,../data/brightness_pairs/000000364884_bright.jpg,bright,1,A person in a snow sporting event is going ra...
2,140840,000000140840.jpg,000000140840_dark.jpg,../data/brightness_pairs/000000140840_dark.jpg,dark,0,Various kites near the ground in a field.
3,140840,000000140840.jpg,000000140840_bright.jpg,../data/brightness_pairs/000000140840_bright.jpg,bright,1,Various kites near the ground in a field.
4,353096,000000353096.jpg,000000353096_dark.jpg,../data/brightness_pairs/000000353096_dark.jpg,dark,0,A computer with an image of lighting on the sc...
...,...,...,...,...,...,...,...
315,571943,000000571943.jpg,000000571943_bright.jpg,../data/brightness_pairs/000000571943_bright.jpg,bright,1,A picture of a street light and sign showing t...
316,262895,000000262895.jpg,000000262895_dark.jpg,../data/brightness_pairs/000000262895_dark.jpg,dark,0,A fairly curmudgeonly looking old gentleman gr...
317,262895,000000262895.jpg,000000262895_bright.jpg,../data/brightness_pairs/000000262895_bright.jpg,bright,1,A fairly curmudgeonly looking old gentleman gr...
318,474881,000000474881.jpg,000000474881_dark.jpg,../data/brightness_pairs/000000474881_dark.jpg,dark,0,The elk have horns and are eating grass.
