# MNIST Classification based ADMM

## Import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
torch.manual_seed(42)

<torch._C.Generator at 0x79723813a930>

## Load dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

## Neural net architecture

In [3]:
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

## ADMM hyperparameters

In [4]:
model = model()
rho = 1e-3
l1_lambda = 1e-4
primal_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## ADMM variables

In [5]:
z = [torch.zeros_like(p) for p in model.parameters()] # proximal variable
u = [torch.zeros_like(p) for p in model.parameters()] # dual variable

## Training loop

In [6]:
EPOCHS = 6

In [7]:
for epoch in range(EPOCHS):
    model.train()
    total_task_loss = 0
    total_admm_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        primal_optimizer.zero_grad()
                
        output = model(data)
        
        task_loss = F.nll_loss(output, target)
        admm_loss = 0
        
        # Augmented Lagrangian
        for idx, param in enumerate(model.parameters()):
            admm_loss += (rho/2) * torch.sum((param - z[idx] + u[idx])**2)
        
        total_loss = task_loss + admm_loss
        
        # Primal variable update
        total_loss.backward()
        primal_optimizer.step()
        
        # Proximal variable update
        with torch.no_grad():
            for idx, param in enumerate(model.parameters()):
                v = param + u[idx]
                z[idx] = torch.sign(v) * torch.maximum(
                    torch.abs(v) - l1_lambda/rho,
                    torch.tensor(0.0)
                )
        
        # Dual variable update
        with torch.no_grad():
            for idx, param in enumerate(model.parameters()):
                u[idx] = u[idx] + param - z[idx]
        
        total_task_loss += task_loss
        total_admm_loss += admm_loss
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1}/{EPOCHS} [{batch_idx*len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                  f'Task Loss: {task_loss:.6f}\t'
                  f'ADMM Loss: {admm_loss:.6f}')
            
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nEpoch: {epoch+1}')
    print(f'Average Training Task Loss: {total_task_loss/len(train_loader):.6f}')
    print(f'Average Training ADMM Loss: {total_admm_loss/len(train_loader):.6f}')
    print(f'Test Loss: {test_loss:.6f}')
    print(f'Test Accuracy: {accuracy:.2f}%\n')


Epoch: 1
Average Training Task Loss: 0.286753
Average Training ADMM Loss: 0.954863
Test Loss: 0.153135
Test Accuracy: 95.54%


Epoch: 2
Average Training Task Loss: 0.131405
Average Training ADMM Loss: 0.858187
Test Loss: 0.110601
Test Accuracy: 96.61%


Epoch: 3
Average Training Task Loss: 0.104709
Average Training ADMM Loss: 0.827278
Test Loss: 0.099504
Test Accuracy: 97.00%


Epoch: 4
Average Training Task Loss: 0.088673
Average Training ADMM Loss: 0.808508
Test Loss: 0.090686
Test Accuracy: 97.17%


Epoch: 5
Average Training Task Loss: 0.079736
Average Training ADMM Loss: 0.800099
Test Loss: 0.089905
Test Accuracy: 97.21%


Epoch: 6
Average Training Task Loss: 0.074099
Average Training ADMM Loss: 0.796454
Test Loss: 0.080473
Test Accuracy: 97.56%



## Check graph sparsity

In [8]:
def compute_sparsity(model):
    total_params = 0
    zero_params = 0
    for param in model.parameters():
        total_params += param.numel()
        zero_params += (param.abs() < 1e-5).sum().item()
    return zero_params / total_params * 100

sparsity = compute_sparsity(model)
print(f'\nFinal model sparsity: {sparsity:.2f}% parameters near zero')


Final model sparsity: 0.17% parameters near zero
