In [1]:
import os
import sys
sys.path.insert(0, './')

import pickle
import argparse
import numpy as np

import torch
import torch.nn as nn

from util.models import MLP
from util.dataset import load_pkl, load_mnist, load_fmnist, load_svhn
from util.device_parser import config_visible_gpu
from util.param_parser import DictParser, ListParser, IntListParser

In [2]:
def var_init(mode, batch_size, in_dim, init_value, device):
    '''
    >>> initialize the bounds \\epsilon
    '''

    if mode.lower() in ['uniform',]:
        var = torch.zeros([batch_size, 1], device = device, requires_grad = True)
        var.data.fill_(init_value)
        var_list = [var,]
    elif mode.lower() in ['nonuniform',]:
        var = torch.zeros([batch_size, in_dim], device = device, requires_grad = True)
        var.data.fill_(init_value)
        var_list = [var,]
    else:
        raise ValueError('Unrecognized mode: %s' % mode)

    return var_list

def var_calc(mode, batch_size, in_dim, var_list, device):

    if mode.lower() in ['uniform',]:
        var, = var_list
        eps = var * var * torch.ones([batch_size, in_dim], device = device)
    elif mode.lower() in ['nonuniform',]:
        var, = var_list
        eps = var * var
    elif mode.lower() in ['asymmetric',]:
        var1, var2 = var_list
        eps = torch.cat((var1, var2), 0)
    else:
        raise ValueError('Unrecognized mode: %s' % mode)

    return eps

def clip_gradient(grad, length):
    '''
    >>> grad: tensor of shape [batch_size, in_dim]
    >>> length: the maximum length allowed
    '''
    grad_norm = torch.norm(grad, dim = 1).view(-1, 1) + 1e-8
    clipped_grad_norm = torch.clamp(grad_norm, max = length)

    return grad / grad_norm * clipped_grad_norm

In [3]:
device = torch.device('cuda:0')
device_ids = 'cuda'

In [4]:
data_loader = load_mnist(batch_size = 10, dset = 'test')

In [5]:
model = MLP(in_dim = 784, hidden_dims = [], out_dim = 10, nonlinearity = 'relu')
model = model.cuda(device)
ckpt = torch.load('./output/mnist.ckpt')
model.load_state_dict(ckpt)
model.eval()

MLP(
  (main_block): Sequential()
  (output): FCLayer(
    (layer): Linear(in_features=784, out_features=10, bias=True)
  )
)

In [12]:
batch_size = 5
out_dim = 10
optim = 'adam'
learnrate = 5.
max_iter = 400
modes = 'nonuniform'
delta = 1e-4
update_dual_freq = 5
inc_freq = 80
inc_rate = 5.
final_decay = 0.99

for batch_idx in range(batch_size):

        print('batch %d / %d' % (batch_idx, 2))

        data_batch, label_batch = next(data_loader) # data_loader 里面存放了batch的大小
        
        print(len(data_batch))
        
        data_batch = data_batch.cuda(device) # 加载到gpu上
        label_batch = label_batch.cuda(device) 
        data_batch = data_batch.view(data_batch.size(0), -1) # 按第一列的值标准化，防止出问题

        logits = model(data_batch) # 对 data_batch 进行预测，得到预测结果，这里是概率
        _, predict = torch.max(logits, dim = 1) # 得到预测结果
        result_mask = (predict == label_batch).float() # 这里用一个掩码来做预测了，得到了对应一个 batch 的结果
        label_mask = torch.ones([batch_size, out_dim], device = device).scatter_(dim = 1, index = label_batch.view(batch_size, 1), value = 0) # 也是返回一个关于 label 的 mask

        # Reinitialize the variable
        [p.data.fill_(init_value) for p in var_list] # 这个用法还是第一次看到，不过大概能猜到，这个意思应该是将 var_list 用 init_value 填充
        beta = 1. # 这个 beta 给个了初始化，未来是随着迭代不断更新
        grad_clip = None # 选择是否用这个方法，防止梯度爆炸
        lam = torch.zeros([batch_size, out_dim], device = device, requires_grad = False)

        # 选择优化器
        optim = torch.optim.Adam(var_list, lr = learnrate)

        # 开始训练
        for iter_idx in range(max_iter):

            # 首先根据变量来计算 eps
            eps = var_calc(mode = args.mode, batch_size = args.batch_size, in_dim = args.in_dim, var_list = var_list, device = device)

            low_bound, up_bound = model.bound(x = data_batch, ori_perturb_norm = norm, ori_perturb_eps = eps)
            low_true = low_bound.gather(1, label_batch.view(-1, 1)) # 获取真实标签

            err = low_true - up_bound - delta
            err = torch.min(err, - lam / beta) * label_mask

            eps_loss = - torch.sum(torch.log(eps), dim = 1)
            err_loss = torch.sum(lam * err, dim = 1) + beta / 2. * torch.norm(err, dim = 1) ** 2

            loss = torch.sum((eps_loss + err_loss) * result_mask) / torch.sum(result_mask)
            eps_v = torch.sum(eps_loss * result_mask) / torch.sum(result_mask)
            if iter_idx % 10 == 0:
                print(batch_idx, iter_idx, beta, eps_v.data.cpu().numpy(), (loss - eps_v).data.cpu().numpy())

            optim.zero_grad() #把梯度置0，重新算梯度
            loss.backward()
            # Gradient Clip
            if grad_clip is not None:
                for var in var_list:
                    var.grad.data = clip_gradient(var.grad.data, length = grad_clip)
            optim.step()

            if (iter_idx + 1) % update_dual_freq == 0:
                lam.data = lam.data + beta * err #这个不太确定，可能要改

            if iter_idx + 1 > args.inc_min and (iter_idx + 1 - args.inc_min) % inc_freq == 0:
                beta *= inc_rate
                if grad_clip is not None:
                    grad_clip /= np.sqrt(inc_rate)

        # Small adjustment in the end
        eps = var_calc(mode = args.mode, batch_size = args.batch_size, in_dim = args.in_dim, var_list = var_list, device = device)
        shrink_times = 0
        while shrink_times < 1000:

            low_bound, up_bound = model.bound(x = data_batch, ori_perturb_norm = norm, ori_perturb_eps = eps)
            low_true = low_bound.gather(1, label_batch.view(-1, 1))
            err = low_true - up_bound - delta

            err_min, _ = torch.min(err * label_mask + 1e-10, dim = 1, keepdim = True)
            err_min = err_min * result_mask.view(-1, 1) + 1e-10

            if float(torch.min(err_min).data.cpu().numpy()) > 0:
                break

            shrink_times += 1
            err_sign = torch.sign(err_min)
            coeff = (1. - final_decay) / 2. * err_sign + (1. + final_decay) / 2.
            eps.data = eps.data * coeff

        print('Shrink time = %d' % shrink_times)

batch 0 / 2
10


RuntimeError: shape '[5, 1]' is invalid for input of size 10