In [22]:
import torch
import torchvision
import torch.nn.functional as F

In [23]:
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data',train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081, ))
                               ])),
    batch_size=200, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data',train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081, ))
                               ])),
    batch_size=200, shuffle=True)

In [30]:
w1 = torch.randn(200,784,requires_grad=True)
b1 = torch.randn(200,requires_grad=True)
w2 = torch.randn(200,200,requires_grad=True)
b2 = torch.randn(200,requires_grad=True)
w3 = torch.randn(10,200,requires_grad=True)
b3 = torch.randn(10,requires_grad=True)

In [25]:
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

tensor([[-0.1358, -0.0360, -0.0545,  ..., -0.2844, -0.0529,  0.0405],
        [ 0.0763, -0.0358, -0.0787,  ..., -0.0214,  0.0060,  0.1058],
        [-0.0387, -0.0824, -0.1132,  ...,  0.0272,  0.1101, -0.0306],
        ...,
        [-0.0434,  0.0549,  0.0980,  ..., -0.0154, -0.0412,  0.1141],
        [ 0.0676,  0.1171,  0.0294,  ...,  0.0701, -0.0571, -0.0504],
        [-0.0504,  0.0371,  0.0581,  ...,  0.0493, -0.0357, -0.0786]],
       requires_grad=True)

In [26]:
def forward(x):
    x = x@w1.t() + b1
    x = F.relu(x)
    x = x@w2.t() + b2
    x = F.relu(x)
    x = x@w3.t() + b3
    x = F.relu(x)
    return x

In [27]:
optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [28]:
for epoch in range(10):
    for batch_idx,(data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        logits = forward(data)
        loss = criterion(logits, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100.*batch_idx / len(train_loader), loss.item()))
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28*28)
        logits = forward(data)
        test_loss += criterion(logits, target).item()
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set:Average loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset), 
        100. * correct / len(test_loader.dataset)))


Test set:Average loss:0.0086, Accuracy:4503/10000(45%)


Test set:Average loss:0.0058, Accuracy:6785/10000(67%)


Test set:Average loss:0.0045, Accuracy:7533/10000(75%)


Test set:Average loss:0.0039, Accuracy:7821/10000(78%)


Test set:Average loss:0.0036, Accuracy:7971/10000(79%)


Test set:Average loss:0.0034, Accuracy:8084/10000(80%)


Test set:Average loss:0.0032, Accuracy:8148/10000(81%)


Test set:Average loss:0.0031, Accuracy:8194/10000(81%)


Test set:Average loss:0.0030, Accuracy:8243/10000(82%)


Test set:Average loss:0.0029, Accuracy:8274/10000(82%)

