In [1]:
%matplotlib inline

In [90]:
from PIL import Image
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision
import torch.utils.data as data
import torch.optim as optim
import time
import copy
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [3]:
plt.ion()   # interactive mode

In [4]:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, planes):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(planes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out

In [6]:
class ResidualNet(nn.Module):
    def __init__(self, input_planes, height, width, number_of_blocks, classes):
        super(ResidualNet, self).__init__()
        
        if number_of_blocks < 2:
            raise ValueError("The residual net needs at least two blocks.")
        
        self.conv1 = conv3x3(input_planes, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.residual1 = ResidualBlock(16)
        self.conv2 = conv3x3(16, 32, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.residual2 = ResidualBlock(32)
        self.conv3 = conv3x3(32, 64, stride=2)
        self.bn3 = nn.BatchNorm2d(64)
        
        self.laterResidualBlocks = nn.ModuleList()
        
        for _ in range(number_of_blocks-2):
            self.laterResidualBlocks.append(ResidualBlock(64))
        
        self.dense_input_dim = height * width * 4
        self.dense = nn.Linear(self.dense_input_dim, classes)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.residual1(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.residual2(x)
        x = self.relu(self.bn3(self.conv3(x)))
        
        for block in self.laterResidualBlocks:
            x = block(x)
        
        return self.dense(x.view(-1, self.dense_input_dim))

In [7]:
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

class AdvancedImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, filter_fn=None, shuffle=False):
        super(AdvancedImageFolder, self).__init__(root, transform, target_transform, loader)
        
        if filter_fn is not None:
            self.imgs = list(filter(filter_fn, self.imgs))
        
        if shuffle:
            random.shuffle(self.imgs)

In [1]:
def tightest_image_crop(img, preserve_aspect_ratio=False):
    image_indices = F.threshold(Variable(img[0]), 0.0000001, 0).data.nonzero()
    top_i = image_indices[0,0]
    bottom_i = image_indices[-1,0]
    
    mins, _ = image_indices.min(dim=0)
    left_i = mins[1]
    
    maxs, _ = image_indices.max(dim=0)
    right_i = maxs[1]
    
    new_width = right_i-left_i+1
    new_height = top_i-bottom_i+1
        
    if preserve_aspect_ratio:
        if new_width > new_height:
            result = img[:, top_i:top_i+new_width, left_i:right_i+1]
            show(result)
            plt.show()
            return img[:, top_i:top_i+new_width, left_i:right_i+1]
        else:
            result = img[:, top_i:bottom_i+1, left_i:left_i+new_height]
            show(result)
            plt.show()
            return img[:, top_i:bottom_i+1, left_i:left_i+new_height]
        
    return img[:, top_i:bottom_i+1, left_i:right_i+1]

def square_padding(img):
    _, height, width = img.size()
    
    if height > width:
        padding = torch.zeros(1, height, height-width)
        return torch.cat((padding, img), 2)
    elif width > height:
        padding = torch.zeros(1, width-height, width)
        return torch.cat((padding, img), 1)
    
    return img

def image_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')

In [1]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda img: 1 - img),
        transforms.Lambda(tightest_image_crop),
        transforms.Lambda(square_padding),
        transforms.ToPILImage(),
        transforms.Resize(32),
        transforms.ToTensor()
    ])

NameError: name 'transforms' is not defined

In [102]:
def split_dataset(dset, batch_size=128, thread_count=4):
    sampler_dset_train = data.sampler.SubsetRandomSampler(list(range(int(0.7*len(dset)))))
    sampler_dset_test = data.sampler.SubsetRandomSampler(list(range(int(0.7*len(dset)), 
                                                                    int(0.85*len(dset)))))
    sampler_dset_validation = data.sampler.SubsetRandomSampler(list(range(int(0.85*len(dset)), 
                                                                          len(dset))))

    loader_dset_train = torch.utils.data.DataLoader(
        dset, batch_size=batch_size, num_workers=thread_count,
        pin_memory=True, sampler = sampler_dset_train)
    loader_dset_test = torch.utils.data.DataLoader(
        dset, batch_size=batch_size, num_workers=thread_count,
        pin_memory=True, sampler = sampler_dset_test)
    loader_dset_validation = torch.utils.data.DataLoader(
        dset, batch_size=batch_size, num_workers=thread_count,
        pin_memory=True, sampler = sampler_dset_validation)

    return loader_dset_train, loader_dset_test, loader_dset_validation

dset_type = AdvancedImageFolder('by_class', transform, 
                                target_transform = lambda n: 0 if n < 10 else 1, 
                                loader = image_loader,
                                shuffle = True)
dset_digit = AdvancedImageFolder('by_class', transform, loader = image_loader,
                                 filter_fn = lambda p: p[1] < 10,
                                 shuffle = True)
dset_char = AdvancedImageFolder('by_class', transform, target_transform = lambda n: n - 10, 
                                loader = image_loader, filter_fn = lambda p: p[1] >= 10,
                                shuffle = True)

dset_uppercase_char = AdvancedImageFolder('by_class', transform, target_transform = lambda n: n - 10, 
                                loader = image_loader, filter_fn = lambda p: p[1] >= 10 and p[1] <= 35,
                                shuffle = True)

loader_type_train, loader_type_test, _ = split_dataset(dset_type)
loader_digit_train, loader_digit_test, _ = split_dataset(dset_digit)
loader_char_train, loader_char_test, _ = split_dataset(dset_char)
loader_uppercase_char_train, loader_uppercase_char_test, _ = split_dataset(dset_uppercase_char)

In [98]:
resnet_type = ResidualNet(1, 32, 32, 4, 2)
resnet_digit = ResidualNet(1, 32, 32, 6, 10)
resnet_char = ResidualNet(1, 32, 32, 2, 52)
resnet_uppercase_char = ResidualNet(1, 32, 32, 3, 26)

resnet_type = nn.DataParallel(resnet_type.cuda())
resnet_digit = nn.DataParallel(resnet_digit.cuda())
resnet_char = nn.DataParallel(resnet_char.cuda())
resnet_uppercase_char = nn.DataParallel(resnet_uppercase_char.cuda())

In [93]:
def train_model(model, dset_loader, criterion, optimizer, lr_scheduler, num_epochs=5):
    since = time.time()

    best_model = model
    best_acc = 0.0
    model.train(True)

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        optimizer = lr_scheduler(optimizer, epoch)

        running_loss = 0.0
        running_corrects = 0

        current_batch = 0
        # Iterate over data.
        for data in dset_loader:
            current_batch += 1
            # get the inputs
            inputs, labels = data

            # wrap them in Variable
            inputs, labels = Variable(inputs.cuda()), \
                             Variable(labels.cuda())

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)

            # backward
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.data[0]
            running_corrects += torch.sum(preds == labels.data)

            if current_batch % 20 == 0:
                curr_acc = running_corrects / (current_batch * dset_loader.batch_size)
                curr_loss = running_loss / (current_batch * dset_loader.batch_size)
                time_elapsed = time.time() - since

                print('Epoch Number: {}, Batch Number: {}, Loss: {:.4f}, Acc: {:.4f}'.format(
                        epoch, current_batch, curr_loss, curr_acc))
                print('Time so far is {:.0f}m {:.0f}s'.format(
                      time_elapsed // 60, time_elapsed % 60))

        epoch_loss = running_loss / (len(dset_loader) * dset_loader.batch_size) 
        epoch_acc = running_corrects / (len(dset_loader) * dset_loader.batch_size) 

        # deep copy the model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model = copy.deepcopy(model)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.train(False)
    
    return best_model

In [94]:
def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.1**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

In [100]:
criterion = nn.CrossEntropyLoss()
optimizer_type = optim.SGD(resnet_type.parameters(), lr=0.001, momentum=0.9)
optimizer_digit = optim.SGD(resnet_digit.parameters(), lr=0.001, momentum=0.9)
optimizer_char = optim.SGD(resnet_char.parameters(), lr=0.001, momentum=0.9)
optimizer_uppercase_char = optim.SGD(resnet_uppercase_char.parameters(), lr=0.001, momentum=0.9)

In [96]:
def test_model(model, dset_loader):
    model.train(False)
    
    running_corrects = 0
    
    for data in dset_loader:
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs.cuda()), \
                         Variable(labels.cuda())

        # forward
        outputs = model(inputs)
        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data)
    
    return running_corrects/(len(dset_loader) * dset_loader.batch_size) 