In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from warnings import warn
from torch.autograd import Variable
from torch.optim import Adam, SGD
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [2]:
class Generator(nn.Module):
    r"""Base class for all Generator models
    Args:
        encoding_dims(torch.Size) : Dimensions of the sample from the noise prior
    """
    # FIXME(Aniket1998): If a user is overriding the default initializer, he must also override the constructor
    # Find an efficient workaround by fixing the initialization mechanism
    def __init__(self, encoding_dims):
        super(Generator, self).__init__()
        self.encoding_dims = encoding_dims

    r"""Default weight initializer for all generator models
    Models that require custom weight initialization can override this method"""
    # TODO(Aniket1998): Think of better dictionary lookup based approaches to initialization
    # That allows easy and customizable weight initialization without overriding
    def _weight_initializer(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)


class Discriminator(nn.Module):
    r"""Base class for all Discriminator models
    Args:
        input_dims(torch.Size) : Dimensions of the input
    """
    def __init__(self, input_dims):
        super(Discriminator, self).__init__()
        self.input_dims = input_dims

    r"""Default weight initializer for all disciminator models
    Models that require custom weight initialization can override this method"""
    # TODO(Aniket1998): Think of better dictionary lookup based approaches to initialization
    # That allows easy and customizable weight initialization without overriding
    def _weight_initializer(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)

In [3]:
class DCGANGenerator(Generator):
    r"""Deep Convolutional GAN (DCGAN) generator from
    "Unsupervised Representation Learning With Deep Convolutional Generative Aversarial Networks
     by Radford et. al. " <https://arxiv.org/abs/1511.06434>
     Args:
        encoding_dims (int, optional) : Dimension of the encoding vector sampled from the noise prior. Default 100
        out_channels (int, optional) : Number of channels in the output Tensor. Default 3
        step_channels (int, optional) : Number of channels in multiples of which the DCGAN steps up
                                        the convolutional features
                                        The step up is done as dim z -> d - > 2 * d -> 4 * d - > 8 * d
                                        where d = step_channels. Default 64
        batchnorm (bool, optional) : If True, use batch normalization in the convolutional layers of the generator
                                     Default True
        nonlinearity(torch.nn.Module, optional) : Nonlinearity to be used in the intermediate convolutional layers
                                                  Defaults to LeakyReLU(0.2) when None is passed. Default None
        last_nonlinearity(torch.nn.Module, optional) : Nonlinearity to be used in the final convolutional layer
                                                       Defaults to tanh when None is passed. Default None
    """
    def __init__(self, encoding_dims=100, out_channels=3, step_channels=64,
                 batchnorm=True, nonlinearity=None, last_nonlinearity=None):
        super(DCGANGenerator, self).__init__(encoding_dims)
        self.ch = out_channels
        self.step_ch = step_channels
        use_bias = not batchnorm

        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity
        last_nl = nn.Tanh() if last_nonlinearity is None else last_nonlinearity

        if batchnorm is True:
            self.model = nn.Sequential(
                nn.ConvTranspose2d(self.encoding_dims, self.step_ch * 8, 4, 1, 0, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 8), nl,
                nn.ConvTranspose2d(self.step_ch * 8, self.step_ch * 4, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 4), nl,
                nn.ConvTranspose2d(self.step_ch * 4, self.step_ch * 2, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 2), nl,
                nn.ConvTranspose2d(self.step_ch * 2, self.ch, 4, 2, 1, bias=use_bias),
                last_nl)
        else:
            self.model = nn.Sequential(
                nn.ConvTranspose2d(self.encoding_dims, self.step_ch * 8, 4, 1, 0, bias=use_bias), nl,
                nn.ConvTranspose2d(self.step_ch * 8, self.step_ch * 4, 4, 2, 1, bias=use_bias), nl,
                nn.ConvTranspose2d(self.step_ch * 4, self.step_ch * 2, 4, 2, 1, bias=use_bias), nl,
                nn.ConvTranspose2d(self.step_ch * 2, ch, 4, 2, 1, bias=use_bias), last_nl)

        self._weight_initializer()

    def forward(self, x):
        return self.model(x)


class DCGANDiscriminator(Discriminator):
    r"""Deep Convolutional GAN (DCGAN) discriminator from
    "Unsupervised Representation Learning With Deep Convolutional Generative Aversarial Networks
     by Radford et. al. " <https://arxiv.org/abs/1511.06434>
     Args:
        encoding_dims (int, optional) : Dimension of the encoding vector sampled from the noise prior. Default 100
        out_channels (int, optional) : Number of channels in the output Tensor. Default 3
        step_channels (int, optional) : Number of channels in multiples of which the DCGAN steps up
                                        the convolutional features
                                        The step up is done as dim z -> d - > 2 * d -> 4 * d - > 8 * d
                                        where d = step_channels. Default 64
        batchnorm (bool, optional) : If True, use batch normalization in the convolutional layers of the generator
                                     Default True
        nonlinearity(torch.nn.Module, optional) : Nonlinearity to be used in the intermediate convolutional layers
                                                  Defaults to LeakyReLU(0.2) when None is passed. Default None
        last_nonlinearity(toch.nn.Module, optional) : Nonlinearity to be used in the final convolutional layer
                                                      Defaults to sigmoid when None is passed. Default None
    """

    def __init__(self, in_channels=3, step_channels=64, batchnorm=True,
                 nonlinearity=None, last_nonlinearity=None):
        super(DCGANDiscriminator, self).__init__(in_channels)
        self.step_ch = step_channels
        self.batchnorm = batchnorm
        use_bias = not batchnorm

        nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity
        last_nl = nn.LeakyReLU(0.2) if last_nonlinearity is None else last_nonlinearity

        if batchnorm is True:
            self.model = nn.Sequential(
                nn.Conv2d(self.input_dims, self.step_ch, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch), nl,
                nn.Conv2d(self.step_ch, self.step_ch * 2, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 2), nl,
                nn.Conv2d(self.step_ch * 2, self.step_ch * 4, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 4), nl,
                nn.Conv2d(self.step_ch * 4, self.step_ch * 8, 4, 2, 1, bias=use_bias),
                nn.BatchNorm2d(self.step_ch * 8), nl,
                nn.Conv2d(self.step_ch * 8, 1, 4, 2, 1, bias=use_bias),
                last_nl)
        else:
            self.model = nn.Sequential(
                nn.Conv2d(self.input_dims, self.step_ch, 4, 2, 1, bias=use_bias), nl,
                nn.Conv2d(self.step_ch, self.step_ch * 2, 4, 2, 1, bias=use_bias), nl,
                nn.Conv2d(self.step_ch * 2, self.step_ch * 4, 4, 2, 1, bias=use_bias), nl,
                nn.Conv2d(self.step_ch * 4, self.step_ch * 8, 4, 2, 1, bias=use_bias), nl, 
                nn.Conv2d(self.step_ch * 8, 1, 4, 2, 1, bias=use_bias), last_nl)

        self._weight_initializer()

    def forward(self, x):
        return self.model(x)

In [4]:
class Trainer(object):
    def __init__(self, generator, discriminator, optimizer_generator, optimizer_discriminator,
                 generator_loss, discriminator_loss, device=torch.device("cuda:0"),
                 batch_size=128, sample_size=8, epochs=5, checkpoints="./model/gan",
                 retain_checkpoints=5, recon="./images/", test_noise=None, **kwargs):
        self.device = device
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        if "optimizer_generator_options" in kwargs:
            self.optimizer_generator = optimizer_generator(self.generator.parameters(),
                                                           **kwargs["optimizer_generator_options"])
        else:
            self.optimizer_generator = optimizer_generator(self.generator.parameters())
        if "optimizer_discriminator_options" in kwargs:
            self.optimizer_discriminator = optimizer_discriminator(self.discriminator.parameters(),
                                                                   **kwargs["optimizer_discriminator_options"])
        else:
            self.optimizer_discriminator = optimizer_discriminator(self.discriminator.parameters())
        if "loss_generator_options" in kwargs:
            self.generator_loss = generator_loss(**kwargs["loss_generator_options"])
        else:
            self.generator_loss = generator_loss()
        if "loss_discriminator_options" in kwargs:
            self.discriminator_loss = discriminator_loss(**kwargs["loss_discriminator_options"])
        else:
            self.discriminator_loss = discriminator_loss()
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.epochs = epochs
        self.checkpoints = checkpoints
        self.retain_checkpoints = retain_checkpoints
        self.recon = recon
        self.test_noise = torch.randn(self.sample_size, self.generator.encoding_dims, 1, 1,
                                      device=self.device) if test_noise is None else test_noise
        self.loss_information = {
            'generator_losses': [],
            'discriminator_loss': [],
            'generator_iters': 0,
            'discriminator_iters': 0,
        }
        if "loss_information" in kwargs:
            self.loss_information.update(kwargs["loss_information"])
        if not "target_dim" in kwargs:
            target_dim = 1
        self.targets = {
            'discriminator_target_real': torch.ones(self.batch_size, target_dim, device=self.device).squeeze(),
            'discriminator_target_fake': torch.zeros(self.batch_size, target_dim, device=self.device).squeeze()
        }
        self.start_epoch = 0
        self.last_retained_checkpoint = 0

    def save_model_extras(self, save_path):
        return {}

    def save_model(self, epoch):
        if self.last_retained_checkpoint == self.retain_checkpoints:
            self.last_retained_checkpoint = 0
        save_path = self.checkpoints + str(self.last_retained_checkpoint) + '.model'
        print("Saving Model at '{}'".format(save_path))
        model = {
            'epoch': epoch + 1,
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_generator': self.optimizer_generator.state_dict(),
            'optimizer_discriminator': self.optimizer_discriminator.state_dict(),
            'generator_losses': self.generator_losses,
            'discriminator_losses': self.discriminator_losses
        }
        # FIXME(avik-pal): Not a very good function name
        model.update(self.save_model_extras(save_path))
        torch.save(model, save_path)

    def load_model_extras(self, load_path):
        pass

    def load_model(self, load_path=""):
        if load_path == "":
            load_path = self.checkpoints + str(self.last_retained_checkpoint) + '.model'
        print("Loading Model From '{}'".format(load_path))
        try:
            check = torch.load(load_path)
            self.start_epoch = check['epoch']
            self.generator_losses = check['generator_losses']
            self.discriminator_losses = check['discriminator_losses']
            self.generator.load_state_dict(check['generator'])
            self.discriminator.load_state_dict(check['discriminator'])
            self.optimizer_generator.load_state_dict(check['optimizer_generator'])
            self.optimizer_discriminator.load_state_dict(check['optimizer_discriminator'])
            # FIXME(avik-pal): Not a very good function name
            self.load_model_extras(check)
        except:
            warn("Model could not be loaded from {}. Training from Scratch".format(load_path))
            self.start_epoch = 0
            self.generator_losses = []
            self.discriminator_losses = []

    def sample_images(self, epoch, nrow=8):
        save_path = "{}/epoch{}.png".format(self.recon, epoch + 1)
        print("Generating and Saving Images to {}".format(save_path))
        self.generator.eval()
        with torch.no_grad():
            images = self.generator(self.test_noise.to(self.device))
            img = torchvision.utils.make_grid(images)
            torchvision.utils.save_image(img, save_path, nrow=nrow)
        self.generator.train()

    def _verbose_matching(self, verbose):
        assert verbose >= 0 and verbose <= 5
        self.save_iter = 10**((6 - verbose) / 2)
        self.save_epoch = 6 - verbose
        self.generate_images = 6 - verbose

    def train_logger(self, running_generator_loss, running_discriminator_loss, epoch, itr=None):
        if itr is None:
            if (epoch + 1) % self.save_epoch == 0 or epoch == self.epochs:
                self.save_model(epoch)
            if (epoch + 1) % self.generate_images or epoch == self.epochs:
                self.sample_images(epoch)
            print("Epoch {} Complete | Mean Generator Loss : {} | Mean Discriminator Loss : {}".format(epoch + 1,
                  running_generator_loss, running_generator_loss))
        else:
            print("Epoch {} | Iteration {} | Mean Generator Loss : {} | Mean Discriminator Loss : {}".format(
                  epoch + 1, itr + 1, running_generator_loss, running_discriminator_loss))

    def train_stopper(self):
        return False

    def generator_train_iter(self, **kwargs):
        sampled_noise = torch.randn(self.batch_size, self.generator.encoding_dims, 1, 1, device=self.device)
        g_loss = self.generator_loss(self.discriminator(self.generator(sampled_noise)))
        g_loss.backward()
        self.loss_information['generator_losses'].append(g_loss)
        self.loss_information['generator_iters'] += 1

    def discriminator_train_iter(self, images, labels, **kwargs):
        sampled_noise = torch.randn(self.batch_size, self.generator.encoding_dims, 1, 1, device=self.device)
        d_real = self.discriminator(images).squeeze()
        d_loss_real = self.discriminator_loss(d_real, self.targets["discriminator_target_real"])
        d_fake = self.discriminator(self.generator(sampled_noise).detach()).squeeze()
        d_loss_fake = self.discriminator_loss(d_fake, self.targets["discriminator_target_fake"])
        d_loss = d_loss_fake + d_loss_real
        d_loss.backward()
        self.loss_information['discriminator_losses'].append(d_loss)
        self.loss_information['discriminator_iters'] += 1

    def train(self, data_loader, **kwargs):
        self.generator.train()
        self.discriminator.train()

        generator_options = {}
        discriminator_options = {}

        if "discriminator_options" in kwargs:
            discriminator_options = kwargs["discriminator_options"]
        if "generator_options" in kwargs:
            generator_options = kwargs["generator_options"]

        for epoch in range(self.start_epoch, self.epochs):

            running_generator_loss = 0.0
            running_discriminator_loss = 0.0

            for images, labels in data_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                self.discriminator.zero_grad()
                self.generator.zero_grad()
                self.discriminator_train_iter(images, labels, **discriminator_options)
                self.optimizer_discriminator.step()
                running_discriminator_loss += self.loss_information['discriminator_losses'][-1]

                self.generator.zero_grad()
                self.generator_train_iter(**generator_options)
                self.optimizer_generator.step()
                running_generator_loss += self.loss_information['generator_losses'][-1]

                # NOTE(avik-pal): A small hack to support WGAN
                if self.train_stopper():
                    break

                if self.loss_information['discriminator_iters'] % self.niter_print_losses == 0:
                    # FIXME(avik-pal): Sadly the iteration printed will be the discriminator iters
                    self.train_logger(running_generator_loss / self.loss_information['generator_iters'],
                                      running_discriminator_loss / self.loss_information['discriminator_iters'],
                                      self.loss_information['discriminator_iters'])

            self.train_logger(running_generator_loss / self.loss_information['generator_iters'],
                              running_discriminator_loss / self.loss_information['discriminator_iters'])

        print("Training of the Model is Complete")

    def __call__(self, data_loader, verbose=1, **kwargs):
        self._verbose_matching(verbose)
        self.train(data_loader, **kwargs)

In [5]:
def get_dataset():
    train_dataset = dsets.CIFAR10(root='/data/avikpal/',
                                  train=True,
                                  transform = transforms.Compose([transforms.ToTensor(),
                                                                  transforms.Normalize(mean = (0.0, 0.0, 0.0), std = (1.0, 1.0, 1.0))]),
                               download=True)
    train_loader = data.DataLoader(train_dataset, batch_size = 128, shuffle = True)
    return train_loader

In [6]:
dataset = get_dataset()

Files already downloaded and verified


In [11]:
trainer = Trainer(DCGANGenerator(out_channels=3), DCGANDiscriminator(in_channels=3), Adam, SGD, nn.BCELoss, nn.BCELoss, device=torch.device("cuda:2"), optimizer_generator_options={"lr": 0.001}, optimizer_discriminator_options={"lr": 1.0})

RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable

In [8]:
trainer(dataset)

RuntimeError: cudaEventSynchronize in future::wait: device-side assert triggered