In [3]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import EuroSAT

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dataset transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for EfficientNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Automatically download and load EuroSAT RGB dataset
dataset = EuroSAT(root="./data", transform=transform, download=True)

# Split dataset (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load EfficientNet-B0 (best accuracy for structured pruning)
model = models.efficientnet_b0(weights=None)  # Training from scratch

# Modify classifier for 10 EuroSAT classes
num_classes = len(dataset.classes)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
model = model.to(device)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, dataloader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {correct/total:.4f}")

# Train before pruning
print("Training Base Model from Scratch...")
train_model(model, train_loader, optimizer, criterion, epochs=10)

# ----------------- STRUCTURED PRUNING -----------------

def prune_model(model, amount=0.3):
    """
    Apply structured pruning (L1 Norm based) on convolutional layers.
    Removes `amount` of least important filters per layer.
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)  # L1 Norm Pruning
            prune.remove(module, 'weight')  # Make pruning permanent
    return model

# Apply structured pruning (30% of filters)
print("Applying Structured Pruning...")
pruned_model = prune_model(model, amount=0.3)

# Fine-tune after pruning
print("Fine-tuning Pruned Model...")
train_model(pruned_model, train_loader, optimizer, criterion, epochs=5)

# Evaluate the model
def evaluate_model(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

# Evaluate on validation set
accuracy = evaluate_model(pruned_model, val_loader)
print(f"Final Pruned Model Accuracy: {accuracy:.4f}")


Downloading https://cdn-lfs.hf.co/repos/fc/1d/fc1dee780dee1dae2ad48856d0961ac6aa5dfcaaaa4fb3561be4aedf19b7ccc7/8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27EuroSAT.zip%3B+filename%3D%22EuroSAT.zip%22%3B&response-content-type=application%2Fzip&Expires=1740384316&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MDM4NDMxNn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mYy8xZC9mYzFkZWU3ODBkZWUxZGFlMmFkNDg4NTZkMDk2MWFjNmFhNWRmY2FhYWE0ZmIzNTYxYmU0YWVkZjE5YjdjY2M3LzhlYmVhNjI2MzQ5MzU0YzUzMjhiMTQyYjk2ZDA0MzBlNjQ3MDUxZjI2ZWZjMmRjOTc0Yzg0M2YyNWVjZjcwYmQ%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=i5cJ0LyP4JSVGdUKCeSnxlRxKcG1ggglcgCLxSqWaO72gBJIx6HHG9YpL8th1uaGmuNA70tP1WgdhbBqgs-ITOShp57ahbB-mtX-pd-xRJ5UoyzN-KkvNlanKKZ%7EDtWiXvne5gyc%7EqUEvqaE3VoafDUKbGZPYuhRdbY5hljeFfLotiTcKGGvCFeNhOhcgT6e6sr7PS0Zy3AXdEW3K1N2baiSvA%7

100%|██████████| 94.3M/94.3M [00:00<00:00, 205MB/s]


Extracting ./data/eurosat/EuroSAT.zip to ./data/eurosat
Training Base Model from Scratch...
Epoch 1, Loss: 1.0067, Accuracy: 0.6485
Epoch 2, Loss: 0.5260, Accuracy: 0.8225
Epoch 3, Loss: 0.3638, Accuracy: 0.8784
Epoch 4, Loss: 0.2880, Accuracy: 0.9039
Epoch 5, Loss: 0.2403, Accuracy: 0.9184
Epoch 6, Loss: 0.2027, Accuracy: 0.9333
Epoch 7, Loss: 0.1862, Accuracy: 0.9364
Epoch 8, Loss: 0.1593, Accuracy: 0.9461
Epoch 9, Loss: 0.1511, Accuracy: 0.9494
Epoch 10, Loss: 0.1306, Accuracy: 0.9562
Applying Structured Pruning...
Fine-tuning Pruned Model...
Epoch 1, Loss: 0.5669, Accuracy: 0.8081
Epoch 2, Loss: 0.2636, Accuracy: 0.9128
Epoch 3, Loss: 0.2013, Accuracy: 0.9297
Epoch 4, Loss: 0.1711, Accuracy: 0.9413
Epoch 5, Loss: 0.1540, Accuracy: 0.9482
Final Pruned Model Accuracy: 0.9526
