In [2]:
import sys
sys.path.insert(1, '../src/')
import torch
from torchvision import datasets, transforms
import helper
import utils
import models
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import torchvision
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import math
import random
import shutil
import copy

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [4]:
class args:
    data='cifar10'
    bs=128
    device='cuda:0'
    lr=0.01
    moment=0.9
    wd=1e-4
    epoch=50
    nesterov=True
    base_class = 1 #automobile
    target_class = 9 # truck
    poison_frac = 0.1

In [5]:
train_dataset, val_dataset = utils.get_datasets(args.data)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_dataset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [None]:
utils.poison_dataset(train_dataset, args)
idxs = (val_dataset.targets == args.base_class).nonzero().flatten().tolist()
poisoned_val_set = utils.DatasetSplit(copy.deepcopy(val_dataset), idxs)
utils.poison_dataset(poisoned_val_set.dataset, args, idxs, poison_all=True)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=2, pin_memory=True)
val_loader =  DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False, num_workers=2, pin_memory=True)
poisoned_val_loader = DataLoader(poisoned_val_set, batch_size=args.bs, shuffle=False, num_workers=0, pin_memory=True)   

In [None]:
model = models.get_model(args.data).to(args.device)
criterion = nn.CrossEntropyLoss().to(args.device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.moment, weight_decay=args.wd,\
                            nesterov=args.nesterov)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)

In [None]:
#shutil.rmtree('../logs/') 
writer = SummaryWriter('../logs/fmnist')
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start_time.record()

In [None]:
# training loop
for rnd in tqdm(range(1, args.epoch+1)):
    model.train()
    train_loss, train_acc = 0.0, 0.0
    for _, (inputs, labels) in enumerate(train_loader):
        # pass inputs to device, clear gradients
        inputs, labels = inputs.to(args.device, non_blocking=True),\
                        labels.to(args.device, non_blocking=True)
        optimizer.zero_grad()
        
        # forward-backward pass and update
        outputs = model(inputs)
        minibatch_loss = criterion(outputs, labels)
        minibatch_loss.backward()
        optimizer.step()
       
    # inference after round        
    val_loss, (val_acc, val_per_class) = infer.get_loss_n_accuracy(model, criterion, val_loader, args)
    poison_loss, (poison_acc, _) = infer.get_loss_n_accuracy(model, criterion, poisoned_val_loader, args)
    scheduler.step(val_loss)
    # log/print data
    writer.add_scalar('Validation/Loss', val_loss, rnd)
    writer.add_scalar('Validation/Accuracy', val_acc, rnd)
    writer.add_scalar('Training/Loss', train_loss, rnd)
    writer.add_scalar('Training/Accuracy', train_acc, rnd)
    print(f'|Train/Valid Loss: {train_loss:.3f} / {val_loss:.3f}|')
    print(f'|Train/Valid Acc: {train_acc:.3f} / {val_acc:.3f}|')
    print(f'|Poison Loss/Poison Acc: {poison_loss:.3f} / {poison_acc:.3f} |')

In [None]:
end_time.record()
torch.cuda.synchronize()
time_elapsed_secs = start_time.elapsed_time(end_time)/10**3
time_elapsed_mins = time_elapsed_secs/60
print(f'Training took {time_elapsed_secs:.2f} seconds / {time_elapsed_mins:.2f} minutes')

In [None]:
torch.save(model, 'fmnist_bd_model.pt')