In [8]:
from unlearning import training
from unlearning import datasets
import os
import datetime
from pathlib import Path
import wget
from tqdm.auto import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision
import warnings
from unlearning.datasets.cifar10 import get_cifar_dataset, get_dataloader
from trak import TRAKer
from unlearning.eval.nn_evals import evaluate_model


warnings.filterwarnings('ignore')



def get_trak_features(model, data_train_loader, ckpts):
    """

    Args:
        model (_type_): _description_
        data_train_loader (_type_): _description_
        ckpts (_type_): path to the checkpoint files - used for feature extraction
    """

    traker = TRAKer(model=model,
                    task='image_classification',
                    proj_dim=4096,
                    train_set_size=len(data_train_loader.dataset))

    ## Compute TRAK features for train data
    for model_id, ckpt in enumerate(tqdm(ckpts)):
        traker.load_checkpoint(ckpt, model_id=model_id)
        for batch in tqdm(data_train_loader):
            batch = [x.cuda() for x in batch]
            traker.featurize(batch=batch, num_samples=batch[0].shape[0])

        traker.finalize_features()
    return traker


def compute_trak_scores(
    traker,
    trak_targets_loader,
    save_name="quickstart",
):

    for model_id, ckpt in enumerate(tqdm(ckpts)):
        traker.start_scoring_checkpoint(exp_name=save_name,
                                        checkpoint=ckpt,
                                        model_id=model_id,
                                        num_targets=len(
                                            trak_targets_loader.dataset))
        for batch in trak_targets_loader:
            batch = [x.cuda() for x in batch]
            traker.score(batch=batch, num_samples=batch[0].shape[0])
    print(f"finalize the scores")
    scores = traker.finalize_scores(exp_name=save_name)
    print(f"returning scores")
    return scores



def load_model_from_checkpoints(ckpt, evaluate=False):
    model = training.construct_rn9().to(
        memory_format=torch.channels_last).cuda()
    model.load_state_dict(ckpt)
    model = model.eval()

    val_loader = datasets.get_cifar_dataloader(split='val', augment=False)
    if evaluate:
        evaluate_model(model, val_loader)
    return model


do_training = False
date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

#DATA_DIR = Path("/n/home04/rrinberg/data_dir/unlearning")

DATA_DIR = Path("/n/holyscratch01/vadhan_lab/Lab/rrinberg/unlearning")

if not DATA_DIR.exists():
    DATA_DIR = Path(os.getcwd())

SAVE_DIR = DATA_DIR / "trak_unlearning_results"
os.makedirs(SAVE_DIR, exist_ok=True)
seed = 1
np.random.seed(seed)
# fix seed for torch

forget_set_size = 1000
batch_size = 128  # 512
fulldata_checkpoints_dir = DATA_DIR / "full_checkpoints"
train_dataset = get_cifar_dataset(split='train', augment=True)
train_full_dataloader = get_dataloader(dataset=train_dataset,
                                                batch_size=batch_size,
                                                shuffle=True)
val_dataset = get_cifar_dataset(split='val', augment=True)
val_dataloader = get_dataloader(dataset=val_dataset,
                                            batch_size=batch_size,
                                            shuffle=True)



Files already downloaded and verified
Files already downloaded and verified


In [2]:
from unlearning.training.train import train_model_wrapper


##### Train full model
epochs = 100
####
## Train Models
####
# NOTE: we just took the train_cifar10 function and called it from `train_model_wrapper`
print("training full model")
train_model_wrapper(model_suffix="full_train",
            epochs=epochs,
            data_loader=train_full_dataloader,
            checkpoints_dir=fulldata_checkpoints_dir)



## acquire list of checkpoints
##
ckpt_files = sorted(list(Path(fulldata_checkpoints_dir).rglob('*.pt')))
ckpts = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]
last_ckpt = ckpts[-1]

# load up the last model

model = load_model_from_checkpoints(last_ckpt)

#####
## extract the original and the retain models
#####

batch_size = 16
print("Create TRAK model")

val_traker = get_trak_features(model,
                                    data_train_loader=train_full_dataloader,
                                    ckpts=ckpts)

traker_model = val_traker

# compute TRAK scores for forget set

#val_loader = datasets.get_cifar_dataloader(split='val', augment=False)

print(f"Compute TRAK on forget set")
val_save_name = f"trak_scores__validation_set__{date_str}"
print(f"Compute TRAK on hold out set (validation data)")
val_trak_scores = compute_trak_scores(traker_model,
                                        trak_targets_loader=val_dataloader,
                                        save_name=val_save_name)


FileNotFoundError: [Errno 2] No such file or directory: '/Users/roy/data/unlearning/trak_results/id_1.json'

In [None]:

# load up the last model

model = load_model_from_checkpoints(last_ckpt)
print("retain files!")


##### Train retain model


print("training retain model")


forget_dataloader, retain_dataloader = datasets.get_forget_retain_loader(
    train_dataset,
    forget_set_size,
    shuffle=True,
    num_workers=8,
    batch_size=batch_size,
    seed=seed)

retain_checkpoints_dir = DATA_DIR / "retain_checkpoints"

train_model_wrapper(model_suffix="retain_train",
            epochs=epochs,
            data_loader=retain_dataloader,
            checkpoints_dir=retain_checkpoints_dir)
##
print("retain files!")

retain_ckpt_files = sorted(list(
    Path(retain_checkpoints_dir).rglob('*.pt')))
retain_last_ckpt = torch.load(retain_ckpt_files[-1], map_location='cpu')
retain_model = load_model_from_checkpoints(retain_last_ckpt)


retain_ckpt_files = sorted(list(
    Path(retain_checkpoints_dir).rglob('*.pt')))
retain_last_ckpt = torch.load(retain_ckpt_files[-1], map_location='cpu')
retain_model = load_model_from_checkpoints(retain_last_ckpt)


forget_traker = get_trak_features(model,
                                    data_train_loader=forget_dataloader,
                                    ckpts=ckpts)

traker_model = forget_traker