Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [3]:
args = argparse.Namespace(lr=0.0, eta=0.3, decor_lr=1e-4, whiten=True, epochs=10)

model = Decorrelation(784, bias=False, eta=args.eta, whiten=args.whiten).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 106.554642
epoch 1  	time:0.339 s	bp loss: 0.000000	decorrelation loss: 71.350754
epoch 2  	time:0.338 s	bp loss: 0.000000	decorrelation loss: 28.919159
epoch 3  	time:0.373 s	bp loss: 0.000000	decorrelation loss: 12.483562
epoch 4  	time:0.396 s	bp loss: 0.000000	decorrelation loss: 5.877904
epoch 5  	time:0.300 s	bp loss: 0.000000	decorrelation loss: 3.077055
epoch 6  	time:0.300 s	bp loss: 0.000000	decorrelation loss: 1.815406
epoch 7  	time:0.299 s	bp loss: 0.000000	decorrelation loss: 1.225905
epoch 8  	time:0.286 s	bp loss: 0.000000	decorrelation loss: 0.979269
epoch 9  	time:0.332 s	bp loss: 0.000000	decorrelation loss: 0.863898
epoch 10 	time:0.315 s	bp loss: 0.000000	decorrelation loss: 0.789051


In [5]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-4, whiten=True, epochs=10)

class Model(nn.Sequential):
    def __init__(self, in_features, eta, whiten):
        super().__init__(DecorLinear(in_features, 100, eta=eta, whiten=whiten))

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, whiten=args.whiten).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.380835	decorrelation loss: 106.632347
epoch 1  	time:0.349 s	bp loss: 4.167106	decorrelation loss: 71.136711
epoch 2  	time:0.352 s	bp loss: 2.883931	decorrelation loss: 28.742876
epoch 3  	time:0.360 s	bp loss: 2.493549	decorrelation loss: 12.482757
epoch 4  	time:0.340 s	bp loss: 2.280089	decorrelation loss: 5.857890
epoch 5  	time:0.343 s	bp loss: 2.151146	decorrelation loss: 3.050898
epoch 6  	time:0.325 s	bp loss: 2.063200	decorrelation loss: 1.789254
epoch 7  	time:0.349 s	bp loss: 2.014316	decorrelation loss: 1.235146
epoch 8  	time:0.398 s	bp loss: 1.983537	decorrelation loss: 0.983310
epoch 9  	time:0.359 s	bp loss: 1.943933	decorrelation loss: 0.860065
epoch 10 	time:0.341 s	bp loss: 1.870017	decorrelation loss: 0.797620


In [12]:
args = argparse.Namespace(lr=1e-3, eta=0.3, decor_lr=1e-4, whiten=False, epochs=30) # NOTE: Fails for whiten=True

class Model(nn.Sequential):
    def __init__(self, in_features, eta, whiten):
        super().__init__(DecorLinear(in_features, 100, decor_bias=False, eta=eta, whiten=whiten),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=False, eta=eta, whiten=whiten)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, whiten=args.whiten).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.620473	decorrelation loss: 112.731506
epoch 1  	time:0.260 s	bp loss: 1.902638	decorrelation loss: 21.809277
epoch 2  	time:0.251 s	bp loss: 0.918407	decorrelation loss: 21.362579
epoch 3  	time:0.253 s	bp loss: 0.553532	decorrelation loss: 24.658171
epoch 4  	time:0.256 s	bp loss: 0.406264	decorrelation loss: 17.113710
epoch 5  	time:0.255 s	bp loss: 0.328443	decorrelation loss: 11.661909
epoch 6  	time:0.258 s	bp loss: 0.267328	decorrelation loss: 8.902997
epoch 7  	time:0.256 s	bp loss: 0.222273	decorrelation loss: 7.267405
epoch 8  	time:0.263 s	bp loss: 0.188145	decorrelation loss: 6.286094
epoch 9  	time:0.261 s	bp loss: 0.157164	decorrelation loss: 5.313452
epoch 10 	time:0.252 s	bp loss: 0.132010	decorrelation loss: 4.743002
epoch 11 	time:0.262 s	bp loss: 0.110575	decorrelation loss: 4.302014
epoch 12 	time:0.266 s	bp loss: 0.093747	decorrelation loss: 3.870065
epoch 13 	time:0.256 s	bp loss: 0.080258	decorrelation loss: 3.629483
epoch 14 	tim

In [None]:
# TO DO: ConvNet
