In [None]:
'''
@ author: haijun xiong
@ date  : 2021/9/26
'''
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.pooling = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.AF = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.AF(self.conv1(x))
        x = self.pooling(x)
        x = self.AF(self.conv2(x))
        x = self.pooling(x)
        x = x.view(x.size(0), -1)
        x = self.AF(self.fc1(x))
        x = self.AF(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
class Model:
    def __init__(self, train_data, test_data, lr, epoch):
        self.train_loadr = train_data
        self.test_loadr = test_data
        self.lr = lr
        self.epoch = epoch
        self.net = LeNet5().float()
        self.net.cuda()
        self.loss = nn.CrossEntropyLoss()
        self.loss.cuda()
        self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, momentum=0.9)
    
    def Var(self, x):
        x = autograd.Variable(x).cuda()
        return x

    def fit(self):
        for i in range(self.epoch):
            self.net.train()
            for (X, y) in self.train_loadr:
                self.optimizer.zero_grad()
                X, y = self.Var(X), self.Var(y)
                y_hat = self.net(X)
                l = self.loss(y_hat, y)
                l.backward()
                self.optimizer.step()
            test_acc = self.evaluate_accuracy()
            print("epoch:{}, test_acc:{}".format(i, test_acc))
    
    def evaluate_accuracy(self):
        self.net.eval()
        metric = Accumulator(2)
        for (X, y) in self.test_loadr:
            X, y = X.cuda(), y.cuda()
            y_hat = self.net(X)
            if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
                y_hat = y_hat.argmax(axis=1)
            cmp = y_hat.type(y.dtype) == y
            right = float(cmp.type(y.dtype).sum())
            metric.add(right, y.size(0))
        return metric[0] / metric[1]

In [None]:
batch_size = 256
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])),
    batch_size=1, shuffle=True)

In [None]:
model = Model(train_loader, test_loader, lr=1e-2, epoch=10)
model.fit()

In [None]:
model.evaluate_accuracy()