Analyses properties of decorrelation and whitening methods for decorrelated networks

In [7]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.5))
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)    
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [9]:
model = Decorrelation(784, variance=1.0).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 3.325848
epoch 1  	time:0.529 s	bp loss: 0.000000	decorrelation loss: 2.886503
epoch 2  	time:0.548 s	bp loss: 0.000000	decorrelation loss: 2.119532
epoch 3  	time:0.573 s	bp loss: 0.000000	decorrelation loss: 1.562483
epoch 4  	time:0.556 s	bp loss: 0.000000	decorrelation loss: 1.166143
epoch 5  	time:0.562 s	bp loss: 0.000000	decorrelation loss: 0.876968
epoch 6  	time:0.542 s	bp loss: 0.000000	decorrelation loss: 0.660428
epoch 7  	time:0.540 s	bp loss: 0.000000	decorrelation loss: 0.502379
epoch 8  	time:0.529 s	bp loss: 0.000000	decorrelation loss: 0.385764
epoch 9  	time:0.529 s	bp loss: 0.000000	decorrelation loss: 0.300305
epoch 10 	time:0.527 s	bp loss: 0.000000	decorrelation loss: 0.237025


In [15]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.404356	decorrelation loss: 3.324482
epoch 1  	time:0.572 s	bp loss: 4.667752	decorrelation loss: 2.885933
epoch 2  	time:0.557 s	bp loss: 3.502808	decorrelation loss: 2.113524
epoch 3  	time:0.598 s	bp loss: 2.883469	decorrelation loss: 1.571422
epoch 4  	time:0.613 s	bp loss: 2.567964	decorrelation loss: 1.164550
epoch 5  	time:0.569 s	bp loss: 2.384269	decorrelation loss: 0.872825
epoch 6  	time:0.556 s	bp loss: 2.252117	decorrelation loss: 0.659291
epoch 7  	time:0.550 s	bp loss: 2.152915	decorrelation loss: 0.503887
epoch 8  	time:0.552 s	bp loss: 2.068698	decorrelation loss: 0.386170
epoch 9  	time:0.551 s	bp loss: 1.983859	decorrelation loss: 0.301467
epoch 10 	time:0.549 s	bp loss: 1.913287	decorrelation loss: 0.236745


In [16]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, variance=variance), # CONFLATION OF DECORRELATION AND BP PARAMETERS??? HAS_GRAD??? SET NONE
                        nn.ReLU(),
                        DecorLinear(100, 10, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=20)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.504519	decorrelation loss: 3.514931
epoch 1  	time:0.586 s	bp loss: 2.413006	decorrelation loss: 3.069027
epoch 2  	time:0.570 s	bp loss: 2.244280	decorrelation loss: 2.355799
epoch 3  	time:0.584 s	bp loss: 2.054247	decorrelation loss: 2.765030
epoch 4  	time:0.577 s	bp loss: 1.841620	decorrelation loss: 6.795751
epoch 5  	time:0.562 s	bp loss: 1.666811	decorrelation loss: 16.553492
epoch 6  	time:0.561 s	bp loss: 1.505444	decorrelation loss: 17.876627
epoch 7  	time:0.561 s	bp loss: 1.376239	decorrelation loss: 21.855539
epoch 8  	time:0.568 s	bp loss: 1.252189	decorrelation loss: 30.634464
epoch 9  	time:0.586 s	bp loss: 1.137028	decorrelation loss: 28.027956
epoch 10 	time:0.570 s	bp loss: 1.037167	decorrelation loss: 34.538498
epoch 11 	time:0.577 s	bp loss: 0.949478	decorrelation loss: 40.352013
epoch 12 	time:0.576 s	bp loss: 0.868624	decorrelation loss: 49.843830
epoch 13 	time:0.579 s	bp loss: 0.809069	decorrelation loss: 62.101517
epoch 14 	t