In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from tqdm.notebook import trange, tqdm

## model

In [2]:
class LeNet(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 10):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 6, 5),
            nn.MaxPool2d(2,2),
            nn.ReLU(),
            nn.Conv2d(6, 16, 5),
            nn.MaxPool2d(2,2),
            nn.ReLU(),
            nn.Conv2d(16, 120, 5),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(120, out_channels)
        )
        
    def forward(self, x):
        return self.model(x)

In [3]:
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])
dataset_train = torchvision.datasets.MNIST(root = '.', train = True, download = True, transform = transform)
dataset_test = torchvision.datasets.MNIST(root = '.', train = False, download = True, transform = transform)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size = 256, shuffle = True, num_workers  = 8)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size = 256, shuffle = True, num_workers  = 8)

In [4]:
model = LeNet(1, 10)

In [5]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

In [6]:
loss_func = nn.CrossEntropyLoss()

In [7]:
def train(model, epoch_num, dataloader_train, dataloader_test, optimizer, loss_func, use_cuda = True):
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    model.to(device)
    
    epoch_iter = trange(epoch_num)
    for epoch in epoch_iter:
        data_iter = tqdm(dataloader_train)
        model.train()
        loss_sum = 0
        for x,y in data_iter:
            x = x.to(device)
            y = y.to(device)
            y_ = model(x)
            loss = loss_func(y_,y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_sum += float(loss)
            data_iter.set_postfix(loss = float(loss_sum))
            
        model.eval()
        test_iter = iter(dataloader_test)
        acc = 0
        total = 0
        for x,y in test_iter:
            x = x.to(device)
            y = y.to(device)
            y_ = model(x)
            acc += sum(y_.argmax(-1) == y)
            total += len(y)
        print('Accuracy:{acc:.2f}%'.format(acc = acc/total*100))
        
    return model

In [8]:
model = train(model, 30, dataloader_train, dataloader_test, optimizer, loss_func)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:96.11%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:97.39%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:97.94%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.33%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.40%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.50%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.87%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.80%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.92%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.00%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.00%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.96%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.98%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.94%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.09%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.05%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.09%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.07%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.07%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.13%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.00%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.06%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.09%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.91%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.09%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.10%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.16%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:98.99%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.01%


  0%|          | 0/235 [00:00<?, ?it/s]

Accuracy:99.17%
