In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

def get_data_loaders(batch_size=16):
    import torchvision
    import torchvision.transforms as transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainloader, testloader, classes


In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),    # 32x32x3 → 32x32x64
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # 32x32x64 → 32x32x128
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32 → 16x16
            nn.Dropout2d(0.25),

            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3),

            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(128, 10)
        )
    def forward(self,x):
        logits = self.model(x)
        return logits



model = Model()

In [3]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [4]:
def train_model(model, trainloader, criterion, optimizer, epochs=2):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss:.3f}')
            running_loss = 0.0

In [5]:
def test_model(model, testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [6]:

trainloader, testloader, classes = get_data_loaders()
dataiter = iter(trainloader)
images, labels = next(dataiter)
model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
train_model(model, trainloader, criterion, optimizer, epochs=2)
test_model(model, testloader)


[1,     1] loss: 2.493
[1,     2] loss: 2.321
[1,     3] loss: 2.409
[1,     4] loss: 2.277
[1,     5] loss: 2.730
[1,     6] loss: 2.318
[1,     7] loss: 2.452
[1,     8] loss: 2.330
[1,     9] loss: 2.254
[1,    10] loss: 2.404
[1,    11] loss: 2.264
[1,    12] loss: 2.231
[1,    13] loss: 2.602
[1,    14] loss: 2.530
[1,    15] loss: 2.538
[1,    16] loss: 2.342
[1,    17] loss: 2.407
[1,    18] loss: 2.606
[1,    19] loss: 2.516
[1,    20] loss: 2.375
[1,    21] loss: 2.497
[1,    22] loss: 2.373
[1,    23] loss: 2.546
[1,    24] loss: 2.602
[1,    25] loss: 2.471
[1,    26] loss: 2.496
[1,    27] loss: 2.556
[1,    28] loss: 2.311
[1,    29] loss: 2.544
[1,    30] loss: 2.301
[1,    31] loss: 2.392
[1,    32] loss: 2.376
[1,    33] loss: 2.327
[1,    34] loss: 2.428
[1,    35] loss: 2.425
[1,    36] loss: 2.555
[1,    37] loss: 2.670
[1,    38] loss: 2.468
[1,    39] loss: 2.354
[1,    40] loss: 2.365
[1,    41] loss: 2.225
[1,    42] loss: 2.559
[1,    43] loss: 2.270
[1,    44] 

In [7]:
#print number of parameters
num_params = 0
for x in model.parameters():
  num_params += len(torch.flatten(x))

print(f'{num_params:,} parameters')

8,976,906 parameters
