In [11]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, random_split
import torchvision
import torch.nn as nn
from torchvision.transforms import v2
import multiprocessing
import torch.nn.utils.prune as prune
import torch.nn.functional as F


In [18]:
print(torch.cuda.is_available())

True


In [3]:
# run variables

seed = 42
file_name = 'base_pruned'
# file_name = 'student_self_taught'

class BaseNN(nn.Module):
    def __init__(self, num_classes=10):
        super(BaseNN, self).__init__()

        ### teacher
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

        ### student
        # self.features = nn.Sequential(
        #     nn.Conv2d(3, 16, kernel_size=3, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        #     nn.Conv2d(16, 16, kernel_size=3, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        # )
        # self.classifier = nn.Sequential(
        #     nn.Linear(1024, 256),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(256, num_classes)
        # )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
# device settings

num_workers = 2

print(torch.cuda.get_device_name(torch.cuda.current_device()))

#device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {device} device")

NVIDIA GeForce RTX 2060
Using cuda device


In [5]:
# define datasets and loaders

transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = torchvision.datasets.CIFAR10("./../data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10("./../data", train=False, transform=transform, download=True)

train_dataset, validation_dataset =  random_split(train_dataset, [0.8, 0.2])

print('train set size:', len(train_dataset))
print('validation set size:', len(validation_dataset))
print('test set size:', len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_dataset, batch_size=128, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=num_workers)

class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

Files already downloaded and verified
Files already downloaded and verified
train set size: 40000
validation set size: 10000
test set size: 10000


In [6]:
def train(model, epochs, learning_rate):
    trainingEpoch_loss = []
    validationEpoch_loss = []
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        # training
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
        trainingEpoch_loss.append(running_loss / len(train_loader))

        # validation
        model.eval()
        validation_loss = 0.0
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            validation_loss += loss.item()

        validationEpoch_loss.append(validation_loss / len(validation_loader))

    return trainingEpoch_loss, validationEpoch_loss

In [7]:
def test(model):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [9]:
torch.manual_seed(seed)
model_base = BaseNN(num_classes=10).to(device)


In [10]:
trainingEpoch_loss, validationEpoch_loss = train(model_base, epochs=10, learning_rate=0.001)

Epoch 1/10, Loss: 1.4040568173408507
Epoch 2/10, Loss: 0.9313389938354493
Epoch 4/10, Loss: 0.5812270737171173
Epoch 5/10, Loss: 0.4447835987329483
Epoch 6/10, Loss: 0.31887144092321396
Epoch 3/10, Loss: 0.7309231795310974
Epoch 7/10, Loss: 0.2456826867520809
Epoch 8/10, Loss: 0.1893038306683302
Epoch 9/10, Loss: 0.1533369359701872
Epoch 10/10, Loss: 0.1299334236934781


In [15]:
test(model_base)

Test Accuracy: 72.36%


72.36

In [26]:
torch.manual_seed(seed)
model_base = BaseNN(num_classes=10).to(device)
to_prune = [
    (model_base.features[0], 'weight'),
    (model_base.features[2], 'weight'),
    (model_base.features[5], 'weight'),
    (model_base.features[7], 'weight'),
    (model_base.classifier[0], 'weight'),
    (model_base.classifier[3], 'weight'),
]

prune.global_unstructured(
    to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.8,
)

In [27]:
for (module, _) in to_prune:
    sparsity = 100. * float(torch.sum(module.weight == 0)) / module.weight.nelement()
    print(f"Sparsity in {module.__class__.__name__}: {sparsity:.2f}%")

Sparsity in Conv2d: 9.14%
Sparsity in Conv2d: 62.55%
Sparsity in Conv2d: 44.25%
Sparsity in Conv2d: 44.30%
Sparsity in Linear: 83.53%
Sparsity in Linear: 41.91%


In [28]:
trainingEpoch_loss, validationEpoch_loss = train(model_base, epochs=10, learning_rate=0.001)

Epoch 1/10, Loss: 1.494864939212799
Epoch 2/10, Loss: 1.056314612197876
Epoch 3/10, Loss: 0.8388438291072845
Epoch 4/10, Loss: 0.7041443522453308
Epoch 5/10, Loss: 0.5933612693309784
Epoch 6/10, Loss: 0.5056564187288284
Epoch 7/10, Loss: 0.43067469470500946
Epoch 8/10, Loss: 0.35933805131912233
Epoch 9/10, Loss: 0.29892502924203873
Epoch 10/10, Loss: 0.2596704077124596


In [29]:
test(model_base)

Test Accuracy: 73.12%


73.12

In [30]:
# (opcjonalnie) usunięcie masek, by model był gotowy do eksportu
for module, name in to_prune:
    prune.remove(module, name)

In [31]:
torch.save(model_base.state_dict(), "../models/" + file_name + ".pt")