In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.v2 as v2
import torchvision.datasets as datasets

In [2]:
torch.set_default_device('cpu')

In [3]:
transform_func = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True)
])

In [4]:
mnist_trainset = datasets.CIFAR100(root='cifar100_dataset', train=True, download=True, transform=transform_func)
mnist_testset = datasets.CIFAR100(root='cifar100_dataset', train=False, download=True, transform=transform_func)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to cifar100_dataset/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [36:39<00:00, 76843.72it/s]  


Extracting cifar100_dataset/cifar-100-python.tar.gz to cifar100_dataset
Files already downloaded and verified


In [5]:
mnist_train_dl = torch.utils.data.DataLoader(mnist_trainset, shuffle = True, num_workers=16, batch_size = 64, pin_memory=True, prefetch_factor=4)
mnist_test_dl = torch.utils.data.DataLoader(mnist_testset, shuffle = False, num_workers=16, batch_size = len(mnist_testset))

In [12]:
class DeepConvNet(nn.Module):
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation=1):
            super().__init__()
            self.module = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation),    
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        
        def forward(self, input):
            return self.module(input)
        
    def __init__(self):
        super().__init__()
        self.conv1 = self.ConvBlock(3, 16, 7)
        self.conv2 = self.ConvBlock(16, 32, 3, 2)
        self.conv3 = self.ConvBlock(32, 64, 3, 2, 1)
        self.conv4 = self.ConvBlock(64, 128, 3, 2, 1)
        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(
            nn.Linear(1152, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 100),
            nn.Softmax(dim=1)
            )
    def forward(self, X):
        hidden = self.conv1(X)
        hidden = self.conv2(hidden)
        hidden = self.conv3(hidden)
        hidden = self.conv4(hidden)
        hidden = self.flatten(hidden)
        return self.linear(hidden)

In [13]:
from tqdm import tqdm
model = DeepConvNet().to('cuda')
loss_fn = nn.CrossEntropyLoss().to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

num_epoch = 10
for i in tqdm(range(num_epoch)):
    for inputs, labels in mnist_train_dl:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
    
        

100%|██████████| 10/10 [01:19<00:00,  7.91s/it]


In [None]:
from sklearn.metrics import classification_report
import numpy as np
inputs_test, labels_test = next(iter(mnist_test_dl))
inputs_test = inputs_test.cuda()
output = model(inputs_test).cpu().detach().numpy()
labels_test = labels_test.cpu().detach().numpy()
output = np.argmax(output, axis = 1)
print(classification_report(labels_test, output))

In [9]:
sum(p.numel() for p in model.parameters())

1004132

In [10]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+-----------------------+------------+
|        Modules        | Parameters |
+-----------------------+------------+
| conv1.module.0.weight |    2352    |
|  conv1.module.0.bias  |     16     |
| conv1.module.1.weight |     16     |
|  conv1.module.1.bias  |     16     |
| conv2.module.0.weight |    4608    |
|  conv2.module.0.bias  |     32     |
| conv2.module.1.weight |     32     |
|  conv2.module.1.bias  |     32     |
| conv3.module.0.weight |   18432    |
|  conv3.module.0.bias  |     64     |
| conv3.module.1.weight |     64     |
|  conv3.module.1.bias  |     64     |
| conv4.module.0.weight |   73728    |
|  conv4.module.0.bias  |    128     |
| conv4.module.1.weight |    128     |
|  conv4.module.1.bias  |    128     |
|    linear.0.weight    |   589824   |
|     linear.0.bias     |    512     |
|    linear.2.weight    |   262144   |
|     linear.2.bias     |    512     |
|    linear.4.weight    |   51200    |
|     linear.4.bias     |    100     |
+-----------------------+

1004132