#Two-Pass Forward Propagation Training Approaches

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
# Define the neural network model (simple fully connected network for MNIST)
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Load MNIST data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=64, shuffle=True)

In [None]:
# Instantiate the model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()

In [None]:
# Define the random noise matrix F for error modulation
def generate_random_matrix(input_size, output_size):
    return torch.randn(input_size, output_size) * 0.01  # small random noise

In [None]:
F = generate_random_matrix(10, 28*28)  # F for input modulation (reshaped version of the modulation layer in paper)

In [None]:
# Training loop
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Standard forward pass
        output = model(data)

        # Compute the loss (error) and apply softmax to get probabilities
        loss = criterion(output, target)
        softmax_output = torch.softmax(output, dim=1)

        # Compute the error signal
        e = torch.nn.functional.one_hot(target, num_classes=10).float() - softmax_output

        # Compute delta_x by modulating with the random matrix F
        delta_x = torch.matmul(e, F).view(-1, 1, 28, 28)

        # Modulated input: new input + delta_x
        modulated_input = data + delta_x

        # Forward pass with modulated input
        modulated_output = model(modulated_input)

        # Calculate the loss for modulated output
        modulated_loss = criterion(modulated_output, target)

        # Update the weights using modulated loss (no backpropagation, just gradient computation)
        modulated_loss.backward()

        # Perform manual gradient descent
        with torch.no_grad():
            for param in model.parameters():
                param -= 0.01 * param.grad  # simple gradient descent with learning rate 0.01
                param.grad.zero_()  # reset gradients

    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print('Training completed.')

Epoch 1, Loss: 0.9636867642402649
Epoch 2, Loss: 0.34817734360694885
Epoch 3, Loss: 0.2073085606098175
Epoch 4, Loss: 0.17624208331108093
Epoch 5, Loss: 0.06440341472625732
Epoch 6, Loss: 0.3194601535797119
Epoch 7, Loss: 0.05183613300323486
Epoch 8, Loss: 0.39785975217819214
Epoch 9, Loss: 0.2508266866207123
Epoch 10, Loss: 0.36374637484550476
Training completed.
