Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        torch.flatten # can be removed when moving to Conv2d
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
model = Decorrelation(784, bias=False, variance=1.0).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.560814
epoch 1  	time:0.578 s	bp loss: 0.000000	decorrelation loss: 92.458618
epoch 2  	time:0.543 s	bp loss: 0.000000	decorrelation loss: 67.835510
epoch 3  	time:0.532 s	bp loss: 0.000000	decorrelation loss: 49.946434
epoch 4  	time:0.541 s	bp loss: 0.000000	decorrelation loss: 37.211742
epoch 5  	time:0.536 s	bp loss: 0.000000	decorrelation loss: 27.915833
epoch 6  	time:0.532 s	bp loss: 0.000000	decorrelation loss: 20.952473
epoch 7  	time:0.528 s	bp loss: 0.000000	decorrelation loss: 15.865977
epoch 8  	time:0.530 s	bp loss: 0.000000	decorrelation loss: 12.110641
epoch 9  	time:0.528 s	bp loss: 0.000000	decorrelation loss: 9.355850
epoch 10 	time:0.533 s	bp loss: 0.000000	decorrelation loss: 7.313002


In [4]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.380835	decorrelation loss: 106.638420
epoch 1  	time:0.557 s	bp loss: 4.126562	decorrelation loss: 92.378510
epoch 2  	time:0.550 s	bp loss: 2.728143	decorrelation loss: 67.550423
epoch 3  	time:0.549 s	bp loss: 2.297279	decorrelation loss: 50.074562
epoch 4  	time:0.575 s	bp loss: 2.049079	decorrelation loss: 37.256287
epoch 5  	time:0.577 s	bp loss: 1.873144	decorrelation loss: 27.816923
epoch 6  	time:0.558 s	bp loss: 1.731331	decorrelation loss: 20.886711
epoch 7  	time:0.558 s	bp loss: 1.619119	decorrelation loss: 15.840314
epoch 8  	time:0.561 s	bp loss: 1.512893	decorrelation loss: 12.159686
epoch 9  	time:0.553 s	bp loss: 1.421352	decorrelation loss: 9.431981
epoch 10 	time:0.557 s	bp loss: 1.346510	decorrelation loss: 7.337277


In [6]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, bias=True, variance=1.0),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, bias=True, variance=1.0)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-5, eta=1e-5, epochs=20)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.659184	decorrelation loss: 111.116158
epoch 1  	time:0.565 s	bp loss: 2.588839	decorrelation loss: 95.809464
epoch 2  	time:0.562 s	bp loss: 2.490688	decorrelation loss: 70.225693
epoch 3  	time:0.562 s	bp loss: 2.426370	decorrelation loss: 51.945717
epoch 4  	time:0.561 s	bp loss: 2.382820	decorrelation loss: 38.757549
epoch 5  	time:0.567 s	bp loss: 2.347725	decorrelation loss: 28.985779
epoch 6  	time:0.569 s	bp loss: 2.312520	decorrelation loss: 21.931208
epoch 7  	time:0.578 s	bp loss: 2.285963	decorrelation loss: 16.769054
epoch 8  	time:0.566 s	bp loss: 2.259726	decorrelation loss: 12.832211
epoch 9  	time:0.561 s	bp loss: 2.230120	decorrelation loss: 9.983570
epoch 10 	time:0.562 s	bp loss: 2.205403	decorrelation loss: 7.837932
epoch 11 	time:0.560 s	bp loss: 2.179158	decorrelation loss: 6.218883
epoch 12 	time:0.561 s	bp loss: 2.155225	decorrelation loss: 5.047639
epoch 13 	time:0.560 s	bp loss: 2.133130	decorrelation loss: 4.105361
epoch 14 	

Hypothesis for errors:
 - MNIST zeros
 - MNIST scaling (bias?)
 - variance choice
 - optimizer issues