In [15]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn.utils import clip_grad_norm_
import matplotlib.pyplot as plt
import copy
import numpy as np

torch.manual_seed(2024)
# torch.manual_seed(2023)
HIDDEN_DIM = 256
USE_SIGMA = 1
device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

def gradient_centralization():
    with torch.no_grad():
        for p1, p2 in forward_model.named_parameters():
            if "bias" in p1 or p2.grad is None: continue
            if len(p2.shape) == 2: p2.grad -= p2.grad.mean(dim=1,keepdim=True)
            elif len(p2.shape) == 4: p2.grad -= p2.grad.mean(dim=[1,2,3],keepdim=True) 

class ForwardModel(nn.Module):
    def __init__(self, conv_dim=32, hidden_dim=128, use_sigma=True):
        super(ForwardModel, self).__init__()
        self.use_sigma = use_sigma
        self.hidden_dim = hidden_dim
        self.conv_dim = conv_dim

        # Convolutional layers
        self.conv1 = nn.Conv2d(1, conv_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim, kernel_size=3, padding=1)
        self.max_pool = nn.MaxPool2d(2)

        # Linear layers
        self.fc1 = nn.Linear(conv_dim * 7 * 7, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

        self.LN1 = nn.LayerNorm((conv_dim, 14, 14), elementwise_affine=False)
        self.LN2 = nn.LayerNorm((1, 28, 28), elementwise_affine=False)

        self.fc2.weight.data.zero_()
        self.fc2.bias.data.zero_()
        self.fc1.bias.data.zero_()
        self.conv2.bias.data.zero_()
        self.conv1.bias.data.zero_()
        
        self.act = F.elu

    def forward(self, x):
        
        x = x.view(-1, 1, 28, 28)
        x = self.LN2(x)

        a1 = self.act(self.conv1(x.detach()))
        a1 = self.max_pool(a1)
        
        a2 = self.act(self.conv2(a1.detach()))
        a2 = self.max_pool(a2)
        a2 = a2.view(-1, self.conv_dim * 7 * 7)

        a3 = self.act(self.fc1(a2.detach()))

        a4 = self.fc2(a3.detach())
        self.penultimate_feature = a3.detach()

        return a1, a2, a3, a4

    def forward_logits(self):
        return self.fc2(self.penultimate_feature)

class BackwardModel(nn.Module):
    def __init__(self, conv_dim=32,  hidden_dim=128):
        super(BackwardModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.conv_dim = conv_dim
        self.use_bias = False
        
        alpha = 1
        self.alpha = alpha
        self.fc1 = nn.Linear(10, hidden_dim * alpha, bias=False)
        self.fc2 = nn.Linear(hidden_dim * alpha, self.conv_dim * 7 * 7 * alpha, bias=False)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(self.conv_dim * alpha, self.conv_dim * alpha, kernel_size=5, padding=2, bias=False)
        self.act = F.elu
        
        self.BN1 = torch.nn.BatchNorm2d(self.conv_dim * alpha)
        self.BN2 = torch.nn.BatchNorm1d(self.conv_dim * 7 * 7 * alpha)
        self.BN3 = torch.nn.BatchNorm1d(hidden_dim * alpha)

    def forward(self, t, use_act_derivative=False):
        device = t.device
        t = F.one_hot(t, num_classes=10).float().to(device)
        s3 = self.act(self.BN3(self.fc1(t.detach()))) 
        s2 = self.act(self.BN2(self.fc2(s3.detach()))  )         
        s2_ = s2.view(-1, self.conv_dim * self.alpha, 7, 7)
        s2_ = self.upsample1(s2_)        
        s1 = self.act(self.BN1(self.conv1(s2_.detach())))
        return s1[:, :self.conv_dim], s2[:,:self.conv_dim*7*7], s3[:, :self.hidden_dim]

def normalize_along_axis(matrix, axis=1):
    norm = torch.norm(matrix, dim=axis, keepdim=True)
    return matrix / (norm + 1e-8)

def compute_loss(A, B):
    A = A.reshape(len(A), -1)
    B = B.reshape(len(B), -1)
    m = A.shape[1]
    A_norm = normalize_along_axis(A, axis=1)
    B_norm = normalize_along_axis(B, axis=1)
    C = torch.matmul(A_norm, B_norm.T)
    target_ = F.one_hot(target, num_classes=10).float().to(device)
    identity = target_@target_.T * 1.1 - 0.1
    loss = F.mse_loss(C, identity)
    return loss

def sigma_loss(a1, a2, a3, a4, s1, s2, s3, t):
    loss1 = compute_loss(a1,s1)
    loss2 = compute_loss(a2,s2)
    loss3 = compute_loss(a3,s3)
    loss4 = criteria(a4, t)
    loss = loss1 + loss2 + loss3 + loss4
    return loss, loss1.item(), loss2.item(), loss3.item(), loss4.item()

# Initialize the models
forward_model = ForwardModel(use_sigma=USE_SIGMA, conv_dim=64, hidden_dim=128)
backward_model = BackwardModel(conv_dim=64, hidden_dim=128)

forward_model.to(device)
backward_model.to(device)

# Define the optimizers
forward_optimizer = optim.RMSprop(forward_model.parameters(), lr=0.002)
criteria = nn.CrossEntropyLoss()
    
with torch.no_grad():
    t1, t2, t3 = backward_model(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long(), use_act_derivative=False)
    
import torch
import numpy as np
    
# for epoch in range(30):
for epoch in tqdm(range(30)):
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        if batch_idx > 100: continue
        a1, a2, a3, a4 = forward_model(data.to(device))
        s1, s2, s3 = t1[target], t2[target], t3[target]
        loss, l1, l2, l3, l4 = sigma_loss(a1, a2, a3, a4, s1, s2, s3, target.to(device))
        forward_optimizer.zero_grad(), loss.backward(), gradient_centralization(), forward_optimizer.step()
        
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        a1, a2, a3, a4 = forward_model(data.to(device))
        loss = criteria(a4, target.to(device))
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()
    
    forward_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            _, _, _, outputs = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    forward_model.train()
    
    # assert 0 == 1
    # if epoch in [19]:
    #     torch.save(forward_model.state_dict(), "./saved_models/v4-20EP-F.pt") #Epoch: 46, Test Accuracy: 98.05%; Epoch: 47, Test Accuracy: 98.16%; Epoch: 48, Test Accuracy: 98.01%; Epoch: 49, Test Accuracy: 98.13%
    #     torch.save(backward_model.state_dict(),"./saved_models/v4-20EP-B.pt") #Epoch: 46, Test Accuracy: 98.05%; Epoch: 47, Test Accuracy: 98.16%; Epoch: 48, Test Accuracy: 98.01%; Epoch: 49, Test Accuracy: 98.13%

# Norm (no avg) + SCL
# Epoch: 0, Test Accuracy: 95.37%
# Epoch: 1, Test Accuracy: 96.76%
# Epoch: 2, Test Accuracy: 97.34%
# Epoch: 3, Test Accuracy: 97.91%
# Epoch: 4, Test Accuracy: 97.87%
# Epoch: 5, Test Accuracy: 98.12%
# Epoch: 6, Test Accuracy: 98.12%
# Epoch: 7, Test Accuracy: 98.27%
# Epoch: 8, Test Accuracy: 98.51%
# Epoch: 9, Test Accuracy: 98.58%
# Epoch: 10, Test Accuracy: 98.66%
# Epoch: 11, Test Accuracy: 98.53%
# Epoch: 12, Test Accuracy: 98.64%
# Epoch: 13, Test Accuracy: 98.75%
# Epoch: 14, Test Accuracy: 98.66%
# Epoch: 15, Test Accuracy: 98.58%
# Epoch: 16, Test Accuracy: 98.77%
# Epoch: 17, Test Accuracy: 98.80%
# Epoch: 18, Test Accuracy: 98.70%
# Epoch: 19, Test Accuracy: 98.80%
# Epoch: 20, Test Accuracy: 98.97%

  3%|█▍                                          | 1/30 [00:09<04:48,  9.94s/it]

Epoch: 0, Test Accuracy: 95.67%


  7%|██▉                                         | 2/30 [00:19<04:31,  9.69s/it]

Epoch: 1, Test Accuracy: 97.03%


 10%|████▍                                       | 3/30 [00:29<04:23,  9.74s/it]

Epoch: 2, Test Accuracy: 97.66%


 13%|█████▊                                      | 4/30 [00:38<04:12,  9.70s/it]

Epoch: 3, Test Accuracy: 98.19%


 17%|███████▎                                    | 5/30 [00:48<04:02,  9.69s/it]

Epoch: 4, Test Accuracy: 98.18%


 20%|████████▊                                   | 6/30 [00:58<03:52,  9.67s/it]

Epoch: 5, Test Accuracy: 98.15%


 23%|██████████▎                                 | 7/30 [01:07<03:41,  9.62s/it]

Epoch: 6, Test Accuracy: 98.4%


 27%|███████████▋                                | 8/30 [01:17<03:31,  9.60s/it]

Epoch: 7, Test Accuracy: 98.39%


 30%|█████████████▏                              | 9/30 [01:27<03:23,  9.70s/it]

Epoch: 8, Test Accuracy: 98.7%


 33%|██████████████▎                            | 10/30 [01:37<03:14,  9.75s/it]

Epoch: 9, Test Accuracy: 98.58%


 37%|███████████████▊                           | 11/30 [01:46<03:05,  9.76s/it]

Epoch: 10, Test Accuracy: 98.79%


 40%|█████████████████▏                         | 12/30 [01:56<02:54,  9.68s/it]

Epoch: 11, Test Accuracy: 98.7%


 43%|██████████████████▋                        | 13/30 [02:05<02:44,  9.67s/it]

Epoch: 12, Test Accuracy: 98.6%


 47%|████████████████████                       | 14/30 [02:15<02:34,  9.68s/it]

Epoch: 13, Test Accuracy: 98.76%


 50%|█████████████████████▌                     | 15/30 [02:25<02:25,  9.68s/it]

Epoch: 14, Test Accuracy: 98.8%


 53%|██████████████████████▉                    | 16/30 [02:35<02:15,  9.67s/it]

Epoch: 15, Test Accuracy: 98.54%


 57%|████████████████████████▎                  | 17/30 [02:44<02:05,  9.63s/it]

Epoch: 16, Test Accuracy: 98.8%


 60%|█████████████████████████▊                 | 18/30 [02:54<01:55,  9.59s/it]

Epoch: 17, Test Accuracy: 98.71%


 63%|███████████████████████████▏               | 19/30 [03:03<01:46,  9.67s/it]

Epoch: 18, Test Accuracy: 98.87%


 67%|████████████████████████████▋              | 20/30 [03:13<01:37,  9.73s/it]

Epoch: 19, Test Accuracy: 98.71%


 70%|██████████████████████████████             | 21/30 [03:23<01:28,  9.82s/it]

Epoch: 20, Test Accuracy: 98.83%


 73%|███████████████████████████████▌           | 22/30 [03:33<01:18,  9.76s/it]

Epoch: 21, Test Accuracy: 98.87%


 77%|████████████████████████████████▉          | 23/30 [03:43<01:08,  9.75s/it]

Epoch: 22, Test Accuracy: 98.78%


 80%|██████████████████████████████████▍        | 24/30 [03:52<00:58,  9.70s/it]

Epoch: 23, Test Accuracy: 98.72%


 83%|███████████████████████████████████▊       | 25/30 [04:02<00:48,  9.62s/it]

Epoch: 24, Test Accuracy: 98.82%


 87%|█████████████████████████████████████▎     | 26/30 [04:11<00:38,  9.60s/it]

Epoch: 25, Test Accuracy: 98.69%


 90%|██████████████████████████████████████▋    | 27/30 [04:21<00:28,  9.63s/it]

Epoch: 26, Test Accuracy: 98.62%


 93%|████████████████████████████████████████▏  | 28/30 [04:30<00:19,  9.59s/it]

Epoch: 27, Test Accuracy: 98.77%


 97%|█████████████████████████████████████████▌ | 29/30 [04:40<00:09,  9.54s/it]

Epoch: 28, Test Accuracy: 98.76%


100%|███████████████████████████████████████████| 30/30 [04:49<00:00,  9.66s/it]

Epoch: 29, Test Accuracy: 98.77%



