In [1]:
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from tensorboardX import SummaryWriter
import uuid

from lib.DEC import DEC
from lib.model import train, predict
from lib.sdae import StackedDenoisingAutoEncoder
import lib.model_ae as ae
from lib.utils import cluster_accuracy

In [2]:
class CachedMNIST(Dataset):
    def __init__(self, train, cuda, testing_mode=False):
        img_transform = transforms.Compose([
            transforms.Lambda(self._transformation)
        ])
        self.ds = MNIST(
            './data',
            download = True,
            train = train,
            transform = img_transform
        )
        self.cuda = cuda
        self.testing_mode = testing_mode
        self._cache = dict()

    @staticmethod
    def _transformation(img):
        return torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).float() * 0.02

    def __getitem__(self, index: int) -> torch.Tensor:
        if index not in self._cache:
            self._cache[index] = list(self.ds[index])
            if self.cuda:
                self._cache[index][0] = self._cache[index][0].cuda(non_blocking=True)
                self._cache[index][1] = torch(self._cache[index][1]).cuda(non_blocking=True)
        return self._cache[index]

    def __len__(self) -> int:
        return 128 if self.testing_mode else len(self.ds)

In [8]:
def main(
    cuda,
    batch_size,
    pretrain_epochs,
    finetune_epochs,
    testing_mode
):
    writer = SummaryWriter()  # create the TensorBoard object
    # callback function to call during training, uses writer from the scope

    def training_callback(epoch, lr, loss, validation_loss):
        writer.add_scalars('data/autoencoder', {
            'lr': lr,
            'loss': loss,
            'validation_loss': validation_loss,
        }, epoch)
    ds_train = CachedMNIST(train=True, cuda=cuda, testing_mode=testing_mode)  # training dataset
    ds_val = CachedMNIST(train=False, cuda=cuda, testing_mode=testing_mode)  # evaluation dataset
    autoencoder = StackedDenoisingAutoEncoder(
        [28 * 28, 500, 500, 2000, 10],
        final_activation=None
    )
    if cuda:
        autoencoder.cuda()
    print('Pretraining stage.')
    ae.pretrain(
        ds_train,
        autoencoder,
        cuda=cuda,
        validation=ds_val,
        epochs=pretrain_epochs,
        batch_size=batch_size,
        optimizer=lambda model: SGD(model.parameters(), lr=0.1, momentum=0.9),
        scheduler=lambda x: StepLR(x, 100, gamma=0.1),
        corruption=0.2
    )
    print('Training stage.')
    ae_optimizer = SGD(params=autoencoder.parameters(), lr=0.1, momentum=0.9)
    ae.train(
        ds_train,
        autoencoder,
        cuda=cuda,
        validation=ds_val,
        epochs=finetune_epochs,
        batch_size=batch_size,
        optimizer=ae_optimizer,
        scheduler=StepLR(ae_optimizer, 100, gamma=0.1),
        corruption=0.2,
        update_callback=training_callback
    )
    print('DEC stage.')
    model = DEC(
        cluster_number=10,
        hidden_dimension=10,
        encoder=autoencoder.encoder
    )
    if cuda:
        model.cuda()
    dec_optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    train(
        dataset=ds_train,
        model=model,
        epochs=100,
        batch_size=256,
        optimizer=dec_optimizer,
        stopping_delta=0.000001,
        cuda=cuda
    )
    predicted, actual = predict(ds_train, model, 1024, silent=True, return_actual=True, cuda=cuda)
    actual = actual.cpu().numpy()
    predicted = predicted.cpu().numpy()
    reassignment, accuracy = cluster_accuracy(actual, predicted)
    print('Final DEC accuracy: %s' % accuracy)
    if not testing_mode:
        predicted_reassigned = [reassignment[item] for item in predicted]  # TODO numpify
        confusion = confusion_matrix(actual, predicted_reassigned)
        normalised_confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
        confusion_id = uuid.uuid4().hex
        sns.heatmap(normalised_confusion).get_figure().savefig('confusion_%s.png' % confusion_id)
        print('Writing out confusion diagram with UUID: %s' % confusion_id)
        writer.close()

In [9]:
parms = {
    'cuda': 'True',
    'batch_size':'',
    'pretrain_epochs':'300',
    'finetune_epochs':'500',
    'testing_mode':'False'
}

In [None]:
main(**params)