In [1]:
#!/usr/bin/env python
import torch
import torchvision.datasets
import torchvision.transforms as transforms
from torch.utils.data import dataset
from torch import nn
# from torch.nn.modules import upsampling
# from torch.functional import F
from torch.optim import Adam

image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5),
                         (.5, .5, .5))
])


def get_data_loader(dataset_location, batch_size):
    trainvalid = torchvision.datasets.SVHN(
        dataset_location, split='train',
        download=True,
        transform=image_transform
    )

    trainset_size = int(len(trainvalid) * 0.9)
    trainset, validset = dataset.random_split(
        trainvalid,
        [trainset_size, len(trainvalid) - trainset_size]
    )

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    validloader = torch.utils.data.DataLoader(
        validset,
        batch_size=batch_size,
    )

    testloader = torch.utils.data.DataLoader(
        torchvision.datasets.SVHN(
            dataset_location, split='test',
            download=True,
            transform=image_transform
        ),
        batch_size=batch_size,
    )

    return trainloader, validloader, testloader


class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(3, 8, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.Conv2d(8, 16, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 16, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ELU(),
            nn.Dropout2d(p=0.1),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 512, 2),
        )

        self.mlp = nn.Sequential(
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10),
        )
        # for p in self.parameters():
        #     if p.dim() > 1:
        #         nn.init.xavier_uniform_(p)

    def forward(self, x):
        return self.mlp(self.extract_features(x))

    def extract_features(self, x):
        return self.conv_stack(x)[:, :, 0, 0]


def evaluate(classify, dataset):
    with torch.no_grad():
        classify.eval()
        correct = 0.
        total = 0.
        for x, y in dataset:
            if cuda:
                x = x.cuda()
                y = y.cuda()

            c = (classify(x).argmax(dim=-1) == y).sum().item()
            t = x.size(0)
            correct += c
            total += t
    acc = correct / float(total)
    return acc


if __name__ == "__main__":
    train, valid, test = get_data_loader("svhn", 32)
    classify = Classifier()
    params = classify.parameters()
    optimizer = Adam(params)
    ce = nn.CrossEntropyLoss()
    best_acc = 0.
    cuda = torch.cuda.is_available()
    if cuda:
        classify = classify.cuda()

    for _ in range(50):
        classify.train()
        for i, (x, y) in enumerate(train):
            if cuda:
                x = x.cuda()
                y = y.cuda()
            out = classify(x)
            loss = ce(out, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % 200 == 0:
                print(loss.item())
        acc = evaluate(classify, valid)
        print("Validation acc:", acc,)

        if acc > best_acc:
            best_acc = acc
            torch.save(classify, "svhn_classifier.pt")
            print("Saved.")
    classify = torch.load("svhn_classifier.pt")
    print("Test accuracy:", evaluate(classify, test))


Using downloaded and verified file: svhn/train_32x32.mat
Using downloaded and verified file: svhn/test_32x32.mat
1.1669793128967285
0.7245062589645386
1.2448557615280151
0.5017054080963135
1.5755804777145386
0.4533226191997528
0.38006654381752014
0.5034358501434326
0.643135666847229
0.4568672180175781
Validation acc: 0.8588588588588588


  "type " + obj.__name__ + ". It won't be checked "


Saved.
0.6984328031539917
0.4277397692203522
0.5440733432769775
0.7335171699523926
0.37157973647117615
0.6662212610244751
0.7070134878158569
0.43206557631492615
0.37481260299682617
0.24624542891979218
Validation acc: 0.8835653835653836
Saved.
0.6158313751220703
0.4484349489212036
0.5603472590446472
0.4165746569633484
0.3895421624183655
0.33405473828315735
0.267414391040802
0.3296146094799042
0.2902681231498718
0.26894646883010864
Validation acc: 0.8951678951678952
Saved.
0.4139152765274048
0.5824165940284729
0.6200612783432007
0.3874046802520752
0.19156727194786072
0.7204138040542603
0.35433176159858704
0.08428553491830826
0.4587779641151428
0.15806102752685547
Validation acc: 0.8993993993993994
Saved.
0.3505890369415283
0.12056907266378403
0.39257779717445374
0.4409851133823395
0.28597304224967957
0.2988276183605194
0.4879668056964874
0.6041111350059509
0.2834666073322296
0.3158580958843231
Validation acc: 0.9085449085449085
Saved.
0.4161139130592346
0.6342220306396484
0.2497352957725

0.14851875603199005
0.301197350025177
0.21809101104736328
0.9045352339744568
Validation acc: 0.9167349167349167
0.2208276093006134
0.45374542474746704
0.22437754273414612
0.13263556361198425
0.21714140474796295
0.19487105309963226
0.528866171836853
0.340954065322876
0.26738497614860535
0.30825167894363403
Validation acc: 0.9183729183729183
0.47794201970100403
0.375997394323349
0.23906715214252472
0.5660343170166016
0.20854803919792175
0.16611750423908234
0.2517043948173523
0.6213659644126892
0.03372389078140259
0.281422883272171
Validation acc: 0.9213759213759214
0.291009783744812
0.43557292222976685
0.4943186938762665
0.28734397888183594
0.1138266772031784
0.2656649351119995
0.1416906714439392
0.09671378135681152
0.46039968729019165
1.0157763957977295
Validation acc: 0.9242424242424242
Saved.
0.19095459580421448
0.13835810124874115
0.49721643328666687
0.07535502314567566
0.36653274297714233
0.4311251640319824
0.083079993724823
0.1947057545185089
0.14823117852210999
0.23192189633846283