In [49]:
import copy

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, random_split
import torchvision
import torch.nn as nn
from torchvision.transforms import v2
import multiprocessing
import torch.nn.utils.prune as prune
import torch.nn.functional as F


In [33]:
print(torch.cuda.is_available())

True


In [72]:
# run variables

seed = 42
file_name = 'base_pruned'
# file_name = 'student_self_taught'

class BaseNN(nn.Module):
    def __init__(self, conv1_out = 128, num_classes=10):
        super(BaseNN, self).__init__()

        ### teacher
        self.features = nn.Sequential(
            nn.Conv2d(3, conv1_out, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(conv1_out, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

        ### student
        # self.features = nn.Sequential(
        #     nn.Conv2d(3, 16, kernel_size=3, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        #     nn.Conv2d(16, 16, kernel_size=3, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        # )
        # self.classifier = nn.Sequential(
        #     nn.Linear(1024, 256),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(256, num_classes)
        # )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [36]:
# device settings

num_workers = 2

print(torch.cuda.get_device_name(torch.cuda.current_device()))

#device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {device} device")

NVIDIA GeForce RTX 2060
Using cuda device


In [44]:
# define datasets and loaders

transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = torchvision.datasets.CIFAR10("./../data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10("./../data", train=False, transform=transform, download=True)

train_dataset, validation_dataset =  random_split(train_dataset, [0.8, 0.2])

print('train set size:', len(train_dataset))
print('validation set size:', len(validation_dataset))
print('test set size:', len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=num_workers)

class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

Files already downloaded and verified
Files already downloaded and verified
train set size: 40000
validation set size: 10000
test set size: 10000


In [38]:
def train(model, epochs, learning_rate):
    trainingEpoch_loss = []
    validationEpoch_loss = []
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        # training
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
        trainingEpoch_loss.append(running_loss / len(train_loader))

        # validation
        model.eval()
        validation_loss = 0.0
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            validation_loss += loss.item()

        validationEpoch_loss.append(validation_loss / len(validation_loader))

    return trainingEpoch_loss, validationEpoch_loss

In [39]:
def test(model):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [57]:
torch.manual_seed(seed)
model_base = BaseNN(num_classes=10).to(device)


In [58]:
trainingEpoch_loss, validationEpoch_loss = train(model_base, epochs=10, learning_rate=0.001)

Epoch 1/10, Loss: 1.394992379283905
Epoch 2/10, Loss: 0.910749391746521
Epoch 3/10, Loss: 0.7008289664268493
Epoch 4/10, Loss: 0.5391978739261627
Epoch 5/10, Loss: 0.39849857199192046
Epoch 6/10, Loss: 0.2833144622325897
Epoch 7/10, Loss: 0.20346237179636956
Epoch 8/10, Loss: 0.15738942106366158
Epoch 9/10, Loss: 0.13843745221942663
Epoch 10/10, Loss: 0.12825698957666753


In [59]:
test(model_base)

Test Accuracy: 73.15%


73.15

In [60]:
pruned_model = copy.deepcopy(model_base)
to_prune = [
    (pruned_model.features[0], 'weight'),
    (pruned_model.features[2], 'weight'),
    (pruned_model.features[5], 'weight'),
    (pruned_model.features[7], 'weight'),
    (pruned_model.classifier[0], 'weight'),
    (pruned_model.classifier[3], 'weight'),
]

prune.global_unstructured(
    to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.8,
)

In [61]:
for (module, _) in to_prune:
    sparsity = 100. * float(torch.sum(module.weight == 0)) / module.weight.nelement()
    print(f"Sparsity in {module.__class__.__name__}: {sparsity:.2f}%")

Sparsity in Conv2d: 42.48%
Sparsity in Conv2d: 84.11%
Sparsity in Conv2d: 83.53%
Sparsity in Conv2d: 73.45%
Sparsity in Linear: 79.85%
Sparsity in Linear: 75.53%


In [62]:
trainingEpoch_loss, validationEpoch_loss = train(pruned_model, epochs=10, learning_rate=0.001)

Epoch 1/10, Loss: 0.2623787738263607
Epoch 2/10, Loss: 0.1002238373786211
Epoch 3/10, Loss: 0.064100375540182
Epoch 4/10, Loss: 0.05431569690480828
Epoch 5/10, Loss: 0.05003045916259289
Epoch 6/10, Loss: 0.04500716479923576
Epoch 7/10, Loss: 0.0435512925138697
Epoch 8/10, Loss: 0.042859381295274945
Epoch 9/10, Loss: 0.03539651275193319
Epoch 10/10, Loss: 0.03438207645555958


In [63]:
test(pruned_model)

Test Accuracy: 74.28%


74.28

In [64]:
# (opcjonalnie) usunięcie masek, by model był gotowy do eksportu
for module, name in to_prune:
    prune.remove(module, name)

In [65]:
torch.save(pruned_model.state_dict(), "../models/" + file_name + ".pt")

In [66]:
# === Structured pruning: kanały Conv2d ===
def get_pruned_channels(conv_layer, amount):
    weight = conv_layer.weight.detach().cpu()
    norms = weight.view(weight.size(0), -1).norm(p=2, dim=1)
    num_prune = int(amount * weight.size(0))
    prune_indices = torch.argsort(norms)[:num_prune]
    keep_indices = torch.argsort(norms)[num_prune:]
    return keep_indices.tolist()

In [73]:
keep = get_pruned_channels(model_base.features[0], amount=0.5)
pruned_conv_model = BaseNN(conv1_out=len(keep), num_classes=10).to(device)

In [74]:
train(pruned_conv_model, epochs=10, learning_rate=0.001)

Epoch 1/10, Loss: 1.3979026472091676
Epoch 2/10, Loss: 0.9206561213493347
Epoch 3/10, Loss: 0.7185071380615234
Epoch 4/10, Loss: 0.571894864320755
Epoch 5/10, Loss: 0.43536783180236815
Epoch 6/10, Loss: 0.3214219146490097
Epoch 7/10, Loss: 0.2346030507683754
Epoch 8/10, Loss: 0.18979957507550715
Epoch 9/10, Loss: 0.15605480465590954
Epoch 10/10, Loss: 0.12942429059743882


([1.3979026472091676,
  0.9206561213493347,
  0.7185071380615234,
  0.571894864320755,
  0.43536783180236815,
  0.3214219146490097,
  0.2346030507683754,
  0.18979957507550715,
  0.15605480465590954,
  0.12942429059743882],
 [1.0840710431906828,
  0.8283173024274741,
  0.7873738983254523,
  0.7937244482480796,
  0.7859747411718794,
  0.944733291294924,
  1.0018971938236503,
  1.1450536600343741,
  1.1957154201853806,
  1.349167570946323])

In [75]:
test(pruned_conv_model)

Test Accuracy: 72.68%


72.68

In [76]:
torch.save(pruned_conv_model.state_dict(), "../models/" + "pruned_conv_model" + ".pt")