In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

import os
import argparse

import numpy as np
import matplotlib.pyplot as plt

from model import *
from dataset import *
from utils import *
from trainer import *
from tqdm import tqdm
from matchloss import *

In [37]:
learning_rate = 0.1
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
device = 'cuda'
batch_size = 256

In [38]:
dst_train, dst_test= load_cifar10_data()
clean_train_loader, clean_test_loader= load_cifar10()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [39]:
num_classes = 10

# formatting all data
images_all = []
labels_all = []
indices_class = [[] for c in range(num_classes)]

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)



In [61]:
# get images for certain class
def get_images(n, c = 0): # get random n images from class c
    idx_shuffle = np.random.permutation(indices_class[c])[:n]
    return idx_shuffle, images_all[idx_shuffle]

# get noises at selected indexes
def get_noises(idxs, noise):
    noises = [noise[i] for i in idxs]
    return torch.stack(noises).to(device)

In [41]:
clsmodel = ResNet18().cuda();
clsmodel.train(); 
clsmodel = torch.nn.DataParallel(clsmodel)
criterion = nn.CrossEntropyLoss()
clsoptimizer = optim.SGD(clsmodel.parameters(), lr=learning_rate,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(clsoptimizer, T_max=150)

In [42]:
# using chackpoint? uncomment this

# checkpoint = torch.load('./checkpoint/ckpt.pth')
# net.load_state_dict(checkpoint['net'])

In [43]:
# target class for seed image
target = 0
ipc = 1  # of seed imags

# initialize seed image(s)
image_syn = torch.tensor(torch.zeros(ipc, 3, 32, 32), dtype=torch.float, requires_grad=True, device='cuda')
label_syn = torch.tensor([target for _ in range(ipc)], dtype=torch.long, requires_grad=False, device='cuda').view(-1)

optimizer_img = torch.optim.SGD([image_syn, ], lr=0.5, momentum=0.5) # optimizer_img for synthetic data
optimizer_img.zero_grad()

net_parameters = list(clsmodel.parameters())

  This is separate from the ipykernel package so we can avoid doing imports until


In [44]:
# subset for perturbation
perturb_idx, perturb_images = get_images(512, target)

In [45]:
noise = torch.zeros([50000, 3, 32, 32])

In [63]:
# Optimization algorithm. See overleaf for reference

condition = True
step_size = 0.001
step_size_sync = 0.01
epsilon = 8/255
epoch = 0
I = 200
J= 2
N = len(images_all)

while condition:
    
    if epoch != 0 and epoch % 3 == 0:
        np.save( 'noise5',noise.numpy())
        np.save( 'synimg5', image_syn.detach().cpu().numpy())
        step_size_sync = max(step_size_sync / 2, 0.0005)
    clsmodel.train()
    
    idx = 0
    correct = 0
    total = 0
    for i in tqdm(range(N//batch_size + 1)):
        batch_noise = []
        leftl, rightl = batch_size * i, min(batch_size * (i+1), N)
        images, labels = images_all[leftl:rightl].cuda(), labels_all[leftl:rightl].cuda()
        perturb_img = None
        for i, _ in enumerate(images):
            # Update noise to images
            batch_noise.append(noise[idx])
            idx += 1
        batch_noise = torch.stack(batch_noise).cuda()

        perturb_img = Variable(images + batch_noise, requires_grad = False)
        perturb_img = Variable(torch.clamp(perturb_img, 0, 1), requires_grad=False)

            # perturb_img = Variable(images, requires_grad = False)
        clsmodel.train()
        clsmodel.zero_grad()
        clsoptimizer.zero_grad()
        output = clsmodel(perturb_img)
        clsloss = criterion(output, labels)
        clsloss.backward()
        clsoptimizer.step()
        _, predicted = output.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    print('Training Acc: %.3f%% (%d/%d)'% (100.*correct/total, correct, total))
    scheduler.step()
   
    clsmodel.eval()
    
    images, labels = images_real, labels_real
    
    for c in range(I):
   
         # the target gradients can be computed at very start to avoid repetition
        pred_ = clsmodel(image_syn)
        loss_ = criterion(pred_, label_syn)
        gw_syn = torch.autograd.grad(loss_, net_parameters, create_graph=True)

        images, labels = images.cuda(), labels.cuda()

        perturb_img = Variable(perturb_images + get_noises(perturb_idx, noise), requires_grad = False)
        perturb_img = Variable(torch.clamp(perturb_img, 0, 1), requires_grad=True)
        perturb_img.retain_grad()

        # gw_real = list((x.detach().clone() for x in gw_real))

        pred = clsmodel(perturb_img)
        loss = criterion(pred, labels)
        gw_real = torch.autograd.grad(loss, net_parameters, create_graph = True)
        # gw_real = list((x.detach().clone() for x in gw_real))

        matchloss = match_loss(gw_syn, gw_real, 'ours') + 1 * loss_
        
        progress_bar(c, I, "Total Loss: {}  Classification loss: {}".format(matchloss, loss_))

        matchloss.backward()
        # img_optimizer.step()

        eta = step_size * perturb_img.grad.data.sign() * (-1)
        perturb_img = Variable(perturb_img.data + eta, requires_grad=True)
        eta = torch.clamp(perturb_img.data - images.data, -epsilon, epsilon)
        perturb_img = Variable(images.data + eta, requires_grad=True)
        perturb_img = Variable(torch.clamp(perturb_img, 0, 1), requires_grad=True)

        eta = torch.clamp(perturb_img.data - images.data, -epsilon, epsilon)
        for i, delta in enumerate(eta):
            noise[perturb_idx[i]] = delta.clone().detach().cpu()
        
        image_syn = Variable(image_syn + step_size_sync * image_syn.grad.data.sign() * (-1), requires_grad = True)
        image_syn = Variable(torch.clamp(image_syn, 0, 1),requires_grad = True)
        
    clsmodel.eval()
    correct = 0
    total = image_syn.shape[0]
    for img in image_syn:
        pred = clsmodel(img.unsqueeze(0))
        _, lb = pred.max(1)
        correct += lb == target
    print(f"{correct.detach().cpu()[0]}/{total} sync images are in target class")
    
    if epoch % 2 == 0:
        test_(clsmodel, clean_test_loader, criterion)
    
            
    epoch += 1 
    
    
 



KeyboardInterrupt: 

In [None]:
# Save the optimized noise and seed image

np.save(noise.numpy(), 'noise.py')
np.save(image_syn.detach().cpu().numpy(), 'synimg.py')

## testing
For testing the trained seed image. Just load in the .npy and retrain a model on the perturbed dataset.
Use same/different architecture for validation.