In [139]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms

import torch
torch.manual_seed(2024)
HIDDEN_DIM = 256
USE_SIGMA = 1
USE_BP = 0

# Define the forward and backward models
class ForwardModel(nn.Module):
    def __init__(self):
        super(ForwardModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
        self.fc3 = nn.Linear(HIDDEN_DIM, 10)
        
        if USE_SIGMA:
            self.fc3.weight.data.zero_()
            self.fc3.bias.data.zero_()

    def forward(self, x):
        
        if USE_SIGMA:
            x = x.view(-1, 28 * 28)
            a1 = F.relu(self.fc1(x))
            a2 = F.relu(self.fc2(a1.detach()))
            a3 = self.fc3(a2.detach())
            
            self.penultimate_feature = a2.detach()
        
        if USE_BP:
            x = x.view(-1, 28 * 28)
            a1 = F.relu(self.fc1(x))
            a2 = F.relu(self.fc2(a1))
            a3 = self.fc3(a2)
                       
        return a1, a2, a3
        
    def forward_logits(self):
        a3 = self.fc3(self.penultimate_feature)
        return a3

class BackwardModel(nn.Module):
    def __init__(self):
        super(BackwardModel, self).__init__()
        self.fc1 = nn.Linear(10, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
                
        init.orthogonal_(self.fc1.weight)
        # init.orthogonal_(self.fc2.weight)
        
        self.fc1.bias.data.zero_()
        self.fc2.bias.data.zero_()

    def forward(self, t):
        device = t.device
        t = torch.nn.functional.one_hot(t, num_classes=10).float().to(device)
        s2 = F.relu(self.fc1(t))
        s1 = F.relu(self.fc2(s2.detach()))
        return s1, s2

# Define the loss functions
def sigma_loss(a1, a2, a3, s1, s2, t):
    loss1 = F.mse_loss(a1/a1.norm(), s1/s1.norm())
    loss2 = F.mse_loss(a2/a2.norm(), s2/s2.norm())
    loss3 = F.mse_loss(a3, torch.nn.functional.one_hot(t, num_classes=10).float().to(t.device))
    return loss1, loss2, loss3
                       
def bp_loss(a, b):
    return criteria(a,b)
    
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the models
forward_model = ForwardModel()
backward_model = BackwardModel()

# Define the optimizers: momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False
forward_optimizer = optim.SGD(forward_model.parameters(), lr=0.1, momentum=0.5)
backward_optimizer = optim.SGD(backward_model.parameters(), lr=0.1, momentum=0.5)
criteria = nn.CrossEntropyLoss()

# Training loop
for epoch in range(30):
    for batch_idx, (data, target) in enumerate(train_loader):
            
        if USE_SIGMA: 

            s1, s2 = backward_model(target)
            a1, a2, a3 = forward_model(data)
            loss1, loss2, loss3 = sigma_loss(a1, a2, a3, s1, s2, target)

            # Update parameters
            forward_optimizer.zero_grad()
            backward_optimizer.zero_grad()
            loss1.backward(retain_graph=True)
            loss2.backward(retain_graph=True)
            loss3.backward(retain_graph=True)
            forward_optimizer.step()
            backward_optimizer.step()

            # Update the linear head again
            a3 = forward_model.forward_logits()
            loss1, loss2, loss3 = sigma_loss(a1, a2, a3, s1, s2, target)
            forward_optimizer.zero_grad()
            loss3.backward(retain_graph=True)
            forward_optimizer.step()
                       
        if USE_BP:
            a1, a2, a3, s2_recon = forward_model(data)
            loss = criteria(a3, target)
            forward_optimizer.zero_grad()
            loss.backward(retain_graph=True)
            forward_optimizer.step()  

    # Print statistics
    print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss1: {loss1.item():.3f}, Loss2: {loss2.item():.3f}, Loss3: {loss3.item():.3f}')

    # Evaluate on test set
    forward_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            _, _, outputs = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    forward_model.train()

Epoch: 0, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.060
Epoch: 0, Test Accuracy: 81.14%
Epoch: 1, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.053
Epoch: 1, Test Accuracy: 82.75%
Epoch: 2, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.045
Epoch: 2, Test Accuracy: 83.51%
Epoch: 3, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.041
Epoch: 3, Test Accuracy: 84.69%
Epoch: 4, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.040
Epoch: 4, Test Accuracy: 85.19%
Epoch: 5, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.039
Epoch: 5, Test Accuracy: 85.12%
Epoch: 6, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.041
Epoch: 6, Test Accuracy: 85.52%
Epoch: 7, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.028
Epoch: 7, Test Accuracy: 85.07%
Epoch: 8, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.040
Epoch: 8, Test Accuracy: 85.46%
Epoch: 9, Batch: 937, Loss1: 0.000, Loss2: 0.000, Loss3: 0.036
Epoch: 9, Test Accuracy: 86.14%
Epoch: 10, Batch: 937, Loss1: 0.000, Loss2: 0.000,