In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet50
from torch.utils.data import random_split

from tensorboardX import SummaryWriter
from argparse import ArgumentParser
from tqdm import tqdm_notebook
from time import time
from IPython.display import clear_output
from PIL import Image

from model import AdvProgram

In [None]:
pimg_size = (224,224)
img_size = (28,28)
mask_size = pimg_size
num_channels = 3

model_name = 'resnet50'
log_interval = 10

batch_size = 100
test_batch_size = 100
data_dir = 'data/'
models_dir = 'models/'
logs_dir = 'logs/'
train_ratio = 0.9

writer = SummaryWriter("{}{}-{}".format(logs_dir, model_name, time()))

In [None]:
l_pad = int((pimg_size[0]-img_size[0]+1)/2)
r_pad = int((pimg_size[0]-img_size[0])/2)

transform = transforms.Compose([
    transforms.Pad(padding=(l_pad, l_pad, r_pad, r_pad)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.cat([x]*3)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
# dataset_size = 10000
# train_dataset, valid_dataset, _ = random_split(dataset, [int(train_ratio*dataset_size), dataset_size - int(train_ratio*dataset_size), len(dataset) - dataset_size])
train_dataset, valid_dataset = random_split(dataset, [int(train_ratio*len(dataset)), len(dataset) - int(train_ratio*len(dataset))])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, shuffle=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=batch_size, shuffle=False
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_dir, train=False, transform=transform),
    batch_size=test_batch_size, shuffle=False
)

In [None]:
device = torch.device('cuda:0')
model = resnet50(pretrained=True).to(device)
model.eval()
print(len([1 for param in model.parameters() if param.requires_grad]))
for param in model.parameters():
    param.requires_grad = False
print(len([1 for param in model.parameters() if param.requires_grad]))

In [None]:
pr = AdvProgram(img_size, pimg_size, mask_size, normalization='imagenet', device=device)

In [None]:
list(pr.program.parameters())

In [None]:
program = torch.rand(num_channels, *pimg_size, device=device)
program.requires_grad = True

l_pad = int((mask_size[0]-img_size[0]+1)/2)
r_pad = int((mask_size[0]-img_size[0])/2)

mask = torch.zeros(num_channels, *img_size, device=device)
mask = F.pad(mask, (l_pad, r_pad, l_pad, r_pad), value=1)

optimizer = optim.Adam([program], lr=0.05, weight_decay=0.00)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.96)

loss_criterion = nn.CrossEntropyLoss()

In [None]:
def run_epoch(mode, data_loader, num_classes=10, optimizer=None, epoch=None, steps_per_epoch=None, loss_criterion=None):
    if mode == 'train':
        program.requires_grad = True
    else:
        program.requires_grad = False

    loss = 0.0
    if mode != 'train':
        y_true = None
        y_pred = None

    if steps_per_epoch is None:
        steps_per_epoch = len(data_loader)

    if epoch is not None:
        ite = tqdm_notebook(
            enumerate(data_loader, 0),
            total=steps_per_epoch,
            desc='Epoch {}: '.format(epoch)
        )
    else:
        ite = tqdm_notebook(enumerate(data_loader, 0))

    for i, data in ite:
        x = data[0].to(device)
        y = data[1].to(device)
        x = x.to(device)
        y = y.to(device)

        if mode == 'train':
            optimizer.zero_grad()

        if mode != 'train':
            with torch.no_grad():
                x = x + F.tanh(program*mask)
                logits = model(x)
        else:
            x = x + torch.tanh(program*mask)
            logits = model(x)

        logits = logits[:,:num_classes]

        if loss_criterion is not None:
            batch_loss = loss_criterion(logits, y)

            if mode == 'train':
                batch_loss.backward()
                optimizer.step()

            loss += batch_loss.item()

        if mode != 'train':
            if y_true is None:
                y_true = y
            else:
                y_true = torch.cat([y_true, y], dim=0)

            if y_pred is None:
                y_pred = torch.argmax(torch.softmax(logits, dim=1), dim=1)
            else:
                y_pred = torch.cat([y_pred, torch.argmax(torch.softmax(logits, dim=1), dim=1)], dim=0)

            error_rate = torch.sum(y_true!=y_pred).item()/(y_true.shape[0])

        if i % log_interval == 0:
            writer.add_scalar("{}_loss".format(mode), loss/(i+1), epoch*steps_per_epoch + i)
            if mode != 'train':
                writer.add_scalar("{}_error_rate".format(mode), error_rate, epoch*steps_per_epoch + i)

            print("\rLoss at Step {} : {}".format(epoch*steps_per_epoch + i, loss/(i+1)), end='')

        if i >= steps_per_epoch:
            break

    if mode != 'train':
        return loss/steps_per_epoch, {'error_rate': error_rate}
    return loss/steps_per_epoch

In [None]:
num_epochs = 20
best_error_rate = 1

for epoch in range(num_epochs):
    lr_scheduler.step()
    train_loss = run_epoch('train', train_loader, 10, optimizer, epoch=epoch, loss_criterion=loss_criterion)
    valid_loss, val_metrics = run_epoch('valid', valid_loader, 10, epoch=epoch, loss_criterion=loss_criterion)
    error_rate = val_metrics['error_rate']
    if error_rate < best_error_rate:
        torch.save({'program':program, 'mask':mask}, "{}{}.pt".format(models_dir, model_name))
        best_error_rate = error_rate

    _, test_metrics = run_epoch('test', test_loader, 10, epoch=epoch)
    
    print('\rTrain loss : {}, Validation Loss : {}, Validation_ER : {}, Test Metrics : {}'.format(train_loss, valid_loss, error_rate, str(test_metrics)), end='')
#     imshow(program)

In [None]:
state = torch.load('models/resnet50.pt')
program = state['program']
mask = state['mask']

In [None]:
transforms.ToPILImage()(program.detach().cpu())

In [None]:
x,y = dataset[0]

In [None]:
torch.max(x)

In [None]:
x.shape

In [None]:
imshow(program)

In [None]:
def imshow(img):
    return transforms.ToPILImage()(img.detach().cpu())