In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import time

In [2]:
USE_CUDA = torch.cuda.is_available()
torch.cuda.manual_seed(1) if USE_CUDA else torch.manual_seed(1)

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True} if USE_CUDA else {}

In [1]:
import json
import loguru

with open('../configs/cnn_configs.json', 'r') as f:
    configs = json.load(f)
    loguru.logger.info('Loaded the configuration file.')

BATCH_SIZE = configs['BATCH_SIZE']
LR = configs['LR']
MOMENTUM = configs['MOMENTUM']
EPOCHS = configs['EPOCHS']
LOG_INTERVAL = configs['LOG_INTERVAL']

[32m2024-04-05 13:30:12.569[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoaded the configuration file.[0m


In [5]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        loguru.logger.debug(f'params num: {sum(p.numel() for p in self.parameters() if p.requires_grad)}')

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [21]:
model = Net()
if USE_CUDA:
    model.cuda()

[32m2024-04-04 19:01:16.196[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m__init__[0m:[36m9[0m - [34m[1mparams num: 21840[0m


In [8]:
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

In [9]:
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [10]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            data, target = Variable(data), Variable(target)
            output = model(data)
            test_loss += F.nll_loss(output, target).item()
            pred = output.data.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(target.data).cpu().sum()
    
    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In [11]:
for epoch in range(1, EPOCHS + 1):
    init = time.time()
    train(epoch)
    ending = time.time()
    loguru.logger.debug(f'Time for the epoch: {ending - init:.4f}s.')
    test(epoch)



[32m2024-04-04 18:00:07.568[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 17.4860s.[0m



Test set: Average loss: 0.2244, Accuracy: 9348/10000 (93%)



[32m2024-04-04 18:00:21.672[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.6831s.[0m



Test set: Average loss: 0.1278, Accuracy: 9623/10000 (96%)



[32m2024-04-04 18:00:35.739[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.7143s.[0m



Test set: Average loss: 0.0912, Accuracy: 9704/10000 (97%)



[32m2024-04-04 18:00:49.847[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.7178s.[0m



Test set: Average loss: 0.0785, Accuracy: 9761/10000 (98%)



[32m2024-04-04 18:01:03.762[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.6005s.[0m



Test set: Average loss: 0.0720, Accuracy: 9777/10000 (98%)



[32m2024-04-04 18:01:17.863[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.7315s.[0m



Test set: Average loss: 0.0628, Accuracy: 9796/10000 (98%)



[32m2024-04-04 18:01:31.929[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.6869s.[0m



Test set: Average loss: 0.0575, Accuracy: 9812/10000 (98%)



[32m2024-04-04 18:01:45.861[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.5611s.[0m



Test set: Average loss: 0.0574, Accuracy: 9814/10000 (98%)



[32m2024-04-04 18:01:59.868[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.6024s.[0m



Test set: Average loss: 0.0491, Accuracy: 9840/10000 (98%)



[32m2024-04-04 18:02:13.953[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [34m[1mTime for the epoch: 10.7101s.[0m



Test set: Average loss: 0.0482, Accuracy: 9846/10000 (98%)



In [16]:
torch.save(model.state_dict(), '../model/cnn/mnist_cnn.pt')

In [19]:
torch.load('../model/cnn/mnist_cnn.pt', map_location=lambda storage, loc: storage)

OrderedDict([('conv1.weight',
              tensor([[[[-2.2801e-01, -2.0178e-01,  1.1392e-01,  1.4996e-01,  6.7549e-02],
                        [-2.4287e-01, -8.0902e-03,  4.7206e-02, -8.0363e-02, -7.8436e-02],
                        [-2.5816e-01, -2.2800e-01, -1.5778e-01, -6.3274e-02, -1.7308e-01],
                        [-1.3900e-01,  1.3632e-02, -5.6343e-02, -1.5601e-01, -2.3128e-01],
                        [-1.1529e-01, -1.5112e-01, -1.7182e-01, -1.0877e-01,  5.7885e-02]]],
              
              
                      [[[ 2.6658e-01,  4.1212e-01,  4.6037e-01,  2.5170e-01,  8.1825e-02],
                        [ 1.3873e-01,  2.6663e-01,  5.7617e-02, -9.3925e-02, -9.6656e-02],
                        [ 2.5818e-01, -1.6053e-01, -3.8164e-01, -1.0151e-01, -2.5368e-01],
                        [-1.1415e-01, -3.3096e-01, -1.5438e-01, -2.7884e-01,  1.0435e-01],
                        [-1.9238e-01, -3.3290e-01, -5.2033e-02,  1.6483e-01,  2.0401e-01]]],
              
           