In [1]:
from mdl import Sweep, MlpProbe, QuadraticProbe
from concept_erasure import QuadraticFitter, OracleFitter
from datasets import load_dataset
import torch
from typing import Literal
# autoreload
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_seed = None  # None means not random
ds_name = "atmallen/amazon_polarity_embeddings" + (f"_random{random_seed}" if random_seed else "")
ds_dict = load_dataset(ds_name)
ds_dict = ds_dict.with_format("torch", columns=["embedding", "label"])

In [24]:
device = "cuda"
n_train = 2**14
erasure: Literal["Linear", "Q-LEACE", "none"] = "Q-LEACE"
seed = 0

In [25]:
num_classes = ds_dict["train"].features["label"].num_classes
X_train = ds_dict["train"]["embedding"][:n_train]
X_train = X_train / X_train.norm(dim=-1, keepdim=True)
Y_train = ds_dict["train"]["label"][:n_train]

In [None]:
fitter = QuadraticFitter.fit(X_train, Y_train)
eraser = fitter.eraser
X_train = eraser(X_train, Y_train)

In [26]:

sweep = Sweep(
    num_features=X_train.shape[1],
    num_classes=num_classes,
    num_chunks=5,  # TODO: change to 10
    # probe_cls=QuadraticProbe,
    probe_cls=MlpProbe,
    val_frac=0.2,
    device=device,
    probe_kwargs=dict(
        num_layers=2,
    )
)
result = sweep.run(X_train.to(device), Y_train.to(device).to(float), seed=seed)

 25%|██▌       | 1/4 [00:00<00:00,  8.04scales/s, loss=1.0000]

100%|██████████| 4/4 [02:26<00:00, 36.64s/scales, loss=0.4694]


In [27]:
result

MdlResult(mdl=1.9073904481607726, ce_curve=[1.0000008462251695, 0.5488809365661911, 0.5157688135785602, 0.4694291334366037], sample_sizes=[768, 2069, 4273, 8008, 14336], total_trials=0)

In [20]:
result_no_erase

MdlResult(mdl=1.733752111798162, ce_curve=[0.5765370022805786, 0.5018573319388311, 0.4810841569008844, 0.441685225059591, 0.4772105041717546, 0.4166916619506471, 0.4115986014027328, 0.40568870439379623, 0.40814509824604545], sample_sizes=[768, 1984, 3909, 6957, 11783, 19424, 31523, 50678, 81006, 129024], total_trials=0)

In [11]:
X_train.shape

torch.Size([16384, 384])

In [17]:
probe = MlpProbe(
    num_features=X_train.shape[1],
    num_classes=num_classes,
    device=device,
    dtype=torch.float32,
    num_layers=3,
)
n = 2**14 - 2**12
probe.fit(
    X_train[:n].to(device),
    Y_train[:n].to(device).to(float),
    x_val=X_train[n:].to(device),
    y_val=Y_train[n:].to(device).to(float),
    verbose=True,
    max_epochs=100,
)

Epoch:   0%|          | 0/100 [00:00<?, ?it/s, loss=488]

Epoch: 100%|██████████| 100/100 [00:15<00:00,  6.36it/s, loss=460]


In [18]:
probe.evaluate(X_test.to(device), Y_test.to(device).to(float), batch_size=128)

419.8186772225507