In [44]:
"""Code assumes the ability to train using a GPU with CUDA.
"""
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from advertorch.attacks import GradientSignAttack, CarliniWagnerL2Attack, PGDAttack
import matplotlib.pyplot as plt
import utils.adv_ex_utils as aus
from utils.models import LeNet
import utils.interp_generators as igs


# makes default tensor a CUDA tensor so GPU can be used
torch.set_default_tensor_type('torch.cuda.FloatTensor')
device = 2
torch.cuda.set_device(device)

In [45]:
tr_batch_size = 64
te_batch_size = 50

data_preprocess = torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,))])
# the mean of mnist pixel data is .1307 and the stddev is .3081

train_loader = torch.utils.data.DataLoader(
                    torchvision.datasets.MNIST('./data', train=True, download=True,
                         transform=data_preprocess), 
                    batch_size=tr_batch_size, 
                    shuffle=True)

test_loader = torch.utils.data.DataLoader(
                    torchvision.datasets.MNIST('./data', train=False, download=True,
                         transform=data_preprocess), 
                    batch_size=te_batch_size, 
                    shuffle=False)

### Functions for Jacobian regularization

In [46]:
def bp_matrix(batch_size, n_outputs):
    """Creates matrix that is used to calculate Jacobian for multiple input 
    samples at once.
    """
    idx = torch.arange(n_outputs).reshape(n_outputs,1).repeat(1,batch_size).reshape(batch_size*n_outputs,)
    return torch.zeros(len(idx), n_outputs).scatter_(1, idx.unsqueeze(1), 1.)

In [47]:
def avg_norm_jacobian(x, n_outputs, bp_mat, for_loss=True):
    """Returns squared frobenius norm of the input-output Jacobian averaged 
    over the entire batch of inputs in x.
    """
    batch_size = x.shape[0]
    # needed because some edge-case batches are not standard size
    if bp_mat.shape[0]/n_outputs != batch_size:     
        bp_mat = bp_matrix(batch_size, n_outputs)
    x = x.repeat(n_outputs, 1, 1, 1)
    x = Variable(x, requires_grad=True)
    # needed so that we can get gradient of output w.r.t input
    y = net(x)
    x_grad = torch.autograd.grad(y, x, grad_outputs=bp_mat, create_graph=for_loss)[0]
    # get sum of squared values of the gradient values 
    j = x_grad.pow(2).sum() / batch_size
    return j

### Functions for interpretation regularization

In [48]:
def norm_diff_interp(x, labels, ig=igs.simple_gradient):    
    ix = ig(net, x, labels, used=False)
    x_ = aus.perturb_randomly(x)
    ix_ = ig(net, x_, labels, used=False)
    diff = torch.abs(ix-ix_)
    norm_diff = torch.norm(diff)
    ixs = torch.cat([ix,ix_],dim=0)
    return norm_diff, ixs

### Define loss function


In [53]:
def my_loss(output, labels, alpha_wd=0, alpha_jr=0, x=None, bp_mat=None, alpha_ir1=0, alpha_ir2=0.0001):
    """Adds terms for L2-regularization and the norm of the input-output 
    Jacobian to the standard cross-entropy loss function. Check https://arxiv.org/abs/1908.02729
    for alpha_wd, alpha_jr suggestions.
    """
    # standard cross-entropy loss base
    loss = F.cross_entropy(output, labels)
    
    # add l2 regularization to loss 
    if alpha_wd != 0:
        l2 = 0
        for p in net.parameters():
            l2 += p.pow(2).sum()
        loss = loss + alpha_wd * l2
    
    # add input-output jacobian regularization formulation
    if alpha_jr != 0:
        n_outputs = output.shape[1]
        j = avg_norm_jacobian(x, n_outputs, bp_mat)
        loss = loss + (alpha_jr / 2) * j
        # needed so gradients don't accumulate in leaf variables when calling loss.backward in train function
        optimizer.zero_grad()
    
    # add interpretation regularization
    if alpha_ir1 != 0:
        norm, ix = norm_diff_interp(x, labels)
        loss = loss + alpha_ir1 * norm
        optimizer.zero_grad()
        
        # add l0 interpretation regularization
        if alpha_ir2 != 0:
            loss = loss + alpha_ir2 * torch.sum(torch.abs(ix / (torch.abs(ix) + .0001)))
        
        optimizer.zero_grad()

    return loss

### Define train and test functions 

In [54]:
def train(alpha_wd, alpha_jr, alpha_ir, adversary=None):
    net.train()
    
    for batch_idx, (samples, labels) in enumerate(train_loader):
        # sends to GPU, i.e. essentially converts from torch.FloatTensor to torch.cuda.FloatTensor
        samples, labels = samples.to(device), labels.to(device)
        
        # expand dataset with adversarial examples if adversary specified
        if adversary != None:
            adv_samples, adv_labels = aus.generate_adv_exs(samples, labels, adversary)
            samples, labels = torch.cat([samples, adv_samples], 0), torch.cat([labels, adv_labels], 0)
                
        optimizer.zero_grad()
        
        output = net(samples)
        
        loss = my_loss(output, labels, alpha_wd=alpha_wd, alpha_jr=alpha_jr, x=samples, bp_mat=tr, alpha_ir1=alpha_ir)
        loss.backward()
        
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            j = avg_norm_jacobian(samples, output.shape[1], tr, for_loss=False)
            i,_ = norm_diff_interp(samples, labels)
            print(f'\tLoss: {loss.item():.6f} Average norm of Jacobian: {j:6f} Norm of difference in interpretations: {i:6f}')

In [55]:
def test(alpha_wd, alpha_jr, alpha_ir):
    net.eval()
    test_loss = 0
    correct = 0
    
    for samples, labels in test_loader:
        samples, labels = samples.to(device), labels.to(device)
        output = net(samples)
        test_loss += my_loss(output, labels, alpha_wd=alpha_wd, alpha_jr=alpha_jr, x=samples, bp_mat=te, alpha_ir1=alpha_ir).item()
        # output is a tensor, .data retrieves its data, max returns the index of the highest valued element
        preds = output.data.max(1, keepdim=True)[1]
        correct += preds.eq(labels.data.view_as(preds)).sum().item()
                
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * float(correct / len(test_loader.dataset))
    
    print(f'\tTest set accuracy: ({test_accuracy:.2f}%)')

### Training with interpretation regularization

In [56]:
# training details
n_epochs = 30
log_interval = 200
training_round = 2
torch.manual_seed(training_round)

# instantiate model and optimizer
learning_rate = 0.01
momentum = 0.9
net = LeNet()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
lr_decayer = StepLR(optimizer, step_size=10, gamma=0.1)

# make model CUDA enabled and define GPU/device to use
net.cuda()

tr = bp_matrix(tr_batch_size, 10)
te = bp_matrix(te_batch_size, 10)
for epoch in range(1, n_epochs + 1):
    print(f'Epoch #{epoch}')
    train(0,0.001,200)
    test(0,.001,200)
    lr_decayer.step()

Epoch #1
	Loss: 4.847697 Average norm of Jacobian: 0.015538 Norm of difference in interpretations: 0.005444
	Loss: 2.661771 Average norm of Jacobian: 17.968513 Norm of difference in interpretations: 0.004571
	Loss: 2.449969 Average norm of Jacobian: 27.894032 Norm of difference in interpretations: 0.003753
	Loss: 2.397987 Average norm of Jacobian: 45.971424 Norm of difference in interpretations: 0.003798
	Loss: 2.159640 Average norm of Jacobian: 43.929634 Norm of difference in interpretations: 0.003663
	Test set accuracy: (96.71%)
Epoch #2
	Loss: 2.172888 Average norm of Jacobian: 47.361488 Norm of difference in interpretations: 0.003795
	Loss: 2.308616 Average norm of Jacobian: 60.109856 Norm of difference in interpretations: 0.003741
	Loss: 2.121992 Average norm of Jacobian: 51.011551 Norm of difference in interpretations: 0.003691
	Loss: 2.336782 Average norm of Jacobian: 40.847168 Norm of difference in interpretations: 0.003297
	Loss: 2.037063 Average norm of Jacobian: 37.391800 No

In [57]:
torch.save(net.state_dict(), f'trained_models/interp_reg_tests/jr{}_ir1{}_ir2{}')