Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # can be removed when moving to Conv2d
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
model = Decorrelation(784, bias=False, variance=1.0).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.560814
epoch 1  	time:0.554 s	bp loss: 0.000000	decorrelation loss: 92.458618
epoch 2  	time:0.525 s	bp loss: 0.000000	decorrelation loss: 67.835510
epoch 3  	time:0.527 s	bp loss: 0.000000	decorrelation loss: 49.946434
epoch 4  	time:0.549 s	bp loss: 0.000000	decorrelation loss: 37.211742
epoch 5  	time:0.557 s	bp loss: 0.000000	decorrelation loss: 27.915833
epoch 6  	time:0.547 s	bp loss: 0.000000	decorrelation loss: 20.952473
epoch 7  	time:0.524 s	bp loss: 0.000000	decorrelation loss: 15.865977
epoch 8  	time:0.529 s	bp loss: 0.000000	decorrelation loss: 12.110641
epoch 9  	time:0.546 s	bp loss: 0.000000	decorrelation loss: 9.355850
epoch 10 	time:0.538 s	bp loss: 0.000000	decorrelation loss: 7.313002


In [4]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.380835	decorrelation loss: 106.638420
epoch 1  	time:0.566 s	bp loss: 4.126562	decorrelation loss: 92.378510
epoch 2  	time:0.558 s	bp loss: 2.728143	decorrelation loss: 67.550423
epoch 3  	time:0.567 s	bp loss: 2.297279	decorrelation loss: 50.074562
epoch 4  	time:0.546 s	bp loss: 2.049079	decorrelation loss: 37.256287
epoch 5  	time:0.552 s	bp loss: 1.873144	decorrelation loss: 27.816923
epoch 6  	time:0.580 s	bp loss: 1.731331	decorrelation loss: 20.886711
epoch 7  	time:0.567 s	bp loss: 1.619119	decorrelation loss: 15.840314
epoch 8  	time:0.558 s	bp loss: 1.512893	decorrelation loss: 12.159686
epoch 9  	time:0.579 s	bp loss: 1.421352	decorrelation loss: 9.431981
epoch 10 	time:0.563 s	bp loss: 1.346510	decorrelation loss: 7.337277


In [6]:
class Model(nn.Sequential):
    def __init__(self, in_features, variance):
        super().__init__(DecorLinear(in_features, 100, bias=True, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, bias=True, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-5, eta=1e-5, epochs=20)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.659184	decorrelation loss: 111.116158
epoch 1  	time:0.559 s	bp loss: 2.598590	decorrelation loss: 110.478119
epoch 2  	time:0.568 s	bp loss: 2.503471	decorrelation loss: 110.325562
epoch 3  	time:0.561 s	bp loss: 2.433827	decorrelation loss: 110.372955
epoch 4  	time:0.562 s	bp loss: 2.385222	decorrelation loss: 110.450310
epoch 5  	time:0.558 s	bp loss: 2.344632	decorrelation loss: 110.519104
epoch 6  	time:0.558 s	bp loss: 2.303873	decorrelation loss: 110.553780
epoch 7  	time:0.557 s	bp loss: 2.271375	decorrelation loss: 110.941040
epoch 8  	time:0.557 s	bp loss: 2.238182	decorrelation loss: 110.781067
epoch 9  	time:0.555 s	bp loss: 2.203967	decorrelation loss: 111.277313
epoch 10 	time:0.555 s	bp loss: 2.173552	decorrelation loss: 111.346657
epoch 11 	time:0.556 s	bp loss: 2.141980	decorrelation loss: 111.523827
epoch 12 	time:0.557 s	bp loss: 2.111793	decorrelation loss: 112.008949
epoch 13 	time:0.557 s	bp loss: 2.081334	decorrelation loss: 112

In [None]:
# Convolutional; and run main experiment on full data; do we have a gain?


Hypothesis for errors:
 - MNIST zeros
 - MNIST scaling (bias?)
 - variance choice => no normalization would also be an option
 - optimizer issues
 - add other learning rule?