Skip to content

Commit

Permalink
Merge pull request #68 from amzn/simindex
Browse files Browse the repository at this point in the history
Added similarity index folder
  • Loading branch information
adamian committed Mar 19, 2020
2 parents ef5c201 + 59685c7 commit 94c2416
Show file tree
Hide file tree
Showing 8 changed files with 748 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Each folder in this repository corresponds to a method or tool for transfer/meta

In more detail:
- [xfer-ml](xfer-ml): A library that allows quick and easy transfer of knowledge stored in deep neural networks implemented in MXNet. xfer-ml can be used with data of arbitrary numeric format, and can be applied to the common cases of image or text data. It can be used as a pipeline that spans from extracting features to training a repurposer. The repurposer is then an object that carries out predictions in the target task. You can also use individual components of the library as part of your own pipeline. For example, you can leverage the feature extractor to extract features from deep neural networks or ModelHandler, which allows for quick building of neural networks, even if you are not an MXNet expert.
- [leap](leap): MXNet implementation of "leap", the meta-gradient path learner published in ICLR 2019: [(link)](https://arxiv.org/abs/1812.01054) by S. Flennerhag, P. G. Moreno, N. Lawrence, A. Damianou.
- [leap](leap): MXNet implementation of "leap", the meta-gradient path learner published in ICLR 2019: [(link)](https://arxiv.org/abs/1812.01054) by S. Flennerhag, P. G. Moreno, N. Lawrence, A. Damianou.
- [nn_similarity_index](nn_similarity_index): PyTorch code for comparing trained neural networks using both feature and gradient information.


Navigate to the corresponding folder for more details.
Expand Down
58 changes: 58 additions & 0 deletions nn_similarity_index/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
## Similarity of Neural Networks with Gradients
--------------------------------------------------------------------------------

## Introduction

This folder contains code for comparing trained neural networks using both feature and gradient information. The implementation relies on the following three files:

*sketched_kernels.py* computes the sketched kernel matrices of individual residual blocks based on a pretrained ImageNet model and a given dataset.

*sim_indices.py* computes the similarity scores between two residual blocks.

*utils.py* provides two helper functions, including *load_model* for loading an ImageNet model and *load_dataset* for creating a dataloader object.

## Requirements
```
python >= 3.5
torch >= 1.0
torchvision
numpy
```

## Example
Generate our proposed kernel matrices for individual residual blocks
given a pretrained ImageNet model and a dataset (cifar10 below)
```
CUDA_VISIBLE_DEVICES=0 python -u cwt_kernel_mat.py \
--datapath data/ \
--modelname resnet18 \
--pretrained \
--seed 1111 \
--task cifar10 \
--split test \
--bsize 256 \
--num-buckets-sketching 128 \
--num-buckets-per-sample 1
```

Given sketched kernel matrices calculated on one dataset (cifar10 below),
compute a heatmap in which each entry is the similarity score between two residual blocks
```
python -u compute_similarity.py \
--loadpath sketched_kernel_mat/ \
--filename1 resnet18_test_cifar10_1111.npy \
--simindex cka
```

Given sketched kernel matrices calculated on two datasets (cifar10 and cifar100 below),
compute a heatmap in which each entry is the similarity score between two residual blocks
```
python -u compute_similarity.py \
--loadpath sketched_kernel_mat/ \
--filename1 resnet18_test_cifar10_1111.npy \
--filename2 resnet18_test_cifar100_1111.npy \
--simindex cka
```

## Authors
Shuai Tang
54 changes: 54 additions & 0 deletions nn_similarity_index/compute_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================

import os
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import argparse
from sim_indices import SimIndex


if __name__ == "__main__":

# Get arguments from the command line
parser = argparse.ArgumentParser(description='PyTorch CWT sketching kernel matrices')

parser.add_argument('--loadpath', type=str,
help='absolute path to the folder that contains the file')
parser.add_argument('--filename1', type=str,
help='absolute path to the file that contains kernel matrices')
parser.add_argument('--filename2', type=str, default=None,
help='absolute path to the file that contains kernel matrices')
parser.add_argument('--simindex', type=str, choices=['euclidean', 'cka', 'nbs'], default='cka',
help='similarity index to use in computing the scores')

args = parser.parse_args()

# load the file that contains kernel matrices of individual residual blocks
kernel_matrices_1 = np.load(args.loadpath + args.filename1, allow_pickle=True).item()
kernel_matrices_2 = np.load(args.loadpath + args.filename2, allow_pickle=True).item() if args.filename2 else kernel_matrices_1

n_resblocks_1 = len(kernel_matrices_1)
n_resblocks_2 = len(kernel_matrices_2)
sim_scores = np.zeros((n_resblocks_1, n_resblocks_2))

simindices = SimIndex()
func_ = getattr(simindices, args.simindex)

for layer_id1 in range(n_resblocks_1):
for layer_id2 in range(n_resblocks_2):
sim_scores[layer_id1, layer_id2] = func_(kernel_matrices_1[layer_id1], kernel_matrices_2[layer_id2])

np.save(args.loadpath + 'heatmap.npy', sim_scores)
103 changes: 103 additions & 0 deletions nn_similarity_index/cwt_kernel_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================

import os
os.environ["OMP_NUM_THREADS"] = "1"

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
from abc import ABC
import os
import argparse
from sketched_kernels import SketchedKernels

from utils import *


if __name__ == "__main__":

# Get arguments from the command line
parser = argparse.ArgumentParser(description='PyTorch CWT sketching kernel matrices')

parser.add_argument('--datapath', type=str,
help='absolute path to the dataset')
parser.add_argument('--modelname', type=str,
help='model name')
parser.add_argument('--pretrained', action='store_true',
help='whether to load a pretrained ImageNet model')

parser.add_argument('--seed', default=0, type=int,
help='random seed for sketching')
parser.add_argument('--task', default='cifar10', type=str, choices=['cifar10', 'cifar100', 'svhn', 'stl10'],
help='the name of the dataset, cifar10 or cifar100 or svhn or stl10')
parser.add_argument('--split', default='train', type=str,
help='split of the dataset, train or test')
parser.add_argument('--bsize', default=512, type=int,
help='batch size for computing the kernel')

parser.add_argument('--M', '--num-buckets-sketching', default=512, type=int,
help='number of buckets in Sketching')
parser.add_argument('--T', '--num-buckets-per-sample', default=1, type=int,
help='number of buckets each data sample is sketched to')

parser.add_argument('--freq_print', default=10, type=int,
help='frequency for printing the progress')

args = parser.parse_args()

# Set the backend and the random seed for running our code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(args.seed)
if device == 'cuda':
cudnn.benchmark = True
torch.cuda.manual_seed(args.seed)

# The size of images for training and testing ImageNet models
imgsize = 224

# Generate a dataloader that iteratively reads data
# Load a model, either pretrained or not
loader = load_dataset(args.task, args.split, args.bsize, args.datapath, imgsize)
net = load_model(device, args.modelname, pretrained=True)

# Set the model to be in the evaluation mode. VERY IMPORTANT!
# This step to fix the running statistics in batchnorm layers,
# and disable dropout layers
net.eval()

csm = SketchedKernels(net, loader, imgsize, device, args.M, args.T, args.freq_print)
csm.compute_sketched_kernels()

# Compute sketched kernel matrices for each layer
for layer_id in range(len(csm.kernel_matrices)):
nkme = (csm.kernel_matrices[layer_id].sum() ** 0.5) / csm.n_samples
print("The norm of the kernel mean embedding of layer {:d} is {:.4f}".format(layer_id, nkme))

del net, loader
torch.cuda.empty_cache()

# Save the sketched kernel matrices
savepath = 'sketched_kernel_mat/'
if not os.path.isdir(savepath):
os.mkdir(savepath)

save_filename = '{}_{}_{}_{}.npy'.format(args.modelname, args.split, args.task, args.seed)
np.save(savepath + save_filename, csm.kernel_matrices)
4 changes: 4 additions & 0 deletions nn_similarity_index/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python >= 3.5
torch >= 1.0
torchvision
numpy
65 changes: 65 additions & 0 deletions nn_similarity_index/sim_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================

import numpy as np
from abc import ABC
import os
import argparse


class SimIndex(ABC):

r"""
The class that supports three similarity indices.
Notes:
Currently supports Euclidean distance, Centred Kernel Alignment
and Normalised Bures Similarity between two kernel matrices.
"""

def __init__(self):
...

def centering(self, kmat):
r"""
Centering the kernel matrix
"""
return kmat - kmat.mean(axis=0, keepdims=True) - kmat.mean(axis=1, keepdims=True) + kmat.mean()

def euclidean(self, kmat_1, kmat_2):
r"""
Compute the Euclidean distance between two kernel matrices
"""
return np.linalg.norm(kmat_1 - kmat_2)

def cka(self, kmat_1, kmat_2):
r"""
Compute the Centred Kernel Alignment between two kernel matrices.
\rho(K_1, K_2) = \Tr (K_1 @ K_2) / ||K_1||_F / ||K_2||_F
"""
kmat_1 = self.centering(kmat_1)
kmat_2 = self.centering(kmat_2)
return np.trace(kmat_1 @ kmat_2) / np.linalg.norm(kmat_1) / np.linalg.norm(kmat_2)

def nbs(self, kmat_1, kmat_2):
r"""
Compute the Normalised Bures Similarity between two kernel matrices.
\rho(K_1, K_2) = \Tr( (K_1^{1/2} @ K_2 @ K_1^{1/2})^{1/2} ) / \Tr(K_1) / \Tr(K_2)
"""
kmat_1 = self.centering(kmat_1)
kmat_2 = self.centering(kmat_2)
return sum(np.real(np.linalg.eigvals(kmat_1 @ kmat_2)).clip(0.) ** 0.5) / ((np.trace(kmat_1) * np.trace(kmat_2)) ** 0.5)

0 comments on commit 94c2416

Please sign in to comment.