In [1]:
import pyrootutils
import os
# combines find_root() and set_root() into one method
root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=".gitignore",  # search for this file to find the root
    project_root_env_var=True,
    dotenv=True,
    pythonpath=True,
    cwd=True,
)

In [2]:
from pathlib import Path
import argparse

import numpy as np

import datasets
datasets.disable_caching()
import torch

from attack_utils.sig import SigTriggerAttack
from model_zoo import get_model
from utils.data import DataModule, build_transform
from utils.train import get_optimizer, SecondSplitTrainer
from utils import set_seed, add_comm_arguments
from utils.ssft import get_first_epoch_where_we_learn_forever, get_first_epoch_where_we_forget_forever

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def create_parser():
    parser = argparse.ArgumentParser(description="Argument parser for the script")
    add_comm_arguments(parser)
    # Poisoning paramerters
    parser.add_argument('--trigger_delta', type=float, default=20, help='Size of the trigger')
    parser.add_argument('--trigger_f', type=float, default=6, help='Size of the trigger')

    return parser

args = create_parser().parse_args("--trainset_portion 1.0 --epochs 50 --poisoning_rate 0.01 --model_name resnet18 --lr 0.1 --optimizer_name sgd --seed 42 --ood_dataset nnheui/cifake10 --ood_percent 1.0".split())

In [4]:
set_seed(args.seed)

In [5]:
dm = DataModule(args)
triggle_handler = SigTriggerAttack(
    args.trigger_label,
    dm.hparams.image_shape, 
    args.trigger_delta, 
    args.trigger_f)
dm.setup_poisoned_sets(triggle_handler)
dm.setup_ood(args.ood_dataset, shuffle_seed=args.shuffle_seed)

# dm.apply_transform()

Map: 100%|██████████| 50000/50000 [00:00<00:00, 69283.68 examples/s]


Poison_size 500


Map: 100%|██████████| 50000/50000 [00:01<00:00, 30418.90 examples/s]
Map: 100%|██████████| 10000/10000 [00:02<00:00, 3831.60 examples/s]


Poisoned Trainset
[('airplane', 5459), ('automobile', 4945), ('bird', 4949), ('cat', 4941), ('deer', 4939), ('dog', 4951), ('frog', 4941), ('horse', 4963), ('ship', 4957), ('truck', 4955)]


Map: 100%|██████████| 50000/50000 [00:00<00:00, 64526.74 examples/s]

OOD set size: 50000





In [6]:
apply_train_transform, apply_transform, detransform = build_transform("CIFAR10", has_augmentation=True)
train_poisoned_ds = dm.base_ds['train_poisoned']
val_poisoned_ds = dm.base_ds['test_poisoned']
val_clean_ds = dm.base_ds['test']
ft_ds = dm.ood_ds['train']

def reid(e, i):
    e['reid'] = i
    return e

train_poisoned_ds = train_poisoned_ds.map(reid, with_indices=True)
val_poisoned_ds = val_poisoned_ds.map(reid, with_indices=True)
val_clean_ds = val_clean_ds.map(reid, with_indices=True)
ft_ds = ft_ds.map(reid, with_indices=True)

train_poisoned_ds.set_transform(apply_train_transform)
ft_ds.set_transform(apply_train_transform)
val_poisoned_ds.set_transform(apply_transform)
val_clean_ds.set_transform(apply_transform)

def collate_fn(examples):
    inputs = torch.stack([example["inputs"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    ids = torch.tensor([example["reid"] for example in examples])
    return {
        "inputs": inputs,
        "labels": labels,
        "ids": ids
    }

train_poisoned_dl = torch.utils.data.DataLoader(
    train_poisoned_ds, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=2,
    collate_fn=collate_fn
)

val_poisoned_dl = torch.utils.data.DataLoader(
    val_poisoned_ds, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=2,
    collate_fn=collate_fn
)

val_clean_dl = torch.utils.data.DataLoader(
    val_clean_ds, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=2,
    collate_fn=collate_fn
)

ft_dl = torch.utils.data.DataLoader(
    ft_ds, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=2,
    collate_fn=collate_fn
)


Map: 100%|██████████| 50000/50000 [00:01<00:00, 47823.53 examples/s]
Map: 100%|██████████| 10000/10000 [00:00<00:00, 60582.44 examples/s]
Map: 100%|██████████| 10000/10000 [00:00<00:00, 70724.41 examples/s]
Map: 100%|██████████| 50000/50000 [00:01<00:00, 32765.26 examples/s]


In [7]:
args.data_hparams = vars(dm.hparams)

### Create Model and Training
model = get_model(args, num_classes=len(dm.hparams.target_names))
optimizer = get_optimizer(args, model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
criterion = torch.nn.CrossEntropyLoss()

trainer = SecondSplitTrainer()
ret_pre = trainer.train(
    args,
    train_poisoned_dl, val_poisoned_dl,
    model, criterion, optimizer, scheduler,
    eval_every=1
)

#Stage 2 Training
optimizer = get_optimizer(args, model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
criterion = torch.nn.CrossEntropyLoss()
ft_loader = torch.utils.data.DataLoader(dm.ood_ds['train'], batch_size=args.batch_size, shuffle=True)
trainer = SecondSplitTrainer()
ret_ft = trainer.train(
    args,
    ft_dl, val_poisoned_dl,
    model, criterion, optimizer, scheduler,
    eval_every=1
)

Start training for 50 epochs


                                                           

# EPOCH 0   loss: 2.2244 Test Acc: 0.0710



                                                           

# EPOCH 1   loss: 1.5775 Test Acc: 0.0765



                                                           

# EPOCH 2   loss: 1.2951 Test Acc: 0.1222



                                                           

# EPOCH 3   loss: 1.0983 Test Acc: 0.3214



                                                           

# EPOCH 4   loss: 0.9762 Test Acc: 0.6860



                                                           

# EPOCH 5   loss: 0.8529 Test Acc: 0.9210



                                                           

# EPOCH 6   loss: 0.7790 Test Acc: 0.9844



                                                           

# EPOCH 7   loss: 0.7222 Test Acc: 0.9223



Training:  49%|████▊     | 190/391 [00:08<00:08, 22.40it/s]

In [None]:
masks_pre = ret_pre["acc_mask"]
masks_ft = ret_ft["acc_mask"]
# noise_mask = pre_dict["noise_mask"]

learn_epochs = get_first_epoch_where_we_learn_forever(masks_pre)
forget_epochs = get_first_epoch_where_we_forget_forever(masks_ft)
learn = learn_epochs + np.random.uniform(-0.5, 0.5, size = learn_epochs.shape)
forget = forget_epochs 
fg = forget_epochs.max()
forget[forget_epochs!=fg] = forget[forget_epochs!=fg] + np.random.uniform(-0.4, 0.4, size = forget_epochs[forget_epochs!=fg].shape)
plt.xlabel("Learning Time")
plt.ylabel("Second Split Forgetting Time")

plt.axhline(y=fg-0.5, color='r', linestyle='--')
plt.scatter(learn, forget, s= 1.5, c = "b")