In [2]:
import torch, torch.utils.data
import torch.nn as nn
import torchvision.transforms.v2 as v2
import torchvision.datasets as datasets

In [3]:
torch.set_default_device('cpu')

In [4]:
transform_func = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
])

In [5]:
mnist_trainset = datasets.MNIST(root='mnist_dataset', train=True, download=True, transform=transform_func)
mnist_testset = datasets.MNIST(root='mnist_dataset', train=False, download=True, transform=transform_func)

In [35]:
mnist_train_dl = torch.utils.data.DataLoader(mnist_trainset, shuffle = True, num_workers=16, batch_size = 512, pin_memory=True, prefetch_factor=4)
mnist_test_dl = torch.utils.data.DataLoader(mnist_testset, shuffle = False, num_workers=4, batch_size = len(mnist_testset))

In [7]:
class DeepConvNet(nn.Module):
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation=1):
            super().__init__()
            self.module = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation),    
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        
        def forward(self, input):
            return self.module(input)
        
    def __init__(self):
        super().__init__()
        self.conv1 = self.ConvBlock(1, 16, 7)
        self.conv2 = self.ConvBlock(16, 32, 3, 2)
        self.conv3 = self.ConvBlock(32, 64, 3, 2, 1)
        self.conv4 = self.ConvBlock(64, 128, 3, 2, 1)
        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(
            nn.Linear(1152, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.Softmax(dim=1)
            )
    def forward(self, X):
        hidden = self.conv1(X)
        hidden = self.conv2(hidden)
        hidden = self.conv3(hidden)
        hidden = self.conv4(hidden)
        hidden = self.flatten(hidden)
        return self.linear(hidden)

In [8]:
%load_ext line_profiler

In [34]:
from tqdm import tqdm
model = DeepConvNet().to('cuda')
loss_fn = nn.CrossEntropyLoss().to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
num_epoch = 10

for i in tqdm(range(num_epoch)):
    for inputs, labels in mnist_train_dl:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:16<00:00,  1.66s/it]


In [25]:
from sklearn.metrics import classification_report
import numpy as np
inputs_test, labels_test = next(iter(mnist_test_dl))
inputs_test = inputs_test.cuda()
output = model(inputs_test).cpu().detach().numpy()
labels_test = labels_test.cpu().detach().numpy()
output = np.argmax(output, axis = 1)
print(classification_report(labels_test, output))

              precision    recall  f1-score   support

           0       1.00      0.99      1.00       980
           1       0.99      1.00      0.99      1135
           2       1.00      0.99      0.99      1032
           3       0.98      1.00      0.99      1010
           4       0.99      0.99      0.99       982
           5       0.99      0.99      0.99       892
           6       0.99      0.99      0.99       958
           7       0.99      0.98      0.99      1028
           8       0.99      0.99      0.99       974
           9       0.98      0.99      0.99      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000



In [11]:
sum(p.numel() for p in model.parameters())

174474

In [12]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+-----------------------+------------+
|        Modules        | Parameters |
+-----------------------+------------+
| conv1.module.0.weight |    784     |
|  conv1.module.0.bias  |     16     |
| conv1.module.1.weight |     16     |
|  conv1.module.1.bias  |     16     |
| conv2.module.0.weight |    4608    |
|  conv2.module.0.bias  |     32     |
| conv2.module.1.weight |     32     |
|  conv2.module.1.bias  |     32     |
| conv3.module.0.weight |   18432    |
|  conv3.module.0.bias  |     64     |
| conv3.module.1.weight |     64     |
|  conv3.module.1.bias  |     64     |
| conv4.module.0.weight |   73728    |
|  conv4.module.0.bias  |    128     |
| conv4.module.1.weight |    128     |
|  conv4.module.1.bias  |    128     |
|    linear.0.weight    |   73728    |
|     linear.0.bias     |     64     |
|    linear.2.weight    |    2048    |
|     linear.2.bias     |     32     |
|    linear.4.weight    |    320     |
|     linear.4.bias     |     10     |
+-----------------------+

174474

In [None]:
# Test for multiprocessing
import torch.multiprocessing as mp

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in mnist_train_dl:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = DeepConvNet()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()