In [2]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import torch.nn as nn
from torchvision.transforms import v2
import multiprocessing
import torch.quantization
import torch.optim as optim

In [3]:
# run variables

seed = 42

In [4]:
# device settings

num_workers = multiprocessing.cpu_count() // 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [5]:
# define datasets and loaders

transform = v2.Compose([
    v2.ToTensor(),  
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = CIFAR10("./../data", train=True, transform=transform, download=True)
test_dataset = CIFAR10("./../data", train=False, transform=transform, download=True)

train_dataset, validation_dataset =  random_split(train_dataset, [0.8, 0.2])

print('train set size:', len(train_dataset))
print('validation set size:', len(validation_dataset))
print('test set size:', len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_dataset, batch_size=128, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=num_workers)

class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]



Files already downloaded and verified
Files already downloaded and verified
train set size: 40000
validation set size: 10000
test set size: 10000


In [6]:
class BaseNN(nn.Module):
    def __init__(self, num_classes=10):
        super(BaseNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.classifier(x)
        x = self.dequant(x)
        return x

In [7]:
torch.manual_seed(seed)
model = BaseNN(num_classes=10).to(device)

In [8]:
def train(model, epochs, learning_rate):
    trainingEpoch_loss = []
    validationEpoch_loss = []
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            logits = outputs[0] if isinstance(outputs, tuple) else outputs
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        trainingEpoch_loss.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss:.4f}")

        # Ewaluacja
        model.eval()
        validation_loss = 0.0
        with torch.no_grad():
            for inputs, labels in validation_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                logits = outputs[0] if isinstance(outputs, tuple) else outputs
                loss = criterion(logits, labels)

                validation_loss += loss.item()

        avg_val_loss = validation_loss / len(validation_loader)
        validationEpoch_loss.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")

    return trainingEpoch_loss, validationEpoch_loss

In [9]:
trainingEpoch_loss, validationEpoch_loss = train(model, epochs=7, learning_rate=0.0001)

Epoch 1/7, Training Loss: 1.4158
Epoch 1/7, Validation Loss: 1.1367
Epoch 2/7, Training Loss: 1.0025
Epoch 2/7, Validation Loss: 0.9575
Epoch 3/7, Training Loss: 0.8531
Epoch 3/7, Validation Loss: 0.8859
Epoch 4/7, Training Loss: 0.7614
Epoch 4/7, Validation Loss: 0.8120
Epoch 5/7, Training Loss: 0.6935
Epoch 5/7, Validation Loss: 0.7987
Epoch 6/7, Training Loss: 0.6364
Epoch 6/7, Validation Loss: 0.7949
Epoch 7/7, Training Loss: 0.5884
Epoch 7/7, Validation Loss: 0.7684


In [10]:
torch.save(model.state_dict(), "../models/quantized_base_model.pt")