In [15]:
import torch
from transformers import AutoModel, AutoTokenizer
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch.profiler as profiler
from torch.utils.data import DataLoader, TensorDataset
from copy import deepcopy


In [29]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import copy
import numpy as np
from thop import profile


############################################
# Device Selection
############################################

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Seed for reproducibility
torch.manual_seed(42)
if device == 'cuda':
    torch.cuda.manual_seed(42)

############################################
# MNIST loading
############################################
# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)



############################################
# Model Definition
############################################
# Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 28 * 28, 10)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.constant_(self.conv1.bias, 0)
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.constant_(self.conv2.bias, 0)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

def train_one_epoch(model, optimizer, criterion, dataloader, device='cpu'):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for (images, labels) in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return running_loss / total, 100.0 * correct / total

def evaluate(model, dataloader, criterion, device='cpu'):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return running_loss / total, 100.0 * correct / total

def prune_conv_layer(conv_layer, prune_ratio=0.5, device='cpu'):
    w = conv_layer.weight.data
    out_channels = w.shape[0]

    # Compute L1-norm
    channel_norms = w.abs().mean(dim=(1,2,3))
    keep_channels = int(out_channels * (1 - prune_ratio))
    _, indices_to_keep = torch.topk(channel_norms, k=keep_channels)
    indices_to_keep = indices_to_keep.sort()[0]

    new_w = w[indices_to_keep, :, :, :].clone().to(device)
    if conv_layer.bias is not None:
        new_bias = conv_layer.bias.data[indices_to_keep].clone().to(device)
    else:
        new_bias = None

    new_conv = nn.Conv2d(
        in_channels=conv_layer.in_channels,
        out_channels=keep_channels,
        kernel_size=conv_layer.kernel_size,
        stride=conv_layer.stride,
        padding=conv_layer.padding,
        dilation=conv_layer.dilation,
        groups=conv_layer.groups,
        bias=(conv_layer.bias is not None)
    ).to(device)

    # Re-init weights for stability
    nn.init.kaiming_uniform_(new_conv.weight, nonlinearity='relu')
    nn.init.constant_(new_conv.bias, 0)

    with torch.no_grad():
        new_conv.weight.data = new_w
        if new_bias is not None:
            new_conv.bias.data = new_bias

    return new_conv

# Create baseline model
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

dummy_input = torch.randn(1, 1, 28, 28).to(device)
flops_before, _ = profile(model, inputs=(dummy_input,), verbose=False)

##########################################
# Train Baseline Model (5 epochs)
##########################################
print("=== Baseline Model Training ===")
if device == 'cuda':
    torch.cuda.synchronize()
baseline_start = time.time()
for epoch in range(5):
    train_loss, train_acc = train_one_epoch(model, optimizer, criterion, train_loader, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Loss: {val_loss:.4f}, Test Acc: {val_acc:.2f}%")
if device == 'cuda':
    torch.cuda.synchronize()
baseline_end = time.time()
baseline_total_time = baseline_end - baseline_start

baseline_val_loss, baseline_val_acc = val_loss, val_acc
if device == 'cuda':
    baseline_memory = torch.cuda.max_memory_allocated(device)
else:
    baseline_memory = None

##########################################
# Prune both conv1 and conv2 to improve the model
##########################################
pruned_model = copy.deepcopy(model).to(device)
pruned_model.eval()

# Prune conv1 by 50%
pruned_model.conv1 = prune_conv_layer(pruned_model.conv1, prune_ratio=0.5, device=device)

# After pruning conv1, the input channels of conv2 must be adjusted
# The original conv2 expects 16 channels in. After pruning conv1 by half, we have 8 channels out of conv1 now.
# We must re-initialize conv2 accordingly.
old_conv2 = pruned_model.conv2
in_channels_new = pruned_model.conv1.out_channels
out_channels_old = old_conv2.out_channels

# Create a new conv2 with updated in_channels
new_conv2 = nn.Conv2d(in_channels_new, out_channels_old, kernel_size=3, padding=1, bias=True).to(device)
nn.init.kaiming_uniform_(new_conv2.weight, nonlinearity='relu')
nn.init.constant_(new_conv2.bias, 0)
pruned_model.conv2 = new_conv2

# Now prune conv2 by 50%
pruned_model.conv2 = prune_conv_layer(pruned_model.conv2, prune_ratio=0.5, device=device)

# Adjust FC layer
new_out_channels = pruned_model.conv2.out_channels
new_fc_in_features = new_out_channels * 28 * 28
pruned_model.fc1 = nn.Linear(new_fc_in_features, 10).to(device)
nn.init.xavier_uniform_(pruned_model.fc1.weight)
nn.init.constant_(pruned_model.fc1.bias, 0)

optimizer_pruned = optim.Adam(pruned_model.parameters(), lr=1e-3)

flops_after, _ = profile(pruned_model, inputs=(dummy_input,), verbose=False)

#####################
# Train Pruned Model (5 epochs)
#####################
print("=== Pruned Model Training ===")
if device == 'cuda':
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
pruned_start = time.time()
for epoch in range(5):
    train_loss_p, train_acc_p = train_one_epoch(pruned_model, optimizer_pruned, criterion, train_loader, device)
    val_loss_p, val_acc_p = evaluate(pruned_model, test_loader, criterion, device)
    print(f"[Pruned] Epoch {epoch+1}: Train Loss: {train_loss_p:.4f}, Train Acc: {train_acc_p:.2f}%, Test Loss: {val_loss_p:.4f}, Test Acc: {val_acc_p:.2f}%")
if device == 'cuda':
    torch.cuda.synchronize()
pruned_end = time.time()
pruned_total_time = pruned_end - pruned_start

pruned_val_loss, pruned_val_acc = val_loss_p, val_acc_p
if device == 'cuda':
    pruned_memory = torch.cuda.max_memory_allocated(device)
else:
    pruned_memory = None

#####################
# Final Report
#####################
print("\n=== Final Results Summary ===")
print("Baseline Model Final Metrics:")
print(f" - Final Test Loss: {baseline_val_loss:.4f}")
print(f" - Final Test Acc: {baseline_val_acc:.2f}%")
print(f" - FLOPs Before Pruning: {flops_before}")
print(f" - Total Training Time (5 epochs): {baseline_total_time:.2f}s")
if baseline_memory is not None:
    print(f" - Peak Memory Usage: {baseline_memory / (1024**2):.2f} MB")

print("\nPruned Model Final Metrics:")
print(f" - Final Test Loss: {pruned_val_loss:.4f}")
print(f" - Final Test Acc: {pruned_val_acc:.2f}%")
print(f" - FLOPs After Pruning: {flops_after}")
print(f" - Total Training Time (5 epochs): {pruned_total_time:.2f}s")
if pruned_memory is not None:
    print(f" - Peak Memory Usage: {pruned_memory / (1024**2):.2f} MB")

print("\nComparison:")
print(f"FLOPs reduction: {flops_before} -> {flops_after}")
print(f"Time Reduction (training): {baseline_total_time:.2f}s -> {pruned_total_time:.2f}s")
if baseline_memory is not None and pruned_memory is not None:
    print(f"Memory Reduction: {baseline_memory / (1024**2):.2f} MB -> {pruned_memory / (1024**2):.2f} MB")


=== Baseline Model Training ===
Epoch 1: Train Loss: 0.1866, Train Acc: 95.29%, Test Loss: 0.0773, Test Acc: 97.74%
Epoch 2: Train Loss: 0.0424, Train Acc: 98.70%, Test Loss: 0.0549, Test Acc: 98.36%
Epoch 3: Train Loss: 0.0222, Train Acc: 99.28%, Test Loss: 0.0576, Test Acc: 98.21%
Epoch 4: Train Loss: 0.0149, Train Acc: 99.52%, Test Loss: 0.0616, Test Acc: 98.32%
Epoch 5: Train Loss: 0.0125, Train Acc: 99.55%, Test Loss: 0.0664, Test Acc: 98.29%
=== Pruned Model Training ===
[Pruned] Epoch 1: Train Loss: 0.1385, Train Acc: 95.72%, Test Loss: 0.0722, Test Acc: 97.73%
[Pruned] Epoch 2: Train Loss: 0.0447, Train Acc: 98.61%, Test Loss: 0.0570, Test Acc: 98.22%
[Pruned] Epoch 3: Train Loss: 0.0234, Train Acc: 99.25%, Test Loss: 0.0598, Test Acc: 98.23%
[Pruned] Epoch 4: Train Loss: 0.0162, Train Acc: 99.46%, Test Loss: 0.0653, Test Acc: 98.13%
[Pruned] Epoch 5: Train Loss: 0.0106, Train Acc: 99.67%, Test Loss: 0.0652, Test Acc: 98.29%

=== Final Results Summary ===
Baseline Model Final M