Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.5)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
args = argparse.Namespace(lr=0.0, eta=1e-2, decor_lr=1e0, whiten=True, epochs=10)

model = Decorrelation(784, bias=False, eta=args.eta, whiten=args.whiten).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 6.627906
epoch 1  	time:0.368 s	bp loss: 0.000000	decorrelation loss: 6.338502
epoch 2  	time:0.401 s	bp loss: 0.000000	decorrelation loss: 5.772131
epoch 3  	time:0.454 s	bp loss: 0.000000	decorrelation loss: 5.262663
epoch 4  	time:0.352 s	bp loss: 0.000000	decorrelation loss: 4.841757
epoch 5  	time:0.347 s	bp loss: 0.000000	decorrelation loss: 4.469853
epoch 6  	time:0.462 s	bp loss: 0.000000	decorrelation loss: 4.121214
epoch 7  	time:0.415 s	bp loss: 0.000000	decorrelation loss: 3.824177
epoch 8  	time:0.379 s	bp loss: 0.000000	decorrelation loss: 3.550914
epoch 9  	time:0.355 s	bp loss: 0.000000	decorrelation loss: 3.314016
epoch 10 	time:0.387 s	bp loss: 0.000000	decorrelation loss: 3.112059


In [4]:
args = argparse.Namespace(lr=1e-4, eta=1e-2, decor_lr=1e0, whiten=True, epochs=10)

class Model(nn.Sequential):
    def __init__(self, in_features, eta, whiten):
        super().__init__(DecorLinear(in_features, 100, eta=eta, whiten=whiten))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, whiten=args.whiten).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.008767	decorrelation loss: 6.632705
epoch 1  	time:0.340 s	bp loss: 4.302653	decorrelation loss: 6.338794
epoch 2  	time:0.343 s	bp loss: 3.197514	decorrelation loss: 5.750855
epoch 3  	time:0.339 s	bp loss: 2.656548	decorrelation loss: 5.279569
epoch 4  	time:0.427 s	bp loss: 2.371899	decorrelation loss: 4.851607
epoch 5  	time:0.391 s	bp loss: 2.208116	decorrelation loss: 4.459737
epoch 6  	time:0.374 s	bp loss: 2.083977	decorrelation loss: 4.116480
epoch 7  	time:0.334 s	bp loss: 1.984350	decorrelation loss: 3.816057
epoch 8  	time:0.357 s	bp loss: 1.891012	decorrelation loss: 3.558414
epoch 9  	time:0.429 s	bp loss: 1.808348	decorrelation loss: 3.338013
epoch 10 	time:0.413 s	bp loss: 1.733128	decorrelation loss: 3.116416


In [5]:
args = argparse.Namespace(lr=1e-3, eta=1e-2, decor_lr=1e1, whiten=True, epochs=30) # NOTE: Fails for whiten=True

class Model(nn.Sequential):
    def __init__(self, in_features, eta, whiten):
        super().__init__(DecorLinear(in_features, 100, decor_bias=False, eta=eta, whiten=whiten),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=False, eta=eta, whiten=whiten)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, whiten=args.whiten).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.482052	decorrelation loss: 6.917549
epoch 1  	time:0.458 s	bp loss: 2.086559	decorrelation loss: 5.980656
epoch 2  	time:0.404 s	bp loss: 1.297697	decorrelation loss: 4.035069
epoch 3  	time:0.436 s	bp loss: 0.941635	decorrelation loss: 2.515571
epoch 4  	time:0.397 s	bp loss: 0.758750	decorrelation loss: 2.012101
epoch 5  	time:0.370 s	bp loss: 0.665646	decorrelation loss: 2.113142
epoch 6  	time:0.409 s	bp loss: 0.595319	decorrelation loss: 1.488684
epoch 7  	time:0.605 s	bp loss: 0.530433	decorrelation loss: 1.215478
epoch 8  	time:0.379 s	bp loss: 0.474380	decorrelation loss: 1.158361
epoch 9  	time:0.358 s	bp loss: 0.437583	decorrelation loss: 0.948612
epoch 10 	time:0.354 s	bp loss: 0.411881	decorrelation loss: 0.876683
epoch 11 	time:0.505 s	bp loss: 0.383885	decorrelation loss: 0.754988
epoch 12 	time:0.356 s	bp loss: 0.353847	decorrelation loss: 0.715052
epoch 13 	time:0.359 s	bp loss: 0.333629	decorrelation loss: 0.731375
epoch 14 	time:0.364

In [6]:
# TO DO: ConvNet
