In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
batch_size      = 64
test_batch_size = 1000
epochs          = 15
lr              = 0.01
momentum        = 0.9
log_interval    = 100
prune_ratio     = 0.5
device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Data transforms & loaders ─────────────────────────────────────────────────
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = DataLoader(
    datasets.MNIST(root='./data', train=True,  download=True,  transform=transform),
    batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(
    datasets.MNIST(root='./data', train=False, download=True,  transform=transform),
    batch_size=test_batch_size, shuffle=False
)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net().to(device)

In [5]:
def magnitude_prune(model, prune_ratio):
    all_weights = torch.cat([
        param.view(-1) for name, param in model.named_parameters()
        if 'weight' in name and param.requires_grad
    ])
    num_params_to_keep = int(len(all_weights) * (1 - prune_ratio))
    threshold, _ = torch.kthvalue(all_weights.abs(), len(all_weights) - num_params_to_keep)

    for name, param in model.named_parameters():
        if 'weight' in name and param.requires_grad:
            mask = (param.abs() >= threshold).float()
            param.data *= mask
            print(f"Pruned {name}: Kept {mask.sum().item()} / {mask.numel()}")

magnitude_prune(model, prune_ratio)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# ─── Training & Validation ─────────────────────────────────────────────────────
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []


Pruned conv1.weight: Kept 285.0 / 288
Pruned conv2.weight: Kept 16764.0 / 18432
Pruned fc1.weight: Kept 581558.0 / 1179648
Pruned fc2.weight: Kept 1218.0 / 1280


In [6]:
def train(epoch):
    model.train()
    correct, total, running_loss = 0, 0, 0
    for batch_idx, (data, target) in enumerate(train_loader, 1):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

        if batch_idx % log_interval == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')

    train_losses.append(running_loss / len(train_loader.dataset))
    train_accuracies.append(100. * correct / total)

def validate():
    model.eval()
    correct, val_loss = 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    val_losses.append(val_loss / len(val_loader.dataset))
    val_accuracies.append(100. * correct / len(val_loader.dataset))
    print(f"\nValidation — Avg Loss: {val_losses[-1]:.4f}, Accuracy: {val_accuracies[-1]:.2f}%\n")


In [None]:
if __name__ == '__main__':
    start_time = time.time()
    for epoch in range(1, epochs + 1):
        train(epoch)
        validate()
    elapsed = time.time() - start_time
    mins, secs = divmod(elapsed, 60)
    print(f'\n⏱️ Total Time: {mins:.0f}m {secs:.2f}s')

    print(f"Training Accuracy: {train_accuracies[-1]:.2f}%")
    print(f"Validation Accuracy: {val_accuracies[-1]:.2f}%")

    # ─── Plotting ──────────────────────────────────────────────────────────────
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs+1), train_losses, label="Train Loss")
    plt.plot(range(1, epochs+1), val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss over Epochs")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs+1), train_accuracies, label="Train Acc")
    plt.plot(range(1, epochs+1), val_accuracies, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy over Epochs")
    plt.legend()

    plt.tight_layout()
    plt.show()

Epoch 1 [6400/60000] Loss: 0.6465
Epoch 1 [12800/60000] Loss: 0.4022
Epoch 1 [19200/60000] Loss: 0.3339
Epoch 1 [25600/60000] Loss: 0.1120
Epoch 1 [32000/60000] Loss: 0.2125
Epoch 1 [38400/60000] Loss: 0.1453
Epoch 1 [44800/60000] Loss: 0.0529
Epoch 1 [51200/60000] Loss: 0.1215
Epoch 1 [57600/60000] Loss: 0.1002

Validation — Avg Loss: 0.0645, Accuracy: 98.10%

Epoch 2 [6400/60000] Loss: 0.1148
Epoch 2 [12800/60000] Loss: 0.0404
Epoch 2 [19200/60000] Loss: 0.0396
Epoch 2 [25600/60000] Loss: 0.0583
Epoch 2 [32000/60000] Loss: 0.2119
Epoch 2 [38400/60000] Loss: 0.0361
Epoch 2 [44800/60000] Loss: 0.0583
Epoch 2 [51200/60000] Loss: 0.0317
Epoch 2 [57600/60000] Loss: 0.1105

Validation — Avg Loss: 0.0406, Accuracy: 98.71%

Epoch 3 [6400/60000] Loss: 0.0789
