# Direct Optimization of initial weights via Linear Programming

For a fully connected feed forward neural net using RELU activation functions $f$, for a given input $x \in D$ and parameter set $P = (W_1, W_2, ...)$ where each $W_i$ is the weight matrix for a layer of the network, then there exists matrix $A$ such that $f(x, P) = Ax$.

Consider the following generalization loss function, defined as:

$L_{gen} = (f(x, P) - f(x, P^*)) - (f(x^*, P) - f(x^*, P^*))$

where $x^* = x + x'$ and $P^* = P + P'$ for some small perturbations $x', P'$ (applying a different perturbation to each weight matrix in $P'$)

Then if perturbations $x', P'$ do not change the activation pattern of $f(x, P)$, $L_{gen}$ is equivalent to:

$L_{gen} = Ax - (A + A')x - A(x+x') + (A+A')(x+x')$

$ = Ax - Ax - A'x - Ax -Ax' + Ax + Ax' + A'x + A'x'$

$ = A'x'$

Therefore, we can obtain a small value of $L_{gen}$ by finding an initialization $P$ which has few activation patterns for the training data. Since the activation function is RELU, this is equivalent to reducing the amount of times RELU sets an element to 0.

In [1]:
""" First goal: investigate the claim that having fewer RELU 0s is a good thing for an initialization

    - Create a custom RELU subclass using PyTorch to count how often RELU 0s happen for initializations of a MLP
    - Create many different initializations of a model and count the intial RELU 0s
    - Train them all and compare trained model performance vs RELU 0 count

    Ideally, will see a strongly negative correlation between number of RELU 0s and model performance
"""
import math
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.datasets import MNIST

In [2]:
class ReluCount(nn.ReLU):
    """
        Intended to behave identically to nn.ReLU, except for the zero counting.
        To disable zero counting, set count_zeros to False.
    """
    def __init__(self, inplace: bool = False):
        super().__init__(inplace)
        self.zero_count = 0
        self.count_zeros = False

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        true_result = F.relu(input, inplace=self.inplace)
        if self.count_zeros == True:
            for entry in torch.flatten(true_result):
                if entry.item() == 0:
                    self.zero_count += 1
        return true_result


In [3]:
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081

batch_size_train = 64
batch_size_test = 1000


train_loader = DataLoader(MNIST('images/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                               ])),
                          batch_size=batch_size_train, shuffle=True, pin_memory=False)

test_loader = DataLoader(MNIST('images/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                               ])),
                         batch_size=batch_size_test, shuffle=True)

In [4]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape

flat = torch.flatten(example_data, 1)
flat.shape

torch.Size([1000, 784])

In [5]:
class MnistReluCount(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu_count = ReluCount()
        # mnist images are 1x28x28, so flattened they will have a length of 28*28=784
        self.fc1 = nn.Linear(784, 750)
        self.fc2 = nn.Linear(750, 320)
        self.fc3 = nn.Linear(320, 50)
        self.fc4 = nn.Linear(50, 10)

    def forward(self, x):
        x = torch.flatten(x, 1) # flatten batches 2D images to 1D vectors
        x = self.relu_count(self.fc1(x))
        x = self.relu_count(self.fc2(x))
        x = self.relu_count(self.fc3(x))
        x = self.relu_count(self.fc4(x))
        return x

In [6]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    total_loss = 0
    batches = 0
    
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        batches += 1
        
    return total_loss / batches

def count_relu_0s(dataloader, model: ReluCount):
    model.eval()
    model.relu_count.count_zeros = True
    for (X, y) in dataloader:
        model(X)
    model.relu_count.count_zeros = False
    return model.relu_count.zero_count

def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    correct = 0
    
    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for (X, y) in dataloader:
            out = model(X)
            test_loss += loss_fn(out, y).item()
            pred = out.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).sum()

    test_loss /= num_batches
    accuracy = 100.0 * correct / len(dataloader.dataset)
    print(f"Avg loss: {test_loss:>8f}")
    print(f"Accuracy: {correct}/{len(dataloader.dataset)} = {accuracy}")
    print()
    return test_loss, accuracy

In [7]:
lr = .01
epochs = 5
model = MnistReluCount()

In [8]:
# get RELU 0 count
start_t = time.time()
model_zeros = count_relu_0s(test_loader, model)
print(f"Counting initial ReLU 0s took {time.time() - start_t} seconds")
print(model_zeros)

Counting initial ReLU 0s took 31.979676246643066 seconds
5779271


In [9]:
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss()

In [10]:
# train model
train_losses = []
test_losses = []


start_t = time.time()
for t in range(epochs):
    
    train_losses.append(train_loop(train_loader, model, loss_function, optimizer))
    if True:
        print(f"Epoch: {t+1}")
        test_losses.append(test_loop(test_loader, model, loss_function))
end_t = time.time()
print(f"Finished after {end_t - start_t} seconds")

Epoch: 1
Avg loss: 0.908930
Accuracy: 7288/10000 = 72.87999725341797

Epoch: 2
Avg loss: 0.728550
Accuracy: 7505/10000 = 75.05000305175781

Epoch: 3
Avg loss: 0.664025
Accuracy: 7619/10000 = 76.19000244140625

Epoch: 4
Avg loss: 0.628251
Accuracy: 7689/10000 = 76.88999938964844

Epoch: 5
Avg loss: 0.607580
Accuracy: 7725/10000 = 77.25

Finished after 55.181816816329956 seconds
