Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
args = argparse.Namespace(lr=0.0, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

model = Decorrelation(784, bias=False, eta=args.eta, variance=args.variance).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.560814
epoch 1  	time:0.297 s	bp loss: 0.000000	decorrelation loss: 71.353897
epoch 2  	time:0.269 s	bp loss: 0.000000	decorrelation loss: 28.923294
epoch 3  	time:0.268 s	bp loss: 0.000000	decorrelation loss: 12.493031
epoch 4  	time:0.282 s	bp loss: 0.000000	decorrelation loss: 5.892365
epoch 5  	time:0.297 s	bp loss: 0.000000	decorrelation loss: 3.095443
epoch 6  	time:0.276 s	bp loss: 0.000000	decorrelation loss: 1.836523
epoch 7  	time:0.270 s	bp loss: 0.000000	decorrelation loss: 1.248613
epoch 8  	time:0.271 s	bp loss: 0.000000	decorrelation loss: 1.002738
epoch 9  	time:0.277 s	bp loss: 0.000000	decorrelation loss: 0.887474
epoch 10 	time:0.286 s	bp loss: 0.000000	decorrelation loss: 0.812327


In [4]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, eta=eta, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.380835	decorrelation loss: 106.638420
epoch 1  	time:0.375 s	bp loss: 4.167106	decorrelation loss: 71.140137
epoch 2  	time:0.308 s	bp loss: 2.883932	decorrelation loss: 28.747271
epoch 3  	time:0.343 s	bp loss: 2.493549	decorrelation loss: 12.492141
epoch 4  	time:0.401 s	bp loss: 2.280089	decorrelation loss: 5.872338
epoch 5  	time:0.336 s	bp loss: 2.151146	decorrelation loss: 3.069327
epoch 6  	time:0.365 s	bp loss: 2.063200	decorrelation loss: 1.810406
epoch 7  	time:0.339 s	bp loss: 2.014316	decorrelation loss: 1.257885
epoch 8  	time:0.267 s	bp loss: 1.983537	decorrelation loss: 1.006842
epoch 9  	time:0.267 s	bp loss: 1.943933	decorrelation loss: 0.883589
epoch 10 	time:0.341 s	bp loss: 1.870017	decorrelation loss: 0.820903


In [11]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-3, variance=1.0, epochs=30) # FAILS FOR E.G. variance=1.0 at this learning rate

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, decor_bias=True, eta=eta, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=True, eta=eta, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.500252	decorrelation loss: 110.414078
epoch 1  	time:0.294 s	bp loss: 2.369059	decorrelation loss: 110.095940
epoch 2  	time:0.274 s	bp loss: 2.093243	decorrelation loss: 111.840752
epoch 3  	time:0.336 s	bp loss: 1.851873	decorrelation loss: 119.033813
epoch 4  	time:0.313 s	bp loss: 1.649912	decorrelation loss: 128.727875
epoch 5  	time:0.266 s	bp loss: 1.483499	decorrelation loss: 141.375565
epoch 6  	time:0.280 s	bp loss: 1.328734	decorrelation loss: 156.265747
epoch 7  	time:0.323 s	bp loss: 1.215306	decorrelation loss: 171.440079
epoch 8  	time:0.290 s	bp loss: 1.098081	decorrelation loss: 188.659470
epoch 9  	time:0.265 s	bp loss: 1.006956	decorrelation loss: 210.630112
epoch 10 	time:0.312 s	bp loss: 0.907868	decorrelation loss: 234.998413
epoch 11 	time:0.329 s	bp loss: 0.834331	decorrelation loss: 259.065826
epoch 12 	time:0.276 s	bp loss: 0.777433	decorrelation loss: 285.908173
epoch 13 	time:0.282 s	bp loss: 0.727131	decorrelation loss: 303

In [6]:
# TO DO: ConvNet
