In [2]:
import numpy as np  
import os 
import argparse
import torch 
from torch.autograd import Variable
import torchvision.transforms as transforms
import random
from torch.utils.data import DataLoader
from torchvision import datasets 
import torch.nn as nn
import torch.nn.functional as F 
import torchvision.utils as vutils


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False, help='cifar10 | lsun | mnist',default='mnist')
parser.add_argument('--dataroot', required=False, help='path to data',default='../../affect-recognition/data/pics/')
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='image size input')
parser.add_argument('--channels', type=int, default=1, help='number of channels')
parser.add_argument('--latentdim', type=int, default=100, help='size of latent vector')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes in data set')
parser.add_argument('--epoch', type=int, default=200, help='number of epoch')
parser.add_argument('--lrate', type=float, default=0.0002, help='learning rate')
parser.add_argument('--beta', type=float, default=0.5, help='beta for adam optimizer')
parser.add_argument('--beta1', type=float, default=0.999, help='beta1 for adam optimizer')
parser.add_argument('--output', default='../../affect-recognition/data/pics/', help='folder to output images and model checkpoints')
parser.add_argument('--randomseed', type=int, help='seed',default=42)

_StoreAction(option_strings=['--randomseed'], dest='randomseed', nargs=None, const=None, default=42, type=<class 'int'>, choices=None, help='seed', metavar=None)

In [3]:
opt,u = parser.parse_known_args()

In [4]:
img_shape = (opt.channels, opt.imageSize, opt.imageSize)
print(img_shape)
cuda = True if torch.cuda.is_available() else False 

os.makedirs(opt.output, exist_ok=True)

if opt.randomseed is None: 
    opt.randomseed = random.randint(1,10000)
random.seed(opt.randomseed)
torch.manual_seed(opt.randomseed)

# preprocessing for mnist, lsun, cifar10
if opt.dataset == 'mnist': 
    dataset = datasets.MNIST(root = opt.dataroot, train=True,download=True, 
        transform=transforms.Compose([transforms.Resize(opt.imageSize), 
            transforms.ToTensor(), 
            transforms.Normalize((0.5,), (0.5,))]))

elif opt.dataset == 'lsun': 
    dataset = datasets.LSUN(root = opt.dataroot, train=True,download=True, 
        transform=transforms.Compose([transforms.Resize(opt.imageSize), 
            transforms.CenterCrop(opt.imageSize),
            transforms.ToTensor(), 
            transforms.Normalize((0.5,), (0.5,))]))

elif opt.dataset == 'cifar10':  
    dataset = datasets.CIFAR10(root = opt.dataroot, train=True,download=True, 
        transform=transforms.Compose([transforms.Resize(opt.imageSize), 
            transforms.ToTensor(), 
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))



assert dataset 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = opt.batchSize, shuffle=True)

# building generator
class Generator(nn.Module): 
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embed = nn.Embedding(opt.n_classes, opt.n_classes)
        self.depth=128

        def init(input, output, normalize=True): 
            layers = [nn.Linear(input, output)]
            if normalize: 
                layers.append(nn.BatchNorm1d(output, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers 

        self.generator = nn.Sequential(

            *init(opt.latentdim+opt.n_classes, self.depth), 
            *init(self.depth, self.depth*2), 
            *init(self.depth*2, self.depth*4), 
            *init(self.depth*4, self.depth*8),
            nn.Linear(self.depth * 8, int(np.prod(img_shape))),
            nn.Tanh()

            )

    # torchcat needs to combine tensors 
    def forward(self, noise, labels): 
        gen_input = torch.cat((self.label_embed(labels), noise), -1)
        img = self.generator(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module): 
    def __init__(self): 
        super(Discriminator, self).__init__()
        self.label_embed1 = nn.Embedding(opt.n_classes, opt.n_classes)
        self.dropout = 0.4 
        self.depth = 512

        def init(input, output, normalize=True): 
            layers = [nn.Linear(input, output)]
            if normalize: 
                layers.append(nn.Dropout(self.dropout))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers 

        self.discriminator = nn.Sequential(
            *init(opt.n_classes+int(np.prod(img_shape)), self.depth, normalize=False),
            *init(self.depth, self.depth), 
            *init(self.depth, self.depth),
            nn.Linear(self.depth, 1),
            nn.Sigmoid()
            )

    def forward(self, img, labels): 
        imgs = img.view(img.size(0),-1)
        inpu = torch.cat((imgs, self.label_embed1(labels)), -1)
        validity = self.discriminator(inpu)
        return validity 


# weight initialization
def init_weights(m): 
    if type(m)==nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)


# Building generator 
generator = Generator()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=opt.lrate, betas=(opt.beta, opt.beta1))

# Building discriminator  
discriminator = Discriminator()
discriminator.apply(init_weights)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=opt.lrate, betas=(opt.beta, opt.beta1))

# Loss functions 
a_loss = torch.nn.BCELoss()

# Labels 
real_label = 0.9
fake_label = 0.0

FT = torch.LongTensor
FT_a = torch.FloatTensor

if cuda: 
    generator.cuda()
    discriminator.cuda()
    a_loss.cuda()
FT = torch.cuda.LongTensor
FT_a = torch.cuda.FloatTensor

# training 
for epoch in range(opt.epoch): 
    for i, (imgs, labels) in enumerate(dataloader): 
        batch_size = imgs.shape[0]

        # convert img, labels into proper form 
        imgs = Variable(imgs.type(FT_a))
        labels = Variable(labels.type(FT))

        # creating real and fake tensors of labels 
        reall = Variable(FT_a(batch_size,1).fill_(real_label))
        f_label = Variable(FT_a(batch_size,1).fill_(fake_label))

        # initializing gradient
        gen_optimizer.zero_grad() 
        d_optimizer.zero_grad()

        #### TRAINING GENERATOR ####
        # Feeding generator noise and labels 
        noise = Variable(FT_a(np.random.normal(0, 1,(batch_size, opt.latentdim))))
        gen_labels = Variable(FT(np.random.randint(0, opt.n_classes, batch_size)))

        gen_imgs = generator(noise, gen_labels)

        # Ability for discriminator to discern the real v generated images 
        validity = discriminator(gen_imgs, gen_labels)

        # Generative loss function 
        g_loss = a_loss(validity, reall)

        # Gradients 
        g_loss.backward()
        gen_optimizer.step()

        #### TRAINING DISCRIMINTOR ####

        d_optimizer.zero_grad()

        # Loss for real images and labels 
        validity_real = discriminator(imgs, labels)
        d_real_loss = a_loss(validity_real, reall)

        # Loss for fake images and labels 
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = a_loss(validity_fake, f_label)

        # Total discriminator loss 
        d_loss = 0.5 * (d_fake_loss+d_real_loss)

        # calculates discriminator gradients
        d_loss.backward()
        d_optimizer.step()


        if i%100 == 0: 
            vutils.save_image(gen_imgs, '%s/real_samples.png' % opt.output, normalize=True)
            fake = generator(noise, gen_labels)
            vutils.save_image(fake.detach(), '%s/fake_samples_epoch_%03d.png' % (opt.output, epoch), normalize=True)

    print("[Epoch: %d/%d]" "[D loss: %f]" "[G loss: %f]" % (epoch+1, opt.epoch, d_loss.item(), g_loss.item()))

    # checkpoints 
    torch.save(generator.state_dict(), '%s/generator_epoch_%d.pth' % (opt.output, epoch))
    torch.save(discriminator.state_dict(), '%s/generator_epoch_%d.pth' % (opt.output, epoch))

(1, 32, 32)




[Epoch: 1/200][D loss: 0.446146][G loss: 1.147342]
[Epoch: 2/200][D loss: 0.424998][G loss: 1.095250]
[Epoch: 3/200][D loss: 0.538535][G loss: 1.661360]
[Epoch: 4/200][D loss: 0.478863][G loss: 1.246031]
[Epoch: 5/200][D loss: 0.450313][G loss: 1.370050]
[Epoch: 6/200][D loss: 0.472542][G loss: 1.548119]
[Epoch: 7/200][D loss: 0.569261][G loss: 1.175852]
[Epoch: 8/200][D loss: 0.509154][G loss: 1.520882]
[Epoch: 9/200][D loss: 0.620282][G loss: 0.997288]
[Epoch: 10/200][D loss: 0.584178][G loss: 1.535263]
[Epoch: 11/200][D loss: 0.556249][G loss: 1.195990]
[Epoch: 12/200][D loss: 0.612358][G loss: 1.374866]
[Epoch: 13/200][D loss: 0.640880][G loss: 1.130364]
[Epoch: 14/200][D loss: 0.579006][G loss: 1.432647]
[Epoch: 15/200][D loss: 0.576021][G loss: 0.794312]
[Epoch: 16/200][D loss: 0.589840][G loss: 1.078212]
[Epoch: 17/200][D loss: 0.530936][G loss: 1.193568]
[Epoch: 18/200][D loss: 0.596794][G loss: 0.936076]
[Epoch: 19/200][D loss: 0.593174][G loss: 0.858233]
[Epoch: 20/200][D los

[Epoch: 158/200][D loss: 0.595762][G loss: 0.979018]
[Epoch: 159/200][D loss: 0.425997][G loss: 2.507332]
[Epoch: 160/200][D loss: 0.475792][G loss: 2.171040]
[Epoch: 161/200][D loss: 0.554022][G loss: 0.931997]
[Epoch: 162/200][D loss: 0.477525][G loss: 1.947245]
[Epoch: 163/200][D loss: 0.496748][G loss: 1.199639]
[Epoch: 164/200][D loss: 0.564644][G loss: 1.279712]
[Epoch: 165/200][D loss: 0.544490][G loss: 1.424178]
[Epoch: 166/200][D loss: 0.679059][G loss: 2.584683]
[Epoch: 167/200][D loss: 0.332812][G loss: 3.166833]
[Epoch: 168/200][D loss: 0.572656][G loss: 1.735618]
[Epoch: 169/200][D loss: 0.525864][G loss: 1.282853]
[Epoch: 170/200][D loss: 0.619113][G loss: 1.208614]
[Epoch: 171/200][D loss: 0.604309][G loss: 0.871398]
[Epoch: 172/200][D loss: 0.602969][G loss: 1.249989]
[Epoch: 173/200][D loss: 0.550324][G loss: 1.163842]
[Epoch: 174/200][D loss: 0.326341][G loss: 1.891591]
[Epoch: 175/200][D loss: 0.501505][G loss: 1.319116]
[Epoch: 176/200][D loss: 0.566315][G loss: 1.7