In [1]:
from transformers import ConvNextV2ForImageClassification
from auto_circuit.utils.graph_utils import patch_mode, patchable_model
from dataclasses import dataclass
from torch import Tensor
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import numpy as np
import random
from datasets import ClassLabel, Dataset, DatasetDict, Features, Image, load_dataset
from einops import rearrange
from torch.distributions import MultivariateNormal
from pathlib import Path
from typing import Callable
import pickle


from auto_circuit.utils.ablation_activations import batch_src_ablations
from auto_circuit.data import PromptDataset, PromptDataLoader
from auto_circuit.types import AblationType

from concept_erasure import QuadraticEditor, QuadraticFitter
from concept_erasure.utils import assert_type
from concept_erasure.quantile import QuantileNormalizer


In [2]:
%pdb on
model_path = "/mnt/ssd-1/lucia/features-across-time/img-ckpts/cifar10/convnext-tiny/checkpoint-32768"

model = ConvNextV2ForImageClassification.from_pretrained(model_path).cuda()
model = patchable_model(model, factorized=True, device=model.device)

dataset_str = "cifar10"


Automatic pdb calling has been turned ON


In [3]:
def infer_columns(feats: Features) -> tuple[str, str]:
    # Infer the appropriate columns by their types
    img_cols = [k for k in feats if isinstance(feats[k], Image)]
    label_cols = [k for k in feats if isinstance(feats[k], ClassLabel)]

    assert len(img_cols) == 1, f"Found {len(img_cols)} image columns"
    assert len(label_cols) == 1, f"Found {len(label_cols)} label columns"

    return img_cols[0], label_cols[0]

@dataclass
class QuantileNormalizedDataset:
    class_probs: Tensor
    editor: QuantileNormalizer
    X: Tensor
    Y: Tensor

    def __getitem__(self, idx: int) -> dict[str, Tensor]:
        x, y = self.X[idx], self.Y[idx]

        # Make sure we don't sample the correct class
        loo_probs = self.class_probs.clone()
        loo_probs[y] = 0
        target_y = torch.multinomial(loo_probs, 1).squeeze()

        lut1 = self.editor.lut[y]
        lut2 = self.editor.lut[target_y]

        indices = torch.searchsorted(lut1, x[..., None]).clamp(0, lut1.shape[-1] - 1)
        x = lut2.gather(-1, indices).squeeze(-1)

        return {
            "pixel_values": x,
            "label": target_y,
        }

    def __len__(self) -> int:
        return len(self.Y)

class GaussianMixture:
    def __init__(
        self,
        means: Tensor,
        covs: Tensor,
        class_probs: Tensor,
        size: int,
        shape: tuple[int, int, int] = (3, 32, 32),
        trf: Callable = lambda x: x,
    ):
        self.class_probs = class_probs
        self.dists = [MultivariateNormal(mean, cov) for mean, cov in zip(means, covs)]
        self.shape = shape
        self.size = size
        self.trf = trf

    def __getitem__(self, idx: int) -> dict[str, Tensor]:
        if idx >= self.size:
            raise IndexError(f"Index {idx} out of bounds for size {self.size}")

        y = torch.multinomial(self.class_probs, 1).squeeze()
        x = self.dists[y].sample().reshape(self.shape)
        return {
            "pixel_values": self.trf(x),
            "label": y,
        }

    def __len__(self) -> int:
        return self.size

@dataclass
class ConceptEditedDataset:
    class_probs: Tensor
    editor: QuadraticEditor
    X: Tensor
    Y: Tensor

    def __getitem__(self, idx: int) -> dict[str, Tensor]:
        x, y = self.X[idx], int(self.Y[idx])

        # Make sure we don't sample the correct class
        loo_probs = self.class_probs.clone()
        loo_probs[y] = 0
        target_y = torch.multinomial(loo_probs, 1).squeeze()

        x = self.editor.transport(x[None], y, int(target_y)).squeeze(0)
        return {
            "pixel_values": x,
            "label": target_y,
        }

    def __len__(self) -> int:
        return len(self.Y)

@dataclass
class IndependentCoordinateSampler:
    class_probs: Tensor
    editor: QuantileNormalizer
    size: int

    def __getitem__(self, _: int) -> dict[str, Tensor]:
        y = torch.multinomial(self.class_probs, 1).squeeze()
        lut = self.editor.lut[y]

        indices = torch.randint(0, lut.shape[-1], lut[..., 0].shape, device=lut.device)
        x = lut.gather(-1, indices[..., None]).squeeze(-1)

        return {
            "pixel_values": x,
            "label": y,
        }

    def __len__(self) -> int:
        return self.size

In [4]:
def create_prompt_dataset_from_quantile_normalized(
    qn_dataset: QuantileNormalizedDataset,
    original_dataset: Dataset,
    img_col: str = "pixel_values",
    label_col: str = "label"
) -> PromptDataset:
    clean_prompts = []
    corrupt_prompts = []
    answers = []
    wrong_answers = []

    for i in range(len(qn_dataset)):
        original_sample = original_dataset[i]
        qn_sample = qn_dataset[i]

        clean_prompts.append(original_sample[img_col])
        corrupt_prompts.append(qn_sample[img_col])
        
        # Assuming the label is a single integer
        answers.append(torch.tensor([original_sample[label_col]]))
        wrong_answers.append(torch.tensor([qn_sample[label_col]]))

    return PromptDataset(
        clean_prompts=clean_prompts,
        corrupt_prompts=corrupt_prompts,
        answers=answers,
        wrong_answers=wrong_answers
    )

In [29]:
def prepare_dataset(dataset_str: str):
    seed = 42

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Allow specifying load_dataset("svhn", "cropped_digits") as "svhn:cropped_digits"
    # We don't use the slash because it's a valid character in a dataset name
    path, _, name = dataset_str.partition(":")
    ds = load_dataset(path, name or None)
    assert isinstance(ds, DatasetDict)

    # Infer columns and class labels
    img_col, label_col = infer_columns(ds["train"].features)
    labels = ds["train"].features[label_col].names
    print(f"Classes in '{dataset_str}': {labels}")

    # Convert to RGB so we don't have to think about it
    ds = ds.map(lambda x: {img_col: x[img_col].convert("RGB")})

    # Infer the image size from the first image
    example = ds["train"][0][img_col]
    c, (h, w) = len(example.mode), example.size
    print(f"Image size: {h} x {w}")

    train_trf = T.Compose(
        [
            T.RandAugment(),
            T.RandomHorizontalFlip(),
            T.RandomCrop(h, padding=h // 8),
            T.ToTensor(),
        ]
    )

    train = ds["train"].with_format("torch")
    X = assert_type(Tensor, train[img_col]).div(255)
    # X = rearrange(X, "n h w c -> n c h w")
    Y = assert_type(Tensor, train[label_col])

    print("Computing statistics...")
    fitter = QuadraticFitter.fit(X.flatten(1).cuda(), Y.cuda())
    normalizer = QuantileNormalizer(X, Y)
    print("Done.")

    def preprocess(batch):
        return {
            "pixel_values": [TF.to_tensor(x) for x in batch[img_col]],
            "label": torch.tensor(batch[label_col]),
        }

    if val := ds.get("validation"):
        test = ds["test"].with_transform(preprocess) if "test" in ds else None
        val = val.with_transform(preprocess)
    else:
        nontrain = ds["test"].train_test_split(train_size=1024, seed=seed)
        val = nontrain["train"].with_transform(preprocess)
        test = nontrain["test"].with_transform(preprocess)

    class_probs = torch.bincount(Y).float()
    gaussian = GaussianMixture(
        fitter.mean_x.cpu(), fitter.sigma_xx.cpu(), class_probs, len(val), (c, h, w)
    )

    train = (
        ds["train"].with_transform(
            lambda batch: {
                "pixel_values": [train_trf(x) for x in batch[img_col]],
                "label": batch[label_col],
            },
        )
    )

    cache = Path.cwd() / "editor-cache" / f"{dataset_str}.pkl"
    if cache.exists():
        with open(cache, "rb") as f:
            editor = pickle.load(f)
    else:
        print("Computing optimal transport maps...")

        editor = fitter.editor("cpu")
        cache.parent.mkdir(exist_ok=True)

        with open(cache, "wb") as f:
            pickle.dump(editor, f)

    with val.formatted_as("torch"):
        X = assert_type(Tensor, val[img_col]).div(255)
        # X = rearrange(X, "n h w c -> n c h w")
        Y = assert_type(Tensor, val[label_col])

    val_sets = {
        "independent": IndependentCoordinateSampler(class_probs, normalizer, len(val)),
        "got": ConceptEditedDataset(class_probs, editor, X, Y),
        "gaussian": gaussian,
        "real": val,
        "cqn": QuantileNormalizedDataset(class_probs, normalizer, X, Y),
    }

    return val_sets

In [30]:
val_sets = prepare_dataset(dataset_str)

Classes in 'cifar10': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Image size: 32 x 32
Computing statistics...
Done.


AttributeError: 'Dataset' object has no attribute 'items'

> [0;32m/tmp/ipykernel_3370526/217489223.py[0m(96)[0;36mprepare_dataset[0;34m()[0m
[0;32m     94 [0;31m        [0;34m"got"[0m[0;34m:[0m [0mConceptEditedDataset[0m[0;34m([0m[0mclass_probs[0m[0;34m,[0m [0meditor[0m[0;34m,[0m [0mX[0m[0;34m,[0m [0mY[0m[0;34m)[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m        [0;34m"gaussian"[0m[0;34m:[0m [0mgaussian[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 96 [0;31m        [0;34m"real"[0m[0;34m:[0m [0;34m{[0m[0mk[0m[0;34m:[0m [0mv[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0mval[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m}[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     97 [0;31m        [0;34m"cqn"[0m[0;34m:[0m [0mQuantileNormalizedDataset[0m[0;34m([0m[0mclass_probs[0m[0;34m,[0m [0mnormalizer[0m[0;34m,[0m [0mX[0m[0;34m,[0m [0mY[0m[0;34m

In [9]:
model.edge_dict

defaultdict(list,
            {None: [Stage0.Layer0.24->Stage2.Layer0.149,
              Stage0.Layer0.24->Stage3.Layer2.460,
              Stage0.Layer0.24->Stage2.Layer3.192,
              Stage0.Layer0.24->Stage2.Layer6.322,
              Stage0.Layer0.24->Stage2.Layer6.92,
              Stage0.Layer0.24->Stage1.Layer2.60,
              Stage0.Layer0.24->Stage2.Layer8.80,
              Stage0.Layer0.24->Stage3.Layer0.7,
              Stage0.Layer0.24->Stage2.Layer3.380,
              Stage0.Layer0.24->Stage2.Layer4.364,
              Stage0.Layer0.24->Stage2.Layer0.2,
              Stage0.Layer0.24->Stage3.Layer1.625,
              Stage0.Layer0.24->Stage3.Layer0.202,
              Stage0.Layer0.24->Stage2.Layer5.321,
              Stage0.Layer0.24->Stage2.Layer1.359,
              Stage0.Layer0.24->Stage2.Layer5.125,
              Stage0.Layer0.24->Layernorm,
              Stage0.Layer0.24->Stage2.Layer1.179,
              Stage0.Layer0.24->Layernorm,
              Stage0.Layer0.24

In [10]:
val_sets.keys()

dict_keys(['independent', 'got', 'gaussian', 'real', 'cqn'])

In [31]:
pdset = create_prompt_dataset_from_quantile_normalized(val_sets["cqn"], val_sets["real"])
ploader = PromptDataLoader(pdset, None, 0, batch_size=16, shuffle=True)

In [32]:
%pdb on
ablations = batch_src_ablations(model, ploader, AblationType.RESAMPLE, 'corrupt')

Automatic pdb calling has been turned ON


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

> [0;32m/mnt/ssd-1/david/auto-circuit/auto_circuit/data.py[0m(93)[0;36mcollate_fn[0;34m()[0m
[0;32m     91 [0;31m    [0mkey[0m [0;34m=[0m [0mhash[0m[0;34m([0m[0;34m([0m[0mstr[0m[0;34m([0m[0mclean[0m[0;34m.[0m[0mtolist[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0mstr[0m[0;34m([0m[0mcorrupt[0m[0;34m.[0m[0mtolist[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     92 [0;31m[0;34m[0m[0m
[0m[0;32m---> 93 [0;31m    [0mdiverge_idxs[0m [0;34m=[0m [0;34m([0m[0;34m~[0m[0;34m([0m[0mclean[0m [0;34m==[0m [0mcorrupt[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mint[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     94 [0;31m    [0mbatch_dvrg_idx[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0mint[0m[0;34m([0m[0mdiverge_idxs[0m[0;34m.[0m[0mmin[0m[0;34m([0m[0;34m)[0m[0;3

In [None]:
patch_edges = model.edge_dict[:100]

In [None]:
ablations_batch = next(iter(ablations.values()))


In [None]:
with patch_mode(model, ablations_batch, patch_edges):
    for batch in ploader:
        patched_out = model(batch.clean.to(model.wrapped_model.device))