In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

In [2]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

num_train = train_dataset.data.size()[0]
num_test = test_dataset.data.size()[0]

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [3]:
class myReshape(nn.Module):
    def __init__(self, sz):
        super().__init__()
        self.size = sz
        pass
    
    def forward(self, x):
        return torch.reshape(x, [x.size()[0], self.size])

class SimpleNet(nn.Module):
    def __init__(self, sz_input, sz_cv1, sz_mp1, sz_cv2, sz_mp2, sz_fc1, sz_fc2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=sz_input[1],
            out_channels=sz_cv1[1],
            kernel_size=sz_input[2]-sz_cv1[2]+1
            )
        self.maxp1 = nn.MaxPool2d(kernel_size=int(sz_cv1[2]/sz_mp1[2]))
        self.conv2 = nn.Conv2d(in_channels=sz_mp1[1],
                               out_channels=sz_cv2[1],
                               kernel_size=sz_mp1[2] - sz_cv2[2] + 1
                               )
        self.maxp2 = nn.MaxPool2d(kernel_size=int(sz_cv2[2] / sz_mp2[2]))
        self.rs = myReshape(sz_mp2[1]*sz_mp2[2]*sz_mp2[3])
        self.ln1 = nn.Linear(in_features=sz_mp2[1]*sz_mp2[2]*sz_mp2[3],
                out_features=sz_fc1[1])
        self.ac1 = nn.ReLU()
        self.ln2 = nn.Linear(in_features=sz_fc1[1],
                out_features=sz_fc2[1])
        self.ac2 = nn.ReLU()
        self.net = nn.Sequential(
            self.conv1, self.maxp1, self.conv2, self.maxp2,
            self.rs, self.ln1, self.ac1, self.ln2, self.ac2
        )
        pass

    def forward(self, x):
        y_pred = self.net(x)
        return y_pred
    
    def score(self, x, gt):
        y_pred = self.forward(x)
        [max_score, max_position] = torch.max(y_pred, dim=1)
        num_false = torch.count_nonzero(max_position - gt)
        return (gt.size()[0] - num_false).true_divide(gt.size()[0])

    def predict(self, x):
        x = torch.reshape(x, [1, 1, 28, 28])
        xmax = torch.max(x)
        xmin = torch.min(x)
        x = ((x - xmin) * 2).true_divide((xmax - xmin)) - 1
        y_pred = self.forward(x)
        [max_score, max_position] = torch.max(y_pred, dim=1)
        return max_position


# TODO:define model
size_input = [BATCH_SIZE, 1, 28, 28]
size_conv1 = [BATCH_SIZE, 3, 24, 24]
size_maxp1 = [BATCH_SIZE, 3, 12, 12]
size_conv2 = [BATCH_SIZE, 6, 8, 8]
size_maxp2 = [BATCH_SIZE, 6, 4, 4]
size_fc1 = [BATCH_SIZE, 32]
size_fc2 = [BATCH_SIZE, 10]
    
model = SimpleNet(sz_input=size_input,
    sz_cv1=size_conv1,
    sz_mp1=size_maxp1,
    sz_cv2=size_conv2,
    sz_mp2=size_maxp2,
    sz_fc1=size_fc1,
    sz_fc2=size_fc2,
    )

# TODO:define loss function
criterion = nn.CrossEntropyLoss()

In [4]:
lnrate = 1e-3
# train and evaluate
start_time = time.time()
train_time = 0

for epoch in range(NUM_EPOCHS):
    st_epoch = time.time()
    optimizer = torch.optim.SGD(model.net.parameters(), lr=lnrate)
    optimizer.zero_grad()
    lnrate = lnrate * 9 / 10
    num_false = 0

    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        y_pred = model.forward(images.data)
        loss = criterion(y_pred, labels.data)
        loss.backward()
        optimizer.step()

    dr_epoch = time.time() - st_epoch
    train_time = train_time + dr_epoch

    # evaluate
    # TODO:calculate training accuracy and test accuracy
    num_false = 0
    for images, labels in tqdm(train_loader):
        y_pred = model.forward(images.data)
        [max_score, max_position] = torch.max(y_pred, dim=1)
        num_false = num_false + torch.count_nonzero(max_position - labels.data)

    train_accuracy = (num_train - num_false).true_divide(num_train)

    num_false = 0
    for images, labels in tqdm(test_loader):
        y_pred = model.forward(images.data)
        [max_score, max_position] = torch.max(y_pred, dim=1)
        num_false = num_false + torch.count_nonzero(max_position - labels.data)

    test_accuracy = (num_test - num_false).true_divide(num_test)

    print('\nEpoch %d' % epoch)
    print('Training accuracy: %0.2f%%' % (train_accuracy*100))
    print('Testing accuracy: %0.2f%%' % (test_accuracy*100))

total_time = time.time() - start_time
print('Total training time: %ds' % train_time)
print('Total time, test included: %ds' % total_time)

100%|██████████| 468/468 [00:14<00:00, 31.94it/s]
100%|██████████| 468/468 [00:11<00:00, 41.67it/s]
100%|██████████| 78/78 [00:01<00:00, 42.23it/s]
  1%|          | 4/468 [00:00<00:14, 32.09it/s]
Epoch 0
Training accuracy: 80.26%
Testing accuracy: 80.21%
100%|██████████| 468/468 [00:13<00:00, 34.32it/s]
100%|██████████| 468/468 [00:11<00:00, 42.11it/s]
100%|██████████| 78/78 [00:01<00:00, 41.91it/s]
  1%|          | 4/468 [00:00<00:13, 35.49it/s]
Epoch 1
Training accuracy: 87.16%
Testing accuracy: 87.48%
100%|██████████| 468/468 [00:13<00:00, 34.77it/s]
100%|██████████| 468/468 [00:11<00:00, 41.26it/s]
100%|██████████| 78/78 [00:01<00:00, 41.21it/s]
  1%|          | 4/468 [00:00<00:13, 34.28it/s]
Epoch 2
Training accuracy: 90.28%
Testing accuracy: 90.52%
100%|██████████| 468/468 [00:13<00:00, 33.69it/s]
100%|██████████| 468/468 [00:11<00:00, 41.79it/s]
100%|██████████| 78/78 [00:01<00:00, 42.05it/s]
  1%|          | 4/468 [00:00<00:14, 32.09it/s]
Epoch 3
Training accuracy: 92.59%
Testi

In [5]:
from torchvision.transforms import ToPILImage
show = ToPILImage()
pic = test_dataset.data[19]
show(pic).show()
pred = model.predict(pic)
print('The number is predicted to be %d' % pred)

The number is predicted to be 4
