Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def plot_results(model, data):

    plt.figure(figsize=(14,9))

    plt.hist([lower_triangular(A1, offset=0), lower_triangular(A2, offset=0)], bins=30, label=['correlated', 'decorrelated'])
    plt.legend()
    plt.xlabel('$x_i x_j$')
    plt.legend();

    plt.subplot(2,3,4)
    min = np.max([np.min(np.abs(model.weight.numpy().flatten())), 1e-10])
    max = np.max(np.abs(model.weight.numpy().flatten()))
    plt.imshow(np.abs(model.weight), cmap=plt.get_cmap('hot'), interpolation='nearest', norm=matplotlib.colors.LogNorm(vmin=min, vmax=max))
    plt.title('|decorrelation weights|')
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.subplot(2,3,5)
    min = np.min(np.stack([A1.numpy().flatten(), A2.numpy().flatten()]))
    max = np.max(np.stack([A1.numpy().flatten(), A2.numpy().flatten()]))
    plt.imshow(A1, cmap=plt.get_cmap('hot'), interpolation='nearest') #, norm=matplotlib.colors.LogNorm(vmin=min, vmax=max))
    plt.title('$x_i x_j$')
    plt.colorbar(fraction=0.046, pad=0.04)

    plt.subplot(2,3,6)
    plt.imshow(A2, cmap=plt.get_cmap('hot'), interpolation='nearest') #, norm=matplotlib.colors.LogNorm(vmin=min, vmax=max))
    plt.title('$x_i x_j$ decorrelated')
    plt.colorbar(fraction=0.046, pad=0.04);

In [3]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.25)),
        # torch.flatten # not necessary but useful for debugging
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [8]:
args = argparse.Namespace(lr=1e-4, eta=0.3, decor_lr=1e-3, variance=1.0, epochs=10) # FAILS FOR E.G. variance=1.0 at this learning rate

class Model(nn.Sequential):
    def __init__(self, in_features, eta, variance):
        super().__init__(DecorLinear(in_features, 100, decor_bias=True, eta=eta, variance=variance),
                        nn.LeakyReLU(),
                        DecorLinear(100, 10, decor_bias=True, eta=eta, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))
    
model = Model(784, eta=args.eta, variance=args.variance).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 2.632900	decorrelation loss: 111.242065
epoch 1  	time:0.331 s	bp loss: 2.447502	decorrelation loss: 109.873917
epoch 2  	time:0.383 s	bp loss: 2.158636	decorrelation loss: 111.574059
epoch 3  	time:0.346 s	bp loss: 1.918815	decorrelation loss: 116.422638
epoch 4  	time:0.311 s	bp loss: 1.718795	decorrelation loss: 124.699272
epoch 5  	time:0.377 s	bp loss: 1.537218	decorrelation loss: 135.436310
epoch 6  	time:0.330 s	bp loss: 1.369530	decorrelation loss: 151.938904
epoch 7  	time:0.324 s	bp loss: 1.226115	decorrelation loss: 174.656601
epoch 8  	time:0.391 s	bp loss: 1.112562	decorrelation loss: 197.272705
epoch 9  	time:0.354 s	bp loss: 1.013306	decorrelation loss: 223.131668
epoch 10 	time:0.328 s	bp loss: 0.925326	decorrelation loss: 248.646149
