Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def plot_results(init_model, model, dataloader):

    for batch in dataloader:
        init_model.forward(batch[0].to(device))
        model.forward(batch[0].to(device))
        init_modules = decor_modules(init_model)
        modules = decor_modules(model)
        for i, (imod, mod) in enumerate(zip(init_modules, modules)):
        
            state = imod.normalizer * imod.decor_state            
            Ci = (state.T @ state) / len(state)
            Ci = Ci.detach().numpy()

            state = mod.normalizer * mod.decor_state            
            C = (state.T @ state) / len(state)
            C = C.detach().numpy()

            plt.subplot(2, len(modules), i+1)
            plt.title(f'layer {i+1}')
            plt.xlabel('$x_i x_j$')
            plt.hist([lower_triangular(Ci, offset=-1), lower_triangular(C, offset=-1)], bins=30, label=['correlated', 'decorrelated'])
            plt.subplot(2, len(modules), i+1+len(modules))
            plt.hist([np.diagonal(Ci), np.diagonal(C)], bins=30, label=['correlated', 'decorrelated'])
            plt.xlabel('$x_i^2$')
            print(f'layer {i+1} mean covariance before decorrelation: {np.mean(lower_triangular(Ci, offset=-1)):.2f}')
            print(f'layer {i+1} mean covariance after decorrelation: {np.mean(lower_triangular(C, offset=-1)):.2f}')
            print(f'layer {i+1} mean variance before decorrelation: {np.mean(np.diagonal(Ci)):.2f}')
            print(f'layer {i+1} mean variance after decorrelation: {np.mean(np.diagonal(C)):.2f}')
        break
    plt.legend();

In [3]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [4]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-3, variance=1.0, epochs=10) # FAILS FOR E.G. variance=1.0 at this learning rate

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, decor_bias=True, eta=eta, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=True, eta=eta, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.637977	decorrelation loss: 108.277657
epoch 1  	time:0.327 s	bp loss: 2.417980	decorrelation loss: 16.301802
epoch 2  	time:0.353 s	bp loss: 2.140957	decorrelation loss: 1.996554
epoch 3  	time:0.332 s	bp loss: 1.900352	decorrelation loss: 2.015632
epoch 4  	time:0.357 s	bp loss: 1.712744	decorrelation loss: 2.895918
epoch 5  	time:0.347 s	bp loss: 1.566925	decorrelation loss: 6.614901
epoch 6  	time:0.333 s	bp loss: 1.429413	decorrelation loss: 29.154255
epoch 7  	time:0.361 s	bp loss: 1.336625	decorrelation loss: 68.512886
epoch 8  	time:0.338 s	bp loss: 1.237959	decorrelation loss: 113.681610
epoch 9  	time:0.329 s	bp loss: 1.131953	decorrelation loss: 230.516312
epoch 10 	time:0.346 s	bp loss: 1.075101	decorrelation loss: 327.518250
