In [3]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.autograd import Variable
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

class SimpleNet(torch.nn.Module):
    # TODO:define model
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 576)
        self.bc1 = nn.BatchNorm1d(576)

        self.fc2 = nn.Linear(576, 324)
        self.bc2 = nn.BatchNorm1d(324)

        self.fc3 = nn.Linear(324, 144)
        self.bc3 = nn.BatchNorm1d(144)

        self.fc4 = nn.Linear(144, 10)

    def forward(self, x):
        x = x.view((-1, 784))
        h = self.fc1(x)
        h = self.bc1(h)
        h = nn.functional.relu(h)
        h = nn.functional.dropout(h, p=0.5, training=self.training)  

        h = self.fc2(h)
        h = self.bc2(h)
        h = nn.functional.relu(h)
        h = nn.functional.dropout(h, p=0.2, training=self.training)  

        h = self.fc3(h)
        h = self.bc3(h)
        h = nn.functional.relu(h)
        h = nn.functional.dropout(h, p=0.1, training=self.training)  

        h = self.fc4(h)
        out = nn.functional.log_softmax(h, dim=0)
        return out

model = SimpleNet()

# TODO:define loss function and optimiter
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# train and evaluate
for epoch in range(NUM_EPOCHS):
    train_loss = 0
    train_acc = 0
    for images, labels in tqdm(train_loader):
        # TODO:forward + backward + optimize
        optimizer.zero_grad()  
        out = model(images)  
        lossvalue = criterion(out, labels) 
        optimizer.zero_grad()  
        lossvalue.backward()  
        optimizer.step() 
        train_loss += float(lossvalue)
        _, pred = out.max(1)
        num_correct = (pred == labels).sum()
        acc = int(num_correct) / images.shape[0]
        train_acc += acc
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    eval_loss = 0
    eval_acc = 0
    model.eval()  
    for images, labels in tqdm(test_loader):
        images = images.view(-1, 784)
        testout = model(images)
        testloss = criterion(testout, labels)
        eval_loss += float(testloss)

        _, pred = testout.max(1)
        num_correct = (pred == labels).sum()
        acc = int(num_correct) / images.shape[0]
        eval_acc += acc
        
    train_loss = train_loss / len(train_loader)
    train_acc = train_acc / len(train_loader)
    eval_loss = eval_loss / len(test_loader)
    eval_acc = eval_acc / len(test_loader)
    print("[Epoch: %d] Train Loss: %5.5f Train Accuracy: %5.5f" % (epoch + 1, train_loss, train_acc))
    print("[Epoch: %d] Test Loss: %5.5f Test Accuracy: %5.5f" % (epoch + 1, eval_loss, eval_acc))

print('Training Accuracy: %.2f%%' % (train_acc * 100))
print('Testing Accuracy: %.2f%%' % (eval_acc * 100))

100%|██████████| 468/468 [00:21<00:00, 21.62it/s]
100%|██████████| 78/78 [00:01<00:00, 50.38it/s]
  1%|          | 3/468 [00:00<00:17, 26.69it/s]

[Epoch: 1] Train Loss: 0.78197 Train Accuracy: 0.82068
[Epoch: 1] Test Loss: 0.27383 Test Accuracy: 0.93610


100%|██████████| 468/468 [00:17<00:00, 27.49it/s]
100%|██████████| 78/78 [00:01<00:00, 40.80it/s]
  1%|          | 3/468 [00:00<00:16, 28.65it/s]

[Epoch: 2] Train Loss: 0.20731 Train Accuracy: 0.93718
[Epoch: 2] Test Loss: 0.14653 Test Accuracy: 0.95743


100%|██████████| 468/468 [00:16<00:00, 29.02it/s]
100%|██████████| 78/78 [00:01<00:00, 53.09it/s]
  1%|          | 4/468 [00:00<00:15, 29.82it/s]

[Epoch: 3] Train Loss: 0.13008 Train Accuracy: 0.96120
[Epoch: 3] Test Loss: 0.11232 Test Accuracy: 0.96575


100%|██████████| 468/468 [00:17<00:00, 27.03it/s]
100%|██████████| 78/78 [00:01<00:00, 47.32it/s]
  1%|          | 3/468 [00:00<00:16, 28.26it/s]

[Epoch: 4] Train Loss: 0.09777 Train Accuracy: 0.97067
[Epoch: 4] Test Loss: 0.09278 Test Accuracy: 0.97115


100%|██████████| 468/468 [00:16<00:00, 28.12it/s]
100%|██████████| 78/78 [00:01<00:00, 50.68it/s]
  1%|          | 3/468 [00:00<00:16, 28.13it/s]

[Epoch: 5] Train Loss: 0.07735 Train Accuracy: 0.97574
[Epoch: 5] Test Loss: 0.08630 Test Accuracy: 0.97306


100%|██████████| 468/468 [00:16<00:00, 27.54it/s]
100%|██████████| 78/78 [00:01<00:00, 50.69it/s]
  1%|          | 3/468 [00:00<00:16, 27.73it/s]

[Epoch: 6] Train Loss: 0.06416 Train Accuracy: 0.98002
[Epoch: 6] Test Loss: 0.08068 Test Accuracy: 0.97546


100%|██████████| 468/468 [00:16<00:00, 28.79it/s]
100%|██████████| 78/78 [00:01<00:00, 50.60it/s]
  1%|          | 3/468 [00:00<00:17, 26.85it/s]

[Epoch: 7] Train Loss: 0.05403 Train Accuracy: 0.98302
[Epoch: 7] Test Loss: 0.07159 Test Accuracy: 0.97766


100%|██████████| 468/468 [00:16<00:00, 28.35it/s]
100%|██████████| 78/78 [00:01<00:00, 50.82it/s]
  1%|          | 3/468 [00:00<00:17, 26.92it/s]

[Epoch: 8] Train Loss: 0.04605 Train Accuracy: 0.98506
[Epoch: 8] Test Loss: 0.06398 Test Accuracy: 0.98057


100%|██████████| 468/468 [00:16<00:00, 28.40it/s]
100%|██████████| 78/78 [00:01<00:00, 50.85it/s]
  1%|          | 3/468 [00:00<00:16, 28.24it/s]

[Epoch: 9] Train Loss: 0.03979 Train Accuracy: 0.98723
[Epoch: 9] Test Loss: 0.07154 Test Accuracy: 0.97766


100%|██████████| 468/468 [00:16<00:00, 28.38it/s]
100%|██████████| 78/78 [00:01<00:00, 50.52it/s]

[Epoch: 10] Train Loss: 0.03542 Train Accuracy: 0.98831
[Epoch: 10] Test Loss: 0.07456 Test Accuracy: 0.97867
Training Accuracy: 98.83%
Testing Accuracy: 97.87%



