In [1]:
%load_ext autoreload
%autoreload 2

import os, sys, itertools
from pathlib import Path
sys.path.append('../../')

import torch
from torch import nn, optim
from torch.nn import functional as F

from PreTrainedFIMv2.repara_model import VGG16bn_FIM
from PreTrainedFIMv2.util import CIFAR10Worker

params = {'epoch_num': 10, 'log_interval': 1250}
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [2]:
vgg_model = VGG16bn_FIM().to(device)
vgg_model = vgg_model.to_device_child_tensors(device)
criterion = nn.CrossEntropyLoss()
worker = CIFAR10Worker(device, vgg_model, criterion, params)
worker.model.not_eval_FIM()

worker = worker.load_data_loader()
worker = worker.set_save_path('1epoch_VGG16bn_p1_w0_diagonal.pth')
worker = worker.load_chckpt()
worker.test()

Files already downloaded and verified
Files already downloaded and verified
Test-set loss: 0.1237
Accuracy: 84.180%


In [3]:
worker.model.initialize_FIM_weight()
worker.model.inactivate_parameters_ex_specific_layer(12)
worker.test(default_env=False)

Test-set loss: 0.1727
Accuracy: 78.340%


In [4]:
worker.model.initialize_FIM_weight()
worker.model.inactivate_parameters_ex_specific_layer(12)
worker.test(default_env=False)

False
False
False
False
False
False
False
False
False
False
False
False
True
Test-set loss: 0.1730
Accuracy: 78.120%


# OK, Let's estimate Fisher Information Matrix!

In [5]:
for batch_idx, (inputs, targets) in enumerate(worker.testloader):
    if batch_idx > 1: break
    else: pass
inputs, targets = inputs.to(device), targets.to(device)
targets_hat, logvars = worker.model(inputs)
_, predicted = targets_hat.max(1)
logvars = torch.squeeze(torch.cat(logvars,dim=1))
print(targets, predicted, logvars)

tensor([3, 1, 0, 9], device='cuda:1') tensor([3, 1, 0, 9], device='cuda:1') tensor([[1.0064, 1.0038, 1.0045,  ..., 1.0067, 1.0067, 1.0028],
        [1.0060, 1.0038, 1.0052,  ..., 1.0304, 1.0210, 1.0028],
        [1.0057, 1.0038, 1.0046,  ..., 1.0077, 1.0071, 1.0031],
        [1.0068, 1.0038, 1.0052,  ..., 1.0909, 1.0524, 1.0094]],
       device='cuda:1', grad_fn=<SqueezeBackward0>)


# Evaluate FIM

In [6]:
next(worker.model.features[53].logvar[0].parameters())[0,0,:10,:10]

tensor([[ 0.0041, -0.0297,  0.0042],
        [ 0.0096, -0.0011, -0.0188],
        [ 0.0121, -0.0075, -0.0074]], device='cuda:1', grad_fn=<SliceBackward>)

In [7]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, worker.model.parameters()), lr=0.001)
epoch = 1
worker.evaluate_FIM(epoch, optimizer, [12])
targets_hat, logvars = worker.model(inputs)
logvars = torch.squeeze(torch.cat(logvars,dim=1))
print(logvars)
worker.test(default_env=False)

====> Epoch: 1 Average loss: 107255.8691
tensor([[1.0064, 1.0038, 1.0045,  ..., 1.0000, 1.0000, 1.0000],
        [1.0060, 1.0038, 1.0052,  ..., 1.0000, 1.0000, 1.0000],
        [1.0057, 1.0038, 1.0046,  ..., 1.0000, 1.0000, 1.0000],
        [1.0068, 1.0038, 1.0052,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:1', grad_fn=<SqueezeBackward0>)
Test-set loss: 0.1725
Accuracy: 78.210%


In [9]:
next(worker.model.features[53].logvar[0].parameters())[0,0,:10,:10]

tensor([[-0.0052, -0.0063,  0.0052],
        [ 0.0195, -0.0020,  0.0250],
        [ 0.0073, -0.0012,  0.0146]], device='cuda:1', grad_fn=<SliceBackward>)

In [12]:
worker.model.inactivate_parameters_ex_specific_layer(12)
worker.test(default_env=False)

Test-set loss: 0.1237
Accuracy: 84.190%


In [13]:
logvars[0, -580:-470]

tensor([1.0102, 1.0020, 1.0026, 1.0049, 1.0100, 1.0086, 1.0029, 1.0167, 1.0049,
        1.0063, 1.0082, 1.0028, 1.0074, 1.0072, 1.0062, 1.0070, 1.0116, 1.0184,
        1.0000, 1.0202, 1.0134, 1.0002, 1.0229, 1.0072, 1.0026, 1.0287, 1.0121,
        1.0073, 1.0186, 1.0030, 1.0068, 1.0061, 1.0096, 1.0190, 1.0070, 1.0062,
        1.0205, 1.0005, 1.0159, 1.0158, 1.0044, 1.0034, 1.0060, 1.0087, 1.0052,
        1.0000, 1.0055, 1.0029, 1.0168, 1.0088, 1.0145, 1.0189, 1.0130, 1.0108,
        1.0066, 1.0064, 1.0022, 1.0306, 1.0065, 1.0000, 1.0056, 1.0032, 1.0148,
        1.0054, 1.0071, 1.0027, 1.0057, 1.0185, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:1'

In [9]:
worker.test(default_env=True)

Test-set loss: 0.1237
Accuracy: 84.180%


In [10]:
from PreTrainedFIMv2.reparameterize_layer import ReparamterNorm
for layer in worker.model.features:
    if isinstance(layer, ReparamterNorm):
        print(layer.reparameterization)

False
False
False
False
False
False
False
False
False
False
False
False
False


In [11]:
worker.model.inactivate_parameters_ex_specific_layer(12)

In [12]:
for layer in worker.model.features:
    if isinstance(layer, ReparamterNorm):
        print(layer.reparameterization)

False
False
False
False
False
False
False
False
False
False
False
False
True


In [13]:
worker.test(default_env=False)

Test-set loss: 11.6129
Accuracy: 10.000%


In [14]:
from PreTrainedFIMv2.reparameterize_layer import ReparamterNorm
for layer in worker.model.features:
    if isinstance(layer, ReparamterNorm):
        print(layer.reparameterization)
        layer.logvar.apply(init_weights)
        x = layer

False
False
False
False
False
False
False
False
False
False
False
False
True


In [15]:
next(x.logvar[0].parameters())

Parameter containing:
tensor([[[[ 5.4633e-03, -1.3956e-02,  6.9591e-03],
          [ 1.4707e-02,  2.3867e-03,  9.0728e-03],
          [-3.7169e-03, -4.1768e-03,  1.4037e-02]],

         [[-7.2960e-03, -1.9033e-03,  3.9158e-03],
          [-3.8228e-03, -1.0482e-02, -9.5008e-03],
          [-5.2222e-03, -1.2270e-02,  1.4639e-02]],

         [[ 5.0269e-03, -2.6704e-03, -4.1772e-03],
          [ 7.5359e-03,  1.0762e-02,  2.0017e-02],
          [ 9.6328e-03,  1.5463e-02, -1.4614e-02]],

         ...,

         [[-3.2718e-03,  1.0801e-02, -5.7736e-03],
          [ 1.2905e-05, -4.2908e-03, -1.2092e-03],
          [ 3.0090e-03, -5.7543e-03, -3.8095e-03]],

         [[-1.8042e-02, -2.5839e-02,  1.0507e-03],
          [-6.3802e-03,  1.6689e-02, -1.9892e-02],
          [ 6.1013e-05,  5.9372e-03,  2.8489e-03]],

         [[-4.0836e-03,  3.1889e-03, -2.1419e-03],
          [-7.5283e-03, -1.7349e-03, -9.5544e-03],
          [ 1.1219e-02,  4.0675e-04, -5.0515e-03]]],


        [[[ 4.4917e-03,  5.5304

In [6]:
worker.model.eval_FIM()
for layerId in [12]:
    worker.model.inactivate_parameters_ex_specific_layer(layerId)
counter = 0
for layer in worker.model.features:
    if isinstance(layer, ReparamterNorm):
        print(layer.reparameterization)
        counter+=1

False
False
False
False
False
False
False
False
False
False
False
False
True


In [20]:
worker.model.features[49].reparameterization

False

In [17]:
counter

13

In [5]:
targets_hat, logvars = worker.model(inputs)
logvars = torch.squeeze(torch.cat(logvars,dim=1))
print(logvars)
_, predicted = targets_hat.max(1)
print(targets, predicted)

tensor([3, 1, 0, 9], device='cuda:1')
tensor([6, 6, 6, 9], device='cuda:1')
tensor([[1.0128, 1.0022, 1.0209,  ..., 1.0008, 1.0010, 1.0018],
        [1.0227, 1.0055, 1.0268,  ..., 1.0028, 1.0005, 1.0042],
        [1.0083, 1.0018, 1.0135,  ..., 1.0003, 1.0004, 1.0000],
        [1.0245, 1.0062, 1.0296,  ..., 1.0009, 1.0026, 1.0095]],
       device='cuda:1', grad_fn=<SqueezeBackward0>)


Test-set loss: 6.1692
Accuracy: 10.000%


In [36]:
targets_hat, logvars = worker.model(inputs)
_, predicted = targets_hat.max(1)
logvars = torch.squeeze(torch.cat(logvars,dim=1))
print(targets)
print(predicted)
print(logvars)

tensor([3, 1, 0, 9], device='cuda:1')
tensor([6, 6, 6, 6], device='cuda:1')
tensor([[1.0124, 1.0024, 1.0216,  ..., 1.0000, 1.0029, 1.0036],
        [1.0224, 1.0063, 1.0259,  ..., 1.0021, 1.0012, 1.0026],
        [1.0088, 1.0018, 1.0138,  ..., 1.0004, 1.0027, 1.0020],
        [1.0260, 1.0061, 1.0306,  ..., 1.0005, 1.0026, 1.0004]],
       device='cuda:1', grad_fn=<SqueezeBackward0>)


In [30]:
worker.model.eval_FIM()
for batch_idx, (inputs, targets) in enumerate(worker.testloader):
    if batch_idx > 1:
        break
    else:
        pass

inputs, targets = inputs.to(device), targets.to(device)
targets_hat, logvars = worker.model(inputs)
_, predicted = targets_hat.max(1)
logvars = torch.squeeze(torch.cat(logvars,dim=1))
print(targets)
print(predicted)

def loss_for_FIM(y_hat, logvars, y):
    kl_div = torch.sum(worker.model.layer_dims*(torch.exp(logvars) - logvars - 1))
    #w_squared = self.model.w_squared
    cross_entropy = F.cross_entropy(y_hat, y)
    return 10*cross_entropy +  kl_div # + w_squared
    #return kl_div + w_squared
loss = loss_for_FIM(targets_hat, logvars, targets)
loss.requires_grad = True

tensor([3, 1, 0, 9], device='cuda:1')
tensor([5, 5, 5, 5], device='cuda:1')


In [31]:
loss.backward()

In [32]:
optimizer.step()

In [33]:
logvars

tensor([[1.0081, 1.0185, 1.0333, 1.0183, 1.0248, 1.0197, 1.0199, 1.0214, 1.0262,
         1.0272, 1.0129, 1.0195, 1.0938],
        [1.0101, 1.0209, 1.0334, 1.0180, 1.0251, 1.0197, 1.0206, 1.0228, 1.0283,
         1.0284, 1.0130, 1.0213, 1.0876],
        [1.0073, 1.0185, 1.0332, 1.0177, 1.0241, 1.0193, 1.0191, 1.0212, 1.0285,
         1.0270, 1.0126, 1.0228, 1.0774],
        [1.0109, 1.0213, 1.0342, 1.0182, 1.0244, 1.0199, 1.0202, 1.0219, 1.0258,
         1.0275, 1.0125, 1.0214, 1.0987]], device='cuda:1')

In [9]:
worker.model = worker.model.to(device)