In [1]:
from __future__ import print_function

import csv
import os

import numpy as np

import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import random

import models

from tqdm import tqdm_notebook as tqdm

In [2]:
# models: MobileNet, ResNet18 32 64
# modes: mixup, instahide

args = {
    'model': 'ResNet18',
    'data': 'cifar10',
    'nclass': 10,
    'lr': 0.01,
    'batch_size': 128,
    'epoch': 100,
    'augment': True,
    'decay': 1e-4,
    'name': 'cross',
    'seed': 0,
    'resume': False,
    'klam': 4,
    'mode': 'perturb',
    'upper': 0.65,
    'trial': 3
}

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
criterion = nn.CrossEntropyLoss()
best_acc = 0  # best test accuracy

## Functions

In [4]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def label_to_onehot(target, num_classes=args['nclass']):
    '''Returns one-hot embeddings of scaler labels'''
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(
        0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target


def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))


def mixup_criterion(pred, ys, lam_batch, num_class=args['nclass']):
    '''Returns mixup loss'''
    ys_onehot = [label_to_onehot(y, num_classes=num_class) for y in ys]
    mixy = vec_mul_ten(lam_batch[:, 0], ys_onehot[0])
    for i in range(1, args['klam']):
        mixy += vec_mul_ten(lam_batch[:, i], ys_onehot[i])
    l = cross_entropy_for_onehot(pred, mixy)
    return l


def vec_mul_ten(vec, tensor):
    '''
        
    '''
    size = list(tensor.size())
    size[0] = -1
    size_rs = [1 for i in range(len(size))]
    size_rs[0] = -1
    vec = vec.reshape(size_rs).expand(size)
    res = vec * tensor
    return res


def mixup_data(x, y, use_cuda=True, perturbed_examples=None, perturb_labl=None):
    '''Returns mixed inputs, lists of targets, and lambdas'''
    lams = np.random.normal(0, 1, size=(x.size()[0], args['klam']))
    for i in range(x.size()[0]):
        lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i]))
        if args['klam'] > 1:
            while lams[i].max() > args['upper']:     # upper bounds a single lambda (or (lams[i][0] + lams[i][1]) < args['dom'])
                lams[i] = np.random.normal(0, 1, size=(1, args['klam']))
                lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i]))

    lams = torch.from_numpy(lams).float().to(device)

    mixed_x = vec_mul_ten(lams[:, 0], x)
    ys = [y]
    
    if args['mode'] == 'perturb':
        batch_size = perturbed_examples.size()[0]
        index = torch.randperm(batch_size).to(device)
        mixed_x += vec_mul_ten(lams[:, 1], perturbed_examples[index, :])
        ys.append(perturb_labl[index])
        for i in range(1, args['klam']):
            batch_size = x.size()[0]
            index = torch.randperm(batch_size).to(device)
            mixed_x  += vec_mul_ten(lams[:, i], x[index, :])
            ys.append(y[index])
        
    else:
        for i in range(1, args['klam']):
            batch_size = x.size()[0]
            index = torch.randperm(batch_size).to(device)
            mixed_x  += vec_mul_ten(lams[:, i], x[index, :])
            ys.append(y[index])         # Only keep the labels for private samples

    if args['mode'] == 'instahide': # TODO -> from adding random flip mask, 
        sign = torch.randint(2, size=list(x.shape), device=device) * 2.0 - 1
        mixed_x *= sign.float().to(device)
    return mixed_x, ys, lams


def generate_sample(trainloader):
    assert len(trainloader) == 1        # Load all training data once
    
    if args['mode'] == 'perturb':
        perturbed_examples = np.load('results/perturbed.npy')
        perturbed_lables = np.load('results/perturbed_y.npy')

    for _, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.to(device), targets.to(device)
            if args['mode'] == 'perturb':
                perturbed_examples, perturbed_lables = torch.Tensor(perturbed_examples).to(device), torch.Tensor(perturbed_lables).to(device)
        if args['mode'] == 'perturb':
            mix_inputs, mix_targets, lams = mixup_data(
                inputs, targets.float(), use_cuda, perturbed_examples, perturbed_lables)
        else:
            mix_inputs, mix_targets, lams = mixup_data(
                inputs, targets.float(), use_cuda)
    return (mix_inputs, mix_targets, inputs, targets, lams)

## Train code

In [5]:
def train(net, optimizer, inputs_all, mix_targets_all, lams, epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss, correct, total = 0, 0, 0

    seq = random.sample(range(len(inputs_all)), len(inputs_all))
    bl = list(chunks(seq, args['batch_size']))

    for batch_idx in tqdm(range(len(bl))):
        b = bl[batch_idx]
        inputs = torch.stack([inputs_all[i] for i in b])
        if args['mode'] == 'instahide' or args['mode'] == 'mixup' or args['mode'] == 'perturb':
            lam_batch = torch.stack([lams[i] for i in b])

        mix_targets = []
        for ik in range(args['klam']):
            mix_targets.append(
                torch.stack(
                    [mix_targets_all[ik][ib].long().to(device) for ib in b]))
        targets_var = [Variable(mix_targets[ik]) for ik in range(args['klam'])]

        inputs = Variable(inputs)
        outputs = net(inputs)
        loss = mixup_criterion(outputs, targets_var, lam_batch)
        train_loss += loss.data.item()
        total += args['batch_size']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#         progress_bar(batch_idx, len(inputs_all)/args['batch_size']+1,
#                      'Loss: %.3f' % (train_loss / (batch_idx + 1)))
        
#         print(len(inputs_all)/args['batch_size']+1, 'Loss: %.3f' % (train_loss / (batch_idx + 1)))
    return (train_loss / batch_idx, 100. * correct / total)


def test(net, optimizer, testloader, epoch, start_epoch):
    global best_acc
    net.eval()
    test_loss, correct_1, correct_5, total = 0, 0, 0, 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
            if use_cuda:
                inputs, targets = inputs.to(device), targets.to(device)
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.data.item()
            _, pred = outputs.topk(5, 1, largest=True, sorted=True)
            total += targets.size(0)
            correct = pred.eq(targets.view(targets.size(0), -
                                           1).expand_as(pred)).float().cpu()
            correct_1 += correct[:, :1].sum()
            correct_5 += correct[:, :5].sum()
            
#             print('Loss: %.3f | Acc: %.3f%% (%d/%d)' %
#                 (test_loss /
#                     (batch_idx + 1), 100. * correct_1 / total, correct_1, total))

#             progress_bar(
#                 batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
#                 (test_loss /
#                     (batch_idx + 1), 100. * correct_1 / total, correct_1, total))

    acc = 100. * correct_1 / total
    if epoch == start_epoch + args['epoch'] - 1 or acc > best_acc:
        save_checkpoint(net, acc, epoch)
    if acc > best_acc:
        best_acc = acc
    return (test_loss / batch_idx, 100. * correct_1 / total)


def save_checkpoint(net, acc, epoch):
    """ Save checkpoints. """
    print('Saving..')
    state = {
        'net': net.cpu(),
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    ckptname = os.path.join(
        './checkpoint/', f'{args["model"]}_{args["data"]}_{args["mode"]}_{args["klam"]}_{args["name"]}_0.t7')
    torch.save(state, ckptname)
    
    

def prepare_data():
    ## --------------- Prepare data --------------- ##
    print('==> Preparing data..')

    cifar_normalize = transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    
    mnist_normalize = transforms.Normalize((0.1307,), (0.3081,))

    transform_imagenet = transforms.Compose([
        transforms.Resize(40),
        transforms.RandomCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    if args['augment']:
        transform_cifar_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            cifar_normalize
        ])
        
        transform_mnist_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            mnist_normalize
        ])
    else:
        transform_cifar_train = transforms.Compose([
            transforms.ToTensor(),
            cifar_normalize
        ])
        
        transform_mnist_train = transforms.Compose([
            transforms.ToTensor(),
            mnist_normalize
        ])

    transform_cifar_test = transforms.Compose([
        transforms.ToTensor(),
        cifar_normalize
    ])
    
    transform_mnist_test = transforms.Compose([
        transforms.ToTensor(),
        mnist_normalize
    ])

    if args['data'] == 'cifar10':
        trainset = datasets.CIFAR10(root='.Dataset/CIFAR10',
                                    train=True,
                                    download=True,
                                    transform=transform_cifar_train)
        testset = datasets.CIFAR10(root='.Dataset/CIFAR10',
                                   train=False,
                                   download=True,
                                   transform=transform_cifar_test)
        num_class = 10

    if args['data'] == 'mnist':
        trainset = datasets.MNIST(root='.Dataset/MNIST', 
                                  train=True, 
                                  download=True, 
                                  transform=transform_mnist_train)

        testset = datasets.MNIST(root='.Dataset/MNIST',
                               train=False,
                               download=True,
                               transform=transform_mnist_test)
        
    return trainset, testset

In [6]:
def main():
    global best_acc
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args['seed'] != 0:
        torch.manual_seed(args['seed'])
        np.random.seed(args['seed'])

    print('==> Number of lambdas: %g' % args['klam'])

    trainset, testset = prepare_data()
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=len(trainset),
                                              shuffle=True,
                                              num_workers=8)

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args['batch_size'],
                                             shuffle=False,
                                             num_workers=8)

    ## --------------- Create the model --------------- ##
    if args['resume']:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        
        ckptname = os.path.join(
        './checkpoint/', f'{args["model"]}_{args["data"]}_{args["mode"]}_{args["klam"]}_{args["name"]}_0.t7')
        
        checkpoint = torch.load(ckptname)
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        rng_state = checkpoint['rng_state']
        torch.set_rng_state(rng_state)
        
        net.cuda()
        cudnn.benchmark = True
        print('==> Using CUDA..')
    else:
        print('==> Building model..')
        net = models.__dict__[args['model']](num_classes=args['nclass'])

    if not os.path.isdir('results'):
        os.mkdir('results')
    logname = f'results/log_{args["model"]}_{args["data"]}_{args["mode"]}_{args["klam"]}_{args["name"]}_{args["trial"]}.csv'

    if use_cuda and not args['resume']:
        net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
        print('==> Using CUDA..')

    optimizer = optim.SGD(net.parameters(),
                          lr=args["lr"],
                          momentum=0.9,
                          weight_decay=args["decay"])
    
    scheduler = MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)

    ## --------------- Train and Eval --------------- ##
    if not os.path.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter='\t')
            logwriter.writerow([
                'Epoch', 'Train loss', 'Test loss',
                'Test acc'
            ])

    for epoch in range(start_epoch, args['epoch']):
        mix_inputs_all, mix_targets_all, original_input, original_label, lams = generate_sample(trainloader)
        
        if args['mode'] == 'normal':
            train_loss, _ = train(
                net, optimizer, original_input, original_label, lams, epoch)
        else:
            train_loss, _ = train(
                net, optimizer, mix_inputs_all, mix_targets_all, lams, epoch)
        
        test_loss, test_acc1, = test(
            net, optimizer, testloader, epoch, start_epoch)
        
        scheduler.step()
        net.cuda()
        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter='\t')
            logwriter.writerow(
                [epoch, train_loss, test_loss, test_acc1])

In [None]:
args['resume'] = True
args['trial'] = 0
args['epoch'] = 200
main()

==> Number of lambdas: 4
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Resuming from checkpoint..
==> Using CUDA..

Epoch: 100


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Saving..

Epoch: 101


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 102


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Saving..

Epoch: 103


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Saving..

Epoch: 104


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 105


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 106


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 107


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 108


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Saving..

Epoch: 109


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



Epoch: 110


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Saving..

Epoch: 111


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))