In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.v2 as v2
import torchvision.datasets as datasets

In [2]:
torch.set_default_device('cpu')

In [3]:
transform_func = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True)
])

In [4]:
mnist_trainset = datasets.CIFAR10(root='cifar10_dataset', train=True, download=True, transform=transform_func)
mnist_testset = datasets.CIFAR10(root='cifar10_dataset', train=False, download=True, transform=transform_func)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
mnist_train_dl = torch.utils.data.DataLoader(mnist_trainset, shuffle = True, num_workers=16, batch_size = 64, pin_memory=True, prefetch_factor=4)
mnist_test_dl = torch.utils.data.DataLoader(mnist_testset, shuffle = False, num_workers=16, batch_size = len(mnist_testset))

In [6]:
class ConvNet(nn.Module):
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation=1):
            super().__init__()
            self.module = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation),    
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        
        def forward(self, input):
            return self.module(input)
        
    def __init__(self):
        super().__init__()
        self.conv1 = self.ConvBlock(3, 16, 7)
        self.conv2 = self.ConvBlock(16, 32, 3, 2)
        self.conv3 = self.ConvBlock(32, 64, 3, 2, 1)
        self.conv4 = self.ConvBlock(64, 128, 3, 2, 1)
        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(
            nn.Linear(1152, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=1)
            )
    def forward(self, X):
        hidden = self.conv1(X)
        hidden = self.conv2(hidden)
        hidden = self.conv3(hidden)
        hidden = self.conv4(hidden)
        hidden = self.flatten(hidden)
        return self.linear(hidden)

In [7]:
from tqdm import tqdm
from accelerate import Accelerator
model = ConvNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

num_epoch = 20
for i in tqdm(range(num_epoch)):
    loss_arr = []
    for inputs, labels in mnist_train_dl:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss_arr.append(loss.item())
        loss.backward()
        optimizer.step()
    print(sum(loss_arr))
    
        

100%|██████████| 20/20 [02:39<00:00,  7.99s/it]


In [8]:
from sklearn.metrics import classification_report
import numpy as np
inputs_test, labels_test = next(iter(mnist_test_dl))
inputs_test = inputs_test.cuda()
output = model(inputs_test).cpu().detach().numpy()
labels_test = labels_test.cpu().detach().numpy()
output = np.argmax(output, axis = 1)
print(classification_report(labels_test, output))

              precision    recall  f1-score   support

           0       0.73      0.65      0.69      1000
           1       0.86      0.71      0.78      1000
           2       0.61      0.50      0.55      1000
           3       0.48      0.49      0.48      1000
           4       0.63      0.55      0.59      1000
           5       0.59      0.57      0.58      1000
           6       0.71      0.79      0.75      1000
           7       0.69      0.74      0.72      1000
           8       0.67      0.87      0.75      1000
           9       0.69      0.78      0.73      1000

    accuracy                           0.67     10000
   macro avg       0.67      0.67      0.66     10000
weighted avg       0.67      0.67      0.66     10000



In [9]:
sum(p.numel() for p in model.parameters())

429194

In [10]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+-----------------------+------------+
|        Modules        | Parameters |
+-----------------------+------------+
| conv1.module.0.weight |    2352    |
|  conv1.module.0.bias  |     16     |
| conv1.module.1.weight |     16     |
|  conv1.module.1.bias  |     16     |
| conv2.module.0.weight |    4608    |
|  conv2.module.0.bias  |     32     |
| conv2.module.1.weight |     32     |
|  conv2.module.1.bias  |     32     |
| conv3.module.0.weight |   18432    |
|  conv3.module.0.bias  |     64     |
| conv3.module.1.weight |     64     |
|  conv3.module.1.bias  |     64     |
| conv4.module.0.weight |   73728    |
|  conv4.module.0.bias  |    128     |
| conv4.module.1.weight |    128     |
|  conv4.module.1.bias  |    128     |
|    linear.0.weight    |   294912   |
|     linear.0.bias     |    256     |
|    linear.2.weight    |   32768    |
|     linear.2.bias     |    128     |
|    linear.4.weight    |    1280    |
|     linear.4.bias     |     10     |
+-----------------------+

429194