In [1]:
from IPython import display

import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset

import numpy as np

import time
from visdom import Visdom

from lib.VisdomWrapper import *
from lib.DataManager import *

In [2]:
batch_size = 512
num_epochs = 30

loss = nn.CrossEntropyLoss()
vis = VisdomController()

mnist = dset.MNIST('input', train=True, download=True, transform=T.ToTensor())
balanced_train = DataLoader(mnist, batch_size =batch_size, shuffle=True)

mnist_low_zero = get_unbalanced_mnist([.05, .15, .1, .1, .1, .1, .1, .1, .1, .1], batch_size=batch_size)

mnist_test=dset.MNIST('input', train=False, download=True, transform=T.ToTensor())
balanced_test = DataLoader(mnist_test)

Setting up a new session...


In [3]:
def build_classifier():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(4*4*64, 4*4*64, bias=True),
        nn.LeakyReLU(0.01),
        nn.Linear(4*4*64, 10, bias=True)
    )

In [9]:
def get_optimizer(model):
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    return optimizer

In [5]:
def train_classifier(classifier, optimizer, data, key="Loss"):
    for epoch in range(num_epochs):
        for n_batch, (x, y) in enumerate(data):
            if len(x) != batch_size:
                continue
            optimizer.zero_grad
            x = x.cuda()
            y = y.cuda()
            scores = classifier(x)
            out = loss(scores, y)
            out.backward()
            optimizer.step()
        display.clear_output(True)
        print("Epoch {}, {} / {}".format(epoch, n_batch, len(data)))
        print("Loss: ", out.item())
        vis.loss_axis = epoch
        vis.PlotLoss(key, out.item())
            

In [6]:
def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [10]:
balanced_net = build_classifier().cuda()
balanced_net.train()
train_classifier(balanced_net, get_optimizer(balanced_net), balanced_train, key="Balanced")


Epoch 29, 117 / 118
Loss:  0.06898085027933121


In [12]:
low_zero_net = build_classifier().cuda()
low_zero_net.train()
train_classifier(low_zero_net, get_optimizer(low_zero_net), mnist_low_zero, key="Low Zero")

Epoch 29, 117 / 118
Loss:  0.04684881865978241


In [11]:
device = torch.device("cuda")
balanced_net.eval()
test(balanced_net, device, balanced_test)

Accuracy: 9675/10000 (97%)



In [13]:
device = torch.device("cuda")
low_zero_net.eval()
test(low_zero_net, device, balanced_test)

Accuracy: 8854/10000 (89%)

