In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.backends.mps as mps

# Configuration

In [26]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print("Running on device %s" % device)

Running on device mps


In [27]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load Datasets

In [28]:
batch_size = 64
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


# Simple CNN model

In [29]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Initialization


In [30]:
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# AGC Function

In [None]:
def apply_agc(model, compression_ratio=0.5):
    for param in model.parameters():
        if param.grad is not None:
            grad = param.grad
            
            k = int(compression_ratio * grad.numel())  # Keep top-k gradients
            if k < 1:
                continue
            
            top_k_values, top_k_indices = torch.topk(grad.view(-1), k)
            
            sparse_grad = torch.zeros_like(grad.view(-1))
            sparse_grad[top_k_indices] = top_k_values
            
            param.grad = sparse_grad.view(grad.shape)

# Training loop


In [32]:
epochs = 10
compression_ratio = 0.5  # Keep 50% of gradients

for epoch in range(epochs):
    running_loss = 0.0

    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Apply AGC before optimization
        apply_agc(model, compression_ratio)

        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(trainloader):.4f}")

print("Training complete!")

Epoch 1/10, Loss: 2.3030
Epoch 2/10, Loss: 2.3029
Epoch 3/10, Loss: 2.3027
Epoch 4/10, Loss: 2.3027
Epoch 5/10, Loss: 2.3027
Epoch 6/10, Loss: 2.3027
Epoch 7/10, Loss: 2.3027
Epoch 8/10, Loss: 2.3027
Epoch 9/10, Loss: 2.3027
Epoch 10/10, Loss: 2.3027
Training complete!


# Evaluate the model

In [33]:
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 10.00%
