In [13]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.optim.lr_scheduler import StepLR

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 100
batch_size = 256
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor()])

# dataset = MNIST('./data', transform=img_transform, download=True)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# load the training and test datasets
train_data = MNIST(root='./data', train=True,
                                   download=True, transform=img_transform)
test_data = MNIST(root='./data', train=False,
                                  download=True, transform=img_transform)

# dataset = MNIST('./data', transform=img_transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [15]:
class Cons_lin_layer(nn.Module):
    def __init__(self, n_in, n_out, nonlin = 'tanh', p = 2):
        super(Cons_lin_layer,self).__init__()
        self.lin = nn.Linear(n_in,n_out)
        if nonlin == 'none':
            self.layer = self.lin
        else:
            if nonlin == 'tanh':
                self.nonlin = nn.Tanh()
            elif nonlin == 'relu':
                self.nonlin = nn.ReLU()
            elif nonlin == 'sigmoid':
                self.nonlin = nn.Sigmoid()
            self.layer = nn.Sequential(self.lin,self.nonlin)
       
        self.p = p
   
    def forward(self, x):
        energy = x.norm(p=self.p, dim=-1, keepdim=True)
        y = self.layer(x)
        y = F.normalize(y,p=self.p,dim=-1)*energy
        return y

In [16]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encoder = nn.Sequential(
            Cons_lin_layer(28*28,128, nonlin='relu'),
            Cons_lin_layer(128,64, nonlin='relu'),
            Cons_lin_layer(64,32, nonlin='relu'),
            Cons_lin_layer(32,10, nonlin='none'))
    def forward(self, x):
        x = self.encoder(x)
        output = F.log_softmax(x, dim=1)
        return output

In [27]:
class Net_simple(nn.Module):
    def __init__(self):
        super(Net_simple, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28,128),nn.ReLU(),
            nn.Linear(128,64),nn.ReLU(),
            nn.Linear(64,32),nn.ReLU(),
            nn.Linear(32,10))
    def forward(self, x):
        x = self.encoder(x)
        output = F.log_softmax(x, dim=1)
        return output

In [24]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        img, target = data.to(device), target.to(device)
        img = img.view(img.size(0), -1)
        optimizer.zero_grad()
        output = model(img)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            img, target = data.to(device), target.to(device)
            img = img.view(img.size(0), -1)
            output = model(img)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [25]:
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 100 + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()


Test set: Average loss: 0.2642, Accuracy: 9198/10000 (91.98%)


Test set: Average loss: 0.1414, Accuracy: 9571/10000 (95.71%)


Test set: Average loss: 0.0959, Accuracy: 9713/10000 (97.13%)


Test set: Average loss: 0.0837, Accuracy: 9750/10000 (97.5%)


Test set: Average loss: 0.0751, Accuracy: 9784/10000 (97.84%)


Test set: Average loss: 0.0725, Accuracy: 9794/10000 (97.94%)


Test set: Average loss: 0.0720, Accuracy: 9798/10000 (97.98%)


Test set: Average loss: 0.0712, Accuracy: 9803/10000 (98.03%)


Test set: Average loss: 0.0707, Accuracy: 9809/10000 (98.09%)


Test set: Average loss: 0.0697, Accuracy: 9813/10000 (98.13%)


Test set: Average loss: 0.0709, Accuracy: 9807/10000 (98.07%)


Test set: Average loss: 0.0709, Accuracy: 9812/10000 (98.12%)


Test set: Average loss: 0.0713, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0713, Accuracy: 9809/10000 (98.09%)


Test set: Average loss: 0.0714, Accuracy: 9807/10000 (98.07%)


Test set: Average loss: 0.0713, Accuracy


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)


Test set: Average loss: 0.0715, Accuracy: 9808/10000 (98.08%)



KeyboardInterrupt: 

In [28]:
model2 = Net_simple().to(device)
optimizer2 = optim.Adadelta(model2.parameters(), lr=1)

scheduler2 = StepLR(optimizer2, step_size=1, gamma=0.7)
for epoch in range(1, 100 + 1):
    train(model2, device, train_loader, optimizer2, epoch)
    test(model2, device, test_loader)
    scheduler2.step()


Test set: Average loss: 0.2236, Accuracy: 9305/10000 (93.05%)


Test set: Average loss: 0.1482, Accuracy: 9554/10000 (95.54%)


Test set: Average loss: 0.1167, Accuracy: 9633/10000 (96.33%)


Test set: Average loss: 0.0992, Accuracy: 9697/10000 (96.97%)


Test set: Average loss: 0.0833, Accuracy: 9751/10000 (97.51%)


Test set: Average loss: 0.0827, Accuracy: 9737/10000 (97.37%)


Test set: Average loss: 0.0774, Accuracy: 9770/10000 (97.7%)


Test set: Average loss: 0.0760, Accuracy: 9760/10000 (97.6%)


Test set: Average loss: 0.0746, Accuracy: 9772/10000 (97.72%)


Test set: Average loss: 0.0736, Accuracy: 9767/10000 (97.67%)


Test set: Average loss: 0.0726, Accuracy: 9771/10000 (97.71%)


Test set: Average loss: 0.0726, Accuracy: 9775/10000 (97.75%)


Test set: Average loss: 0.0726, Accuracy: 9777/10000 (97.77%)


Test set: Average loss: 0.0726, Accuracy: 9773/10000 (97.73%)


Test set: Average loss: 0.0724, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0724, Accuracy:


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accurac


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)


Test set: Average loss: 0.0722, Accuracy: 9774/10000 (97.74%)



KeyboardInterrupt: 