In [None]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn.utils import clip_grad_norm_
import matplotlib.pyplot as plt
import copy
import numpy as np

torch.manual_seed(2024)
# torch.manual_seed(2023)
HIDDEN_DIM = 256
USE_SIGMA = 1
device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

def gradient_centralization():
    with torch.no_grad():
        for p1, p2 in forward_model.named_parameters():
            if "bias" in p1 or p2.grad is None: continue
            if len(p2.shape) == 2: p2.grad -= p2.grad.mean(dim=1,keepdim=True)
            elif len(p2.shape) == 4: p2.grad -= p2.grad.mean(dim=[1,2,3],keepdim=True) 

class ForwardModel(nn.Module):
    def __init__(self, conv_dim=32, hidden_dim=128, use_sigma=True):
        super(ForwardModel, self).__init__()
        self.use_sigma = use_sigma
        self.hidden_dim = hidden_dim
        self.conv_dim = conv_dim

        # Convolutional layers
        self.conv1 = nn.Conv2d(1, conv_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim, kernel_size=3, padding=1)
        self.max_pool = nn.MaxPool2d(2)

        # Linear layers
        self.fc1 = nn.Linear(conv_dim * 7 * 7, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

        self.LN1 = nn.LayerNorm((conv_dim, 14, 14), elementwise_affine=False)
        self.LN2 = nn.LayerNorm((1, 28, 28), elementwise_affine=False)

        self.fc2.weight.data.zero_()
        self.fc2.bias.data.zero_()
        self.fc1.bias.data.zero_()
        self.conv2.bias.data.zero_()
        self.conv1.bias.data.zero_()
        
        self.act = F.elu

    def forward(self, x):
        
        x = x.view(-1, 1, 28, 28)
        x = self.LN2(x)

        a1 = self.act(self.conv1(x.detach()))
        a1 = self.max_pool(a1)
        a1 = self.LN1(a1)

        a2 = self.act(self.conv2(a1.detach()))
        a2 = self.max_pool(a2)
        a2 = a2.view(-1, self.conv_dim * 7 * 7)

        a3 = self.act(self.fc1(a2.detach()))

        a4 = self.fc2(a3.detach())
        self.penultimate_feature = a3.detach()

        return a1, a2, a3, a4

    def forward_logits(self):
        return self.fc2(self.penultimate_feature)

class BackwardModel(nn.Module):
    def __init__(self, conv_dim=32,  hidden_dim=128):
        super(BackwardModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.conv_dim = conv_dim
        
        alpha = 1
        self.alpha = alpha
        self.fc1 = nn.Linear(10, hidden_dim * alpha, bias=False)
        self.fc2 = nn.Linear(hidden_dim * alpha, self.conv_dim * 7 * 7 * alpha, bias=False)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(self.conv_dim * alpha, self.conv_dim * alpha, kernel_size=5, padding=2, bias=False)
        self.act = F.elu
        
        self.BN1 = torch.nn.BatchNorm2d(self.conv_dim * alpha)
        self.BN2 = torch.nn.BatchNorm1d(self.conv_dim * 7 * 7 * alpha)
        self.BN3 = torch.nn.BatchNorm1d(hidden_dim * alpha)

    def forward(self, t, use_act_derivative=False):
        device = t.device
        t = F.one_hot(t, num_classes=10).float().to(device)
        s3 = self.act(self.BN3(self.fc1(t.detach())))

        s2 = self.act(self.BN2(self.fc2(s3.detach()))  )
        
        s2_ = s2.view(-1, self.conv_dim * self.alpha, 7, 7)
        s2_ = self.upsample1(s2_)        
        s1 = self.act(self.BN1(self.conv1(s2_.detach())))
        
        return s1[:, :self.conv_dim], s2[:,:self.conv_dim*7*7], s3[:, :self.hidden_dim]

def normal(x): 
    if len(x.shape) == 2: return x / x.std(dim=1, keepdim=True)
    if len(x.shape) == 4: return x / x.std(dim=[1,2,3], keepdim=True)
    # if len(x.shape) == 2: return x / ( 1 + x.std(dim=1, keepdim=True))
    # if len(x.shape) == 4: return x / ( 1 + x.std(dim=[1,2,3], keepdim=True))

def standard(x): 
    if len(x.shape) == 2: 
        x = x - x.mean(dim=1, keepdim=True)
        return x / x.std(dim=1, keepdim=True)
    if len(x.shape) == 4: 
        x = x - x.mean(dim=[1,2,3], keepdim=True)
        return x / x.std(dim=[1,2,3], keepdim=True)

def single_loss(a, s):
    a = normal(a)
    s = normal(s)
    return F.mse_loss(a, s)

def sigma_loss(a1, a2, a3, a4, s1, s2, s3, t):
    loss1 = single_loss(a1,s1)
    loss2 = single_loss(a2,s2)
    loss3 = single_loss(a3,s3)
    loss4 = criteria(a4, t)
    loss = loss1 + loss2 + loss3 + loss4
    return loss, loss1.item(), loss2.item(), loss3.item(), loss4.item()

def gradient_centralization_B():
    with torch.no_grad():
        for p1, p2 in backward_model.named_parameters():
            if "bias" in p1 or p2.grad is None: continue
            if len(p2.shape) == 2: p2.grad -= p2.grad.mean(dim=1,keepdim=True)
            elif len(p2.shape) == 4: p2.grad -= p2.grad.mean(dim=[1,2,3],keepdim=True) 

# Initialize the models
forward_model = ForwardModel(use_sigma=USE_SIGMA, conv_dim=64, hidden_dim=128)
backward_model = BackwardModel(conv_dim=64, hidden_dim=128)

forward_model.to(device)
backward_model.to(device)

# Define the optimizers
forward_optimizer = optim.RMSprop(forward_model.parameters(),     lr=0.001)
backward_optimizer = optim.RMSprop(backward_model.parameters(),   lr=0.001)
criteria = nn.CrossEntropyLoss()
    
def small_loss(s):
    s = s.reshape(10,-1)
    s = s / s.norm(dim=1, keepdim=True)
    target = torch.eye(10)
    return F.mse_loss(s@s.T, target)
    
with torch.no_grad():
    t1, t2, t3 = backward_model(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long(), use_act_derivative=False)
    
import torch
import numpy as np
    
for epoch in range(30):
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        if batch_idx > 100: continue
        data, target = data.to(device), target.to(device)
        a1, a2, a3, a4 = forward_model(data)
        # t1, t2, t3 = backward_model(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long(), use_act_derivative=False)
        s1, s2, s3 = t1[target], t2[target], t3[target]
        loss, l1, l2, l3, l4 = sigma_loss(a1, a2, a3, a4, s1, s2, s3, target)
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad()
        loss.backward(), gradient_centralization(), gradient_centralization_B()
        forward_optimizer.step()
        backward_optimizer.step()
        
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        a1, a2, a3, a4 = forward_model(data.to(device))
        loss = criteria(a4, target.to(device))
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()
        
    with torch.no_grad():
        print(t1.norm(), t1.std(), t2.norm(), t2.std())
        print(torch.svd(t2)[1].detach())
    
    forward_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            _, _, _, outputs = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    forward_model.train()
    
    # assert 0 == 1
    # if epoch in [19]:
    #     torch.save(forward_model.state_dict(), "./saved_models/v4-20EP-F.pt") #Epoch: 46, Test Accuracy: 98.05%; Epoch: 47, Test Accuracy: 98.16%; Epoch: 48, Test Accuracy: 98.01%; Epoch: 49, Test Accuracy: 98.13%
    #     torch.save(backward_model.state_dict(),"./saved_models/v4-20EP-B.pt") #Epoch: 46, Test Accuracy: 98.05%; Epoch: 47, Test Accuracy: 98.16%; Epoch: 48, Test Accuracy: 98.01%; Epoch: 49, Test Accuracy: 98.13%

# conv_dim 64; s1, s2, s3 use BN before act.
# Epoch: 0, Test Accuracy: 95.15%
# Epoch: 1, Test Accuracy: 96.6%
# Epoch: 2, Test Accuracy: 97.13%
# Epoch: 3, Test Accuracy: 97.89%
# Epoch: 4, Test Accuracy: 97.6%
# Epoch: 5, Test Accuracy: 98.06%
# Epoch: 6, Test Accuracy: 98.1%
# Epoch: 7, Test Accuracy: 98.11%
# Epoch: 8, Test Accuracy: 98.54%
# Epoch: 9, Test Accuracy: 98.47%
# Epoch: 10, Test Accuracy: 98.44%
# Epoch: 11, Test Accuracy: 98.41%
# Epoch: 12, Test Accuracy: 98.59%
# Epoch: 13, Test Accuracy: 98.87%
# Epoch: 14, Test Accuracy: 98.59%
# Epoch: 15, Test Accuracy: 98.65%
# Epoch: 16, Test Accuracy: 98.52%
# Epoch: 17, Test Accuracy: 98.58%
# Epoch: 18, Test Accuracy: 98.76%
# Epoch: 19, Test Accuracy: 98.67%

tensor(283.8713) tensor(0.7854) tensor(142.7051) tensor(0.7888)
tensor([54.7917, 51.2640, 49.5042, 48.6735, 46.3613, 43.9471, 42.0677, 41.4580,
        38.5330, 29.3407])
Epoch: 0, Test Accuracy: 95.15%


In [None]:
Epoch: 0, Test Accuracy: 92.02%
Epoch: 1, Test Accuracy: 92.04%
Epoch: 2, Test Accuracy: 91.61%
Epoch: 3, Test Accuracy: 90.56%
Epoch: 4, Test Accuracy: 90.93%
Epoch: 5, Test Accuracy: 90.65%
Epoch: 6, Test Accuracy: 90.6%
Epoch: 7, Test Accuracy: 90.62%
Epoch: 8, Test Accuracy: 90.27%
Epoch: 9, Test Accuracy: 90.04%
Epoch: 10, Test Accuracy: 89.1%

In [32]:
    with torch.no_grad():
        print(t1.norm(), t1.std(), t2.norm(), t2.std())
        print(torch.svd(t2)[1].detach())


tensor(230.1083) tensor(0.6460) tensor(54.8109) tensor(0.2087)
tensor([41.3883, 15.7216, 14.2806, 13.1660, 12.3549, 11.4565, 11.1876, 10.4063,
         9.1648,  8.0901])


In [25]:
t1.norm(), t1.std(), t2.norm(), t2.std()

(tensor(283.8713), tensor(0.7854), tensor(142.7051), tensor(0.7888))

In [31]:
torch.svd(t2)[1].detach()

tensor([41.3883, 15.7216, 14.2806, 13.1660, 12.3549, 11.4565, 11.1876, 10.4063,
         9.1648,  8.0901])