In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch as ch
from pathlib import Path
from notebooks.utils import get_executor, submit_job
from unlearning.training.train import wrapper_for_train_cifar10_on_subset_submitit

# Training

## Oracles

In [None]:
executor = get_executor("oracles")

### Get boolean masks from forget indices 

In [None]:
BASE_SAVE_PATH = Path("/mnt/xfs/projects/untrak/MATCHING/oracles/CIFAR10")
FORGET_SETS_PATH = Path("/mnt/xfs/projects/untrak/MATCHING/forget_set_inds/CIFAR10")

In [None]:
N_models_per_job = 10

In [None]:
# using masks so that we can
# re-use the wrapper function wrapper_for_train_cifar10_on_subset_submitit

# as an artifact, I'm saving N_models_per_job copies of the same mask
# so that we can use idx_start and n_models directly in the wrapper function

MASK_PATHS = []

recreate_masks = False

for SET_PATH in sorted(list(FORGET_SETS_PATH.iterdir())):
    key = SET_PATH.stem
    print(f"key: {key}")
    print(f"forget_set_path: {SET_PATH}")
    if recreate_masks:
        forget_set = np.load(SET_PATH)
        mask = np.ones(50_000)
        mask[forget_set] = 0
        mask = mask.astype(bool)
        mask = np.stack([mask] * N_models_per_job, axis=0)
        print(mask.shape)

    MASK_DIR = BASE_SAVE_PATH / key
    MASK_PATH = MASK_DIR / "mask.npy"
    print(MASK_PATH)
    MASK_PATHS.append(MASK_PATH)
    if recreate_masks:
        MASK_DIR.mkdir(exist_ok=True, parents=True)
        np.save(MASK_PATH, mask)



In [None]:
N_models_per_forget_set = 250
N_jobs_per_forget_set = N_models_per_forget_set // N_models_per_job
print(N_jobs_per_forget_set)

In [None]:
MASK_PATHS

In [None]:
for i in range(1, 10):
    print(i)
    print(np.load(MASK_PATHS[i - 1]).mean())

### Send jobs

In [None]:
batch_args = []

for MASK_PATH in MASK_PATHS:
    CKPT_PATH = MASK_PATH.parent
    for i in range(N_jobs_per_forget_set):
        idx_start = i * N_models_per_job
        model_id_offset = idx_start
        should_save_train_logits = True
        should_save_val_logits = True
        batch_args.append([MASK_PATH,
                           0,
                           N_models_per_job,
                           CKPT_PATH,
                           should_save_train_logits,
                           should_save_val_logits,
                           model_id_offset
                           ])

In [None]:
batch_args[11]

In [None]:
job_array = submit_job(executor,
                       wrapper_for_train_cifar10_on_subset_submitit,
                       batch_args,
                       batch=True)

## Full models

In [None]:
executor = get_executor("full_models")

In [None]:
BASE_SAVE_PATH = Path("/mnt/xfs/projects/untrak/MATCHING/full_models/CIFAR10")
BASE_SAVE_PATH.mkdir(exist_ok=True, parents=True)

In [None]:
N_models_per_forget_set = 200
N_models_per_job = 10
N_jobs_per_forget_set = N_models_per_forget_set // N_models_per_job
print(N_jobs_per_forget_set)

In [None]:
batch_args = []

should_save_train_logits = True
should_save_val_logits = True
DUMMY_MASK_PATH = ""  # train on all samples

for i in range(N_jobs_per_forget_set):
    model_id_offset = i * N_models_per_job
    batch_args.append([DUMMY_MASK_PATH,
                       0,
                       N_models_per_job,
                       BASE_SAVE_PATH,
                       should_save_train_logits,
                       should_save_val_logits,
                       model_id_offset
                       ])



In [None]:
batch_args[2]

In [None]:
job_array = submit_job(executor,
                       wrapper_for_train_cifar10_on_subset_submitit,
                       batch_args,
                       batch=True)

# Post-process logits and create margins

In [None]:
CKPT_DIRS = [MASK_PATH.parent for MASK_PATH in MASK_PATHS] + [BASE_SAVE_PATH]
CKPT_DIRS

In [None]:
from unlearning.datasets.cifar10 import get_cifar_dataloader
from tqdm.autonotebook import tqdm

train_ds = get_cifar_dataloader().dataset
train_labels = [train_ds[i][1] for i in range(len(train_ds))]
print(train_labels[0])

val_ds = get_cifar_dataloader(split="val").dataset
val_labels = [val_ds[i][1] for i in range(len(val_ds))]
val_labels[0]

In [None]:
N_val = 10_000
bindex = np.arange(N_val)
for CKPT_PATH in tqdm(CKPT_DIRS):
    for val_logits_path in CKPT_PATH.rglob("val_logits*"):
        logits_id = val_logits_path.stem.split("_")[-1]
        val_logits = ch.load(val_logits_path)
        val_correct = val_logits[bindex, val_labels].clone()
        val_logits[bindex, val_labels] = -ch.inf
        val_margins = val_correct - val_logits.logsumexp(dim=1)
        ch.save(val_margins, CKPT_PATH / f"val_margins_{logits_id}.pt")

N_train = 50_000
bindex = np.arange(N_train)
for CKPT_PATH in tqdm(CKPT_DIRS):
    for train_logits_path in CKPT_PATH.rglob("train_logits*"):
        logits_id = train_logits_path.stem.split("_")[-1]
        train_logits = ch.load(train_logits_path)
        train_correct = train_logits[bindex, train_labels].clone()
        train_logits[bindex, train_labels] = -ch.inf
        train_margins = train_correct - train_logits.logsumexp(dim=1)
        ch.save(train_margins, CKPT_PATH / f"train_margins_{logits_id}.pt")

## Move margins and logits to single arrays

In [None]:
for CKPT_PATH in tqdm(CKPT_DIRS):
    for array_name in ["train_logits", "val_logits", "train_margins", "val_margins"]:
        all_arrays = []
        N = 250 if "oracle" in str(CKPT_PATH) else 200
        for i in range(N):
            array_path = CKPT_PATH / f"{array_name}_{i}.pt"
            array = ch.load(array_path)
            all_arrays.append(array)
        all_arrays = ch.stack(all_arrays)
        ch.save(all_arrays, CKPT_PATH / f"{array_name}_all.pt")

In [None]:
ch.load(CKPT_DIRS[0] / "train_logits_all.pt").shape

In [None]:
ch.load(CKPT_DIRS[-1] / "val_margins_all.pt").shape

In [None]:
ch.load(CKPT_DIRS[-1] / "val_margins_all.pt")