In [1]:
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
#make image directory for output
import os
os.makedirs("img", exist_ok=True)


#hyperparams
n_epochs = 200
batch_size = 64
lr = 0.0001
sample_interval = 1000


#MNIST
img_size = 32
n_classes = 10
channels = 1
img_shape = (channels, img_size, img_size)



In [3]:

class G(nn.Module):
    
    #Vanilla Gen was used, might add DCGan style
    
    def __init__(self):
        super(G, self).__init__()
        
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        self.chunk1 = [nn.Linear(100 + n_classes, 128) ,
                       nn.LeakyReLU(0.2, inplace=True) ]
        
        self.chunk2 = [nn.Linear(128 , 256) ,
                       nn.BatchNorm1d(256, 0.8), 
                       nn.LeakyReLU(0.2, inplace=True) ]
        
        self.chunk3 = [nn.Linear( 256, 512) ,
                       nn.BatchNorm1d(512, 0.8), 
                       nn.LeakyReLU(0.2, inplace=True) ]
        
        self.chunk4 = [nn.Linear(512 , 1024) ,
                       nn.BatchNorm1d(1024, 0.8), 
                       nn.LeakyReLU(0.2, inplace=True) ]
        
        self.model = nn.Sequential(*self.chunk1, *self.chunk2, *self.chunk3, *self.chunk4,
                                    nn.Linear(1024, int(np.prod(img_shape) )), nn.Tanh() )
        

    def forward(self, noise, labels):
        # conditional information imposed by concatenation
        # as in the paper
        gen = torch.cat( (self.label_emb(labels) , noise), -1 )  # just concat the labels with noise
        img = self.model(gen)
        img = img.view(img.size(0), channels, img_size, img_size)
        return img
        



In [4]:
class D(nn.Module):
    def __init__(self):
        super(D, self).__init__()

        self.label_embedding = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.Linear(n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity


# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = G()
discriminator = D()

#if you don't have cuda comment this
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()



os.makedirs("./mnist", exist_ok=True)

# MNIST dataloadder
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)


# Optimizers
#optimizer_G = torch.optim.SGD(generator.parameters(), lr=lr, momentum = 0.9  )
#optimizer_D = torch.optim.SGD(discriminator.parameters(), lr=lr, momentum = 0.9 )


optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.99))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5 , 0.99))



# if you don't have cuda use torch.FloatTensor
FloatTensor = torch.cuda.FloatTensor 
LongTensor = torch.cuda.LongTensor 


# stitch images for 0~9 for 20 rows
def sample_image(n_row, i):
    # noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, 100))))
    # Get labels 0~9
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "img/%d.png" % i, nrow=n_row, normalize=True)


counter = 0

for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        #Split data for training
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) 
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        
        #G
        optimizer_G.zero_grad()

        # Generate images to fool Discriminator
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
        gen_labels = Variable(LongTensor(np.random.randint(0, n_classes, batch_size)))
        gen_imgs = generator(z, gen_labels)

        
        #Loss, judged by Discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()


        
        
        #D
        optimizer_D.zero_grad()

        #Real image Sampleed
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)
        
        #Distinguish fakes generated by Generator
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        #Score on Real + Fake
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        if counter % sample_interval == 0:
            sample_image(10, i=counter)
        counter = counter + 1
        
        
        if i%100 == 0:  # print every 100 batch
            print(
                "[%d]-[%d] [Gen loss: %f] [Disc loss: %f]"
                % (epoch, i,  g_loss.item(), d_loss.item())
            )
            
            


[0]-[0] [Gen loss: 0.966336] [Disc loss: 0.510321]
[0]-[100] [Gen loss: 0.698209] [Disc loss: 0.053484]
[0]-[200] [Gen loss: 0.653012] [Disc loss: 0.091736]
[0]-[300] [Gen loss: 0.452887] [Disc loss: 0.097728]
[0]-[400] [Gen loss: 0.663939] [Disc loss: 0.072437]
[0]-[500] [Gen loss: 0.694015] [Disc loss: 0.089367]
[0]-[600] [Gen loss: 0.673241] [Disc loss: 0.075200]
[0]-[700] [Gen loss: 0.121877] [Disc loss: 0.315646]
[0]-[800] [Gen loss: 0.376481] [Disc loss: 0.099120]
[0]-[900] [Gen loss: 0.704450] [Disc loss: 0.079010]
[1]-[0] [Gen loss: 0.942351] [Disc loss: 0.099927]
[1]-[100] [Gen loss: 0.530000] [Disc loss: 0.105590]
[1]-[200] [Gen loss: 0.607756] [Disc loss: 0.089005]
[1]-[300] [Gen loss: 0.613646] [Disc loss: 0.073544]
[1]-[400] [Gen loss: 0.942642] [Disc loss: 0.093918]
[1]-[500] [Gen loss: 0.833987] [Disc loss: 0.068996]
[1]-[600] [Gen loss: 0.667217] [Disc loss: 0.077686]
[1]-[700] [Gen loss: 0.375414] [Disc loss: 0.125738]
[1]-[800] [Gen loss: 0.916494] [Disc loss: 0.11160

[15]-[500] [Gen loss: 0.519285] [Disc loss: 0.121045]
[15]-[600] [Gen loss: 0.761738] [Disc loss: 0.119335]
[15]-[700] [Gen loss: 0.458266] [Disc loss: 0.128974]
[15]-[800] [Gen loss: 0.481872] [Disc loss: 0.118651]
[15]-[900] [Gen loss: 0.557914] [Disc loss: 0.122660]
[16]-[0] [Gen loss: 1.097012] [Disc loss: 0.147661]
[16]-[100] [Gen loss: 0.703295] [Disc loss: 0.139589]
[16]-[200] [Gen loss: 0.593162] [Disc loss: 0.132479]
[16]-[300] [Gen loss: 0.570923] [Disc loss: 0.123098]
[16]-[400] [Gen loss: 0.661902] [Disc loss: 0.145068]
[16]-[500] [Gen loss: 0.447174] [Disc loss: 0.128048]
[16]-[600] [Gen loss: 0.905168] [Disc loss: 0.156794]
[16]-[700] [Gen loss: 0.780286] [Disc loss: 0.117161]
[16]-[800] [Gen loss: 0.506116] [Disc loss: 0.142845]
[16]-[900] [Gen loss: 0.774314] [Disc loss: 0.133119]
[17]-[0] [Gen loss: 0.666042] [Disc loss: 0.102858]
[17]-[100] [Gen loss: 0.598704] [Disc loss: 0.121322]
[17]-[200] [Gen loss: 0.172956] [Disc loss: 0.243160]
[17]-[300] [Gen loss: 0.640065] 

[30]-[800] [Gen loss: 0.583136] [Disc loss: 0.142850]
[30]-[900] [Gen loss: 0.942855] [Disc loss: 0.197316]
[31]-[0] [Gen loss: 0.627682] [Disc loss: 0.147970]
[31]-[100] [Gen loss: 0.654132] [Disc loss: 0.143340]
[31]-[200] [Gen loss: 0.673385] [Disc loss: 0.143640]
[31]-[300] [Gen loss: 0.546776] [Disc loss: 0.133517]
[31]-[400] [Gen loss: 0.242257] [Disc loss: 0.210667]
[31]-[500] [Gen loss: 0.631876] [Disc loss: 0.161326]
[31]-[600] [Gen loss: 0.578474] [Disc loss: 0.100766]
[31]-[700] [Gen loss: 0.668312] [Disc loss: 0.128598]
[31]-[800] [Gen loss: 0.660834] [Disc loss: 0.137747]
[31]-[900] [Gen loss: 0.570593] [Disc loss: 0.129298]
[32]-[0] [Gen loss: 0.452399] [Disc loss: 0.126049]
[32]-[100] [Gen loss: 0.465499] [Disc loss: 0.151089]
[32]-[200] [Gen loss: 0.641081] [Disc loss: 0.159161]
[32]-[300] [Gen loss: 0.490765] [Disc loss: 0.150389]
[32]-[400] [Gen loss: 0.551507] [Disc loss: 0.174942]
[32]-[500] [Gen loss: 0.508534] [Disc loss: 0.144715]
[32]-[600] [Gen loss: 0.487717] 

KeyboardInterrupt: 