In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms

import torch
torch.manual_seed(2024)
HIDDEN_DIM = 256
USE_SIGMA = 1
# USE_BP = 1

import torch.nn as nn
import torch.nn.functional as F

    
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


class ForwardModel(nn.Module):
    def __init__(self, hidden_dim=128, use_sigma=True):
        super(ForwardModel, self).__init__()
        self.use_sigma = use_sigma
        self.hidden_dim = hidden_dim

        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.max_pool = nn.MaxPool2d(2)

        # Linear layers
        self.fc1 = nn.Linear(32 * 7 * 7, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
        
        self.LN1 = torch.nn.LayerNorm((32,14,14), elementwise_affine=False)
        self.LN2 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False)

        if use_sigma:
            
            self.fc2.weight.data.zero_()
            self.fc2.bias.data.zero_()
            self.fc1.bias.data.zero_()
            self.conv2.bias.data.zero_()
            self.conv1.bias.data.zero_()
            
            self.fc2.weight.data.zero_()


    def forward(self, x):
        
        if self.use_sigma:
            x = x.view(-1, 1, 28, 28)
            x = self.LN2(x)
            a1 = F.elu(self.conv1(x.detach()))
            a1 = self.max_pool(a1)
            a1 = self.LN1(a1)
            
            a2 = F.elu(self.conv2(a1.detach()))
            a2 = self.max_pool(a2)
            a2 = a2.view(-1, 32 * 7 * 7)
            
            a3 = F.elu(self.fc1(a2.detach()))
            
            a4 = self.fc2(a3.detach())
            self.penultimate_feature = a3.detach()
            
            return a1, a2, a3, a4
            
        else:
            x = x.view(-1, 1, 28, 28)
            x = F.elu(self.conv1(x))
            x = self.max_pool(x)
            x = F.elu(self.conv2(x))
            x = self.max_pool(x)
            x = x.view(-1, 32 * 7 * 7)
            x = F.elu(self.fc1(x))
            x = self.fc2(x)
            return None, None, None, x

    def forward_logits(self):
        return self.fc2(self.penultimate_feature)

class BackwardModel(nn.Module):
    def __init__(self, hidden_dim=128):
        super(BackwardModel, self).__init__()
        self.hidden_dim = hidden_dim

        # Linear layers
        self.fc1 = nn.Linear(10, hidden_dim, bias=False)
        self.fc1_2 = nn.Linear(10, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, 32 * 7 * 7, bias=False)
        self.fc2_2 = nn.Linear(hidden_dim, 32 * 7 * 7, bias=False)

        # Transposed Convolutions (Deconvolutions)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(32, 32, kernel_size=5, padding=2, bias=False)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2, bias=False)
        
    def forward(self, t, use_act_derivative=False):
        device = t.device
        t = F.one_hot(t, num_classes=10).float().to(device)
        
        if use_act_derivative:
            
            s3 = (F.relu(self.fc1(t.detach())) + F.relu(self.fc1_2(t.detach()))) / 2 * a3.sign().detach()
            s2 = (F.relu(self.fc2(s3.detach())) + F.relu(self.fc2_2(s3.detach())) ) / 2  * a2.sign().detach()
            s1 = s2.view(-1, 32, 7, 7).detach()
            s1 = self.upsample1(s1)
            s1 = (F.relu(self.conv1(s1))+ F.relu(self.conv1_2(s1))) / 2  * a1.sign().detach()
            
            # s3 = F.relu(self.fc1(t.detach())) * a3.sign().detach()
            # s2 = F.relu(self.fc2(s3.detach()))  * a2.sign().detach()
            # s1 = s2.view(-1, 32, 7, 7).detach()
            # s1 = self.upsample1(s1)
            # s1 = F.relu(self.conv1(s1))  * a1.sign().detach()
            
        else:
            
            s3 = (F.relu(self.fc1(t.detach())) + F.relu(self.fc1(t.detach()))) / 2
            
            s2 = (F.relu(self.fc2(s3.detach())) + F.relu(self.fc2_2(s3.detach())) ) / 2
            
            s1 = s2.view(-1, 32, 7, 7).detach()
            s1 = self.upsample1(s1)
            s1 = (F.relu(self.conv1(s1))+ F.relu(self.conv1_2(s1))) / 2
            
        return s1, s2, s3
    
def normal(x): 
    return x / (x.norm()+1e-6)
    
def normal_2d(x): 
    return x / (x.norm()+1e-6)

# Define the loss functions
WEIGHT_DECAY = 0.000000
def sigma_loss(a1, a2, a3, a4, s1, s2, s3, t):
    assert a1.shape == s1.shape, f"shape {a1.shape} does not align with shape {s1.shape}"
    assert a2.shape == s2.shape, f"shape {a2.shape} does not align with shape {s2.shape}"
    assert a3.shape == s3.shape, f"shape {a3.shape} does not align with shape {s3.shape}"
    loss1 = F.mse_loss(normal_2d(a1), normal_2d(s1))
    loss2 = F.mse_loss(normal(a2), normal(s2))
    loss3 = F.mse_loss(normal(a3), normal(s3))
    loss4 = F.mse_loss(a4, torch.nn.functional.one_hot(t, num_classes=10).float().to(t.device))
    loss = loss1+loss2+loss3+loss4
    # + (a1.norm()+a2.norm()+a3.norm()+s1.norm()+s2.norm()+s3.norm())*WEIGHT_DECAY
    return loss, loss1.item(), loss2.item(), loss3.item(), loss4.item(), 

def sigma_loss_head(a4, t):
    return F.mse_loss(a4, torch.nn.functional.one_hot(t, num_classes=10).float().to(t.device))
                       
def bp_loss(a, b):
    return criteria(a,b)

# Initialize the models
forward_model = ForwardModel(use_sigma=USE_SIGMA)
backward_model = BackwardModel()

# Define the optimizers: momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False
forward_optimizer = optim.SGD(forward_model.parameters(), lr=0.1)
backward_optimizer = optim.SGD(backward_model.parameters(), lr=0.1)
criteria = nn.CrossEntropyLoss()

# Training loop
for epoch in range(20):
    for batch_idx, (data, target) in enumerate(train_loader):
        
        if USE_SIGMA: 

            a1, a2, a3, a4 = forward_model(data)
            s1, s2, s3 = backward_model(target, use_act_derivative=True)
            loss, l1, l2, l3, l4 = sigma_loss(a1, a2, a3, a4, s1, s2, s3, target)

            # Update parameters
            forward_optimizer.zero_grad()
            backward_optimizer.zero_grad()
            loss.backward()
            forward_optimizer.step()
            backward_optimizer.step()

            # Update the linear head again
            x = forward_model.forward_logits()
            loss = sigma_loss_head(x, target)
            # loss = criteria(x, target)
            forward_optimizer.zero_grad()
            loss.backward()
            forward_optimizer.step()
                       
        else:
            _, _, _, x = forward_model(data)
            loss = criteria(x, target)
            forward_optimizer.zero_grad()
            loss.backward(retain_graph=True)
            forward_optimizer.step()  
            
    # Print statistics
    print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss1: {l1*100:.3f}, Loss2: {l2*100:.3f}, Loss3: {l3*100:.3f}, Loss4: {l4:.3f}')

    # Evaluate on test set
    forward_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            _, _, _, outputs = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    print(f"norm of fc3 {forward_model.fc2.weight.data.norm():.1f}")
    forward_model.train()

Epoch: 0, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.019, Loss4: 0.044
Epoch: 0, Test Accuracy: 88.27%
norm of fc3 2.0
Epoch: 1, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.019, Loss4: 0.041
Epoch: 1, Test Accuracy: 90.26%
norm of fc3 2.5
Epoch: 2, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.017, Loss4: 0.032
Epoch: 2, Test Accuracy: 90.85%
norm of fc3 2.9
Epoch: 3, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.017, Loss4: 0.031
Epoch: 3, Test Accuracy: 91.23%
norm of fc3 3.1
Epoch: 4, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.016, Loss4: 0.036
Epoch: 4, Test Accuracy: 91.05%
norm of fc3 3.3
Epoch: 5, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.016, Loss4: 0.033
Epoch: 5, Test Accuracy: 91.38%
norm of fc3 3.4
Epoch: 6, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.015, Loss4: 0.040
Epoch: 6, Test Accuracy: 91.19%
norm of fc3 3.5
Epoch: 7, Batch: 937, Loss1: 0.001, Loss2: 0.002, Loss3: 0.015, Loss4: 0.034
Epoch: 7, Test Accuracy: 91.27%
norm of fc3 3.6


In [9]:
torch.topk(a4, k=3, dim=-1).indices

tensor([[1, 8, 7],
        [2, 7, 9],
        [3, 8, 5],
        [4, 8, 1],
        [5, 3, 8],
        [6, 8, 5],
        [7, 9, 3],
        [8, 3, 2],
        [9, 4, 0],
        [0, 8, 1],
        [1, 2, 8],
        [2, 7, 8],
        [3, 2, 0],
        [4, 7, 6],
        [6, 5, 8],
        [6, 2, 0]])

In [28]:
import matplotlib.pyplot as plt

for data, target in test_loader: break

a1, a2, a3, a4 = forward_model(data)
s1, s2, s3 = backward_model(target, use_act_derivative=True)
# p = torch.argmax(a4, dim=-1)
p = torch.topk(a4, k=3, dim=-1).indices
v = torch.topk(a4, k=3, dim=-1).values

for i1 in range(16):
    
    if p[i1][0] != target[i1]:
        
        i2 = 11
        print(f"prediction: {p[i1]}, {v[i1].detach()} groudtruth: {target[i1]}")

        plt.figure()
        fig, axes = plt.subplots(1,2)

        X = a1.detach().cpu()
        # X = X / X.norm()
        X = X.numpy()


        axes[0].imshow(X[i1,i2], vmin=-1, vmax=1)

        X = s1.detach().cpu()
        # X = X / X.norm()
        X = X.numpy()
        print(X[i1,i2].min(), X[i1,i2].max())

        axes[1].imshow(X[i1,i2], vmin=-0.0, vmax=0.03)
        plt.pause(0.2)

In [17]:
torch.topk(a4, k=3, dim=-1).values

tensor([[0.6677, 0.1409, 0.1319],
        [0.3594, 0.2306, 0.1896],
        [0.3866, 0.2124, 0.1795],
        [0.8142, 0.1682, 0.1573],
        [0.6672, 0.1992, 0.1568],
        [0.8042, 0.1464, 0.1125],
        [0.5750, 0.3306, 0.1425],
        [0.5981, 0.2090, 0.1507],
        [0.6315, 0.4538, 0.1122],
        [0.7532, 0.1984, 0.1276],
        [0.9192, 0.1506, 0.1333],
        [0.9009, 0.2255, 0.1081],
        [0.8842, 0.2517, 0.1637],
        [0.5747, 0.2227, 0.1921],
        [0.3461, 0.3427, 0.1909],
        [0.8736, 0.1625, 0.1305]], grad_fn=<TopkBackward0>)

In [None]:
## import matplotlib.pyplot as plt

backward_model = BackwardModel()
s1, s2, s3 = backward_model(target)

i1, i2 = 7, 2
X = s1.detach().cpu()
X = X / X.norm()
X = X.numpy()
print(X[i1,i2].min(), X[i1,i2].max())
plt.imshow(X[i1,i2])

In [None]:
i1, i2 = 0, 6
X = a1.detach().cpu()
X = X / X.norm()
X = X.numpy()
print(X[i1,i2].min(), X[i1,i2].max())
plt.imshow(X[i1,i2])

In [None]:
s2.shape