Analyses properties of decorrelation and whitening methods for decorrelated networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
from decorrelation.decorrelation import Decorrelation, DecorLinear, DecorConv2d
import matplotlib.pyplot as plt
import matplotlib
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from decorrelation.train import decor_train
import argparse

# automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(1),
        transforms.Normalize((0.5), (0.5))
        ])

train_data = MNIST(root='~/Data', train=True, download=True, transform=transform)    
train_data = Subset(train_data, np.random.permutation(len(train_data.data))[:1000])
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

In [15]:
model = Decorrelation(784, variance=1.0).to(device)

lossfun = lambda x, y: nn.Parameter(torch.zeros(1, device=device, dtype=float), requires_grad=True)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 0.000000	decorrelation loss: 3.323010
epoch 1  	time:0.537 s	bp loss: 0.000000	decorrelation loss: 2.883162
epoch 2  	time:0.546 s	bp loss: 0.000000	decorrelation loss: 2.118621
epoch 3  	time:0.536 s	bp loss: 0.000000	decorrelation loss: 1.569368
epoch 4  	time:0.531 s	bp loss: 0.000000	decorrelation loss: 1.163987
epoch 5  	time:0.533 s	bp loss: 0.000000	decorrelation loss: 0.873308
epoch 6  	time:0.537 s	bp loss: 0.000000	decorrelation loss: 0.659124
epoch 7  	time:0.530 s	bp loss: 0.000000	decorrelation loss: 0.503914
epoch 8  	time:0.531 s	bp loss: 0.000000	decorrelation loss: 0.389484
epoch 9  	time:0.529 s	bp loss: 0.000000	decorrelation loss: 0.302870
epoch 10 	time:0.529 s	bp loss: 0.000000	decorrelation loss: 0.236460


In [16]:
class MLP(nn.Sequential):
    """Simple MLP example"""

    def __init__(self, in_features, variance):
        """
        Args:
            in_features: int, number of inputs
            eta: float, decorrelation learning rate
        """
        super().__init__(DecorLinear(in_features, 100, variance=variance),
                        nn.ReLU(),
                        DecorLinear(100, 10, variance=variance)
                        )

    def forward(self, x):
        return super().forward(x.view(len(x), -1))

In [17]:
model = MLP(784, variance=1.0).to(device)

lossfun = torch.nn.CrossEntropyLoss().to(device)

args = argparse.Namespace(lr=1e-4, eta=1e-5, epochs=10)

res = decor_train(args, model, lossfun, train_loader, device)

epoch 0  	time:0.000 s	bp loss: 5.198516	decorrelation loss: 3.328140
epoch 1  	time:0.579 s	bp loss: 4.492834	decorrelation loss: 2.879488
epoch 2  	time:0.564 s	bp loss: 3.411903	decorrelation loss: 2.117524
epoch 3  	time:0.554 s	bp loss: 2.865552	decorrelation loss: 1.559836
epoch 4  	time:0.552 s	bp loss: 2.585968	decorrelation loss: 1.167807
epoch 5  	time:0.547 s	bp loss: 2.420695	decorrelation loss: 0.871845
epoch 6  	time:0.547 s	bp loss: 2.290833	decorrelation loss: 0.661645
epoch 7  	time:0.550 s	bp loss: 2.194351	decorrelation loss: 0.504670
epoch 8  	time:0.549 s	bp loss: 2.112690	decorrelation loss: 0.387976
epoch 9  	time:0.550 s	bp loss: 2.034244	decorrelation loss: 0.300150
epoch 10 	time:0.548 s	bp loss: 1.970677	decorrelation loss: 0.236579
