In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch import autograd

from utils_2 import mnist

In [21]:
train_loader, test_loader = mnist()

In [22]:
def CrossEntropy(x, y):
    m = y.shape[0]
    p = F.softmax(x, dim=1)
    log_likelihood = -1*torch.log(p[range(m),y])
    loss = torch.sum(log_likelihood) / m
    return loss

In [23]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.optim = optim.Adam(self.parameters(), lr=0.01)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def loss(self, output, target, **kwargs):
        self._loss = CrossEntropy(output, target)
        return self._loss

In [24]:
class Net_(nn.Module):
    def __init__(self):
        super(Net_, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.optim = optim.Adam(self.parameters(), lr=0.01)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)
    
    def loss(self, output, target, **kwargs):
        self._loss = F.nll_loss(output, target)
        return self._loss

In [25]:
def train(epoch, models):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data, target
        for model in models:
            model.optim.zero_grad()
            output = model(data)
            loss = model.loss(output, target)
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader))
            losses = ' '.join(['{}: {:.6f}'.format(i, m._loss.item()) for i, m in enumerate(models)])
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader))
        losses = ' '.join(['{}: {:.6f}'.format(i, m._loss.item()) for i, m in enumerate(models)])
        print(line + losses)

In [26]:
models = [Net(), Net_()]

In [27]:
avg_lambda = lambda l: 'Loss: {:.4f}'.format(l)
acc_lambda = lambda c, p: 'Accuracy: {}/{} ({:.0f}%)'.format(c, len(test_loader.dataset), p)
line = lambda i, l, c, p: '{}: '.format(i) + avg_lambda(l) + '\t' + acc_lambda(c, p)

def test(models):
    test_loss = [0]*len(models)
    correct = [0]*len(models)
    with torch.no_grad():
        for data, target in test_loader:
            output = [m(data) for m in models]
            for i, m in enumerate(models):
                test_loss[i] += m.loss(output[i], target, size_average=False).item() # sum up batch loss
                pred = output[i].data.max(1, keepdim=True)[1] # get the index of the max log-probability
                correct[i] += pred.eq(target.data.view_as(pred)).cpu().sum()
    
    for i in range(len(models)):
        test_loss[i] /= len(test_loader.dataset)
    correct_pct = [100. * c / len(test_loader.dataset) for c in correct]
    lines = '\n'.join([line(i, test_loss[i], correct[i], correct_pct[i]) for i in range(len(models))]) + '\n'
    report = 'Test set:\n' + lines
    
    print(report)

In [None]:
for epoch in range(1, 21):
    train(epoch, models)
    test(models)

Test set:
0: Loss: 0.0043	Accuracy: 9328/10000 (93%)
1: Loss: 0.0043	Accuracy: 9315/10000 (93%)

Test set:
0: Loss: 0.0039	Accuracy: 9411/10000 (94%)
1: Loss: 0.0038	Accuracy: 9407/10000 (94%)

Test set:
0: Loss: 0.0037	Accuracy: 9405/10000 (94%)
1: Loss: 0.0038	Accuracy: 9428/10000 (94%)

Test set:
0: Loss: 0.0036	Accuracy: 9430/10000 (94%)
1: Loss: 0.0036	Accuracy: 9436/10000 (94%)

