# Neural Anisotropy Directions

**Authors**: Guillermo Ortiz-Jimenez, Apostolos Modas, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard

## Requirements

For executing the code, please make sure that you meet the following requirements:

* python (Successfully tested on v3.8.3)
* [PyTorch](https://pytorch.org/get-started/previous-versions/) (Successfully tested on v1.5.0 with CUDA v10.0.130)
* [Torchvision](https://pytorch.org/get-started/previous-versions/) (Successfully tested on v0.6.0 with CUDA v10.0.130)
* [NumPy](https://numpy.org/) (Successfully tested on v1.18.1)
* [Matplotlib](https://matplotlib.org/) (Successfully tested on v3.1.3)
* [Seaborn](http://seaborn.pydata.org/) (Successfully tested on v0.10.1)
* [Scikit-Learn](https://scikit-learn.org/stable/) (Successfully tested on v0.22.1)

In our experiments, every package was installed through a Conda environment. Assuming CUDA v10.0.130 and Conda v4.8.1 (installed through [Miniconda3](https://docs.conda.io/en/latest/miniconda.html) on CentOS Linux 7), these are the corresponding commands:

```conda create -n myenv python==3.8.3```  
```conda activate myenv```  
```conda install numpy==1.18.1```  
```conda install pytorch=1.5.0 torchvision=0.6.0 cudatoolkit=10.1 -c pytorch```  
```conda install matplotlib==3.1.3```  
```conda install seaborn==0.10.1```  
```conda install scikit-learn==0.22.1```

## Table of contents

- [General training setup](#training_setup)
- [Linearly separable dataset](#images)
- [NAD Computation](#viz_nads)
- [Poisoning CIFAR-10 dataset](#poison_cifar)

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms

from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
from models import TransformLayer
from models import LogReg, LeNet, VGG11_bn, ResNet18, DenseNet121

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

### <a name=training_setup>General training setup</a>

We first give the implementation of our main training procedure. Specifically, we use a standard SGD optimizer with a linear learning rate schedule to optimize all networks. The main hyperparameters are:
- Number of training epochs:`epochs`
- Maximum learning rate:`max_lr`
- Momentum:`momentum`
- Weight decay:`weight_decay`

In [None]:
def train(model, trans, trainloader, testloader,epochs, max_lr, momentum, weight_decay):
    lr_schedule = lambda t: np.interp([t], [0, epochs], [max_lr, 0])[0]
    opt = torch.optim.SGD(model.parameters(), lr=max_lr, momentum=momentum, weight_decay=weight_decay)
    loss_fun = nn.CrossEntropyLoss()
    
    print('Starting training...')
    print()
    
    best_acc = 0
    for epoch in range(epochs):
        print('Epoch', epoch)
        train_loss_sum = 0
        train_acc_sum = 0
        train_n = 0

        model.train()

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            lr = lr_schedule(epoch + (batch_idx + 1) / len(trainloader))
            opt.param_groups[0].update(lr=lr)

            output = model(trans(inputs))
            loss = loss_fun(output, targets)

            opt.zero_grad()
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            opt.step()

            train_loss_sum += loss.item() * targets.size(0)
            train_acc_sum += (output.max(1)[1] == targets).sum().item()
            train_n += targets.size(0)
                
            if batch_idx % 100 == 0:
                print('Batch idx: %d(%d)\tTrain Acc: %.3f%%\tTrain Loss: %.3f' %
                          (batch_idx, epoch, 100. * train_acc_sum / train_n, train_loss_sum / train_n))

        print('\nTrain Summary\tEpoch: %d | Train Acc: %.3f%% | Train Loss: %.3f' %
                  (epoch, 100. * train_acc_sum / train_n, train_loss_sum / train_n))
        
        test_acc, test_loss = test(model, trans, testloader)
        print('Test  Summary\tEpoch: %d | Test Acc: %.3f%% | Test Loss: %.3f\n' % (epoch, test_acc, test_loss))
        
    return model

def test(model, trans, testloader):
    
    loss_fun = nn.CrossEntropyLoss()
    test_loss_sum = 0
    test_acc_sum = 0
    test_n = 0

    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            output = model(trans(inputs))
            loss = loss_fun(output, targets)

            test_loss_sum += loss.item() * targets.size(0)
            test_acc_sum += (output.max(1)[1] == targets).sum().item()
            test_n += targets.size(0)

        test_loss = (test_loss_sum / test_n)
        test_acc = (100 * test_acc_sum / test_n)

        return test_acc, test_loss

### <a name=gen_linear_data>Linearly separable dataset</a>


In our experiments we make an extensive use of the family of linearly separable datasets $\mathcal{D}(v)$, parameterized by $v\in\mathbb{S}^{D-1}$. We now give the code to generate a dataset from this distribution. In particular, the main hyperparameters of the dataset are:
- Number of samples:`num_samples`
- Noise standard deviation:`sigma`
- Size of discriminative feature: `epsilon`
- Shape of the data:`shape`

In [2]:
class DirectionalLinearDataset(data.Dataset):

    def __init__(self,
                 v,
                 num_samples=10000,
                 sigma=3,
                 epsilon=1,
                 shape=(1, 32, 32)
                ):

        self.v = v
        self.num_samples = num_samples
        self.sigma = sigma
        self.epsilon = epsilon
        self.shape = shape
        self.data, self.targets = self._generate_dataset(self.num_samples)
        super()
    
    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])

        return img, target

    def __len__(self):
        return self.num_samples


    def _generate_dataset(self, n_samples):
        if n_samples > 1:
            data_plus = self._generate_samples(n_samples // 2 + n_samples % 2, 0).astype(np.float32)
            labels_plus = np.zeros([n_samples // 2 + n_samples % 2]).astype(np.long)
            data_minus = self._generate_samples(n_samples // 2, 1).astype(np.float32)
            labels_minus = np.ones([n_samples // 2]).astype(np.long)
            data = np.r_[data_plus, data_minus]
            labels = np.r_[labels_plus, labels_minus]
        else:
            data = self._generate_samples(1, 0).astype(np.float32)
            labels = np.zeros([1]).astype(np.long)

        return torch.from_numpy(data), torch.from_numpy(labels)
        
    def _generate_samples(self, n_samples, label):
        data = self._generate_noise_floor(n_samples)
        sign = 1 if label == 0 else -1
        data = sign * self.epsilon / 2 * self.v[np.newaxis, :] + self._project_orthogonal(data)
        return data

    def _generate_noise_floor(self, n_samples):
        shape = [n_samples] + self.shape
        data = self.sigma * np.random.randn(*shape)
        
        return data

    def _project(self, x):
        proj_x = np.reshape(x, [x.shape[0], -1]) @ np.reshape(self.v, [-1, 1])
        return proj_x[:, :, np.newaxis, np.newaxis] * self.v[np.newaxis, :]

    def _project_orthogonal(self, x):
        return x - self._project(x)
    

def generate_synthetic_data(v,
                            num_train=10000,
                            num_test=10000,
                            sigma=3,
                            epsilon=1,
                            shape=(1, 32, 32),
                            batch_size=128):


    trainset = DirectionalLinearDataset(v,
                                        num_samples=num_train,
                                        sigma=sigma,
                                        epsilon=epsilon,
                                        shape=shape)

    testset = DirectionalLinearDataset(v,
                                       num_samples=num_train,
                                       sigma=sigma,
                                       epsilon=epsilon,
                                       shape=shape)

    trainloader = DataLoader(trainset,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=2,
                             batch_size=batch_size)

    testloader = DataLoader(testset,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=2,
                            batch_size=batch_size
                            )

    return trainloader, testloader, trainset, testset

In [18]:
v = torch.zeros([1, 32, 32]) # Create empty vector
v_fft = torch.rfft(v, signal_ndim=2)
v_fft[0, 3, 4, 1] = 1 # Select coordinate in fourier space
v = torch.irfft(v_fft, signal_ndim=2, signal_sizes=[32, 32])
v = v/ v.norm()
trainloader, testloader, trainset, testset = generate_synthetic_data(v.numpy(),
                                                                     num_train=10000,
                                                                     num_test=10000,
                                                                     sigma=3,
                                                                     epsilon=1,
                                                                     shape=[1, 32, 32],
                                                                     batch_size=128)

Let's now train a LeNet on a `DirectionalLinearDataset` with a single discriminative feature pointing in a random direction. We should expect bad accuracy since most likely this direction is not aligned with the directional inductive bias of the architecture.

In [None]:
v = np.random.randn(1, 32, 32)
v = v / np.linalg.norm(v)
trainloader, testloader, trainset, testset = generate_synthetic_data(v,
                                                                     num_train=10000,
                                                                     num_test=10000,
                                                                     sigma=3,
                                                                     epsilon=1,
                                                                     shape=[1, 32, 32],
                                                                     batch_size=128)


# net = LogReg(input_dim=32 * 32, num_classes=2)
# net = VGG11_bn(num_channels=1, num_classes=2)
# net = ResNet18(num_channels=1, num_classes=2)
# net = DenseNet121(num_channels=1, num_classes=2)

net = LeNet(num_channels=1, num_classes=2)
net = net.to(DEVICE)

trained_model = train(model=net,
                      trans= TransformLayer(mean=torch.tensor(0., device=DEVICE), std=torch.tensor(1., device=DEVICE)),
                      trainloader=trainloader,
                      testloader=testloader,
                      epochs=20,
                      max_lr=0.5,
                      momentum=0,
                      weight_decay=0
                     )

### <a name=vis_NADs>NAD Computation</a>

We now give the code to compute the NADs using the eigendecomposition of the gradient covariance. 

In [None]:
def input_numerical_jacobian(fn, x, scale, device, batch_size=None):
    shape = list(x.shape)
    n_dims = int(np.prod(shape))
    batch_size = n_dims if batch_size is None else batch_size
    v = torch.eye(n_dims).view([n_dims] + shape)
    jac = torch.zeros(n_dims)
    residual = 1 if n_dims % batch_size > 0 else 0
    for n in range(n_dims // batch_size + residual):
        batch_plus = x[None, :] + scale * v[n * batch_size: (n+1) * batch_size].to(device)
        batch_minus = x[None, :] - scale * v[n * batch_size: (n+1) * batch_size].to(device)

        jac[n * batch_size: (n+1) * batch_size] = ((fn(batch_plus) - fn(batch_minus)) / (2 * scale)).detach().cpu()[:, 0]

    return jac.view(shape)


class GradientCovarianceAnisotropyFinder:

    def __init__(self, 
                 model_gen_fun,
                 num_networks,
                 eval_point=None,
                 k=None,
                 scale=1,
                 device='cpu',
                 batch_size=None):

        self.model_gen_fun = model_gen_fun
        self.num_networks = num_networks
        self.eval_point = eval_point
        self.scale = scale
        self.k = k
        self.device = device
        self.batch_size = batch_size
        self._gradients = None


    def _numerical_input_derivative(self, model, v0):
        model = model.to(self.device)
        fn = lambda x: -model(x)
        jac = input_numerical_jacobian(fn, v0, self.scale, self.device, batch_size=self.batch_size)
        return jac

    @property
    def sample_gradients(self):
        if self._gradients is None:
            self._gradients = []
            for n in range(self.num_networks):
                self._gradients.append(self._numerical_input_derivative(self.model_gen_fun(), self.eval_point).cpu().view([-1]))
        return torch.stack(self._gradients).numpy()

    def estimate_NADs(self):
        pca = PCA(n_components=self.k)
        pca.fit(self.sample_gradients)
        return pca.singular_values_, pca.components_

Let's estimate the NADs of a LeNet

In [None]:
def model_gen_fun():
    model = LeNet(num_classes=1, num_channels=1).eval()
    return model


anisotropy_finder = GradientCovarianceAnisotropyFinder(model_gen_fun=model_gen_fun,
                                                       scale=100,
                                                       num_networks=10000,
                                                       k=1024,
                                                       eval_point=torch.randn([1, 32, 32], device=DEVICE),
                                                       device=DEVICE,
                                                       batch_size=None)

eigenvalues, NADs = anisotropy_finder.estimate_NADs()

and visualize them

In [None]:
indices = list(range(5))

plt.figure(figsize=(15,5))

for n, index in enumerate(indices):
    x = NADs[index].reshape([32, 32])
    
    vmax = np.max([np.abs(x.max()), np.abs(x.min())])
    vmin = -vmax

    cmap = sns.cubehelix_palette(8, start=.5, rot=-.75, as_cmap=True, reverse=True)

    x_fft = np.fft.fftshift(np.fft.fft2(x))
    
    plt.subplot(2*np.ceil(len(indices) / 5),5,n+5*(n// 5)+1)
    plt.imshow(x, cmap='BrBG', vmin=vmin, vmax=vmax)
    plt.title(r'Index %d'%index)
    plt.axis('off')

    plt.subplot(2*np.ceil(len(indices) / 5),5,n+5*(n// 5 + 1)+1)
    plt.imshow(np.abs(x_fft)**2, cmap=cmap)
    plt.axis('off')

We expect that you have stored all NADs in the path `./NADs/`. The next lines of code let you visualize them.

In [None]:
NAD_dir = './NADs/'
architecture = 'ResNet18' # 'LogReg', 'LeNet', 'VGG11', 'ResNet18', 'DenseNet121'

# Indices to visualize
indices = list(range(5))

In [None]:
NAD_path = NAD_dir + architecture + '_NADs.npy'

NADs = np.load(NAD_path).reshape([-1, 32, 32])

plt.figure(figsize=(15,5))

for n, index in enumerate(indices):
    x = NADs[index]
    
    vmax = np.max([np.abs(x.max()), np.abs(x.min())])
    vmin = -vmax

    cmap = sns.cubehelix_palette(8, start=.5, rot=-.75, as_cmap=True, reverse=True)

    x_fft = np.fft.fftshift(np.fft.fft2(x))
    
    plt.subplot(2*np.ceil(len(indices) / 5),5,n+5*(n// 5)+1)
    plt.imshow(x, cmap='BrBG', vmin=vmin, vmax=vmax)
    plt.title(r'Index %d'%index)
    plt.axis('off')

    plt.subplot(2*np.ceil(len(indices) / 5),5,n+5*(n// 5 + 1)+1)
    plt.imshow(np.abs(x_fft)**2, cmap=cmap)
    plt.axis('off')

### <a name=poison>Poisoning CIFAR-10 dataset</a>

Finally we give the implementation of the poisoning experiment of Fig. 7.

First, we load CIFAR-10 and modify its training set to include a poisonous carrier.

In [None]:
def load_cifar_data(path,
                    batch_size=128):
    
    tf_train = transforms.Compose([
            transforms.ToTensor()])
    
    tf_test = transforms.Compose([
            transforms.ToTensor()])
    
    trainset = torchvision.datasets.CIFAR10(root=path, download=True, train=True, transform=tf_train)
    testset = torchvision.datasets.CIFAR10(root=path, download=True, train=False, transform=tf_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                              num_workers=2, pin_memory=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False,
                                             num_workers=2, pin_memory=True)

    mean = torch.as_tensor([0.4914, 0.4822, 0.4465], dtype=torch.float, device=DEVICE)[None, :, None, None]
    std = torch.as_tensor([0.247, 0.243, 0.261], dtype=torch.float, device=DEVICE)[None, :, None, None]
    
    return trainloader, testloader, trainset, testset, mean, std


def poison_with_NADs(trainset, NAD_idx, epsilon, NAD_path, num_classes=10, num_channels=3, batch_size=128):
    x = torch.from_numpy(trainset.data.transpose([0, 3, 1, 2])).type(torch.float) / 255.
    y = torch.tensor(trainset.targets, dtype=torch.long)
    shape = x.shape[1:]
    
    V = np.load(NAD_path)
    V = torch.from_numpy(V)
    poison_indices = (NAD_idx, NAD_idx + 1)

    x_poison = x.clone()
    for t in range(num_classes):
        idx = poison_indices[t // (num_channels * 2)]
        channel_idx = t % num_channels
        sign = 2 * (t % 2) - 1
        carrier = torch.zeros_like(x[0])
        carrier[channel_idx] = V[idx].view([1, shape[-2], shape[-1]])
        x_bias = torch.einsum('bi, i->b', x[y == t].view([-1, np.prod(shape)]), carrier.view(-1))
        x_poison[y == t] += (epsilon * sign - x_bias[:, None, None, None]) * carrier[None, :, :, :]

    poisonset = torch.utils.data.TensorDataset(x_poison, y)
    poisonloader = torch.utils.data.DataLoader(poisonset, batch_size=batch_size, shuffle=True, num_workers=2,
                                               pin_memory=True)

    return poisonloader

We then train on the poisoned training set and test on the original CIFAR-10 test set.

In [None]:
architecture = 'ResNet18'

net = ResNet18(num_channels=3, num_classes=10)
net = net.to(DEVICE)

CIFAR_path = './'
NAD_path = NAD_dir + architecture + '_NADs.npy'

poison_idx = 0
epsilon = 0.05

trainloader, testloader, trainset, testset, mean, std = load_cifar_data(CIFAR_path)

poisonloader = poison_with_NADs(trainset,
                                NAD_idx=poison_idx, 
                                epsilon=epsilon, 
                                NAD_path=NAD_path)

trained_model = train(model=net,
                      trans= TransformLayer(mean=mean, std=std),
                      trainloader=poisonloader,
                      testloader=testloader,
                      epochs=50,
                      max_lr=0.21,
                      momentum=0.9,
                      weight_decay=5e-4
                     )