diff --git a/finite_ntk/README.md b/finite_ntk/README.md index e34cec9..9b1668c 100644 --- a/finite_ntk/README.md +++ b/finite_ntk/README.md @@ -1,8 +1,6 @@ # Transfer Learning via Linearized Neural Networks -This repository contains a GPyTorch implementation of finite width neural tangent kernels from the paper [(link)](https://arxiv.org/abs/2103.01439) - -*Fast Adaptation with Linearized Neural Networks* +This repository contains a GPyTorch implementation of finite width neural tangent kernels from the paper [Fast Adaptation with Linearized Neural Networks](https://arxiv.org/abs/2103.01439) by Wesley Maddox, Shuai Tang, Pablo Garcia Moreno, Andrew Gordon Wilson, and Andreas Damianou, diff --git a/finite_ntk/experiments/README.md b/finite_ntk/experiments/README.md new file mode 100644 index 0000000..5395297 --- /dev/null +++ b/finite_ntk/experiments/README.md @@ -0,0 +1,8 @@ + + +## Olivetti + +First, run `cd dataset; python adaptation_dataset_maker.py` to download Olivetti and construct the dataset. + +Then run `python run_ntk.py --prop=XXX` (note that the Fisher flag is untested.) +the defaults are the learning rates/etc. we used for your proportion. diff --git a/finite_ntk/experiments/olivetti/README.md b/finite_ntk/experiments/olivetti/README.md new file mode 100644 index 0000000..5395297 --- /dev/null +++ b/finite_ntk/experiments/olivetti/README.md @@ -0,0 +1,8 @@ + + +## Olivetti + +First, run `cd dataset; python adaptation_dataset_maker.py` to download Olivetti and construct the dataset. + +Then run `python run_ntk.py --prop=XXX` (note that the Fisher flag is untested.) +the defaults are the learning rates/etc. we used for your proportion. diff --git a/finite_ntk/experiments/olivetti/dataset/adaptation_dataset_maker.py b/finite_ntk/experiments/olivetti/dataset/adaptation_dataset_maker.py new file mode 100644 index 0000000..476b448 --- /dev/null +++ b/finite_ntk/experiments/olivetti/dataset/adaptation_dataset_maker.py @@ -0,0 +1,90 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import math +import numpy as np +import sklearn.datasets as datasets +from rotators import * + +def main(): + olivetti = datasets.fetch_olivetti_faces() + inputs = olivetti['data'] + images = olivetti['images'] + targets = olivetti['target'] + + repeat_images = 30 # the number of times we want to repeat each image + n_images = images.shape[0] * repeat_images + image_sz = images.shape[1] + + image_keep = int(math.sqrt(2) * image_sz/2) + + rotated_faces = np.zeros((n_images, image_keep, image_keep)) + all_targets = np.zeros(n_images) + angles = np.pi * np.random.rand(n_images) - np.pi/2 + degrees = 180/np.pi * angles + + img_ind = 0 + for rpt in range(repeat_images): + for face_ind in range(images.shape[0]): + rotated_image = rotate_image(images[face_ind, :, :], degrees[img_ind]) + cropped_image = crop_around_center(rotated_image, image_keep, image_keep) + all_targets[img_ind] = targets[face_ind] + + rotated_faces[img_ind, :, :] = cropped_image + img_ind += 1 + + ## separate training data out ## + n_people = 40 + + n_train_people = 20 + + ## randomly select the 20 people for training and 10 for test + shuffle_people = np.random.permutation(n_people) + train_people = shuffle_people[:n_train_people] + test_people = shuffle_people[n_train_people:] + + train_images = np.zeros((1, image_keep, image_keep)) + train_angles = np.zeros((1)) + for tp in train_people: + keepers = np.where(all_targets == tp)[0] + keep_imgs = rotated_faces[np.ix_(keepers), :, :].squeeze() + keep_angles = degrees[np.ix_(keepers)] + train_images = np.concatenate((train_images, keep_imgs), 0) + train_angles = np.concatenate((train_angles, keep_angles)) + + train_images = np.expand_dims(train_images[1:, :, :], 1) + train_angles = train_angles[1:] + + + test_images = np.zeros((1, image_keep, image_keep)) + test_angles = np.zeros((1)) + test_people_ids = np.zeros((1)) + for i, tp in enumerate(test_people): + keepers = np.where(all_targets == tp)[0] + keep_imgs = rotated_faces[np.ix_(keepers), :, :].squeeze() + keep_angles = degrees[np.ix_(keepers)] + test_images = np.concatenate((test_images, keep_imgs), 0) + test_angles = np.concatenate((test_angles, keep_angles)) + test_people_ids = np.concatenate((test_people_ids, i* np.ones(keep_imgs.shape[0]))) + + test_images = np.expand_dims(test_images[1:, :, :], 1) + test_angles = test_angles[1:] + test_people_ids = test_people_ids[1:] + + np.savez("./rotated_faces_data_withids.npz", + train_images=train_images, train_angles=train_angles, + test_images=test_images, test_angles=test_angles, test_people_ids=test_people_ids) + +if __name__ == '__main__': + main() diff --git a/finite_ntk/experiments/olivetti/dataset/get_faces_loaders.py b/finite_ntk/experiments/olivetti/dataset/get_faces_loaders.py new file mode 100644 index 0000000..0467414 --- /dev/null +++ b/finite_ntk/experiments/olivetti/dataset/get_faces_loaders.py @@ -0,0 +1,42 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import torch +import numpy as np + + +def get_faces_loaders(batch_size=128, test=True, data_path="./data/"): + """ + returns the train (and test if selected) loaders for the olivetti + rotated faces dataset + """ + + dat = np.load(data_path + "rotated_faces_data.npz") + train_images = torch.FloatTensor(dat['train_images']) + train_targets = torch.FloatTensor(dat['train_angles']) + + traindata = torch.utils.data.TensorDataset(train_images, train_targets) + trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, + shuffle=True) + + if test: + test_images = torch.FloatTensor(dat['test_images']) + test_targets = torch.FloatTensor(dat['test_angles']) + + testdata = torch.utils.data.TensorDataset(test_images, test_targets) + testloader = torch.utils.data.DataLoader(testdata, batch_size=batch_size) + + return trainloader, testloader + + return trainloader diff --git a/finite_ntk/experiments/olivetti/dataset/rotators.py b/finite_ntk/experiments/olivetti/dataset/rotators.py new file mode 100644 index 0000000..df31da8 --- /dev/null +++ b/finite_ntk/experiments/olivetti/dataset/rotators.py @@ -0,0 +1,144 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import math +import cv2 +import numpy as np + +def rotate_image(image, angle): + """ + Rotates an OpenCV 2 / NumPy image about it's centre by the given angle + (in degrees). The returned image will be large enough to hold the entire + new image, with a black background + """ + + # Get the image size + # No that's not an error - NumPy stores image matricies backwards + image_size = (image.shape[1], image.shape[0]) + image_center = tuple(np.array(image_size) / 2) + + # Convert the OpenCV 3x2 rotation matrix to 3x3 + rot_mat = np.vstack( + [cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]] + ) + + rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2]) + + # Shorthand for below calcs + image_w2 = image_size[0] * 0.5 + image_h2 = image_size[1] * 0.5 + + # Obtain the rotated coordinates of the image corners + rotated_coords = [ + (np.array([-image_w2, image_h2]) * rot_mat_notranslate).A[0], + (np.array([ image_w2, image_h2]) * rot_mat_notranslate).A[0], + (np.array([-image_w2, -image_h2]) * rot_mat_notranslate).A[0], + (np.array([ image_w2, -image_h2]) * rot_mat_notranslate).A[0] + ] + + # Find the size of the new image + x_coords = [pt[0] for pt in rotated_coords] + x_pos = [x for x in x_coords if x > 0] + x_neg = [x for x in x_coords if x < 0] + + y_coords = [pt[1] for pt in rotated_coords] + y_pos = [y for y in y_coords if y > 0] + y_neg = [y for y in y_coords if y < 0] + + right_bound = max(x_pos) + left_bound = min(x_neg) + top_bound = max(y_pos) + bot_bound = min(y_neg) + + new_w = int(abs(right_bound - left_bound)) + new_h = int(abs(top_bound - bot_bound)) + + # We require a translation matrix to keep the image centred + trans_mat = np.matrix([ + [1, 0, int(new_w * 0.5 - image_w2)], + [0, 1, int(new_h * 0.5 - image_h2)], + [0, 0, 1] + ]) + + # Compute the tranform for the combined rotation and translation + affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :] + + # Apply the transform + result = cv2.warpAffine( + image, + affine_mat, + (new_w, new_h), + flags=cv2.INTER_LINEAR + ) + + return result + + +def largest_rotated_rect(w, h, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + radians), computes the width and height of the largest possible + axis-aligned rectangle within the rotated rectangle. + + Original JS code by 'Andri' and Magnus Hoff from Stack Overflow + + Converted to Python by Aaron Snoswell + """ + + quadrant = int(math.floor(angle / (math.pi / 2))) & 3 + sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle + alpha = (sign_alpha % math.pi + math.pi) % math.pi + + bb_w = w * math.cos(alpha) + h * math.sin(alpha) + bb_h = w * math.sin(alpha) + h * math.cos(alpha) + + gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w) + + delta = math.pi - alpha - gamma + + length = h if (w < h) else w + + d = length * math.cos(alpha) + a = d * math.sin(alpha) / math.sin(delta) + + y = a * math.cos(gamma) + x = y * math.tan(gamma) + + return ( + bb_w - 2 * x, + bb_h - 2 * y + ) + + +def crop_around_center(image, width, height): + """ + Given a NumPy / OpenCV 2 image, crops it to the given width and height, + around it's centre point + """ + + image_size = (image.shape[1], image.shape[0]) + image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5)) + + if(width > image_size[0]): + width = image_size[0] + + if(height > image_size[1]): + height = image_size[1] + + x1 = int(image_center[0] - width * 0.5) + x2 = int(image_center[0] + width * 0.5) + y1 = int(image_center[1] - height * 0.5) + y2 = int(image_center[1] + height * 0.5) + + return image[y1:y2, x1:x2] diff --git a/finite_ntk/experiments/olivetti/run_ntk.py b/finite_ntk/experiments/olivetti/run_ntk.py new file mode 100644 index 0000000..1513b9c --- /dev/null +++ b/finite_ntk/experiments/olivetti/run_ntk.py @@ -0,0 +1,264 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import numpy as np +import torch +import copy +import torch.nn as nn +import argparse +import sys + +import gpytorch +import pickle + +import finite_ntk + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + corr_shape = [x.shape[0], *self.shape] + return x.view(*corr_shape) + +class NTKGP(gpytorch.models.ExactGP): + def __init__(self, train_x, train_y, likelihood, model, fisher=False): + super(NTKGP, self).__init__(train_x, train_y, likelihood) + + self.mean_module = gpytorch.means.ConstantMean() + self.covar_module = finite_ntk.lazy.NTK( + model=model, use_linearstrategy=fisher, used_dims=0 + ) + + def forward(self, x): + mean = self.mean_module(x) + covar = self.covar_module(x) + return gpytorch.distributions.MultivariateNormal(mean, covar) + +def make_network(): + return torch.nn.Sequential( + Reshape(1, 45, 45), + torch.nn.Conv2d(1, 20, kernel_size=5), + torch.nn.ReLU(True), + torch.nn.MaxPool2d(kernel_size=2), + torch.nn.Conv2d(20, 50, kernel_size=5), + torch.nn.ReLU(True), + torch.nn.MaxPool2d(kernel_size=2), + torch.nn.Flatten(), + torch.nn.Linear(3200, 500), + torch.nn.ReLU(True), + torch.nn.Linear(500, 2) + ) + +def gaussian_loglikelihood(input, target, eps=1e-5): + r""" + heteroscedastic Gaussian likelihood where we parameterize the variance + with the 1e-5 + softplus(network) + input: tensor (batch + two-d, presumed to be output from model) + target: tensor + eps (1e-5): a nugget style term to ensure that the variance doesnt go to 0 + """ + dist = torch.distributions.Normal( + input[:, 0], torch.nn.functional.softplus(input[:, 1]) + eps + ) + res = dist.log_prob(target.view(-1)) + return res.mean() + +def main(args): + torch.random.manual_seed(args.seed) + + ## SET UP DATA ## + dataset = np.load("./dataset/rotated_faces_data_withids.npz") + train_images = dataset['train_images'] + train_targets = dataset['train_angles'] + + test_images = dataset['test_images'] + test_targets = dataset['test_angles'] + + train_images = torch.from_numpy(train_images).float() + train_targets = torch.from_numpy(train_targets).float() + + ### prepare adaptation and validation dataset + num_to_keep = int(args.prop * test_images.shape[0]) + adapt_indices = np.random.permutation(test_images.shape[0]) + adapt_people = adapt_indices[:num_to_keep] + val_people = adapt_indices[num_to_keep:] + + adapt_images = torch.from_numpy(test_images[adapt_people]).float() + adapt_targets = torch.from_numpy(test_targets[adapt_people]).float() + + val_images = torch.from_numpy(test_images[val_people]).float() + val_targets = torch.from_numpy(test_targets[val_people]).float() + + ##### standardize targets + train_mean = train_targets.mean() + train_std = train_targets.std() + + train_targets = (train_targets - train_mean) / train_std + val_targets = (val_targets - train_mean) / train_std + adapt_targets = (adapt_targets - train_mean) / train_std + + ##### set up data loaders + # we have to reshape the inputs so that gpytorch internals can stack + adaptloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + adapt_images.reshape(adapt_images.shape[0], -1), + adapt_targets, + ), + batch_size=32, + shuffle=False, + ) + + ###### make the network and set up optimizer + net = make_network() + net.cuda() + + optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3, amsgrad=True) + + lossfn = gaussian_loglikelihood + + ###### now train the network + for i in range(args.epochs): + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + ###### construct the GP model + # we have to reshape the inputs so that gpytorch internals can stack + likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda() + likelihood.noise = rmse.item()**2 + model = NTKGP(adapt_images.reshape(adapt_images.shape[0], -1), adapt_targets, + likelihood, net, fisher=args.fisher).cuda() + + # set in eval mode + likelihood.eval() + model.eval() + + # compute predictive mean + with gpytorch.settings.fast_pred_var(True): + results = model(val_images.reshape(val_images.shape[0], -1).cuda()) + gp_predictions = results.mean.detach().cpu() + gp_rmse = torch.sqrt(torch.mean((gp_predictions - val_targets)**2)) + print('GP RMSE: ', gp_rmse) + + # testing RMSE: + output = net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + network_predictions = output.cpu() + net_rmse = torch.sqrt(torch.mean((network_predictions - val_targets)**2)) + print('Final Net RMSE: ', net_rmse) + + ########## now we create a copy to re-train the last layer + # which is a poor baseline here + frozen_net = copy.deepcopy(net) + parlength = len(list(frozen_net.named_parameters())) + for i, (n, p) in enumerate(frozen_net.named_parameters()): + if i < (parlength - 2): + p.requires_grad = False + else: + # print out the name of the layer to verify we are only using the last layer + print(n) + + #### train the last layer + for i in range(args.adapt_epochs): + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = frozen_net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + # testing RMSE: + output = frozen_net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + frozen_network_predictions = output.cpu() + fnet_rmse = torch.sqrt(torch.mean((frozen_network_predictions - val_targets)**2)) + print('Frozen Net RMSE: ', fnet_rmse) + + ##### and finally we "fine-tune" the whole net + for i in range(args.adapt_epochs): + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + # testing RMSE: + output = net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + rt_network_predictions = output.cpu() + rt_net_rmse = torch.sqrt(torch.mean((rt_network_predictions - val_targets)**2)) + print('Retrained Net RMSE: ', rt_net_rmse) + + output_dict = {'targets': val_targets, + 'init_rmse': net_rmse, + 'gp_rmse': gp_rmse, + 'frozen_rmse': fnet_rmse, + 'retrain_rmse': rt_net_rmse, + 'gp_preds': gp_predictions, + 'rt_preds': rt_network_predictions, + 'frozen_preds': frozen_network_predictions + } + with open(args.output_file, "wb") as handle: + pickle.dump(output_dict, handle, pickle.HIGHEST_PROTOCOL) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--epochs", help="(int) number of epochs to train for", default=150, type=int + ) + parser.add_argument( + "--adapt_epochs", help="(int) number of epochs to train for", default=15, type=int + ) + parser.add_argument( + "--prop", help="(float) proportion of validation points", default=0.5, type=float + ) + parser.add_argument( + "--fisher", action="store_true", help="fisher basis usage flag (default: off)" + ) + parser.add_argument("--seed", help="random seed", type=int, default=10) + parser.add_argument("--output_file", type=str) + args = parser.parse_args() + main(args) + diff --git a/finite_ntk/experiments/run_ntk.py b/finite_ntk/experiments/run_ntk.py new file mode 100644 index 0000000..5d73fd4 --- /dev/null +++ b/finite_ntk/experiments/run_ntk.py @@ -0,0 +1,259 @@ +import numpy as np +import torch +import copy +import torch.nn as nn +import argparse +import sys + +import gpytorch +import pickle + +import finite_ntk + +sys.path.append('../cifar') +from utils import train_epoch + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + corr_shape = [x.shape[0], *self.shape] + return x.view(*corr_shape) + +class NTKGP(gpytorch.models.ExactGP): + def __init__(self, train_x, train_y, likelihood, model, fisher=False): + super(NTKGP, self).__init__(train_x, train_y, likelihood) + + self.mean_module = gpytorch.means.ConstantMean() + self.covar_module = finite_ntk.lazy.NTK( + model=model, use_linearstrategy=fisher, used_dims=0 + ) + + def forward(self, x): + mean = self.mean_module(x) + covar = self.covar_module(x) + return gpytorch.distributions.MultivariateNormal(mean, covar) + +def make_network(): + return torch.nn.Sequential( + Reshape(1, 45, 45), + torch.nn.Conv2d(1, 20, kernel_size=5), + torch.nn.ReLU(True), + torch.nn.MaxPool2d(kernel_size=2), + torch.nn.Conv2d(20, 50, kernel_size=5), + torch.nn.ReLU(True), + torch.nn.MaxPool2d(kernel_size=2), + torch.nn.Flatten(), + torch.nn.Linear(3200, 500), + torch.nn.ReLU(True), + torch.nn.Linear(500, 2) + ) + +def gaussian_loglikelihood(input, target, eps=1e-5): + r""" + heteroscedastic Gaussian likelihood where we parameterize the variance + with the 1e-5 + softplus(network) + input: tensor (batch + two-d, presumed to be output from model) + target: tensor + eps (1e-5): a nugget style term to ensure that the variance doesnt go to 0 + """ + dist = torch.distributions.Normal( + input[:, 0], torch.nn.functional.softplus(input[:, 1]) + eps + ) + res = dist.log_prob(target.view(-1)) + return res.mean() + +def main(args): + torch.random.manual_seed(args.seed) + + ## SET UP DATA ## + dataset = np.load("./dataset/rotated_faces_data_withids.npz") + train_images = dataset['train_images'] + train_targets = dataset['train_angles'] + + test_images = dataset['test_images'] + test_targets = dataset['test_angles'] + #test_ids = dataset['test_people_ids'] + + train_images = torch.from_numpy(train_images).float() + train_targets = torch.from_numpy(train_targets).float() + + ### prepare adaptation and validation dataset + num_to_keep = int(args.prop * test_images.shape[0]) + adapt_indices = np.random.permutation(test_images.shape[0]) + adapt_people = adapt_indices[:num_to_keep] + val_people = adapt_indices[num_to_keep:] + + adapt_images = torch.from_numpy(test_images[adapt_people]).float() + adapt_targets = torch.from_numpy(test_targets[adapt_people]).float() + + val_images = torch.from_numpy(test_images[val_people]).float() + val_targets = torch.from_numpy(test_targets[val_people]).float() + + ##### standardize targets + train_mean = train_targets.mean() + train_std = train_targets.std() + + train_targets = (train_targets - train_mean) / train_std + val_targets = (val_targets - train_mean) / train_std + adapt_targets = (adapt_targets - train_mean) / train_std + + ##### set up data loaders + # we have to reshape the inputs so that gpytorch internals can stack + trainloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(train_images.reshape(train_images.shape[0], -1), + train_targets), batch_size=32, + shuffle=True) + + adaptloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(adapt_images.reshape(adapt_images.shape[0], -1), + adapt_targets), batch_size=32, + shuffle=False) + + ###### make the network and set up optimizer + net = make_network() + net.cuda() + + optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3, amsgrad=True) + + lossfn = gaussian_loglikelihood + + ###### now train the network + for i in range(args.epochs): + #train_epoch(trainloader, net, gaussian_loglikelihood, optimizer) + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + ###### construct the GP model + # we have to reshape the inputs so that gpytorch internals can stack + likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda() + likelihood.noise = rmse.item()**2 + model = NTKGP(adapt_images.reshape(adapt_images.shape[0], -1), adapt_targets, + likelihood, net, fisher=args.fisher).cuda() + + # set in eval mode + likelihood.eval() + model.eval() + + # compute predictive mean + with gpytorch.settings.fast_pred_var(True): + results = model(val_images.reshape(val_images.shape[0], -1).cuda()) + gp_predictions = results.mean.detach().cpu() + gp_rmse = torch.sqrt(torch.mean((gp_predictions - val_targets)**2)) + print('GP RMSE: ', gp_rmse) + + # testing RMSE: + output = net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + network_predictions = output.cpu() + net_rmse = torch.sqrt(torch.mean((network_predictions - val_targets)**2)) + print('Final Net RMSE: ', net_rmse) + + ########## now we create a copy to re-train the last layer + # which is a poor baseline here + frozen_net = copy.deepcopy(net) + parlength = len(list(frozen_net.named_parameters())) + for i, (n, p) in enumerate(frozen_net.named_parameters()): + if i < (parlength - 2): + p.requires_grad = False + else: + # print out the name of the layer to verify we are only using the last layer + print(n) + + #### train the last layer + f_optimizer = torch.optim.Adam(frozen_net.parameters(), lr = 1e-3, amsgrad=True) + for i in range(args.adapt_epochs): + #train_epoch(adaptloader, frozen_net, gaussian_loglikelihood, f_optimizer) + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = frozen_net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + # testing RMSE: + output = frozen_net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + frozen_network_predictions = output.cpu() + fnet_rmse = torch.sqrt(torch.mean((frozen_network_predictions - val_targets)**2)) + print('Frozen Net RMSE: ', fnet_rmse) + + ##### and finally we "fine-tune" the whole net + for i in range(args.adapt_epochs): + #train_epoch(trainloader, net, gaussian_loglikelihood, optimizer) + for input, target in adaptloader: + input, target = input.cuda(), target.cuda() + + outputs = net(input) + loss = -lossfn(outputs, target).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if i % 10 is 0: + with torch.no_grad(): + val_inputs = val_images.reshape(val_images.shape[0], -1).cuda() + outputs = net(val_inputs)[:,0].cpu() + rmse = torch.sqrt(torch.mean((outputs - val_targets)**2)) + print('Epoch: ', i, 'Test RMSE: ', rmse.item()) + + # testing RMSE: + output = net(val_images.reshape(val_images.shape[0], -1).cuda())[:,0].detach() + rt_network_predictions = output.cpu() + rt_net_rmse = torch.sqrt(torch.mean((rt_network_predictions - val_targets)**2)) + print('Retrained Net RMSE: ', rt_net_rmse) + + output_dict = {'targets': val_targets, + 'init_rmse': net_rmse, + 'gp_rmse': gp_rmse, + 'frozen_rmse': fnet_rmse, + 'retrain_rmse': rt_net_rmse, + 'gp_preds': gp_predictions, + 'rt_preds': rt_network_predictions, + 'frozen_preds': frozen_network_predictions + } + with open(args.output_file, "wb") as handle: + pickle.dump(output_dict, handle, pickle.HIGHEST_PROTOCOL) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--epochs", help="(int) number of epochs to train for", default=150, type=int + ) + parser.add_argument( + "--adapt_epochs", help="(int) number of epochs to train for", default=15, type=int + ) + parser.add_argument( + "--prop", help="(float) proportion of validation points", default=0.5, type=float + ) + parser.add_argument( + "--fisher", action="store_true", help="fisher basis usage flag (default: off)" + ) + parser.add_argument("--seed", help="random seed", type=int, default=10) + parser.add_argument("--output_file", type=str) + args = parser.parse_args() + main(args) + diff --git a/finite_ntk/experiments/unsup/README.md b/finite_ntk/experiments/unsup/README.md new file mode 100644 index 0000000..44bea3e --- /dev/null +++ b/finite_ntk/experiments/unsup/README.md @@ -0,0 +1,9 @@ +## unsupervised + +Here, we copied the main src folder of https://github.com/fmu2/gradfeat20 (MIT License) and modified it. + +To run the MAP results, run `python benchmark_rop.py --config=configs/....` while the vi ones are at +`python vi_benchmark_rop.py --config=configs/....` + +We moved the hard paths in the config file and replaced them with $LOC. +You should change that after cloning their repo. \ No newline at end of file diff --git a/finite_ntk/experiments/unsup/benchmark_rop.py b/finite_ntk/experiments/unsup/benchmark_rop.py new file mode 100644 index 0000000..cc1e31a --- /dev/null +++ b/finite_ntk/experiments/unsup/benchmark_rop.py @@ -0,0 +1,294 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import time, configargparse, torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +from model import Net +from util import load_data +import torch.backends.cudnn as cudnn + +from finite_ntk.lazy import flatten, unflatten_like +from finite_ntk.lazy.utils import Rop + +def parser_args(): + parser = configargparse.ArgParser('BiGAN experiments') + parser.add('-c', '--config', required=True, is_config_file=True, + help='config file') + + parser.add_argument('--gpu', type=int, default=0, + help='gpu instance to use (default: 0)') + parser.add_argument('--seed', type=int, default=1818, + help='random seed (default: 1818)') + + # dataset + parser.add_argument('--dataset', type=str, default='cifar10', + choices={'cifar10', 'cifar100', 'svhn'}, + help='dataset (default: cifar10)') + parser.add_argument('--data_path', type=str, + help='path to load dataset') + parser.add_argument('--nclass', type=int, default=10, + help='number of classes (default: 10)') + parser.add_argument('--batchsize', type=int, default=128, + help='batch size (default: 128)') + parser.add_argument('--normalize', action='store_true', + help='whether to normalize data') + + # optimization + parser.add_argument('--optim', type=str, default='sgd', + help='optimizer (default: sgd)') + parser.add_argument('--lr', type=float, default=1e-3, + help='learning rate (default: 1e-3)') + parser.add_argument('--wd', type=float, default=1e-6, + help='weight decay (default: 5e-5)') + parser.add_argument('--niter', type=int, default=80000, + help='number of training iterations (default: 80000)') + parser.add_argument('--stepsize', type=int, default=20000, + help='by which learning rate is halved (default: 20000)') + + # network + parser.add_argument('--mode', type=str, default='full', + choices={'full_plus_gradients', 'gradfull', 'full', 'actv', 'grad'}, + help='features to use (default: full)') + parser.add_argument('--fnet_path', type=str, + help='path to load fnet') + parser.add_argument('--hnet_path', type=str, + help='path to load hnet') + parser.add_argument('--clf_path', type=str, + help='path to load clf') + parser.add_argument('--model_path', type=str, + help='path to save model') + parser.add_argument('--freeze_hnet', nargs='+', type=int, default=0, + help='hnet layers to freeze') + parser.add_argument('--linearize_hnet', nargs='+', type=int, default=0, + help='hnet layers to linearize') + parser.add_argument('--linearize_clf', action='store_true', + help='whether to linearize the classifier') + + return parser.parse_args() + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def train(device, loader, model, fnet, mode, + optimizer, niter, stepsize, losses, it=0): + batch_time = AverageMeter() + data_time = AverageMeter() + end = time.time() + + curr_iter = it + model.train() + + for (x, y) in loader: + data_time.update(time.time() - end) + + # update learning rate + if curr_iter != 0 and curr_iter % stepsize == 0: + for param_group in optimizer.param_groups: + param_group['lr'] = param_group['lr'] * 0.5 + print('iter %d learning rate is %.5f' % (curr_iter, param_group['lr'])) + + x, y = x.to(device), y.to(device) + #print('mode is: ', mode) + if mode == 'full': + # proposed model + logits, jvp = model(x) + logits = logits + jvp + elif mode == 'grad': + # gradient baseline (second term in proposed model) + _, jvp = model(x) + logits = jvp + elif mode == 'gradfull': + _, jvp = model(x) + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(model.fnet_vector, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = jvp + fnet_jvp @ model.fixed_projection + elif mode == 'full_plus_gradients': + logits, jvp = model(x) + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(model.fnet_vector, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = logits + jvp + fnet_jvp @ model.fixed_projection + else: + # activation baseline or fine-tuning (first term in proposed model) + logits = model(x) + + loss = nn.CrossEntropyLoss()(logits, y) + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 10) # clip gradient + optimizer.step() + + losses.update(loss.item(), x.size(0)) + batch_time.update(time.time() - end) + end = time.time() + + if curr_iter % 50 == 0: + #print(model.clf.fc.weight.norm(), model.fixed_projection.norm()) + print('Iteration[{0}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( + curr_iter, + batch_time=batch_time, data_time=data_time, loss=losses)) + curr_iter += 1 + + if curr_iter == niter: + break + + return curr_iter + + +def evaluate(device, loader, model, fnet, mode): + model.eval() + ncorr = 0 + + for i, (x, y) in enumerate(loader): + x, y = x.to(device), y.to(device) + + with torch.no_grad(): + if mode == 'full': + # proposed model + logits, jvp = model(x) + logits = logits + jvp + elif mode == 'grad': + # gradient baseline (second term in proposed model) + _, jvp = model(x) + logits = jvp + elif mode == 'gradfull': + _, jvp = model(x) + # what is full_v -- i think it needs to be trainable? + with torch.enable_grad(): + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(model.fnet_vector, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = jvp + fnet_jvp @ model.fixed_projection + elif mode == 'full_plus_gradients': + logits, jvp = model(x) + # what is full_v -- i think it needs to be trainable? + with torch.enable_grad(): + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(model.fnet_vector, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = logits + jvp + fnet_jvp @ model.fixed_projection + else: + # activation baseline or fine-tuning (first term in proposed model) + logits = model(x) + + pred = torch.argmax(logits.detach_(), dim=1) + ncorr += (pred == y).sum() + + acc = ncorr.float() / len(loader) + print(acc.item()) + + +def main(): + args = parser_args() + print('Batch size: %d' % args.batchsize) + print('Initial learning rate: %.5f' % args.lr) + print('Weight decay: %.6f' % args.wd) + + device = torch.device('cuda:' + str(args.gpu) + if torch.cuda.is_available() else 'cpu') + cudnn.benchmark = True + + # fix random seed + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + + net = Net(nclasses=args.nclass) + fnet = torch.load(args.fnet_path) # feature net (theta_1) + hnet = torch.load(args.hnet_path) # head net (theta_2) + clf = torch.load(args.clf_path) # classifier (omega) + + # load parameters (random vs. pre-trained) as appropriate + net.load_fnet(fnet, freeze=True) + net.load_hnet(hnet, reinit_idx=(), + freeze_idx=args.freeze_hnet, linearize_idx=args.linearize_hnet) + net.load_clf(clf, reinit=False, linearize=args.linearize_clf) + + # construct + net.fnet_vector = torch.nn.Parameter(torch.zeros(flatten(net.fnet.parameters()).shape[0], 1, requires_grad=True)) + fixed_proj = 1e-4 * torch.randn(10816, 10) + net.fixed_projection = torch.nn.Parameter(fixed_proj) + net.fixed_projection.requires_grad = False + + net.to(device) + for p in fnet.parameters(): + p.requires_grad = True + + fnet.to(device) + params = list(filter(lambda p: p.requires_grad, net.parameters())) + if args.optim == 'sgd': + optimizer = optim.SGD( + params, lr=args.lr, weight_decay=args.wd, momentum=0.9) + elif args.optim == 'adam': + optimizer = optim.Adam( + params, lr=args.lr, betas=(0.5, 0.999), weight_decay=args.wd) + + # check trainable parameters + for p in params: + print(p.size()) + + # load training and test data + train_loader, test_loader = load_data( + args.dataset, args.data_path, args.batchsize, args.normalize) + + print('----- Training phase -----') + it = 0 + losses = AverageMeter() + + while it < args.niter: + it = train( + device, train_loader, net, fnet, args.mode, optimizer, + args.niter, args.stepsize, losses, it=it) + + print(int(it / int(50000 / args.batchsize))) + if int(it / int(50000 / args.batchsize)) % 10 == 0: + print('----- Evaluation phase -----') + print('> test accuracy:') + evaluate(device, test_loader, net, fnet, args.mode) + + print('----- Final Evaluation phase -----') + print('> test accuracy:') + evaluate(device, test_loader, net, fnet, args.mode) + + torch.save(net.cpu(), args.model_path) + + +if __name__ == '__main__': + main() diff --git a/finite_ntk/experiments/unsup/configs/finetune_all.config b/finite_ntk/experiments/unsup/configs/finetune_all.config new file mode 100644 index 0000000..eaee9ce --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/finetune_all.config @@ -0,0 +1,16 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-2 +wd 1e-6 +niter 80000 +stepsize 20000 +mode actv +fnet_path $LOC/gradfeat20/models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20/models/pretrained/cifar10/ali/std_hnet1.pt +clf_path $LOC/gradfeat20/models/pretrained/cifar10/ali/std_clf0.pt +freeze_hnet [0] +linearize_hnet [0] +model_path ./actv_all.net diff --git a/finite_ntk/experiments/unsup/configs/full_plus_gradients.config b/finite_ntk/experiments/unsup/configs/full_plus_gradients.config new file mode 100644 index 0000000..dee1641 --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/full_plus_gradients.config @@ -0,0 +1,17 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-1 +wd 5e-5 +niter 50000 +stepsize 10000 +mode full_plus_gradients +fnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_hnet1.pt +clf_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_clf1.pt +model_path full_plus_gradients.net +freeze_hnet [1, 2, 3] +linearize_hnet [1, 2, 3] +linearize_clf diff --git a/finite_ntk/experiments/unsup/configs/fullgrad.config b/finite_ntk/experiments/unsup/configs/fullgrad.config new file mode 100644 index 0000000..53fd6b0 --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/fullgrad.config @@ -0,0 +1,17 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-1 +wd 5e-5 +niter 50000 +stepsize 10000 +mode gradfull +fnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_hnet1.pt +clf_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_clf1.pt +model_path gradfull.net +freeze_hnet [1, 2, 3] +linearize_hnet [1, 2, 3] +linearize_clf diff --git a/finite_ntk/experiments/unsup/configs/grad_conv123.config b/finite_ntk/experiments/unsup/configs/grad_conv123.config new file mode 100644 index 0000000..010c46a --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/grad_conv123.config @@ -0,0 +1,17 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-1 +wd 5e-5 +niter 50000 +stepsize 10000 +mode grad +fnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_hnet1.pt +clf_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_clf1.pt +model_path $LOC/gradfeat20//models/benchmark/cifar10/ali/gradfull_conv123.net +freeze_hnet [1, 2, 3] +linearize_hnet [1, 2, 3] +linearize_clf diff --git a/finite_ntk/experiments/unsup/configs/vi_grad.config b/finite_ntk/experiments/unsup/configs/vi_grad.config new file mode 100644 index 0000000..f4bb872 --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/vi_grad.config @@ -0,0 +1,17 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-1 +wd 5e-5 +niter 50000 +stepsize 10000 +mode gradfull +fnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_hnet1.pt +clf_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_clf1.pt +model_path vi_gradfull.net +freeze_hnet [1, 2, 3] +linearize_hnet [1, 2, 3] +linearize_clf diff --git a/finite_ntk/experiments/unsup/configs/vi_linearization.config b/finite_ntk/experiments/unsup/configs/vi_linearization.config new file mode 100644 index 0000000..b082785 --- /dev/null +++ b/finite_ntk/experiments/unsup/configs/vi_linearization.config @@ -0,0 +1,17 @@ +dataset cifar10 +data_path ./data/cifar10 +nclass 10 +batchsize 128 +optim sgd +lr 1e-2 +wd 5e-5 +niter 50000 +stepsize 10000 +mode full_plus_gradients +fnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/fnet1.pt +hnet_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_hnet1.pt +clf_path $LOC/gradfeat20//models/pretrained/cifar10/ali/ntk_clf1.pt +model_path full_plus_gradients.net +freeze_hnet [1, 2, 3] +linearize_hnet [1, 2, 3] +linearize_clf diff --git a/finite_ntk/experiments/unsup/losses.py b/finite_ntk/experiments/unsup/losses.py new file mode 100644 index 0000000..efe3ab5 --- /dev/null +++ b/finite_ntk/experiments/unsup/losses.py @@ -0,0 +1,238 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import torch + +from gpytorch.utils.lanczos import lanczos_tridiag + +from finite_ntk.lazy import Jacobian, FVP_FD, flatten, unflatten_like, Rop + + +def map_crossentropy( + model, + num_classes=10, + bias=True, + wd=1, + current_pars=None, + num_data=1, + eval_mode=False, +): + r""" + constructor for MAP estimation, returns a function + model (nn.module): torch model + num_classes (int): number of classes + bias (bool): whether to include the bias parameters list in the loss. + current_pars (list/iterable): parameter list, only used when bias=True + num_data (int): number of data points + eval_mode (bool): whether to include the regularizer in the model definition. overloaded + to more generally mean if the loss function is in train (regularizer included) or eval_mode mode + (regularizer not included for test LL computation) + """ + if bias and current_pars is None: + Warning("Nothing will be returned because current_pars is none") + + def criterion(current_pars, input_data, target, return_predictions=True): + r""" + Loss function for MAP + + current_pars (list/iterable): parameter list + input_data (tensor): input data for model + target (tensor): response + return_predictions (bool):if predictions should be returned as well as loss + """ + rhs = current_pars[0] - flatten(model.parameters()).view(-1,1) + + with torch.enable_grad(): + features = Rop(model(input_data), model.parameters(), unflatten_like(rhs, model.parameters()))[0] + + if bias: + features = wd * features + model(input_data) + + predictions = features @ current_pars[1] + + loss = ( + torch.nn.functional.cross_entropy(predictions, target) + * target.shape[0] + ) + + if eval_mode: + output = loss + else: + output = loss + + return output, predictions + + return criterion + + +def laplace_crossentropy( + model, + num_classes=10, + bias=True, + wd=1e-4, + current_pars=None, + num_data=1, + eval_mode=False, +): + r""" + constructor for Laplace approximation, returns a loss function + model (nn.module): torch model + num_classes (int): number of classes + bias (bool): whether to include the bias parameters list in the loss. + current_pars (list/iterable): parameter list, only used when bias=True + num_data (int): number of data points + eval_mode: whether to include the regularizer in the model definition. overloaded + to more generally mean if the loss function is in train (regularizer included) or eval_mode mode + (regularizer not included and we perform sampling) + """ + model_pars = flatten(model.parameters()) + + def criterion(current_pars, input_data, target, return_predictions=True): + r""" + Loss function for Laplace + + current_pars (list/iterable): parameter list + input_data (tensor): input data for model + target (tensor): response + return_predictions (bool):if predictions should be returned as well as loss + """ + if eval_mode: + # this means prediction time + # so do a Fisher vector product + jitter, take the tmatrix invert the cholesky decomp and sample + # F \approx Q T Q' => F^{-1} \approx Q T^{-1} Q' + # F^{-1/2} \approx Q T^{-1/2} + fvp = ((num_data / input_data.shape[0]) * FVP_FD(model, input_data)).add_jitter(1.0) + qmat, tmat = lanczos_tridiag( + fvp.matmul, + 30, + dtype=current_pars[0].dtype, + device=current_pars[0].device, + init_vecs=None, + matrix_shape=[current_pars[0].shape[0], current_pars[0].shape[0]], + ) + + eigs, evecs = torch.symeig(tmat, eigenvectors=True) + + # only consider the top half of the eigenvalues bc they're reliable + eigs_gt_zero = torch.sort(eigs)[1][-int(tmat.shape[0] / 2) :] + + # update the eigendecomposition + # note that @ is a matmul + updated_evecs = (qmat @ evecs)[:, eigs_gt_zero] + + z = torch.randn( + eigs_gt_zero.shape[0], 1, device=tmat.device, dtype=tmat.dtype + ) + approx_lz = updated_evecs @ torch.diag(1.0 / eigs[eigs_gt_zero].pow(0.5)) @ z + sample = current_pars[0] + approx_lz + else: + sample = current_pars[0] + + rhs = sample + if bias: + rhs = sample - model_pars.view(-1, 1) + + predictions = Jacobian(model=model, data=input_data, num_outputs=1)._t_matmul(rhs) + predictions_reshaped = predictions.reshape(target.shape[0], num_classes) + + if bias: + predictions_reshaped = predictions_reshaped + model(input_data) + + loss = ( + torch.nn.functional.cross_entropy(predictions_reshaped, target) + * target.shape[0] + ) + regularizer = current_pars[0].norm() * wd + + if eval_mode: + output = loss + else: + output = num_data * loss + regularizer + + return output, predictions_reshaped + + return criterion + + +def vi_crossentropy( + model, + num_classes=10, + bias=True, + wd=1e-4, + current_pars=None, + num_data=1, + eval_mode=False, +): + r""" + constructor for SVI approximation, returns a loss function + model (nn.module): torch model + num_classes (int): number of classes + bias (bool): whether to include the bias parameters list in the loss. + current_pars (list/iterable): parameter list, only used when bias=True + num_data (int): number of data points + eval_mode: whether to include the regularizer in the model definition. overloaded + to more generally mean if the loss function is in train (regularizer included) or eval_mode mode + (regularizer not included and we perform sampling) + """ + model_pars = flatten(model.parameters()) + + def criterion(current_pars, input_data, target, return_predictions=True): + r""" + Loss function for SVI + + current_pars (list/iterable): parameter list + input_data (tensor): input data for model + target (tensor): response + return_predictions (bool):if predictions should be returned as well as loss + """ + current_dist = torch.distributions.Normal( + current_pars[0], torch.nn.functional.softplus(current_pars[1]) + ) + prior_dist = torch.distributions.Normal( + torch.zeros_like(current_pars[0]), 1 / wd * torch.ones_like(current_pars[1]) + ) + + if not eval_mode: + sample = current_dist.rsample() + else: + sample = current_pars[0] + + rhs = sample + if bias: + rhs = sample - model_pars.view_as(sample) + + # compute J^T \theta + predictions = Jacobian(model=model, data=input_data, num_outputs=1)._t_matmul(rhs) + + predictions_reshaped = predictions.reshape(target.shape[0], num_classes) + if bias: + predictions_reshaped = predictions_reshaped + model(input_data) + + loss = ( + torch.nn.functional.cross_entropy(predictions_reshaped, target) + * target.shape[0] + ) + + regularizer = ( + torch.distributions.kl_divergence(current_dist, prior_dist).sum() / num_data + ) + + if eval_mode: + output = loss + else: + output = loss + regularizer + + return output, predictions_reshaped + + return criterion diff --git a/finite_ntk/experiments/unsup/model.py b/finite_ntk/experiments/unsup/model.py new file mode 100644 index 0000000..472339c --- /dev/null +++ b/finite_ntk/experiments/unsup/model.py @@ -0,0 +1,233 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +## copied from https://github.com/fmu2/gradfeat20 + +import torch +import torch.nn as nn +from util import * + +LEAK = 0.01 # for Jenson-Shannon BiGAN + +class FeatureNet(nn.Module): + """Network section parametrized by theta_1""" + """standard parametrization, used by both baseline and proposed models.""" + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 32, 5, 1) + self.conv2 = nn.Conv2d(32, 64, 4, 2) + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def forward(self, x): + y1 = F.leaky_relu(self.conv1(x), LEAK, inplace=True) + y2 = F.leaky_relu(self.conv2(y1), LEAK, inplace=True) + return y2 + + +class STDHeadNet(nn.Module): + """Network section parametrized by theta_2""" + """standard parametrization, used by baseline""" + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(64, 128, 4, 1) + self.conv2 = nn.Conv2d(128, 256, 4, 2) + self.conv3 = nn.Conv2d(256, 512, 4, 1) + + def reinit(self, idx): + pass + + def freeze(self, idx): + if 1 in idx: + for p in self.conv1.parameters(): + p.requires_grad = False + + if 2 in idx: + for p in self.conv2.parameters(): + p.requires_grad = False + + if 3 in idx: + for p in self.conv3.parameters(): + p.requires_grad = False + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def linearize(self, idx): + pass + + def forward(self, x): + y1 = F.leaky_relu(self.conv1(x), LEAK, inplace=True) + y2 = F.leaky_relu(self.conv2(y1), LEAK, inplace=True) + y3 = F.leaky_relu(self.conv3(y2), LEAK, inplace=True) + #print(y3.shape) + return y3, None + + +class NTKHeadNet(nn.Module): + """Network section parametrized by theta_2""" + """NTK parametrization, used by proposed model""" + def __init__(self): + super().__init__() + self.conv1 = NTKConv2d(64, 128, 4, 1) + self.conv2 = NTKConv2d(128, 256, 4, 2) + self.conv3 = NTKConv2d(256, 512, 4, 1) + self.linear1 = self.linear2 = self.linear3 = None + + def reinit(self, idx): + if 1 in idx: self.conv1.init() + if 2 in idx: self.conv2.init() + if 3 in idx: self.conv3.init() + + def freeze(self, idx=(1, 2)): + if 1 in idx: self.conv1.freeze() + if 2 in idx: self.conv2.freeze() + if 3 in idx: self.conv3.freeze() + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def linearize(self, idx=(3,)): + self.freeze(idx) + if 1 in idx: self.linear1 = NTKConv2d(64, 128, 4, 1, zero_init=True) + if 2 in idx: self.linear2 = NTKConv2d(128, 256, 4, 2, zero_init=True) + if 3 in idx: self.linear3 = NTKConv2d(256, 512, 4, 1, zero_init=True) + + def forward(self, x): + y1 = F.leaky_relu(self.conv1(x), LEAK, inplace=True) + y2 = F.leaky_relu(self.conv2(y1), LEAK, inplace=True) + y3 = F.leaky_relu(self.conv3(y2), LEAK, inplace=True) + + jvp3 = None + if self.linear3 is not None: + jvp3 = self.linear3(y2) + if self.linear2 is not None: + jvp2 = self.linear2(y1) + if self.linear1 is not None: + jvp1 = self.linear1(x) * ((y1 > 0).float() + (y1 < 0).float() * LEAK) + jvp2 = self.conv2(jvp1, add_bias=False) + jvp2 + jvp2 = jvp2 * ((y2 > 0).float() + (y2 < 0).float() * LEAK) + jvp3 = self.conv3(jvp2, add_bias=False) + jvp3 + jvp3 = jvp3 * ((y3 > 0).float() + (y3 < 0).float() * LEAK) + return y3, jvp3 + + +class STDClassifier(nn.Module): + """Logistic regressor parametrized by omega""" + """standard parametrization, used by baseline""" + def __init__(self, nclasses): + super().__init__() + self.nclasses = nclasses + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512, nclasses) + + def reinit(self): + self.fc = nn.Linear(512, self.nclasses) + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def linearize(self): + pass + + def forward(self, x, jvp=None): + x = self.avgpool(x).flatten(1) + logits = self.fc(x) + return logits + + +class NTKClassifier(nn.Module): + """Logistic regressor parametrized by omega""" + """NTK parametrization, used by proposed model""" + def __init__(self, nclasses): + super().__init__() + self.nclasses = nclasses + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = NTKLinear(512, nclasses) + self.linear = None + + def reinit(self): + self.fc.init() + + def freeze(self): + self.fc.freeze() + + def thaw(self): + self.fc.thaw() + + def linearize(self, static=True): + if static: + self.linear = NTKLinear(512, self.nclasses) + self.linear.weight.data = self.fc.weight.data + self.linear.bias.data = self.fc.bias.data + self.linear.freeze() + else: + self.fc.thaw() + self.linear = self.fc + + def forward(self, x, jvp=None): + x = self.avgpool(x).flatten(1) + logits = self.fc(x) + + if jvp is not None: + assert self.linear is not None + jvp = self.avgpool(jvp).flatten(1) + jvp = self.linear(jvp, add_bias=False) + + return logits, jvp + + +class Net(nn.Module): + """Network as either baseline or proposed method""" + def __init__(self, nclasses): + super().__init__() + self.fnet = FeatureNet() + self.hnet = STDHeadNet() + self.clf = STDClassifier(nclasses) + + def load_fnet(self, fnet, freeze=True): + self.fnet = fnet + self.fnet.thaw() + if freeze: + self.fnet.freeze() + + def load_hnet(self, hnet, reinit_idx, freeze_idx, linearize_idx): + self.hnet = hnet + self.hnet.thaw() + self.hnet.reinit(reinit_idx) + self.hnet.freeze(freeze_idx) + self.hnet.linearize(linearize_idx) + + def load_clf(self, clf, reinit=False, linearize=False, static=True): + self.clf = clf + self.clf.thaw() + if reinit: self.clf.reinit() + if linearize: self.clf.linearize(static) + + def forward(self, x): + x, jvp = self.hnet(self.fnet(x)) + return self.clf(x, jvp) \ No newline at end of file diff --git a/finite_ntk/experiments/unsup/parser.py b/finite_ntk/experiments/unsup/parser.py new file mode 100644 index 0000000..dde0f74 --- /dev/null +++ b/finite_ntk/experiments/unsup/parser.py @@ -0,0 +1,104 @@ +# Copyright 2020 anonymous. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import argparse + + +def parser(): + parser = argparse.ArgumentParser(description="fast adaptation training") + + parser.add_argument( + "--dataset", type=str, default="CIFAR10", help="dataset name (default: CIFAR10). Other options include CIFAR100 and STL10" + ) + parser.add_argument( + "--data_path", + type=str, + default=None, + required=True, + metavar="path", + help="path to datasets location (default: None)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=128, + metavar="N", + help="input batch size (default: 128)", + ) + + + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--save_freq", + type=int, + default=25, + metavar="N", + help="save frequency (default: 25)", + ) + parser.add_argument( + "--eval_freq", + type=int, + default=5, + metavar="N", + help="evaluation frequency (default: 5)", + ) + parser.add_argument( + "--lr_init", + type=float, + default=0.1, + metavar="LR", + help="initial learning rate (default: 0.1)", + ) + parser.add_argument( + "--wd", type=float, default=1e-4, help="weight decay (default: 1e-4)" + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--save_path", + type=str, + default=None, + required=False, + help="path to npz results file", + ) + parser.add_argument( + "--bias", + action="store_true", + help="whether to use bias terms in the loss (default: off)", + ) + parser.add_argument( + "--inference", + type=str, + choices=["laplace", "vi", "map"], + required=True, + default="map", + help="inference choice to use", + ) + parser.add_argument('--normalize', action='store_true', + help='whether to normalize data') + parser.add_argument('--fnet_path', type=str, + help='path to load fnet') + parser.add_argument('--hnet_path', type=str, + help='path to load hnet') + + args = parser.parse_args() + + return args diff --git a/finite_ntk/experiments/unsup/util.py b/finite_ntk/experiments/unsup/util.py new file mode 100644 index 0000000..6228382 --- /dev/null +++ b/finite_ntk/experiments/unsup/util.py @@ -0,0 +1,218 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +## copied from https://github.com/fmu2/gradfeat20/src/ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import torch.utils.data as data +from torchvision.datasets import SVHN, CIFAR10, CIFAR100 +from torchvision import transforms + + +class NTKConv2d(nn.Module): + """Conv2d layer under NTK parametrization.""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, bias=True, zero_init=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.bias = None + self.weight = nn.Parameter(torch.Tensor( + out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + self.init(zero_init) + + def init(self, zero_init=False): + if zero_init: + nn.init.constant_(self.weight, 0.) + if self.bias is not None: + nn.init.constant_(self.bias, 0.) + else: + nn.init.normal_(self.weight, 0., 1.) + if self.bias is not None: + nn.init.normal_(self.bias, 0., 1.) + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def forward(self, x, add_bias=True): + weight = np.sqrt(1. / self.out_channels) * self.weight + if add_bias and self.bias is not None: + bias = np.sqrt(.1) * self.bias + return F.conv2d(x, weight, bias, self.stride, self.padding) + else: + return F.conv2d(x, weight, None, self.stride, self.padding) + + +class NTKLinear(nn.Module): + """Linear layer under NTK parametrization.""" + def __init__(self, in_features, out_features, bias=True, zero_init=False): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.bias = None + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + self.init(zero_init) + + def init(self, zero_init=False): + if zero_init: + nn.init.constant_(self.weight, 0.) + if self.bias is not None: + nn.init.constant_(self.bias, 0.) + else: + nn.init.normal_(self.weight, 0., 1.) + if self.bias is not None: + nn.init.normal_(self.bias, 0., 1.) + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + + def thaw(self): + for p in self.parameters(): + p.requires_grad = True + + def forward(self, x, add_bias=True): + weight = np.sqrt(1. / self.out_features) * self.weight + if add_bias and self.bias is not None: + bias = np.sqrt(.1) * self.bias + return F.linear(x, weight, bias) + else: + return F.linear(x, weight, None) + + +def std_to_ntk_conv2d(conv2d): + """STD Conv2d -> NTK Conv2d""" + if isinstance(conv2d, NTKConv2d): + return conv2d + bias = True if conv2d.bias is not None else False + ntk_conv2d = NTKConv2d(conv2d.in_channels, conv2d.out_channels, + conv2d.kernel_size[0], conv2d.stride, conv2d.padding, bias=bias) + # parameter rescaling (see Jacot et al. NeurIPS 18) + ntk_conv2d.weight.data = conv2d.weight.data / np.sqrt(1. / conv2d.out_channels) + if bias: + ntk_conv2d.bias.data = conv2d.bias.data / np.sqrt(.1) + return ntk_conv2d + + +def ntk_to_std_conv2d(conv2d): + """NTK Conv2d -> STD Conv2d""" + if isinstance(conv2d, nn.Conv2d): + return conv2d + bias = True if conv2d.bias is not None else False + std_conv2d = nn.Conv2d(conv2d.in_channels, conv2d.out_channels, + conv2d.kernel_size[0], conv2d.stride, conv2d.padding, bias=bias) + # parameter rescaling (see Jacot et al. NeurIPS 18) + std_conv2d.weight.data = conv2d.weight.data * np.sqrt(1. / conv2d.out_channels) + if bias: + std_conv2d.bias.data = conv2d.bias.data * np.sqrt(.1) + return std_conv2d + + +def std_to_ntk_linear(fc): + """STD Linear -> NTK Linear""" + if isinstance(fc, NTKLinear): + return fc + bias = True if fc.bias is not None else False + ntk_fc = NTKLinear(fc.in_features, fc.out_features) + # parameter rescaling (see Jacot et al. NeurIPS 18) + ntk_fc.weight.data = fc.weight.data / np.sqrt(1. / fc.out_features) + if bias: + ntk_fc.bias.data = fc.bias.data / np.sqrt(.1) + return ntk_fc + + +def ntk_to_std_linear(fc): + """NTK Linear -> STD Linear""" + if isinstance(fc, NTKLinear): + return fc + bias = True if fc.bias is not None else False + std_fc = NTKLinear(fc.in_features, fc.out_features) + # parameter rescaling (see Jacot et al. NeurIPS 18) + std_fc.weight.data = fc.weight.data * np.sqrt(1. / fc.out_features) + if bias: + std_fc.bias.data = fc.bias.data * np.sqrt(.1) + return std_fc + + +def merge_batchnorm(conv2d, batchnorm): + """Folds BatchNorm2d into Conv2d.""" + if isinstance(batchnorm, nn.Identity): + return conv2d + mean = batchnorm.running_mean + sigma = torch.sqrt(batchnorm.running_var + batchnorm.eps) + beta = batchnorm.weight + gamma = batchnorm.bias + + w = conv2d.weight + if conv2d.bias is not None: + b = conv2d.bias + else: + b = torch.zeros_like(mean) + + w = w * (beta / sigma).view(conv2d.out_channels, 1, 1, 1) + b = (b - mean) / sigma * beta + gamma + + fused_conv2d = nn.Conv2d( + conv2d.in_channels, conv2d.out_channels, conv2d.kernel_size, + conv2d.stride, conv2d.padding) + fused_conv2d.weight.data = w + fused_conv2d.bias.data = b + + return fused_conv2d + + +def load_data(dataset, path, batch_size=64, normalize=False): + if normalize: + # Wasserstein BiGAN is trained on normalized data. + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + else: + # BiGAN is trained on unnormalized data (see Dumoulin et al. ICLR 16). + transform = transforms.ToTensor() + + if dataset == 'svhn': + train_set = SVHN(path, split='extra', transform=transform, download=True) + val_set = SVHN(path, split='test', transform=transform, download=True) + + elif dataset == 'cifar10': + train_set = CIFAR10(path, train=True, transform=transform, download=True) + val_set = CIFAR10(path, train=False, transform=transform, download=True) + + elif dataset == 'cifar100': + train_set = CIFAR100(path, train=True, transform=transform, download=True) + val_set = CIFAR100(path, train=False, transform=transform, download=True) + + train_loader = data.DataLoader( + train_set, batch_size, shuffle=True, num_workers=12) + val_loader = data.DataLoader( + val_set, 1, shuffle=False, num_workers=1, pin_memory=True) + return train_loader, val_loader diff --git a/finite_ntk/experiments/unsup/vi_benchmark_rop.py b/finite_ntk/experiments/unsup/vi_benchmark_rop.py new file mode 100644 index 0000000..1c2d267 --- /dev/null +++ b/finite_ntk/experiments/unsup/vi_benchmark_rop.py @@ -0,0 +1,314 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import time, configargparse, torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +from model import Net +from vi_model import LinearizedVINet +from util import load_data +import torch.backends.cudnn as cudnn + +from finite_ntk.lazy import flatten, unflatten_like +from finite_ntk.lazy.utils import Rop + +def parser_args(): + parser = configargparse.ArgParser('BiGAN experiments') + parser.add('-c', '--config', required=True, is_config_file=True, + help='config file') + + parser.add_argument('--gpu', type=int, default=0, + help='gpu instance to use (default: 0)') + parser.add_argument('--seed', type=int, default=1818, + help='random seed (default: 1818)') + + # dataset + parser.add_argument('--dataset', type=str, default='cifar10', + choices={'cifar10', 'cifar100', 'svhn'}, + help='dataset (default: cifar10)') + parser.add_argument('--data_path', type=str, + help='path to load dataset') + parser.add_argument('--nclass', type=int, default=10, + help='number of classes (default: 10)') + parser.add_argument('--batchsize', type=int, default=128, + help='batch size (default: 128)') + parser.add_argument('--normalize', action='store_true', + help='whether to normalize data') + + # optimization + parser.add_argument('--optim', type=str, default='sgd', + help='optimizer (default: sgd)') + parser.add_argument('--lr', type=float, default=1e-3, + help='learning rate (default: 1e-3)') + parser.add_argument('--wd', type=float, default=1e-6, + help='weight decay (default: 5e-5)') + parser.add_argument('--niter', type=int, default=80000, + help='number of training iterations (default: 80000)') + parser.add_argument('--stepsize', type=int, default=20000, + help='by which learning rate is halved (default: 20000)') + + # network + parser.add_argument('--mode', type=str, default='full', + choices={'full_plus_gradients', 'gradfull', 'full', 'actv', 'grad'}, + help='features to use (default: full)') + parser.add_argument('--fnet_path', type=str, + help='path to load fnet') + parser.add_argument('--hnet_path', type=str, + help='path to load hnet') + parser.add_argument('--clf_path', type=str, + help='path to load clf') + parser.add_argument('--model_path', type=str, + help='path to save model') + parser.add_argument('--freeze_hnet', nargs='+', type=int, default=0, + help='hnet layers to freeze') + parser.add_argument('--linearize_hnet', nargs='+', type=int, default=0, + help='hnet layers to linearize') + parser.add_argument('--linearize_clf', action='store_true', + help='whether to linearize the classifier') + + return parser.parse_args() + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + +def get_kl_terms(model): + # this first set of terms is the fnet jacobian parameterss + prior_dist = torch.distributions.Normal( + torch.zeros_like(model.base_model.fnet_mean), torch.ones_like(model.base_model.fnet_mean) + ) + current_dist = torch.distributions.Normal( + model.base_model.fnet_mean, torch.nn.functional.softplus(model.base_model.fnet_inv_sigma) + 1e-5 + ) + regularizer = ( + torch.distributions.kl_divergence(current_dist, prior_dist).sum() + ) + + # now add in the other linearization parameters + regularizer = regularizer + model.kl_divergence() + + return regularizer + +def prepare_logits(model, fnet, x, mode, training_mode=True): + if training_mode: + current_dist = torch.distributions.Normal( + model.base_model.fnet_mean, torch.nn.functional.softplus(model.base_model.fnet_inv_sigma) + ) + fnet_sample = current_dist.rsample() + else: + fnet_sample = model.base_model.fnet_mean + + if mode == 'full': + # proposed model + logits, jvp = model(x) + logits = logits + jvp + + elif mode == 'grad': + # gradient baseline (second term in proposed model) + _, jvp = model(x) + logits = jvp + + elif mode == 'gradfull': + _, jvp = model(x) + with torch.enable_grad(): + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(fnet_sample, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = jvp + fnet_jvp @ model.base_model.fixed_projection + + elif mode == 'full_plus_gradients': + logits, jvp = model(x) + + with torch.enable_grad(): + fnet_jvp = Rop(fnet(x), fnet.parameters(), + unflatten_like(fnet_sample, fnet.parameters()), + create_graph=True)[0] + fnet_jvp = fnet_jvp.view(fnet_jvp.shape[0], -1) / x.shape[0] + logits = logits + jvp + fnet_jvp @ model.base_model.fixed_projection + + else: + # activation baseline or fine-tuning (first term in proposed model) + logits = model(x) + + return logits + +def train(device, loader, model, fnet, mode, + optimizer, niter, stepsize, losses, it=0): + batch_time = AverageMeter() + data_time = AverageMeter() + end = time.time() + + curr_iter = it + model.train() + + for (x, y) in loader: + optimizer.zero_grad() + + data_time.update(time.time() - end) + + # update learning rate + if curr_iter != 0 and curr_iter % stepsize == 0: + for param_group in optimizer.param_groups: + param_group['lr'] = param_group['lr'] * 0.5 + print('iter %d learning rate is %.5f' % (curr_iter, param_group['lr'])) + + x, y = x.to(device), y.to(device) + + logits = prepare_logits(model, fnet, x, mode, training_mode=True) + + logprob = nn.CrossEntropyLoss()(logits, y) + loss = logprob * y.shape[0] + get_kl_terms(model) / 50000. #number of cifar10 data points + loss = loss / y.shape[0] # for avg data loss + + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 10) # clip gradient + optimizer.step() + + losses.update(loss.item(), x.size(0)) + batch_time.update(time.time() - end) + end = time.time() + + if curr_iter % 50 == 0: + print('Iteration[{0}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( + curr_iter, + batch_time=batch_time, data_time=data_time, loss=losses)) + curr_iter += 1 + + if curr_iter == niter: + break + + return curr_iter + + +def evaluate(device, loader, model, fnet, mode): + model.eval() + ncorr = 0 + + for i, (x, y) in enumerate(loader): + x, y = x.to(device), y.to(device) + + with torch.no_grad(): + logits = prepare_logits(model, fnet, x, mode, training_mode=False) + + pred = torch.argmax(logits.detach_(), dim=1) + ncorr += (pred == y).sum() + + acc = ncorr.float() / len(loader) + print(acc.item()) + + +def main(): + args = parser_args() + print('Batch size: %d' % args.batchsize) + print('Initial learning rate: %.5f' % args.lr) + print('Weight decay: %.6f' % args.wd) + + device = torch.device('cuda:' + str(args.gpu) + if torch.cuda.is_available() else 'cpu') + cudnn.benchmark = True + + # fix random seed + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + + basenet = Net(nclasses=args.nclass) + fnet = torch.load(args.fnet_path) # feature net (theta_1) + hnet = torch.load(args.hnet_path) # head net (theta_2) + clf = torch.load(args.clf_path) # classifier (omega) + + # load parameters (random vs. pre-trained) as appropriate + basenet.load_fnet(fnet, freeze=True) + basenet.load_hnet(hnet, reinit_idx=(), + freeze_idx=args.freeze_hnet, linearize_idx=args.linearize_hnet) + basenet.load_clf(clf, reinit=False, linearize=args.linearize_clf) + + # construct + basenet.fnet_mean = torch.nn.Parameter(torch.zeros(flatten(basenet.fnet.parameters()).shape[0], + 1, requires_grad=True)) + basenet.fnet_inv_sigma = torch.nn.Parameter(-5. * torch.ones(flatten(basenet.fnet.parameters()).shape[0], 1)) + basenet.fnet_inv_sigma.requires_grad = True + + fixed_proj = 1e-4 * torch.randn(10816, 10) #TODO: make not hardcoded, infer from network shape + basenet.fixed_projection = torch.nn.Parameter(fixed_proj) + basenet.fixed_projection.requires_grad = False + basenet.to(device) + net = LinearizedVINet(net=basenet, linearize_idx=args.linearize_hnet, prior_std=1.) + net.to(device) + + for p in fnet.parameters(): + p.requires_grad = True + + fnet.to(device) + params = list(filter(lambda p: p.requires_grad, net.parameters())) + if args.optim == 'sgd': + optimizer = optim.SGD( + params, lr=args.lr, weight_decay=args.wd, momentum=0.9) + elif args.optim == 'adam': + optimizer = optim.Adam( + params, lr=args.lr, betas=(0.5, 0.999), weight_decay=args.wd) + + # check trainable parameters + for p in params: + print(p.size()) + + # load training and test data + train_loader, test_loader = load_data( + args.dataset, args.data_path, args.batchsize, args.normalize) + + print('----- Training phase -----') + it = 0 + losses = AverageMeter() + + while it < args.niter: + it = train( + device, train_loader, net, fnet, args.mode, optimizer, + args.niter, args.stepsize, losses, it=it) + + print(int(it / int(50000 / args.batchsize))) + if int(it / int(50000 / args.batchsize)) % 10 == 0: + print('----- Evaluation phase -----') + print('> test accuracy:') + evaluate(device, test_loader, net, fnet, args.mode) + + print('----- Final Evaluation phase -----') + print('> test accuracy:') + evaluate(device, test_loader, net, fnet, args.mode) + + torch.save(net.cpu(), args.model_path) + + +if __name__ == '__main__': + main() diff --git a/finite_ntk/experiments/unsup/vi_model.py b/finite_ntk/experiments/unsup/vi_model.py new file mode 100644 index 0000000..36d32ed --- /dev/null +++ b/finite_ntk/experiments/unsup/vi_model.py @@ -0,0 +1,143 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import torch + +from finite_ntk.lazy import flatten +import numpy as np + + +#### these utils from +# https://github.com/wjmaddox/drbayes/blob/master/subspace_inference/posteriors/utils.py + +def extract_parameters(model): + params = [] + for module in model.modules(): + for name in list(module._parameters.keys()): + if module._parameters[name] is None: + continue + param = module._parameters[name] + params.append((module, name, param.size())) + module._parameters.pop(name) + return params + +def set_weights(params, w, device): + offset = 0 + for module, name, shape in params: + size = np.prod(shape) + value = w[offset:offset + size] + setattr(module, name, value.view(shape).to(device)) + offset += size + +class LinearizedVINet(torch.nn.Module): + def __init__(self, net, prior_std=1., linearize_idx=(3,), eps=1e-6): + super(LinearizedVINet, self).__init__() + self.base_model = net + print(linearize_idx) + + self.eps = 1e-6 + + # this is NOT a general purpose vi method!!! + # specifically, it's an adaptation of SVI for the specific setting of the linearized networks + # defined in Mu et al, ICLR 2020 + numparams = 0 + self.base_params = {} + self.linear_means = {} + self.linear_inv_softplus_sigma = {} + if 1 in linearize_idx: + numparams += sum([p.numel() for p in self.base_model.hnet.linear1.parameters()]) + flattened_pars = flatten(self.base_model.hnet.linear1.parameters()) + self.base_params['linear1'] = extract_parameters(self.base_model.hnet.linear1) + + self.linear_means_linear1 = torch.nn.Parameter( + torch.zeros(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_inv_softplus_sigma_linear1 = torch.nn.Parameter( + -5. * torch.ones(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_means_linear1.requires_grad = True + self.linear_inv_softplus_sigma_linear1.requires_grad = True + + if 2 in linearize_idx: + numparams += sum([p.numel() for p in self.base_model.hnet.linear2.parameters()]) + flattened_pars = flatten(self.base_model.hnet.linear2.parameters()) + self.base_params['linear2'] = extract_parameters(self.base_model.hnet.linear2) + + + self.linear_means_linear2 = torch.nn.Parameter( + torch.zeros(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_inv_softplus_sigma_linear2 = torch.nn.Parameter( + -5. * torch.ones(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_means_linear2.requires_grad = True + self.linear_inv_softplus_sigma_linear2.requires_grad = True + + if 3 in linearize_idx: + numparams += sum([p.numel() for p in self.base_model.hnet.linear3.parameters()]) + flattened_pars = flatten(self.base_model.hnet.linear3.parameters()) + self.base_params['linear3'] = extract_parameters(self.base_model.hnet.linear3) + + + self.linear_means_linear3 = torch.nn.Parameter( + torch.zeros(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_inv_softplus_sigma_linear3 = torch.nn.Parameter( + -5. * torch.ones(flattened_pars.shape[0], device=flattened_pars.device) + ) + self.linear_means_linear3.requires_grad = True + self.linear_inv_softplus_sigma_linear3.requires_grad = True + + self.linearize_idx = linearize_idx + self.prior_std = prior_std + + def sample_and_set_weights(self, param, mean, inv_sigma): + sigma = torch.nn.functional.softplus(inv_sigma) + self.eps + + if self.train: + z = mean + torch.randn_like(mean) * sigma + else: + z = mean + + device = sigma.device + set_weights(param, z, device) + + def _kldiv(self, mean, inv_sigma): + prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(mean) * self.prior_std) + q_dist = torch.distributions.Normal(mean, torch.nn.functional.softplus(inv_sigma) + self.eps) + return torch.distributions.kl.kl_divergence(q_dist, prior_dist) + + def kl_divergence(self): + kldiv = 0. + if 1 in self.linearize_idx: + kldiv = kldiv + self._kldiv(self.linear_means_linear1, self.linear_inv_softplus_sigma_linear1).sum() + if 2 in self.linearize_idx: + kldiv = kldiv + self._kldiv(self.linear_means_linear2, self.linear_inv_softplus_sigma_linear2).sum() + if 3 in self.linearize_idx: + kldiv = kldiv + self._kldiv(self.linear_means_linear3, self.linear_inv_softplus_sigma_linear3).sum() + return kldiv + + + def forward(self, *args, **kwargs): + if 1 in self.linearize_idx: + self.sample_and_set_weights(self.base_params['linear1'], + self.linear_means_linear1, self.linear_inv_softplus_sigma_linear1) + if 2 in self.linearize_idx: + self.sample_and_set_weights(self.base_params['linear2'], + self.linear_means_linear2, self.linear_inv_softplus_sigma_linear2) + if 3 in self.linearize_idx: + self.sample_and_set_weights(self.base_params['linear3'], + self.linear_means_linear3, self.linear_inv_softplus_sigma_linear3) + + return self.base_model(*args, **kwargs)