In [2]:
'''
It's a mini CNN example of pytorch
Li Teng 
29.06.2021
'''

import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset

BATCH_SIZE = 256
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use GPU if available

class ConvNet(nn.Module):# inherit from nn.Module
    
    def __init__(self):#init the module
        super().__init__()
        self.conv1 = nn.Conv2d(1,10,5)
        self.conv2 = nn.Conv2d(10,20,3)
        self.fc1 = nn.Linear(2000, 500) #input 20*10*10, out 500
        self.fc2 = nn.Linear(500, 10) # 10 classes
        
    def forward(self,x):#forward propagation
        in_size = x.size(0) # batch*1*28*28
        x = self.conv1(x) # batch*1*28*28 -> batch*10*24*24
        x = f.max_pool2d(f.relu(x),2) # batch*10*24*24 -> batch*10*12*12
        x = self.conv2(x) # batch*10*12*12 ->batch*20*10*10
        x = f.relu(x)
        x = x.view(in_size,-1) # flat batch*20*10*10 -> batch*2000
        x = self.fc1(x) # batch*2000 -> batch*500
        x = f.relu(x)
        x = self.fc2(x) # batch*500 -> batch*10
        x = f.log_softmax(x,dim=1) # 'dim=1' means logsoftmax along 10 not batch
        return x


def data_process(): #data processing
    transform = transforms.Compose([transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
                                    transforms.Normalize((0.5,), (0.5,)) # range [0.0,1.0] -> [-1.0,1.0]
                                   ])

    data_train = dset.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

    data_test = dset.MNIST(root = "./data/",
                           transform=transform,
                           download = True)

    data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                    batch_size = BATCH_SIZE,
                                                    shuffle = True)

    data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                                   batch_size = BATCH_SIZE,
                                                   shuffle = True)
    
    return data_loader_train, data_loader_test

def train(model, device, train_loader, optimizer):
    model.train() # model in train function
    Loss_sum = 0 # sum loss of the every batch
    Loss_avg = 0 # to comput average loss for each epoch
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)# copy data to device
        optimizer.zero_grad() # gradient to 0 before each batch
        output = model(data) # forward propagation
        loss = f.nll_loss(output, target) # loss_function = log_likelihood
        loss.backward() # compute the gradients
        optimizer.step() # update weights
        Loss_sum += loss.item()
    Loss_avg = Loss_sum/BATCH_SIZE
    print("Train: Loss =",Loss_avg)
    
def test(model, device, test_loader):
    model.eval()
    Loss = 0 
    Correct = 0 # count the number of correct prediction
    with torch.no_grad(): # dont need back propagation
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            predict = output.max(1, keepdim=True)[1] # predict the label
            Loss += f.nll_loss(output, target, reduction='sum').item() # same as training, sum loss for every case in batch
            Correct += predict.eq(target.view_as(predict)).sum().item()
        Loss /= len(test_loader.dataset)
        Correct /= len(test_loader.dataset)
        print("Test:  Loss =",Loss,"Acc:",Correct)

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())
Train_loader,Test_loader = data_process()

for epoch in range(1, EPOCHS + 1):
    print("Epoch:",epoch)
    train(model, DEVICE, Train_loader, optimizer)
    test(model, DEVICE, Test_loader)

Epoch: 1
Train: Loss = 0.24470341840060428
Test:  Loss = 0.08160858153502146 Acc: 0.9759166666666667
Epoch: 2
Train: Loss = 0.061096217403246555
Test:  Loss = 0.050274733527501426 Acc: 0.9847833333333333
Epoch: 3
Train: Loss = 0.039239700217876816
Test:  Loss = 0.03247400665084521 Acc: 0.9903666666666666
Epoch: 4
Train: Loss = 0.031234498935191368
Test:  Loss = 0.023515976093212765 Acc: 0.99315
Epoch: 5
Train: Loss = 0.023477920290133625
Test:  Loss = 0.0173862152757744 Acc: 0.9949666666666667
