In [1]:
import torch,torchvision,os,pyhessian,time
import torchvision.transforms as transforms
import numpy as np
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)

Using cache found in /home/zihao/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [2]:
def get_loader(dset,batch_size,test_size=None,test_batch_size=500):
    if dset == 'MNIST':
        transf = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        # MNIST Dataset (Images and Labels)
        train_dataset = torchvision.datasets.MNIST(root =os.path.abspath('../data'),train=True,transform=transf, download=True)
        test_dataset = torchvision.datasets.MNIST(root =os.path.abspath('../data'),train=False,transform=transf, download=False)
        if test_size is not None and test_size < 10000:
            np.random.seed(42)
            sampled_index=np.random.choice(10000,test_size)
            test_dataset.data = torch.tensor(np.array(test_dataset.data)[sampled_index])
            test_dataset.targets = torch.tensor(np.array(test_dataset.targets)[sampled_index])


        # Dataset Loader (Input Pipeline)
        train_loader = torch.utils.data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True)
        test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size=test_batch_size, shuffle = False)

    elif dset == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False,download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,shuffle=False)
    
    elif dset == 'CIFAR100':
        CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
        CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
        ])
        train_dataset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        test_dataset = torchvision.datasets.CIFAR100(root='../data', train=False,download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,shuffle=False)
        
    
    else:
        print('unrecogonized dataset')
        exit(1)

    return train_loader,test_loader

def test(val_loader,model,noise_std,repeat,device,imgSize=28,imgFlat=True,debug=False,lossfunc=torch.nn.CrossEntropyLoss()):
    model.eval()
    loss_list = []
    acc_list = []
    start = time.time()
    for test in range(repeat):
        if noise_std > 0:
            model.generate_variation(noise_std=noise_std)
        # performance on testset
        correct = 0
        total = 0
        accumulative_loss = 0
        count = 0

        for t_images, t_labels in val_loader:
            count += 1
            if imgFlat:
                t_images = t_images.view(-1,imgSize**2)
            t_images = t_images.to(device)
            t_outputs = model(t_images)
            t_labels = t_labels.to(device)
            t_loss = lossfunc(t_outputs,t_labels)
            accumulative_loss += t_loss.data.item()
            _, t_predicted = torch.max(t_outputs.data, 1)
            total += t_labels.size(0)
            correct += (t_predicted == t_labels).sum()
        acc = (correct.data.item()/ total)
        loss_list.append(accumulative_loss/count)
        acc_list.append(acc)

        if debug:
            print("test %s/%s [%s batches %.4f seconds]:"%(test+1,repeat,count,time.time()-start))
            start = time.time()
            print("loss %.4f acc %.4f"%(accumulative_loss/count,acc))
    end = time.time()
    loss_list = np.array(loss_list)
    acc_list = np.array(acc_list)
    # statistics
    qtl_loss = np.quantile(loss_list,0.95)
    mean_loss = loss_list.mean()
    qtl_acc = np.quantile(acc_list,0.05)
    mean_acc = acc_list.mean()

    return {'mean_acc':mean_acc,'qtl_acc':qtl_acc,'mean_loss':mean_loss,'qtl_loss':qtl_loss,
            'test time':end-start,'acc_list':acc_list,'loss_list':loss_list}

In [3]:
train_loader,test_loader = get_loader('cifar10'.upper(),batch_size=1000,test_batch_size=1000)
train_loader.num_workers = 4
test_loader.num_workers = 4
model.cuda()

Files already downloaded and verified
Files already downloaded and verified


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [4]:
test(test_loader,model,repeat=1,noise_std=0,device='cuda',imgSize=32,imgFlat=False)

{'mean_acc': 0.926,
 'qtl_acc': 0.926,
 'mean_loss': 0.2815205633640289,
 'qtl_loss': 0.2815205633640289,
 'test time': 0.848832368850708,
 'acc_list': array([0.926]),
 'loss_list': array([0.28152056])}

In [5]:
def get_params_grad(model):
    """
    get model parameters and corresponding gradients
    """
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads

def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv

class hessian():
    """
    The class used to compute :
        i) the top 1 (n) eigenvalue(s) of the neural network
        ii) the trace of the entire neural network
        iii) the estimated eigenvalue density
    """

    def __init__(self, model, criterion=torch.nn.CrossEntropyLoss(), data=None, dataloader=None, cuda=True):
        """
        model: the model that needs Hessain information
        criterion: the loss function
        data: a single batch of data, including inputs and its corresponding labels
        dataloader: the data loader including bunch of batches of data
        """

        # make sure we either pass a single batch or a dataloader
        assert (data != None and dataloader == None) or (data == None and
                                                         dataloader != None)

        self.model = model.eval()  # make model is in evaluation model
        self.criterion = criterion

        if data != None:
            self.data = data
            self.full_dataset = False
        else:
            self.data = dataloader
            self.full_dataset = True

        if cuda:
            self.device = 'cuda'
        else:
            self.device = 'cpu'

        # pre-processing for single batch case to simplify the computation.
        if not self.full_dataset:
            self.inputs, self.targets = self.data
            if self.device == 'cuda':
                self.inputs, self.targets = self.inputs.cuda(
                ), self.targets.cuda()

            # if we only compute the Hessian information for a single batch data, we can re-use the gradients.
            outputs = self.model(self.inputs)
            loss = self.criterion(outputs, self.targets)
            loss.backward(create_graph=True)

        # this step is used to extract the parameters from the model
        params, gradsH = get_params_grad(self.model)
        self.params = params
        self.gradsH = gradsH  # gradient used for Hessian computation

In [6]:
for x,y in train_loader:
    break

In [7]:
hcp = hessian(model,data=(x,y))

  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag


In [9]:
hessian_vector_product(hcp.gradsH, hcp.params, [torch.randn(p.size()).cuda() for p in hcp.params])

(tensor([[[[ 3.3758e-02,  6.0048e-01,  8.3159e-01],
           [-1.6497e-01,  3.3124e-01,  5.3094e-01],
           [-3.3000e-01, -3.1138e-02,  1.6916e-01]],
 
          [[-1.0327e-01,  4.7779e-01,  7.4169e-01],
           [-1.9932e-01,  2.9301e-01,  5.6546e-01],
           [-2.9235e-01, -1.2050e-02,  2.4859e-01]],
 
          [[-1.5909e-01,  3.0704e-01,  5.4434e-01],
           [-7.3827e-02,  2.8298e-01,  4.6310e-01],
           [-1.4334e-01,  4.8607e-02,  2.1971e-01]]],
 
 
         [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]],
 
 
         [[[-8.3868e-01, -1.5898e+00, -1

In [10]:
hcp.gradsH

[tensor([[[[-1.3633e-03, -1.3342e-03, -1.6661e-03],
           [-1.4026e-03, -1.2013e-03, -1.4200e-03],
           [-9.3290e-04, -8.7458e-04, -1.2753e-03]],
 
          [[-2.0535e-04, -2.0291e-04, -5.5659e-04],
           [-5.2124e-05,  8.9158e-05, -2.4677e-04],
           [ 2.4134e-04,  1.6613e-04, -3.8825e-04]],
 
          [[-3.6658e-04, -3.0894e-04, -5.1479e-04],
           [-2.9498e-04, -1.5276e-04, -5.1798e-04],
           [-3.0475e-05, -3.2652e-04, -1.1257e-03]]],
 
 
         [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]],
 
 
         [[[ 2.5894e-04,  3.5049e-04,  1

In [16]:
import torch.nn as nn

class myLinear(nn.Linear):
    def __init__(self,layer):
        super().__init__(in_features=layer.in_features,out_features=layer.out_features)
        self.weight.data = layer.weight
        

In [18]:
l = myLinear(nn.Linear(10,20))

In [19]:
dir(l)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__constants__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_version',
 'add_m