In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        # Dynamically calculate the size after convolutional layers
        self._to_linear = None
        self.fc1 = nn.Linear(self._get_conv_output_size(), 128)
        self.fc2 = nn.Linear(128, 10)

    def _get_conv_output_size(self):
        # Create a dummy input to calculate the output size of conv layers
        dummy_input = torch.randn(1, 1, 28, 28) # Assuming MNIST image size of 28x28
        x = torch.relu(self.conv1(dummy_input))
        x = torch.relu(self.conv2(x))
        return x.flatten(1).size(1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(test_loader.dataset)

def global_magnitude_pruning(model, amount):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            parameters_to_prune.append((module, 'weight'))

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

    # Remove pruning reparameterization to make pruning permanent
    for module, _ in parameters_to_prune:
        prune.remove(module, 'weight')

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Entraînement initial
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Entraînement du modèle initial...")
for epoch in range(5):
    train(model, device, train_loader, optimizer, epoch)
    accuracy = test(model, device, test_loader)
    print(f"Époque {epoch+1}, Précision: {accuracy:.2f}%")

# Application du pruning
print("\nApplication du pruning...")
global_magnitude_pruning(model, amount=0.5)
sparsity = 100. * float(torch.sum(model.conv1.weight == 0) +
                       torch.sum(model.conv2.weight == 0) +
                       torch.sum(model.fc1.weight == 0) +
                       torch.sum(model.fc2.weight == 0)) / float(model.conv1.weight.nelement() +
                                                              model.conv2.weight.nelement() +
                                                              model.fc1.weight.nelement() +
                                                              model.fc2.weight.nelement())
print(f"Éparsité après pruning: {sparsity:.2f}%")

# Fine-tuning après pruning
print("\nFine-tuning après pruning...")
for epoch in range(3):
    train(model, device, train_loader, optimizer, epoch)
    accuracy = test(model, device, test_loader)
    print(f"Époque {epoch+1}, Précision: {accuracy:.2f}%")

Entraînement du modèle initial...
Époque 1, Précision: 98.41%
Époque 2, Précision: 98.68%
Époque 3, Précision: 98.60%
Époque 4, Précision: 98.76%
