In [1]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms
from torchinfo import summary
from tqdm import tqdm

In [3]:
BATCH_SIZE = 128

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

print(f"Use CUDA? {use_cuda}")

Use CUDA? True


In [5]:
train_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

test_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

In [6]:
train_dataset = datasets.MNIST(
    "../data", train=True, download=True, transform=train_transforms
)

test_dataset = datasets.MNIST(
    "../data", train=False, download=True, transform=test_transforms
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 107764653.91it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 21889355.59it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25578507.93it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17240297.53it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [7]:
torch.manual_seed(1)

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

train_loader = data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, **kwargs
)

test_loader = data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True, **kwargs
)

In [8]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(
            desc=f"Epoch={epoch} loss={loss.item()} batch_id={batch_idx}"
        )


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.drop1 = nn.Dropout(0.1)
        self.drop2 = nn.Dropout(0.1)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(16)
        self.batch_norm3 = nn.BatchNorm2d(16)
        self.batch_norm4 = nn.BatchNorm2d(64)
        self.gap = nn.AdaptiveAvgPool2d(1)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3)
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3)

        self.fc = nn.Linear(in_features=1 * 1 * 64, out_features=10)

    def forward(self, x):
        """
        ----------------------------------------------------------------------
        | Layer   | rf_in | n_in | j_in | s | p | k | rf_out | n_out | j_out |
        |---------|-------|------|------|---|---|---|--------|-------|-------|
        | conv1   | 1     | 28   | 1    | 1 | 0 | 5 | 5      | 24    | 1     |
        | relu    | -     | -    | -    | - | - | - | -      | -     | -     |
        | bn      | -     | -    | -    | - | - | - | -      | -     | -     |
        | conv2   | 5     | 24   | 1    | 1 | 0 | 3 | 7      | 22    | 1     |
        | relu    | -     | -    | -    | - | - | - | -      | -     | -     |
        | bn      | -     | -    | -    | - | - | - | -      | -     | -     |
        | maxpool | 7     | 22   | 1    | 2 | 0 | 2 | 8      | 11    | 2     |
        | drop    | -     | -    | -    | - | - | - | -      | -     | -     |
        | conv3   | 8     | 11   | 2    | 1 | 0 | 3 | 12     | 9     | 2     |
        | relu    | -     | -    | -    | - | - | - | -      | -     | -     |
        | bn      | -     | -    | -    | - | - | - | -      | -     | -     |
        | conv4   | 12    | 9    | 2    | 1 | 0 | 3 | 16     | 7     | 2     |
        | relu    | -     | -    | -    | - | - | - | -      | -     | -     |
        | bn      | -     | -    | -    | - | - | - | -      | -     | -     |
        | gap     | -     | -    | -    | - | - | - | -      | -     | -     |
        | drop    | -     | -    | -    | - | - | - | -      | -     | -     |
        | fc      | -     | -    | -    | - | - | - | -      | -     | -     |
        ----------------------------------------------------------------------

        Final RF = 16
        """
        x = self.conv1(x)  # 28x28x1 => 24x24x32
        x = F.relu(x)
        x = self.batch_norm1(x)

        x = self.conv2(x)  # 24x24x32 => 22x22x16
        x = F.relu(x)
        x = self.batch_norm2(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)  # 22x22x16 => 11x11x16
        x = self.drop1(x)

        x = self.conv3(x)  # 11x11x16 => 9x9x16
        x = F.relu(x)
        x = self.batch_norm3(x)

        x = self.conv4(x)  # 9x9x16 => 7x7x64
        x = F.relu(x)
        x = self.batch_norm4(x)

        x = self.gap(x)
        x = self.drop2(x)

        x = x.reshape(-1, 64 * 1 * 1)

        x = self.fc(x)  # 64*1*1 => 10

        return F.log_softmax(x, dim=1)

In [10]:
model = Net().to(device)
summary(model, input_size=(1, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
Net                                      [1, 10]                   --
├─Conv2d: 1-1                            [1, 32, 24, 24]           832
├─BatchNorm2d: 1-2                       [1, 32, 24, 24]           64
├─Conv2d: 1-3                            [1, 16, 22, 22]           4,624
├─BatchNorm2d: 1-4                       [1, 16, 22, 22]           32
├─Dropout: 1-5                           [1, 16, 11, 11]           --
├─Conv2d: 1-6                            [1, 16, 9, 9]             2,320
├─BatchNorm2d: 1-7                       [1, 16, 9, 9]             32
├─Conv2d: 1-8                            [1, 64, 7, 7]             9,280
├─BatchNorm2d: 1-9                       [1, 64, 7, 7]             128
├─AdaptiveAvgPool2d: 1-10                [1, 64, 1, 1]             --
├─Dropout: 1-11                          [1, 64, 1, 1]             --
├─Linear: 1-12                           [1, 10]                   650
Tot

In [11]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 20):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

Epoch=1 loss=0.7818135023117065 batch_id=78: 100%|██████████| 79/79 [00:07<00:00, 11.00it/s]



Test set: Average loss: 0.6988, Accuracy: 8614/10000 (86.14%)



Epoch=2 loss=0.6411963105201721 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.46it/s]



Test set: Average loss: 0.2598, Accuracy: 9579/10000 (95.79%)



Epoch=3 loss=0.25512993335723877 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.49it/s]



Test set: Average loss: 0.1505, Accuracy: 9665/10000 (96.65%)



Epoch=4 loss=0.17269328236579895 batch_id=78: 100%|██████████| 79/79 [00:04<00:00, 18.87it/s]



Test set: Average loss: 0.1035, Accuracy: 9788/10000 (97.88%)



Epoch=5 loss=0.35884496569633484 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.57it/s]



Test set: Average loss: 0.0991, Accuracy: 9767/10000 (97.67%)



Epoch=6 loss=0.13317173719406128 batch_id=78: 100%|██████████| 79/79 [00:04<00:00, 17.22it/s]



Test set: Average loss: 0.0647, Accuracy: 9865/10000 (98.65%)



Epoch=7 loss=0.2865554094314575 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.36it/s]



Test set: Average loss: 0.0584, Accuracy: 9878/10000 (98.78%)



Epoch=8 loss=0.12529927492141724 batch_id=78: 100%|██████████| 79/79 [00:04<00:00, 17.84it/s]



Test set: Average loss: 0.0610, Accuracy: 9859/10000 (98.59%)



Epoch=9 loss=0.11400628834962845 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.73it/s]



Test set: Average loss: 0.0478, Accuracy: 9894/10000 (98.94%)



Epoch=10 loss=0.17284023761749268 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 20.40it/s]



Test set: Average loss: 0.0393, Accuracy: 9919/10000 (99.19%)



Epoch=11 loss=0.03258098289370537 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.53it/s]



Test set: Average loss: 0.0341, Accuracy: 9938/10000 (99.38%)



Epoch=12 loss=0.033750422298908234 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.41it/s]



Test set: Average loss: 0.0297, Accuracy: 9951/10000 (99.51%)



Epoch=13 loss=0.08527274429798126 batch_id=78: 100%|██████████| 79/79 [00:04<00:00, 18.57it/s]



Test set: Average loss: 0.0286, Accuracy: 9949/10000 (99.49%)



Epoch=14 loss=0.06614154577255249 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.80it/s]



Test set: Average loss: 0.0293, Accuracy: 9944/10000 (99.44%)



Epoch=15 loss=0.33717823028564453 batch_id=78: 100%|██████████| 79/79 [00:05<00:00, 15.55it/s]



Test set: Average loss: 0.0591, Accuracy: 9824/10000 (98.24%)



Epoch=16 loss=0.02594837173819542 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.69it/s]



Test set: Average loss: 0.0243, Accuracy: 9959/10000 (99.59%)



Epoch=17 loss=0.039352141320705414 batch_id=78: 100%|██████████| 79/79 [00:04<00:00, 19.13it/s]



Test set: Average loss: 0.0209, Accuracy: 9963/10000 (99.63%)



Epoch=18 loss=0.01776367798447609 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.52it/s]



Test set: Average loss: 0.0258, Accuracy: 9950/10000 (99.50%)



Epoch=19 loss=0.028158485889434814 batch_id=78: 100%|██████████| 79/79 [00:03<00:00, 21.98it/s]



Test set: Average loss: 0.0172, Accuracy: 9979/10000 (99.79%)

