In [13]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader

# Define dataset and data loader
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Oxford Flowers 102 dataset
train_dataset = datasets.Flowers102(root="./data", split="train", download=True, transform=transform)
val_dataset = datasets.Flowers102(root="./data", split="val", download=True, transform=transform)
test_dataset = datasets.Flowers102(root="./data", split="test", download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load Pretrained ResNet-50 Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)

# Modify the last FC layer for 102 classes
model.fc = nn.Linear(model.fc.in_features, 102)
model = model.to(device)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Train function
def train_model(model, dataloader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {correct/total:.4f}")

# Train the model before pruning
print("Training Base Model...")
train_model(model, train_loader, optimizer, criterion, epochs=5)

# ----------------- STRUCTURED PRUNING -----------------

def prune_model(model, amount=0.3):
    """
    Apply structured pruning (L1 Norm based) on convolutional layers.
    Removes `amount` of least important filters per layer.
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)  # L1 Norm Pruning
            prune.remove(module, 'weight')  # Remove mask to make pruning permanent
    return model

# Apply structured pruning (30% of filters)
print("Applying Structured Pruning...")
pruned_model = prune_model(model, amount=0.3)

# Fine-tune after pruning
print("Fine-tuning Pruned Model...")
train_model(pruned_model, train_loader, optimizer, criterion, epochs=5)

# Evaluate the model
def evaluate_model(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

# Evaluate on test set
accuracy = evaluate_model(pruned_model, test_loader)
print(f"Final Pruned Model Accuracy: {accuracy:.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 177MB/s]


Training Base Model...
Epoch 1, Loss: 4.1190, Accuracy: 0.2147
Epoch 2, Loss: 2.3011, Accuracy: 0.8490
Epoch 3, Loss: 1.0566, Accuracy: 0.9716
Epoch 4, Loss: 0.3771, Accuracy: 0.9990
Epoch 5, Loss: 0.1533, Accuracy: 1.0000
Applying Structured Pruning...
Fine-tuning Pruned Model...
Epoch 1, Loss: 1.6745, Accuracy: 0.7412
Epoch 2, Loss: 0.4457, Accuracy: 0.9431
Epoch 3, Loss: 0.1798, Accuracy: 0.9804
Epoch 4, Loss: 0.0734, Accuracy: 0.9941
Epoch 5, Loss: 0.0543, Accuracy: 0.9931
Final Pruned Model Accuracy: 0.8618
