Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
args = argparse.Namespace(lr=0.0, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

model = Decorrelation(784, bias=False, eta=args.eta, variance=args.variance).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.560814
epoch 1  	time:0.361 s	bp loss: 0.000000	decorrelation loss: 71.353897
epoch 2  	time:0.359 s	bp loss: 0.000000	decorrelation loss: 28.923294
epoch 3  	time:0.360 s	bp loss: 0.000000	decorrelation loss: 12.493031
epoch 4  	time:0.357 s	bp loss: 0.000000	decorrelation loss: 5.892365
epoch 5  	time:0.372 s	bp loss: 0.000000	decorrelation loss: 3.095443
epoch 6  	time:0.359 s	bp loss: 0.000000	decorrelation loss: 1.836523
epoch 7  	time:0.362 s	bp loss: 0.000000	decorrelation loss: 1.248613
epoch 8  	time:0.371 s	bp loss: 0.000000	decorrelation loss: 1.002738
epoch 9  	time:0.362 s	bp loss: 0.000000	decorrelation loss: 0.887474
epoch 10 	time:0.360 s	bp loss: 0.000000	decorrelation loss: 0.812327


In [4]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, eta=eta, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.380835	decorrelation loss: 106.638420
epoch 1  	time:0.389 s	bp loss: 4.167106	decorrelation loss: 71.140137
epoch 2  	time:0.387 s	bp loss: 2.883932	decorrelation loss: 28.747271
epoch 3  	time:0.387 s	bp loss: 2.493549	decorrelation loss: 12.492141
epoch 4  	time:0.386 s	bp loss: 2.280089	decorrelation loss: 5.872338
epoch 5  	time:0.383 s	bp loss: 2.151146	decorrelation loss: 3.069327
epoch 6  	time:0.396 s	bp loss: 2.063200	decorrelation loss: 1.810406
epoch 7  	time:0.386 s	bp loss: 2.014316	decorrelation loss: 1.257885
epoch 8  	time:0.388 s	bp loss: 1.983537	decorrelation loss: 1.006842
epoch 9  	time:0.380 s	bp loss: 1.943933	decorrelation loss: 0.883589
epoch 10 	time:0.382 s	bp loss: 1.870017	decorrelation loss: 0.820903


In [9]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-3, variance=None, epochs=10) # FAILS FOR E.G. variance=1.0 at this learning rate

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, decor_bias=True, eta=eta, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=True, eta=eta, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.574494	decorrelation loss: 110.781136
epoch 1  	time:0.356 s	bp loss: 2.431993	decorrelation loss: 16.077549
epoch 2  	time:0.348 s	bp loss: 2.198667	decorrelation loss: 0.958726
epoch 3  	time:0.332 s	bp loss: 1.991839	decorrelation loss: 0.831519
epoch 4  	time:0.357 s	bp loss: 1.816990	decorrelation loss: 1.540964
epoch 5  	time:0.353 s	bp loss: 1.665717	decorrelation loss: 4.488208
epoch 6  	time:0.367 s	bp loss: 1.533112	decorrelation loss: 26.396355
epoch 7  	time:0.348 s	bp loss: 1.421979	decorrelation loss: 84.714844
epoch 8  	time:0.348 s	bp loss: 1.317872	decorrelation loss: 140.164139
epoch 9  	time:0.359 s	bp loss: 1.249197	decorrelation loss: 256.773102
epoch 10 	time:0.350 s	bp loss: 1.161782	decorrelation loss: 287.796539


In [7]:
# TO DO: ConvNet
