# First impl of FFF (obsolete)

Compares FFF against PyTorch's `nn.Linear`

Obsolete, keeping (for now) for posterity.

Our FFF code has improved now.

Look in `notebooks/` at the Benchmarks notebooks for clearer code.

In [1]:
! pip install -q torch torchvision tqdm

In [2]:
NEPOCH = 5

In [3]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [4]:
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

# Training the network
for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# Testing the network on the test data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


100%|██████████| 5/5 [00:15<00:00,  3.07s/it]


Finished Training
Accuracy of the network on the 10000 test images: 97 %


# 🔸 FFF

In [6]:
! pip install -q lovely-tensors

In [7]:
import lovely_tensors as lt
lt.monkey_patch()

In [8]:
class FFF(nn.Module):
  def __init__(self, nIn, nOut, depth=8) -> None:
    super().__init__()

    self.input_width = nIn
    self.output_width = nOut
    self.depth = depth
    self.n_nodes = 2**depth - 1

    self._initiate_weights()

  def _initiate_weights(self):
    # init_factor_I1 = 1 / math.sqrt(self.input_width)
    # init_factor_I2 = 1 / math.sqrt(self.depth + 1)
    # def create_weight_parameter(n_nodes, width, init_factor):
    #     return nn.Parameter(torch.empty(n_nodes, width).uniform_(-init_factor, init_factor))
    # self.w1s = create_weight_parameter(n_nodes, width, init_factor_I1)
    # self.w2s = create_weight_parameter(n_nodes, width, init_factor_I2)

    def create_random_unit_vectors(n_nodes, width):
        weights = torch.randn(n_nodes, width)  # Initialize weights randomly
        weights = F.normalize(weights, p=2, dim=-1)  # L2-Normalize along the last dimension
        return nn.Parameter(weights)

    self.w1s = create_random_unit_vectors(self.n_nodes, self.input_width)
    self.w2s = create_random_unit_vectors(self.n_nodes, self.output_width)


  def forward(self, x: torch.Tensor):
    batch_size = x.shape[0]

    # concurrent for batch size (bs, )
    current_node = torch.zeros((batch_size,), dtype=torch.int)

    winner_indices = torch.zeros(batch_size, self.depth, dtype=torch.int)
    winner_lambdas = torch.empty((batch_size, self.depth), dtype=torch.float)

    for i in range(self.depth):
      # compute plane scores
      # dot product between input (x) and weights of the current node (w1s)
      # result is scalar of shape (bs)
      lambda_ = torch.einsum('b i, b i -> b', x, self.w1s[current_node])
      winner_indices[:, i] = current_node

      winner_lambdas[:, i] = lambda_

      # compute next node (left or right)
      plane_choice = (lambda_ > 0).long()
      current_node = (current_node * 2) + plane_choice + 1

    # from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L29
    # FF_41( GeLU ( FF_14(x) )))

    # GeLU(FF_14 @ x) @ FF_41
    # GeLU(W1(x) @ x) @ W2(x)
    selected_w2s = self.w2s[winner_indices.flatten()].view(batch_size, self.depth, self.output_width)
    # y = torch.einsum('b i j , b i -> b j', selected_w2s, F.gelu(all_logits))
    y = torch.einsum('b i j , b i -> b j', selected_w2s, winner_lambdas)
    return y

In [9]:
# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=500, depth=8)
        self.fc2 = FFF(nIn=500, nOut=10, depth=8)
        # self.fc1 = FFF(nIn=28*28, nOut=10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        # y_hat = self.fc2(torch.relu(self.fc1(x)))
        y_hat = self.fc2(self.fc1(x))
        # y_hat = self.fc1(x)
        return y_hat

net = Net()

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

# Training the network
for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

# Testing the network on the test data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


  0%|          | 0/5 [00:00<?, ?it/s]

[1,   100] loss: 1.741
[1,   200] loss: 0.943
[1,   300] loss: 0.712
[1,   400] loss: 0.605
[1,   500] loss: 0.545
[1,   600] loss: 0.516
[1,   700] loss: 0.463
[1,   800] loss: 0.460
[1,   900] loss: 0.423


 20%|██        | 1/5 [00:10<00:42, 10.74s/it]

[2,   100] loss: 0.390
[2,   200] loss: 0.387
[2,   300] loss: 0.352
[2,   400] loss: 0.377
[2,   500] loss: 0.369
[2,   600] loss: 0.343
[2,   700] loss: 0.381
[2,   800] loss: 0.360
[2,   900] loss: 0.327


 40%|████      | 2/5 [00:21<00:32, 10.70s/it]

[3,   100] loss: 0.316
[3,   200] loss: 0.303
[3,   300] loss: 0.324
[3,   400] loss: 0.308
[3,   500] loss: 0.300
[3,   600] loss: 0.312
[3,   700] loss: 0.306
[3,   800] loss: 0.296
[3,   900] loss: 0.279


 60%|██████    | 3/5 [00:32<00:21, 10.68s/it]

[4,   100] loss: 0.290
[4,   200] loss: 0.282
[4,   300] loss: 0.303
[4,   400] loss: 0.293
[4,   500] loss: 0.278
[4,   600] loss: 0.315
[4,   700] loss: 0.278
[4,   800] loss: 0.280
[4,   900] loss: 0.282


 80%|████████  | 4/5 [00:42<00:10, 10.73s/it]

[5,   100] loss: 0.270
[5,   200] loss: 0.275
[5,   300] loss: 0.283
[5,   400] loss: 0.294
[5,   500] loss: 0.255
[5,   600] loss: 0.279
[5,   700] loss: 0.291
[5,   800] loss: 0.248
[5,   900] loss: 0.273


100%|██████████| 5/5 [00:53<00:00, 10.73s/it]


Finished Training
Accuracy of the network on the 10000 test images: 92 %


In [10]:
# Neural network architecture
class _Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=10, depth=16)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc1(x)
        return y_hat

'''
  0%|          | 0/5 [00:00<?, ?it/s]
[1,   100] loss: 2.112
[1,   200] loss: 1.615
[1,   300] loss: 1.647
[1,   400] loss: 1.540
[1,   500] loss: 1.393
[1,   600] loss: 1.377
[1,   700] loss: 1.228
[1,   800] loss: 1.451
[1,   900] loss: 1.432
 20%|██        | 1/5 [12:54<51:39, 774.81s/it]
'''
# ^ 12 mins is too long



'\n  0%|          | 0/5 [00:00<?, ?it/s]\n[1,   100] loss: 2.112\n[1,   200] loss: 1.615\n[1,   300] loss: 1.647\n[1,   400] loss: 1.540\n[1,   500] loss: 1.393\n[1,   600] loss: 1.377\n[1,   700] loss: 1.228\n[1,   800] loss: 1.451\n[1,   900] loss: 1.432\n 20%|██        | 1/5 [12:54<51:39, 774.81s/it]\n'