Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [7]:
args = argparse.Namespace(lr=0.0, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

model = Decorrelation(784, bias=False, eta=args.eta, variance=args.variance).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.592293
epoch 1  	time:0.541 s	bp loss: 0.000000	decorrelation loss: 71.298180
epoch 2  	time:0.545 s	bp loss: 0.000000	decorrelation loss: 28.939163
epoch 3  	time:0.549 s	bp loss: 0.000000	decorrelation loss: 12.465863
epoch 4  	time:0.549 s	bp loss: 0.000000	decorrelation loss: 5.836260
epoch 5  	time:0.545 s	bp loss: 0.000000	decorrelation loss: 3.068482
epoch 6  	time:0.547 s	bp loss: 0.000000	decorrelation loss: 1.828948
epoch 7  	time:0.542 s	bp loss: 0.000000	decorrelation loss: 1.282410
epoch 8  	time:0.543 s	bp loss: 0.000000	decorrelation loss: 1.005300
epoch 9  	time:0.540 s	bp loss: 0.000000	decorrelation loss: 0.890753
epoch 10 	time:0.540 s	bp loss: 0.000000	decorrelation loss: 0.824732


In [8]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-4, variance=1.0, epochs=10)

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, eta=eta, variance=variance))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.491928	decorrelation loss: 106.572433
epoch 1  	time:0.569 s	bp loss: 4.314556	decorrelation loss: 71.382492
epoch 2  	time:0.568 s	bp loss: 2.989045	decorrelation loss: 28.873865
epoch 3  	time:0.567 s	bp loss: 2.557576	decorrelation loss: 12.559443
epoch 4  	time:0.567 s	bp loss: 2.350261	decorrelation loss: 5.853016
epoch 5  	time:0.566 s	bp loss: 2.232522	decorrelation loss: 3.052009
epoch 6  	time:0.568 s	bp loss: 2.139393	decorrelation loss: 1.813134
epoch 7  	time:0.569 s	bp loss: 2.077643	decorrelation loss: 1.269703
epoch 8  	time:0.567 s	bp loss: 2.056343	decorrelation loss: 1.004370
epoch 9  	time:0.566 s	bp loss: 2.018654	decorrelation loss: 0.884865
epoch 10 	time:0.564 s	bp loss: 1.939791	decorrelation loss: 0.818513


In [17]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-3, variance=None, epochs=10) # FAILS FOR E.G. variance=1.0 at this learning rate

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, bias=False, eta=eta, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, bias=False, eta=eta, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.544302	decorrelation loss: 109.754143
epoch 1  	time:0.566 s	bp loss: 2.414789	decorrelation loss: 8.191069
epoch 2  	time:0.571 s	bp loss: 2.217825	decorrelation loss: 0.575085
epoch 3  	time:0.570 s	bp loss: 2.078929	decorrelation loss: 0.408783
epoch 4  	time:0.576 s	bp loss: 1.962890	decorrelation loss: 0.336117
epoch 5  	time:0.565 s	bp loss: 1.860005	decorrelation loss: 0.304384
epoch 6  	time:0.564 s	bp loss: 1.764048	decorrelation loss: 0.283378
epoch 7  	time:0.562 s	bp loss: 1.667927	decorrelation loss: 0.273111
epoch 8  	time:0.562 s	bp loss: 1.593011	decorrelation loss: 0.264786
epoch 9  	time:0.560 s	bp loss: 1.515942	decorrelation loss: 0.259102
epoch 10 	time:0.560 s	bp loss: 1.447145	decorrelation loss: 0.251755


In [6]:
# TO DO: ConvNet
