In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torchvision import datasets, transforms

import torch
torch.manual_seed(2024)
HIDDEN_DIM = 256

# Define the forward and backward models
class ForwardModel(nn.Module):
    def __init__(self):
        super(ForwardModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
        self.fc3 = nn.Linear(HIDDEN_DIM, 10)
        
        self.fc3.weight.data.zero_()
        self.fc1.bias.data.zero_()
        self.fc2.bias.data.zero_()
        self.fc3.bias.data.zero_()

    def forward(self, x, s2=None):
        x = x.view(-1, 28 * 28)
        a1 = F.elu(self.fc1(x))
        a2 = F.elu(self.fc2(a1.detach()))
        a3 = self.fc3(a2.detach())
    
        if s2 is not None:
            s2_recon = self.fc3(s2.detach())
            return a1, a2, a3, s2_recon
        else:
            return a1, a2, a3, None

# # Define the loss functions
# def sigma_loss(a3, t):
#     loss3 = F.mse_loss(a3, torch.nn.functional.one_hot(t, num_classes=10).float().to(t.device))
#     return loss3

# Define the loss function
def sigma_loss(a3, t):
    # num_classes = 10
    # t_one_hot = F.one_hot(t, num_classes).float()
    # loss3 = F.binary_cross_entropy_with_logits(a3, t_one_hot)
    loss = criteria(a3, t)
    return loss

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=HIDDEN_DIM, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=HIDDEN_DIM, shuffle=False)

# Initialize the models
forward_model = ForwardModel()

# Define the optimizers
forward_optimizer = optim.SGD(forward_model.parameters(), lr=0.03, momentum=0.5)

criteria = nn.CrossEntropyLoss()

# Training loop
for epoch in range(30):
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # Forward pass
        a1, a2, a3, s2_recon = forward_model(data)
        # Compute losses
        loss = criteria(a3, target)

        # Update parameters
        forward_optimizer.zero_grad()
        loss.backward(retain_graph=True)
        forward_optimizer.step()
        
        # Print statistics
    print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.3f}')

    # Evaluate on test set
    forward_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            _, _, outputs, _ = forward_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / total}%')
    forward_model.train()

Epoch: 0, Batch: 234, Loss: 1.432
Epoch: 0, Test Accuracy: 75.85%
Epoch: 1, Batch: 234, Loss: 1.205
Epoch: 1, Test Accuracy: 79.66%
Epoch: 2, Batch: 234, Loss: 1.035
Epoch: 2, Test Accuracy: 81.43%
Epoch: 3, Batch: 234, Loss: 0.861
Epoch: 3, Test Accuracy: 82.39%
Epoch: 4, Batch: 234, Loss: 0.741
Epoch: 4, Test Accuracy: 83.1%
Epoch: 5, Batch: 234, Loss: 0.706
Epoch: 5, Test Accuracy: 83.67%
Epoch: 6, Batch: 234, Loss: 0.583
Epoch: 6, Test Accuracy: 84.08%
Epoch: 7, Batch: 234, Loss: 0.763
Epoch: 7, Test Accuracy: 84.28%
Epoch: 8, Batch: 234, Loss: 0.726
Epoch: 8, Test Accuracy: 84.67%
Epoch: 9, Batch: 234, Loss: 0.561
Epoch: 9, Test Accuracy: 84.91%
Epoch: 10, Batch: 234, Loss: 0.526
Epoch: 10, Test Accuracy: 85.23%
Epoch: 11, Batch: 234, Loss: 0.573
Epoch: 11, Test Accuracy: 85.33%
Epoch: 12, Batch: 234, Loss: 0.483
Epoch: 12, Test Accuracy: 85.61%
Epoch: 13, Batch: 234, Loss: 0.523
Epoch: 13, Test Accuracy: 85.71%
Epoch: 14, Batch: 234, Loss: 0.701
Epoch: 14, Test Accuracy: 85.96%
E

In [2]:
forward_model.fc1.weight

Parameter containing:
tensor([[ 0.0023,  0.0237,  0.0337,  ..., -0.0043, -0.0330, -0.0205],
        [-0.0297,  0.0216, -0.0344,  ...,  0.0200, -0.0087, -0.0302],
        [ 0.0146, -0.0248,  0.0243,  ...,  0.0059,  0.0279, -0.0065],
        ...,
        [-0.0246, -0.0254,  0.0223,  ...,  0.0059, -0.0134,  0.0138],
        [ 0.0267, -0.0255, -0.0024,  ..., -0.0045, -0.0024,  0.0087],
        [-0.0053, -0.0010,  0.0096,  ...,  0.0314, -0.0324,  0.0099]],
       requires_grad=True)