In [2]:
import chainer
from chainer import functions as F
from chainer import links as L


class Encoder(chainer.Chain):

    def __init__(self):
        super().__init__()
        with self.init_scope():
            self.c0 = L.Convolution2D(None, 64, 4)
            self.c1 = L.Convolution2D(64, 128, 4)
            self.c2 = L.Convolution2D(128, 256, 4)
            self.c3 = L.Convolution2D(256, 512, 4)
            self.linear = L.Linear(None, 64)
            #self.bn0 = L.BatchNormalization(64)
            self.bn1 = L.BatchNormalization(128)
            self.bn2 = L.BatchNormalization(256)
            self.bn3 = L.BatchNormalization(512)

    def __call__(self, x):
        h = F.relu(self.c0(x))
        features = F.relu(self.bn1(self.c1(h)))
        h = F.relu(self.bn2(self.c2(features)))
        h = F.relu(self.bn3(self.c3(h)))
        return self.linear(h), features

class GlobalDiscriminator(chainer.Chain):

    def __init__(self):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(None, 512)
            self.l1 = L.Linear(512, 512)
            self.l2 = L.Linear(512, 1)
            self.c0 = L.Convolution2D(None, 64, 3)
            self.c1 = L.Convolution2D(64, 32, 3)

    def __call__(self, y, M):
        h = F.relu(self.c0(M))
        h = F.reshape(self.c1(h), (y.shape[0], -1))
        h = F.concat((y, h), axis=1)

        h = F.relu(self.l0(h))
        h = F.relu(self.l1(h))
        return self.l2(h)

class LocalDiscriminator(chainer.Chain):

    def __init__(self):
        super().__init__()
        with self.init_scope():
            self.c0 = L.Convolution2D(None, 512, 1)
            self.c1 = L.Convolution2D(512, 512, 1)
            self.c2 = L.Convolution2D(512, 1, 1)

    def __call__(self, x):
        h = F.relu(self.c0(x))
        h = F.relu(self.c1(h))
        return self.c2(h)

class PriorDiscriminator(chainer.Chain):

    def __init__(self):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(None, 1000)
            self.l1 = L.Linear(1000, 200)
            self.l2 = L.Linear(200, 1)

    def __call__(self, x):
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return F.sigmoid(self.l2(h))

if __name__ == "__main__":
    import numpy as np
    encoder = Encoder()
    x = np.ones((1,3,32,32), dtype=np.float32)
    y = encoder(x)
    print(y.shape)
    discriminator = Discriminator()
    d = discriminator(y)
    print(d.shape)


ModuleNotFoundError: No module named 'chainer'

In [None]:
import argparse
import chainer
from chainer import iterators, optimizers, serializers, reporter, training
from chainer.training import extensions
from chainer import functions as F
from chainer.dataset import concat_examples

from networks import Encoder, LocalDiscriminator, GlobalDiscriminator, PriorDiscriminator


class DeepINFOMAX(chainer.Chain):

    def __init__(self, alpha=1., beta=1., gamma=0.1):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        with self.init_scope():
            self.encoder = Encoder()
            self.local_disc = LocalDiscriminator()
            self.global_disc = GlobalDiscriminator()
            self.prior_disc = PriorDiscriminator()

    def __call__(self, x, t):
        # get encodings
        y, M = self.encoder(x)

        # shuffle batch to pair each element with another
        M_prime = F.concat((M[1:], (M[0])[None,:,:,:]), axis=0)

        # local DIM
        y_M = F.concat((F.broadcast_to(y[:, :, None, None], \
                                       (x.shape[0], y.shape[1], M.shape[-2], M.shape[-1])), M), axis=1)
        y_M_prime = F.concat((F.broadcast_to(y[:, :, None, None], \
                                             (x.shape[0], y.shape[1], M.shape[-2], M.shape[-1])), M_prime), axis=1)

        Ej = F.mean(-F.softplus(-self.local_disc(y_M)))
        Em = F.mean(F.softplus(self.local_disc(y_M_prime)))
        local_loss = (Em - Ej) * self.beta

        # global DIM
        Ej = F.mean(-F.softplus(-self.global_disc(y, M)))
        Em = F.mean(F.softplus(self.global_disc(y, M_prime)))
        global_loss = (Em - Ej) * self.alpha

        # prior term
        z = self.xp.random.uniform(size=y.shape).astype(self.xp.float32)
        
        term_a = F.mean(F.log(self.prior_disc(z)))
        term_b = F.mean(F.log(1. - self.prior_disc(y)))
        prior_loss = -(term_a + term_b) * self.gamma

        loss = global_loss + local_loss + prior_loss

        reporter.report({"loss": loss, "local_loss": local_loss, "global_loss": global_loss, "prior_loss": prior_loss}, self)
        return loss


def main(args):
    train, test = chainer.datasets.get_cifar10()
    train_iter = iterators.SerialIterator(train, args.batchsize)

    dim = DeepINFOMAX(alpha=args.alpha, beta=args.beta, gamma=args.gamma)

    if args.device >= 0:
        chainer.backends.cuda.get_device_from_id(args.device).use()
        dim.to_gpu(args.device)

    opt = optimizers.Adam(alpha=args.learning_rate)
    opt.setup(dim)

    updater = training.updaters.StandardUpdater(
        train_iter, opt, device=args.device)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.output)

    log_interval = (10, "iteration")
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'main/loss', 'main/local_loss', 'main/global_loss', 'main/prior_loss', 'elapsed_time']), trigger=log_interval)

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar(update_interval=log_interval[0]))

    trainer.extend(extensions.snapshot_object(dim.encoder, 'encoder_epoch_{.updater.epoch}'), trigger=(100, "epoch"))

    # Run the training
    trainer.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", "-g", type=int, default=-1)
    parser.add_argument("--epochs", "-e", type=int, default=1000)
    parser.add_argument("--batchsize", "-b", type=int, default=256)
    parser.add_argument("--learning_rate", "-l", type=float, default=1.E-4)
    parser.add_argument("--output", "-o", type=str, default="results")
    parser.add_argument("--alpha", "-A", type=float, default=0.5)
    parser.add_argument("--beta", "-B", type=float, default=1.0)
    parser.add_argument("--gamma", "-G", type=float, default=0.1)
    args = parser.parse_args()

    main(args)
