In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

class LinearWithInverse(nn.Module):
    """Linear layer with a learned approximate inverse"""
    def __init__(self, in_features, out_features):
        super(LinearWithInverse, self).__init__()
        self.forward_layer = nn.Linear(in_features, out_features)
        self.inverse_layer = nn.Linear(out_features, in_features)
        
    def forward(self, x):
        return self.forward_layer(x)
    
    def inverse(self, y):
        return self.inverse_layer(y)

class DTPNetwork(nn.Module):
    """Network for Difference Target Propagation"""
    def __init__(self, layer_sizes):
        super(DTPNetwork, self).__init__()
        self.layers = nn.ModuleList()
        
        # Create layers with forward and inverse mappings
        for i in range(len(layer_sizes)-1):
            self.layers.append(LinearWithInverse(layer_sizes[i], layer_sizes[i+1]))
        
    def forward(self, x):
        activations = [x]
        for layer in self.layers:
            x = F.relu(layer(x))
            activations.append(x)
        return activations
    
    def compute_targets(self, activations, labels, learning_rate=0.1):
        # Start with the top layer target (difference to true label)
        targets = [None] * len(activations)
        top_layer = len(activations) - 1
        targets[top_layer] = labels - activations[top_layer]
        
        # Propagate targets downward
        for i in range(top_layer-1, 0, -1):
            # Difference target propagation formula
            targets[i] = activations[i] + self.layers[i].inverse(
                activations[i+1] + learning_rate * targets[i+1]) - self.layers[i].inverse(activations[i+1])
        
        return targets

def train_dtp(model, train_loader, optimizer, epochs=10, learning_rate=0.1):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            # Flatten the image
            data = data.view(data.size(0), -1)
            
            # Forward pass to get activations
            activations = model(data)
            
            # Convert target to one-hot encoding
            target_onehot = F.one_hot(target, num_classes=10).float()
            
            # Compute targets using DTP
            targets = model.compute_targets(activations, target_onehot, learning_rate)
            
            # Update each layer
            optimizer.zero_grad()
            
            # Compute loss for each layer and update
            for i in range(1, len(activations)):
                # Compute layer-specific loss
                layer_loss = F.mse_loss(activations[i], targets[i])
                
                # Backward pass for this layer only
                layer_loss.backward(retain_graph=True)
                
                # Update only the current layer's parameters
                for param in model.layers[i-1].parameters():
                    if param.grad is not None:
                        param.data -= learning_rate * param.grad
                        param.grad.zero_()
                
                total_loss += layer_loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                      f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {layer_loss.item():.6f}')
        
        print(f'Epoch: {epoch}, Average Loss: {total_loss / len(train_loader.dataset):.6f}')

def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(data.size(0), -1)
            output = model(data)[-1]  # Get final layer output
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}'
          f' ({100. * correct / len(test_loader.dataset):.0f}%)\n')

def main():
    # Hyperparameters
    batch_size = 64
    epochs = 10
    learning_rate = 0.01
    
    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))])
    
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    
    # Create model
    layer_sizes = [784, 500, 200, 100, 10]  # MNIST input is 28x28=784
    model = DTPNetwork(layer_sizes)
    
    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # Train and test
    train_dtp(model, train_loader, optimizer, epochs, learning_rate)
    test(model, test_loader)

if __name__ == '__main__':
    main()

  _torch_pytree._register_pytree_node(


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:11<00:00, 854187.87it/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 124625.71it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1083471.78it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3584969.66it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch: 0, Average Loss: 0.001511
Epoch: 1, Average Loss: 0.001497
Epoch: 2, Average Loss: 0.001474
Epoch: 3, Average Loss: 0.001457
Epoch: 4, Average Loss: 0.001438
Epoch: 5, Average Loss: 0.001411
Epoch: 6, Average Loss: 0.001396
Epoch: 7, Average Loss: 0.001383
Epoch: 8, Average Loss: 0.001371
Epoch: 9, Average Loss: 0.001360

Test set: Average loss: 2.2780, Accuracy: 3886/10000 (39%)

