In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.utils as vutils

import argparse
import sys
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

"Import libraries"

In [68]:
class cganG(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(cganG, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)

        self.model = nn.Sequential(
            *self._create_layer(self.latent_dim + self.classes, 128, False),
            *self._create_layer(128, 256),
            *self._create_layer(256, 512),
            *self._create_layer(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def _create_layer(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, noise, labels):
        z = torch.cat((self.label_embedding(labels), noise), -1)
        x = self.model(z)
        x = x.view(x.size(0), *self.img_shape)
        return x

"Define the generator"
#The generator class consists for three functions by using torch.nn.Module which help you build your network models easily: __init__, _create_layer,and forward.
#The __init__ method is where we typically define the attributes of a class. You can do any setup here. I set the number of classes, the number of channels, the size of image, the dimension of latent vector, and the nn.Embedding module. This module is  simple lookup table that stores embeddings of a fixed dictionary and size. It is used to process the label information with the random latent vector.
#_create_layer is where we define layers. It consists of 5 linear layers, 3 of which are connected to batch normalization layers, and the first 4 linear layers have LeakyReLu activation functions while the last has a Tahn activation function. Batch normaliazation is a method for the extracted features in the hidden units to make training faster and more stable.
#The forward method is called when we use the neural network to make a prediction. torch.cat is used to concatenate the given sequence of seq tensors in the given dimension, -1 here.

In [70]:
class cganD(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(cganD, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = torch.nn.BCELoss()

        self.model = nn.Sequential(
            *self._create_layer(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer(1024, 512, True, True),
            *self._create_layer(512, 256, True, True),
            *self._create_layer(256, 128, False, False),
            *self._create_layer(128, 1, False, False),
            nn.Sigmoid()
        )

    def _create_layer(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, image, labels):
        x = torch.cat((image.view(image.size(0), -1), self.label_embedding(labels)), -1)
        return self.model(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

"Define the discriminator"
The discriminator outputs a single value to show how close an image is to the real images as given the label information. The network consists of 5 linear layers, 2 of which are connected to dropout layers to to prevent overfitting. It makes all the nodes work well as a team by making sure no node is too weak or too strong through some neurons are not included in a particular forward or backward pass. BCE loss function is typically used for the binary classification tasks.

In [42]:
class Model(object):
    def __init__(self,
                 device,
                 data_loader,
                 classes,
                 channels,
                 img_size,
                 latent_dim):
        self.device = device
        self.data_loader = data_loader
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.netG = cganG(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netG.to(self.device)
        self.netD = cganD(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netD.to(self.device)
        self.optim_G = None
        self.optim_D = None

    def create_optim(self, lr, alpha=0.5, beta=0.999):
        self.optim_G = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        self.netG.parameters()),
                                        lr=lr,
                                        betas=(alpha, beta))
        self.optim_D = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        self.netD.parameters()),
                                        lr=lr,
                                        betas=(alpha, beta))

    def train(self,
              epochs,
              log_interval,
              out_dir=''):
        netG.train()
        netD.train()
        viz_noise = torch.randn(self.data_loader.batch_size, self.latent_dim, device=self.device)
        nrows = self.data_loader.batch_size // 8
        viz_label = torch.LongTensor(np.array([num for _ in range(nrows) for num in range(8)])).to(self.device)

        for epoch in range(epochs):
            batch_time = time.time()
            for batch_idx, (data, target) in enumerate(self.data_loader):
                data, target = data.to(self.device), target.to(self.device)
                batch_size = data.size(0)
                real_label = torch.full((batch_size, 1), 1., device=self.device)
                fake_label = torch.full((batch_size, 1), 0., device=self.device)

                # Train G
                self.netG.zero_grad()
                z_noise = torch.randn(batch_size, self.latent_dim, device=self.device)
                x_fake_labels = torch.randint(0, self.classes, (batch_size,), device=self.device)
                x_fake = self.netG(z_noise, x_fake_labels)
                y_fake_g = self.netD(x_fake, x_fake_labels)
                g_loss = self.netD.loss(y_fake_g, real_label)
                g_loss.backward()
                self.optim_G.step()

                # Train D
                self.netD.zero_grad()
                y_real = self.netD(data, target)
                d_real_loss = self.netD.loss(y_real, real_label)
                y_fake_d = self.netD(x_fake.detach(), x_fake_labels)
                d_fake_loss = self.netD.loss(y_fake_d, fake_label)
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                self.optim_D.step()

                if batch_idx % log_interval == 0 and batch_idx > 0:
                    print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
                              epoch, batch_idx, len(self.data_loader),
                              d_loss.mean().item(),
                              g_loss.mean().item(),
                              time.time() - batch_time))
                    #Show a real image
                    #vutils.save_image(data, os.path.join(out_dir, 'real_samples.png'), normalize=True)
                    with torch.no_grad():
                        viz_sample = self.netG(viz_noise, viz_label)
                        #Show a fake image
                        #vutils.save_image(viz_sample, os.path.join(out_dir, 'fake_samples_{}.png'.format(epoch)), nrow=8, normalize=True)
                    batch_time = time.time()

            #self.save_to(path=out_dir, name=self.name, verbose=False)
            print('Total train time: {:.2f}'.format(time.time() - total_time))

    def eval(self,
             mode=None,
             batch_size=None):
        self.netG.eval()
        self.netD.eval()
        if batch_size is None:
            batch_size = self.data_loader.batch_size
        nrows = batch_size // 8
        viz_labels = np.array([num for _ in range(nrows) for num in range(8)])
        viz_labels = torch.LongTensor(viz_labels).to(self.device)

        with torch.no_grad():
            viz_tensor = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
            viz_sample = self.netG(viz_tensor, viz_labels)
            viz_vector = utils.to_np(viz_tensor).reshape(batch_size, self.latent_dim)
            cur_time = datetime.now().strftime("%Y%m%d-%H%M%S")
            np.savetxt('vec_{}.txt'.format(cur_time), viz_vector)
            vutils.save_image(viz_sample, 'img_{}.png'.format(cur_time), nrow=8, normalize=True)

"Define model training"
#__init__: the generator(netG) and the discriminator(netD) are initialized based on the class number(classes), image channel(channels), image size(img_size_), and the length of the latent vector(latent_dim). The optim_G and optim_D are optimizers for the two networks.
#create_optim: define optimizers.
#train: parameters of the train funcntion are the number of training epochs(epochs), the message interval you can check results during training(log_interval), and the output directory(out_dir). self.netG.train() and self.netD.train() turn on train mode. Then, latent vectors (viz_nois) and label (viz_label) are defined. They are used to occasionally produce images during training so that we can track how the model is trained. Torch.randn returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1. Torch.LongTensor defines 64-bit integer(signed) as a data type. 

In [16]:
FLAGS = None

def main():
    device = torch.device("cuda:0" if FLAGS.cuda else "cpu")

    if FLAGS.train:
        print('Loading data...\n')
        dataset = dset.MNIST(root=FLAGS.data_dir, download=False,
                             transform=transforms.Compose([
                             transforms.Resize(FLAGS.img_size),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5,), (0.5,))
                             ]))
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=FLAGS.batch_size,
                                                 shuffle=True, num_workers=4, pin_memory=True)
        print('Creating model...\n')
        model = Model(FLAGS.model, device, dataloader, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
        model.create_optim(FLAGS.lr)

        # Train
        model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)

        model.save_to('')
    else:
        model = Model(FLAGS.model, device, None, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
        model.load_from(FLAGS.out_dir)
        model.eval(mode=1, batch_size=FLAGS.batch_size)   

if __name__ == '__main__':
    args, unknown = parser.parse_known_args()
    parser.add_argument('--model', type=str, default='cgan', help='one of `cgan` and `infogan`.')
    parser.add_argument('--cuda', type=boolean_string, default=True, help='enable CUDA.')
    parser.add_argument('--train', type=boolean_string, default=True, help='train mode or eval mode.')
    parser.add_argument('--data_dir', type=str, default='~/data/MNIST', help='Directory for dataset.')
    parser.add_argument('--out_dir', type=str, default='output', help='Directory for output.')
    parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='size of batches')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
    parser.add_argument('--latent_dim', type=int, default=100, help='latent space dimension')
    parser.add_argument('--classes', type=int, default=10, help='number of classes')
    parser.add_argument('--img_size', type=int, default=64, help='size of images')
    parser.add_argument('--channels', type=int, default=1, help='number of image channels')
    parser.add_argument('--log_interval', type=int, default=100, help='interval between logging and image sampling')
    parser.add_argument('--seed', type=int, default=1, help='random seed')

    FLAGS = parser.parse_args()
    FLAGS.cuda = FLAGS.cuda and torch.cuda.is_available()

    if FLAGS.seed is not None:
        torch.manual_seed(FLAGS.seed)
        if FLAGS.cuda:
            torch.cuda.manual_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)

    cudnn.benchmark = True

    if FLAGS.train:
        utils.clear_folder(FLAGS.out_dir)

    log_file = os.path.join(FLAGS.out_dir, 'log.txt')
    print("Logging to {}\n".format(log_file))
    sys.stdout = utils.StdOut(log_file)

    print("PyTorch version: {}".format(torch.__version__))
    print("CUDA version: {}\n".format(torch.version.cuda))

    print(" " * 9 + "Args" + " " * 9 + "|    " + "Type" + \
          "    |    " + "Value")
    print("-" * 50)
    for arg in vars(FLAGS):
        arg_str = str(arg)
        var_str = str(getattr(FLAGS, arg))
        type_str = str(type(getattr(FLAGS, arg)).__name__)
        print("  " + arg_str + " " * (20-len(arg_str)) + "|" + \
              "  " + type_str + " " * (10-len(type_str)) + "|" + \
              "  " + var_str)

    main()

ArgumentError: argument --model: conflicting option string: --model

In [64]:
data_dir = './data'
img_size = 64
batch_size = 64
classes = 10
device = torch.device("cpu")
channels = 1
latent_dim = 100

In [51]:
dataset = dset.MNIST(root=data_dir, download=False,
                    transform=transforms.Compose([
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                    ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, num_workers=4, pin_memory=True)
model = Model(FLAGS.model, device, dataloader, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
#model.create_optim(FLAGS.lr)

In [66]:
netG = Generator().to(device)
netD = Discriminator().to(device)

NameError: name 'Generator' is not defined

In [71]:
Model(device, dataloader, classes, channels, img_size, latent_dim)

<__main__.Model at 0x1ffc4981bb0>

In [72]:
Model.train(epochs=10, log_interval=100, out_dir='output')

TypeError: train() missing 1 required positional argument: 'self'

In [38]:
class TestClass:
    def __init__(self):
        print("in init")
    def testFunc(self):
        print("in Test Func")

In [39]:
testInstance = TestClass()

in init


In [40]:
testInstance.testFunc()

in Test Func


"Run the model"
#main(): The FLAGS object stores all the arguments and hyper-parameters needed for model definition and training. To make the configuration of the arguments more user-friendly, we will use the argparse module provided by Python.