In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn.utils import clip_grad_norm_

torch.manual_seed(2024)
HIDDEN_DIM = 256
USE_SIGMA = 1
device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

class ForwardModel(nn.Module):
    def __init__(self, hidden_dim=128, use_sigma=True):
        super(ForwardModel, self).__init__()
        self.use_sigma = use_sigma
        self.hidden_dim = hidden_dim

        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.max_pool = nn.MaxPool2d(2)

        # Linear layers
        self.fc1 = nn.Linear(32 * 7 * 7, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

        self.LN1 = nn.LayerNorm((32, 14, 14), elementwise_affine=False)
        self.LN2 = nn.LayerNorm((1, 28, 28), elementwise_affine=False)

        self.fc2.weight.data.zero_()
        self.fc2.bias.data.zero_()
        self.fc1.bias.data.zero_()
        self.conv2.bias.data.zero_()
        self.conv1.bias.data.zero_()

    def forward(self, x):
        
        x = x.view(-1, 1, 28, 28)
        x = self.LN2(x)

        a1 = F.elu(self.conv1(x.detach()))
        a1 = self.max_pool(a1)
        a1 = self.LN1(a1)

        a2 = F.elu(self.conv2(a1.detach()))
        a2 = self.max_pool(a2)
        a2 = a2.view(-1, 32 * 7 * 7)

        a3 = F.elu(self.fc1(a2.detach()))

        a4 = self.fc2(a3.detach())
        self.penultimate_feature = a3.detach()

        return a1, a2, a3, a4

    def forward_logits(self):
        return self.fc2(self.penultimate_feature)

class BackwardModel(nn.Module):
    def __init__(self, hidden_dim=128):
        super(BackwardModel, self).__init__()
        self.hidden_dim = hidden_dim
        
        alpha = 1
        self.alpha = alpha

        self.fc1 = nn.Linear(10, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, 32 * 7 * 7, bias=False)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(32, 32, kernel_size=5, padding=2, bias=False)
        self.LN1 = nn.LayerNorm((32, 14, 14), elementwise_affine=False)

    def forward(self, t, use_act_derivative=False):
        device = t.device
        t = F.one_hot(t, num_classes=10).float().to(device)
        if use_act_derivative:
            s3 = F.elu(self.fc1(t.detach()) * a3.sign().detach())
            s2 = F.elu(self.fc2(s3.detach()) * a2.sign().detach())
            s2_ = s2.view(-1, 32, 7, 7)
            s2_ = self.upsample1(s2_)
            s1 = F.elu(self.conv1(s2_.detach()) * a1.sign().detach())
        else:
            s3 = F.elu(self.fc1(t.detach()))
            # s3 = s3 - s3.mean()
            s2 = F.elu(self.fc2(s3.detach()))
            # s2 = s2 - s2.mean()
            s2_ = s2.view(-1, 32, 7, 7)
            s2_ = self.upsample1(s2_)
            s1 = F.elu(self.conv1(s2_.detach()))
        return s1[:, :32], s2[:,:32*7*7], s3[:, :self.hidden_dim]
    
class BackwardModel(nn.Module):
    def __init__(self, hidden_dim=128):
        super(BackwardModel, self).__init__()
        self.hidden_dim = hidden_dim
        
        alpha = 2
        self.alpha = alpha

        self.fc1_random = nn.Linear(10, 10*alpha, bias=False)
        self.fc1        = nn.Linear(10*alpha, hidden_dim, bias=False)
        self.fc2_random = nn.Linear(hidden_dim, hidden_dim * alpha, bias=False)
        self.fc2        = nn.Linear(hidden_dim* alpha, 32 * 7 * 7, bias=False)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1_random = nn.Conv2d(32, 32 * alpha, kernel_size=1, padding=0, bias=False)
        self.conv1        = nn.Conv2d(32 * alpha, 32, kernel_size=5, padding=2, bias=False)

    def forward(self, t, use_act_derivative=False):
        device = t.device
        t = F.one_hot(t, num_classes=10).float().to(device)
        
        if use_act_derivative:
            
            assert 0 == 1, "Not implemented errors"
            s3 = F.elu(self.fc1(t.detach()) * a3.sign().detach())
            s2 = F.elu(self.fc2(s3.detach()) * a2.sign().detach())
            s2_ = s2.view(-1, 32, 7, 7)
            s2_ = self.upsample1(s2_)
            s1 = F.elu(self.conv1(s2_.detach()) * a1.sign().detach())
            
        else:
            
            s3 = F.elu(self.fc1_random(t.detach()))
            s3 = F.elu(self.fc1(s3.detach()))
            
            s2 = F.elu(self.fc2_random(s3.detach()))
            s2 = F.elu(self.fc2(s2.detach()))
            
            s2_ = s2.view(-1, 32, 7, 7)

            s2_ = self.upsample1(s2_)
            
            s1 = F.elu(self.conv1_random(s2_.detach()))
            
            s1 = F.elu(self.conv1(s1.detach()))
            
        return s1[:, :32], s2[:,:32*7*7], s3[:, :self.hidden_dim]
    
        #     s3 = F.elu(self.fc1(t.detach()))
        #     s2 = F.elu(self.fc2(s3.detach()))
        #     s2_ = s2.view(-1, 32, 7, 7)
        #     s2_ = self.upsample1(s2_)
        #     s1 = F.elu(self.conv1(s2_.detach()))
        # return s1[:, :32], s2[:,:32*7*7], s3[:, :self.hidden_dim]

def normal(x): 
    if len(x.shape) == 2:
        return x / x.std(dim=1, keepdim=True)
    if len(x.shape) == 4:
        return x / x.std(dim=[1,2,3], keepdim=True)

def sigma_loss(a1, a2, a3, a4, s1, s2, s3, t):

    loss1 = F.mse_loss(normal(a1), normal(s1))
    loss2 = F.mse_loss(normal(a2), normal(s2))
    loss3 = F.mse_loss(normal(a3), normal(s3))
    loss4 = criteria(a4, t)
    loss = loss1 + loss2 + loss3 + loss4
    return loss, loss1.item(), loss2.item(), loss3.item(), loss4.item()

# Initialize the models
forward_model = ForwardModel(use_sigma=USE_SIGMA)
backward_model = BackwardModel()

forward_model.to(device)
backward_model.to(device)

# Define the optimizers
# forward_optimizer = optim.SGD(forward_model.parameters(), lr=0.01, weight_decay=0.01, momentum=0.5)
# backward_optimizer = optim.SGD(backward_model.parameters(), lr=0.001)
forward_optimizer = optim.RMSprop(forward_model.parameters(), lr=0.0001)
backward_optimizer = optim.RMSprop(backward_model.parameters(), lr=0.0000005)
criteria = nn.CrossEntropyLoss()

from tqdm import tqdm

for epoch in tqdm(range(30)):
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # if batch_idx > 100: continue

        data, target = data.to(device), target.to(device)
        a1, a2, a3, a4 = forward_model(data)
        s1, s2, s3 = backward_model(target, use_act_derivative=False)
        
        loss, l1, l2, l3, l4 = sigma_loss(a1, a2, a3, a4, s1, s2, s3, target)
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad()
        loss.backward()
        # if batch_idx == 50: assert 0 == 1
        clip_grad_norm_(backward_model.parameters(), 0.1), forward_optimizer.step(), backward_optimizer.step()
        
        # Update the backward model more times
        # for _ in range():
        #     s1, s2, s3 = backward_model(target, use_act_derivative=False)
        #     loss, l1, l2, l3, l4 = sigma_loss(a1.detach(), a2.detach(), a3.detach(), a4.detach(), s1, s2, s3, target)
        #     forward_optimizer.zero_grad(), backward_optimizer.zero_grad()
        #     loss.backward()
        #     # clip_grad_norm_(backward_model.parameters(), 0.1), 
            # forward_optimizer.step(), backward_optimizer.step()

    for batch_idx, (data, target) in enumerate(train_loader):
        # if batch_idx > 100: continue
        a1, a2, a3, a4 = forward_model(data.to(device))
        loss = criteria(a4, target.to(device))
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()
        
    # print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss1: {l1*100:.1f}, Loss2: {l2*100:.1f}, Loss3: {l3*100:.1f}, Loss4: {l4:.1f}')
    
    forward_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            _, _, _, outputs = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    forward_model.train()

  3%|██▎                                                                  | 1/30 [01:50<53:29, 110.67s/it]

Epoch: 0, Test Accuracy: 96.27%


  7%|████▌                                                                | 2/30 [03:47<53:15, 114.11s/it]

Epoch: 1, Test Accuracy: 96.77%


 10%|██████▉                                                              | 3/30 [05:43<51:50, 115.19s/it]

Epoch: 2, Test Accuracy: 96.66%


 13%|█████████▏                                                           | 4/30 [07:41<50:18, 116.11s/it]

Epoch: 3, Test Accuracy: 96.53%


 17%|███████████▌                                                         | 5/30 [09:38<48:36, 116.67s/it]

Epoch: 4, Test Accuracy: 96.22%


 20%|█████████████▊                                                       | 6/30 [11:36<46:47, 116.99s/it]

Epoch: 5, Test Accuracy: 96.22%


 23%|████████████████                                                     | 7/30 [13:34<44:55, 117.20s/it]

Epoch: 6, Test Accuracy: 96.22%


 27%|██████████████████▍                                                  | 8/30 [15:31<42:58, 117.19s/it]

Epoch: 7, Test Accuracy: 96.15%


 30%|████████████████████▋                                                | 9/30 [17:30<41:14, 117.84s/it]

Epoch: 8, Test Accuracy: 96.15%


 33%|██████████████████████▋                                             | 10/30 [19:29<39:26, 118.32s/it]

Epoch: 9, Test Accuracy: 96.09%


 37%|████████████████████████▉                                           | 11/30 [21:28<37:27, 118.30s/it]

Epoch: 10, Test Accuracy: 96.59%


 40%|███████████████████████████▏                                        | 12/30 [23:26<35:27, 118.21s/it]

Epoch: 11, Test Accuracy: 96.39%


 43%|█████████████████████████████▍                                      | 13/30 [25:24<33:29, 118.18s/it]

Epoch: 12, Test Accuracy: 96.48%


 47%|███████████████████████████████▋                                    | 14/30 [27:21<31:27, 118.00s/it]

Epoch: 13, Test Accuracy: 96.63%


 50%|██████████████████████████████████                                  | 15/30 [29:19<29:29, 117.98s/it]

Epoch: 14, Test Accuracy: 96.74%


 53%|████████████████████████████████████▎                               | 16/30 [31:18<27:34, 118.19s/it]

Epoch: 15, Test Accuracy: 95.84%


 57%|██████████████████████████████████████▌                             | 17/30 [33:16<25:36, 118.21s/it]

Epoch: 16, Test Accuracy: 96.42%


 60%|████████████████████████████████████████▊                           | 18/30 [35:15<23:39, 118.32s/it]

Epoch: 17, Test Accuracy: 94.97%


 63%|███████████████████████████████████████████                         | 19/30 [37:13<21:39, 118.17s/it]

Epoch: 18, Test Accuracy: 94.13%


 67%|█████████████████████████████████████████████▎                      | 20/30 [39:11<19:41, 118.16s/it]

Epoch: 19, Test Accuracy: 92.44%


 70%|███████████████████████████████████████████████▌                    | 21/30 [41:09<17:44, 118.25s/it]

Epoch: 20, Test Accuracy: 81.13%


 73%|█████████████████████████████████████████████████▊                  | 22/30 [43:08<15:46, 118.28s/it]

Epoch: 21, Test Accuracy: 63.62%


 77%|████████████████████████████████████████████████████▏               | 23/30 [45:06<13:48, 118.32s/it]

Epoch: 22, Test Accuracy: 19.52%


 80%|██████████████████████████████████████████████████████▍             | 24/30 [47:04<11:49, 118.21s/it]

Epoch: 23, Test Accuracy: 11.78%


 83%|████████████████████████████████████████████████████████▋           | 25/30 [49:01<09:49, 118.00s/it]

Epoch: 24, Test Accuracy: 22.42%


 87%|██████████████████████████████████████████████████████████▉         | 26/30 [51:00<07:52, 118.14s/it]

Epoch: 25, Test Accuracy: 40.18%


 90%|█████████████████████████████████████████████████████████████▏      | 27/30 [52:58<05:54, 118.09s/it]

Epoch: 26, Test Accuracy: 50.9%


 93%|███████████████████████████████████████████████████████████████▍    | 28/30 [54:56<03:56, 118.08s/it]

Epoch: 27, Test Accuracy: 27.68%


 97%|█████████████████████████████████████████████████████████████████▋  | 29/30 [56:54<01:58, 118.10s/it]

Epoch: 28, Test Accuracy: 56.53%


100%|████████████████████████████████████████████████████████████████████| 30/30 [58:52<00:00, 117.75s/it]

Epoch: 29, Test Accuracy: 50.87%





In [2]:
for p1, p2 in backward_model.named_parameters():
    print(p1, p2.grad.shape)
    print(p2.grad.norm())

AttributeError: 'NoneType' object has no attribute 'shape'

In [None]:
import numpy as np
for p1, p2 in forward_model.named_parameters():
    if "bias" in p1: continue
    print(p1, "\t", np.round(p2.data.norm().item(), decimals=2), "\t", np.round(p2.grad.norm().item(), decimals=2))

print()
 
for p1, p2 in backward_model.named_parameters():
    if "bias" in p1: continue
    print(p1, "\t", np.round(p2.data.norm().item(), decimals=2), "\t", np.round(p2.grad.norm().item(), decimals=2))

In [None]:
data.shape, a1.shape, a2.shape, a3.shape

In [None]:
import matplotlib.pyplot as plt

count = 0
for data, target in test_loader: 

    a1, a2, a3, a4 = forward_model(data)
    s1, s2, s3 = backward_model(target, use_act_derivative=False)
    loss, l1, l2, l3, l4 = sigma_loss(a1, a2, a3, a4, s1, s2, s3, target)
    forward_optimizer.zero_grad(), backward_optimizer.zero_grad()
    loss.backward()
    # p = torch.argmax(a4, dim=-1)
    p = torch.topk(a4, k=3, dim=-1).indices
    v = torch.topk(a4, k=3, dim=-1).values

    for i1 in range(16):

        if p[i1][0] != target[i1]:
            
            count += 1
            if count > 3: assert 0 == 1

            i2 = 3
            print(f"prediction: {p[i1]}, {v[i1].detach()} groudtruth: {target[i1]}")

            plt.figure()
            fig, axes = plt.subplots(1,2)

            X = a1.detach().cpu().numpy()
            axes[0].imshow(X[i1,i2], vmin=-1, vmax=1)

            X = s1.detach().cpu().numpy()
            print(X[i1,i2].min(), X[i1,i2].max())
            axes[1].imshow(X[i1,i2], vmin=-0.03, vmax=0.03)
            plt.pause(0.2)
    # break

In [None]:
import matplotlib.pyplot as plt

count = 0
for data, target in test_loader: 

    a1, a2, a3, a4 = forward_model(data)
    s1, s2, s3 = backward_model(target, use_act_derivative=False)
    # p = torch.argmax(a4, dim=-1)
    p = torch.topk(a4, k=3, dim=-1).indices
    v = torch.topk(a4, k=3, dim=-1).values

    for i1 in range(16):

        if p[i1][0] != target[i1]:
            
            count += 1
            if count > 10: assert 0 == 1

            i2 = 3
            print(f"prediction: {p[i1]}, {v[i1].detach()} groudtruth: {target[i1]}")

            plt.figure()
            fig, axes = plt.subplots(1,2)

            X = a1.detach().cpu().numpy()
            axes[0].imshow(X[i1,i2], vmin=-1, vmax=1)

            X = s1.detach().cpu().numpy()
            print(X[i1,i2].min(), X[i1,i2].max())
            axes[1].imshow(X[i1,i2], vmin=-0.03, vmax=0.03)
            plt.pause(0.2)
    # break

In [None]:
torch.manual_seed(2024)
forward_model_rand_init = ForwardModel(use_sigma=USE_SIGMA)
maps = forward_model_rand_init.conv1.weight[:,0].detach().cpu().numpy()
fig, axes = plt.subplots(4,8,figsize=(10,5))
for i in range(4):
    for j in range(8):
        axes[i,j].imshow(maps[i*8+j])

In [None]:
maps = forward_model.conv1.weight[:,0].detach().cpu().numpy()
fig, axes = plt.subplots(4,8,figsize=(10,5))
for i in range(4):
    for j in range(8):
        axes[i,j].imshow(maps[i*8+j])

In [None]:
1568 / 32

In [None]:
## import matplotlib.pyplot as plt

backward_model = BackwardModel()
s1, s2, s3 = backward_model(target)

i1, i2 = 7, 2
X = s1.detach().cpu()
X = X / X.norm()
X = X.numpy()
print(X[i1,i2].min(), X[i1,i2].max())
plt.imshow(X[i1,i2])

In [None]:
i1, i2 = 0, 6
X = a1.detach().cpu()
X = X / X.norm()
X = X.numpy()
print(X[i1,i2].min(), X[i1,i2].max())
plt.imshow(X[i1,i2])

In [None]:
s2.shape