Experiments on variations of FFF.  
π 25 Nov 2023.

Kept (for now) for posterity.

For a more up-to-date version of these experiments, go to `experiments/2023-11-29--fff-topk-lora.ipynb`

In [1]:
! pip install -q torch torchvision tqdm

In [2]:
NEPOCH = 5
BATCH_SIZE = 128

EVERY_N = -1  # Print loss every N batches. -1 to disable

In [3]:
import torch
import numpy as np
import random

# Set a random seed
random_seed = 1337

# PyTorch
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

# Numpy
np.random.seed(random_seed)

# Python's `random` module
random.seed(random_seed)

# If you are using cudnn, set this to True to make computation deterministic
# Note: This might reduce performance
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [4]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

# Load MNIST

In [5]:
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)


# Test harness

In [6]:
def train_and_test(net):
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

    # Training the network
    for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            if hasattr(net, 'orthogonality_penalty'):
                loss += .001 * net.orthogonality_penalty()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % EVERY_N == EVERY_N - 1:  # print EVERY_N mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / EVERY_N:.3f}')
                running_loss = 0.0

    print('Finished Training')

    # Testing the network on the test data
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network over test images: {100 * correct / total:.3f} %')


# 🔹 Baseline (Fully connected) `97.08%`

In [7]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

train_and_test(Net())


100%|██████████| 5/5 [00:14<00:00,  2.80s/it]


Finished Training
Accuracy of the network over test images: 97.080 %


# 🔸 PiSlice_topk `95.48%`

Choose the k best nodes for each computation, as measured by x DOT node.x

Significantly outperforms FFF

Requires nNodes dot-products, a top-k selection, and k Y-projections

In [8]:
! pip install -q lovely-tensors

In [9]:
import lovely_tensors as lt
lt.monkey_patch()

In [10]:
class PiSlice_topk(nn.Module):
    def __init__(self, nIn, nOut, nNodes, nWinners) -> None:
        super().__init__()
        self.nWinners = nWinners
        def random_unit_vectors_of(length):
            weights = torch.randn(nNodes, length)  # Initialize weights randomly
            weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
            return nn.Parameter(weights)
        self.basisX = random_unit_vectors_of(length=nIn)
        self.basisY = random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch, nWinners = x.shape[0], self.nWinners

        # for each node, we calc: lambda = x DOT node.x
        #   i.e. lambdas = [x DOT node.x for node in nodes]
        #   b: nBatch, n: n_nodes, x: nIn
        lambdas = torch.einsum('bx, nx -> bn', x, self.basisX)

        # Get the values and indices of the top k largest values in lambdas
        topk_lambdas, topk_indices = torch.topk(lambdas, nWinners, dim=-1)

        y = torch.zeros(nBatch, self.basisY.size(-1), device=x.device, dtype=x.dtype)
        for i in range(nWinners):
            idx_over_batch = topk_indices[:, i]
            lambda_over_batch = lambdas[torch.arange(nBatch), idx_over_batch]
            nodeY_over_batch = self.basisY[idx_over_batch]

            # y += lambda * nodeY
            y += torch.einsum("b, by -> by", lambda_over_batch, nodeY_over_batch)

        return y

# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = PiSlice_topk(nIn=28*28, nOut=500, nNodes=64, nWinners=8)
        self.fc2 = PiSlice_topk(nIn=500, nOut=10, nNodes=64, nWinners=8)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat

train_and_test(Net())

100%|██████████| 5/5 [00:19<00:00,  3.97s/it]


Finished Training
Accuracy of the network over test images: 95.820 %


# 🔸 .clamp(min=0) `96.12%`

Simplified. Any lambda < 0 is set to 0.

We're doing nNodes y-projections instead of top-k.

In [11]:
class PiSlice_relu(nn.Module):
    def __init__(self, nIn, nOut, nNodes) -> None:
        super().__init__()
        def random_unit_vectors_of(length):
            weights = torch.randn(nNodes, length)  # Initialize weights randomly
            weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
            return nn.Parameter(weights)
        self.basisX = random_unit_vectors_of(length=nIn)
        self.basisY = random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch = x.shape[0]

        # for each node, we calc: lambda = x DOT node.x
        #   i.e. lambdas = [x DOT node.x for node in nodes]
        #   b: nBatch, n: n_nodes, x: nIn
        lambdas = torch.einsum('bx, nx -> bn', x, self.basisX)

        y = torch.einsum('bn, ny -> by', lambdas.clamp(min=0), self.basisY)

        return y

# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = PiSlice_relu(nIn=28*28, nOut=500, nNodes=64)
        self.fc2 = PiSlice_relu(nIn=500, nOut=10, nNodes=64)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat

train_and_test(Net())

100%|██████████| 5/5 [00:13<00:00,  2.65s/it]


Finished Training
Accuracy of the network over test images: 96.120 %


# 🔸 Orthogonality penalty `97.01%`

Encourage basis vectors to be orthogonal to each other in input space.
Similarly for output space.

In [12]:
class PiSlice_relu_ortho(nn.Module):
    def __init__(self, nIn, nOut, nNodes) -> None:
        super().__init__()
        def random_unit_vectors_of(length):
            weights = torch.randn(nNodes, length)  # Initialize weights randomly
            weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
            return nn.Parameter(weights)
        self.basisX = random_unit_vectors_of(length=nIn)
        self.basisY = random_unit_vectors_of(length=nOut)

    def forward(self, x: torch.Tensor):
        nBatch = x.shape[0]

        # for each node, we calc: lambda = x DOT node.x
        #   i.e. lambdas = [x DOT node.x for node in nodes]
        #   b: nBatch, n: n_nodes, x: nIn
        lambdas = torch.einsum('bx, nx -> bn', x, self.basisX)

        # y = sum(node.lambda * node.y if node.lambda > 0)
        y = torch.einsum('bn, ny -> by', lambdas.clamp(min=0), self.basisY)

        return y

def orthogonality_loss(basis_vectors):
    # Compute pairwise dot products
    dot_products = torch.matmul(basis_vectors, basis_vectors.T)
    
    # Zero out diagonal elements (self dot products)
    eye = torch.eye(dot_products.size(0)).to(dot_products.device)
    dot_products = dot_products * (1 - eye)
    
    # Sum of squares of off-diagonal elements (which should be close to zero)
    loss = (dot_products ** 2).sum()
    return loss

# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = PiSlice_relu_ortho(nIn=28*28, nOut=500, nNodes=64)
        self.fc2 = PiSlice_relu_ortho(nIn=500, nOut=10, nNodes=32)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc2(self.fc1(x))
        return y_hat
    def orthogonality_penalty(self):
        # Calculate orthogonality loss for each PiSlice layer
        loss1 = orthogonality_loss(self.fc1.basisX) + orthogonality_loss(self.fc1.basisY)
        loss2 = orthogonality_loss(self.fc2.basisX) + orthogonality_loss(self.fc2.basisY)
        return loss1 + loss2
    
train_and_test(Net())

100%|██████████| 5/5 [00:13<00:00,  2.75s/it]


Finished Training
Accuracy of the network over test images: 97.010 %
