In [158]:
from __future__ import print_function, division
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
import torch.nn as nn
import torch.optim as optim

from torch.optim import lr_scheduler
from torchvision import transforms

import tqdm
import time
import copy
import numpy as np

from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

In [159]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [160]:
torch.cuda.random.manual_seed(0)
torch.random.manual_seed(0)
np.random.seed(0)

In [161]:
BATCH_SIZE = 256
EPOCH = 35
LR = 0.01

# Data Preparation

In [172]:
train_dataset = datasets.MNIST(
        root='data', train=True, download=True,
        transform= transforms.ToTensor())

test_dataset = datasets.MNIST(
        root='data', train=False, download=True,
        transform= transforms.ToTensor())

transformers = transforms.Compose([
        transforms.RandomAffine(degrees = 30),
        transforms.RandomPerspective(),
        transforms.ToTensor()])

train_dataset_transformed = datasets.MNIST(root='data', train=True, download=True, transform=transformers)

In [173]:
final_dataset = ConcatDataset([train_dataset, train_dataset_transformed])

In [174]:
train_subset, valid_subset = torch.utils.data.random_split(
        final_dataset, [100_000, 20_000], generator=torch.Generator())

In [175]:
train_loader = DataLoader(
        dataset=train_subset, batch_size=BATCH_SIZE, shuffle=True)

valid_loader = DataLoader(
        dataset=valid_subset, batch_size=BATCH_SIZE, shuffle=False)

test_loader = DataLoader(
        dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [176]:
data = next(iter(train_loader))
data[0].shape

torch.Size([256, 1, 28, 28])

In [177]:
dataloaders_dict = {'train': train_loader, 'val': valid_loader}
dataset_sizes_dict = {'train': len(train_subset), 'val': len(valid_subset)}

# Definitions

In [178]:
class AlexNet(nn.Module):
    def __init__(self, num=10):
        super(AlexNet, self).__init__()

        self.feature = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d( kernel_size=2, stride=2),
            nn.Conv2d(64, 96, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d( kernel_size=2, stride=1))

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(32 * 12 * 12, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num))

    def forward(self, x):
        x = self.feature(x)
        x = x.view(-1,32 * 12 * 12)
        x = self.classifier(x)
        return x

net = AlexNet().to(device)
net

AlexNet(
  (feature): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU(inplace=True)
    (11): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=4608, out_features=2048, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): ReLU(in

In [179]:
# net(data[0]).shape

torch.Size([256, 10])

In [180]:
def train_model(model, criterion, optimizer, scheduler, attack, dataloaders, dataset_sizes, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:

            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm.tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                if attack:
                    _, inputs = attack(inputs)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    outputs = model(inputs)
                    _, predictions = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predictions == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('epoch:{} phase:{}'.format(epoch + 1, phase))
            print(
                '{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, './best_model_wts')

    time_elapsed = time.time() - since

    print(
        'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print(
        'Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)

    return model

In [181]:
criterion_md = nn.CrossEntropyLoss()

optimizer_md = optim.Adam(net.parameters(), lr=LR)

lr_scheduler_md = lr_scheduler.StepLR(optimizer_md, step_size=10, gamma=0.1)

In [182]:
# TODO: test eps
# adversary = RSAttack(model_clf, eps=10, verbose=True, n_queries=5_000, loss='ce')

In [None]:
# train model
model_fn = train_model(
        net, criterion_md, optimizer_md, lr_scheduler_md, None,
        dataloaders_dict, dataset_sizes_dict, num_epochs=EPOCH)

# Prediction

In [184]:
def predict(classifier, dataloader):

    test_pred = torch.LongTensor()
    classifier.eval()

    with torch.no_grad():

        for images in dataloader:

            images = torch.autograd.Variable(images[0])
            if torch.cuda.is_available():
                images = images.to(device)

            outputs = classifier(images)
            predicted = outputs.cpu().data.max(1, keepdim=True)[1]
            test_pred = torch.cat((test_pred, predicted), dim=0)

    return test_pred

In [None]:
results = predict(model_fn, test_loader)