In [1]:
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

cuda = True if torch.cuda.is_available() else False

In [2]:
os.makedirs('data/generated', exist_ok=True)

In [3]:
img_size = 32
channels = 1
img_shape = (channels, img_size, img_size)

In [4]:
n_epochs = 200
batch_size = 64
lr = 1e-3
latent_dim = 100
num_classes = 10
sample_interval = 400

## Define Generator & Discriminator 

In [7]:
sayHi(*["wassim","matt"])

wassim matt


In [15]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(num_classes, num_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim+num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        global img_shape
        
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

In [16]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1)
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = F.sigmoid(self.model(d_in))
        return validity

In [17]:
# Loss functions
adversarial_loss = torch.nn.BCELoss()

In [18]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

In [19]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [21]:
os.makedirs('data/downloaded/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/downloaded/mnist', train=True, download=False,
                   transform=transforms.Compose([
                        transforms.Resize(img_size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True)

In [24]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

In [25]:
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [26]:
def get_noise(size):
    global latent_dim
    return Variable(FloatTensor(np.random.normal(0, 1, (size, latent_dim))))

In [27]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = get_noise(n_row**2)
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, 'data/generated/{}.png'.format(batches_done), nrow=n_row, normalize=True)

In [28]:
def train_generator(steps):
    
    for i in range(steps):
        
        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = get_noise(batch_size)
        gen_labels = Variable(LongTensor(np.random.randint(0, num_classes, batch_size)))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, Variable(torch.ones(batch_size,1)))

        g_loss.backward()
        optimizer_G.step()
        
    return g_loss

In [29]:
def train_discriminator(steps,real_imgs,labels):
    
    for step in range(steps):
        
        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, Variable(torch.ones(batch_size,1)))

        # Loss for fake images
        z = get_noise(batch_size)
        gen_labels = Variable(LongTensor(np.random.randint(0, num_classes, batch_size)))
        
        validity_fake = discriminator(generator(z, gen_labels).detach(), gen_labels)
        
        d_fake_loss = adversarial_loss(validity_fake, Variable(torch.zeros(batch_size,1)))

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        
    return d_loss

In [59]:
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))
        
        
        d_loss = train_discriminator(2,real_imgs,labels)
        g_loss = train_generator(1)
        
        if i == 0: print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, n_epochs, i, len(dataloader),
                                                            d_loss.item(), g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/200] [Batch 0/938] [D loss: 0.477635] [G loss: 0.708978]
[Epoch 1/200] [Batch 0/938] [D loss: 0.000101] [G loss: 27.438683]
[Epoch 2/200] [Batch 0/938] [D loss: 0.000618] [G loss: 27.630949]
[Epoch 3/200] [Batch 0/938] [D loss: 0.015135] [G loss: 15.387774]
[Epoch 4/200] [Batch 0/938] [D loss: 0.003236] [G loss: 24.422503]
[Epoch 5/200] [Batch 0/938] [D loss: 0.156266] [G loss: 3.582526]
[Epoch 6/200] [Batch 0/938] [D loss: 0.132298] [G loss: 7.139694]
[Epoch 7/200] [Batch 0/938] [D loss: 0.210667] [G loss: 10.697095]
[Epoch 8/200] [Batch 0/938] [D loss: 0.078921] [G loss: 20.500151]
[Epoch 9/200] [Batch 0/938] [D loss: 0.225013] [G loss: 6.086752]
[Epoch 10/200] [Batch 0/938] [D loss: 0.136525] [G loss: 5.953531]
[Epoch 11/200] [Batch 0/938] [D loss: 0.305107] [G loss: 4.950181]
[Epoch 12/200] [Batch 0/938] [D loss: 0.196492] [G loss: 11.266280]
[Epoch 13/200] [Batch 0/938] [D loss: 0.235403] [G loss: 7.950664]
[Epoch 14/200] [Batch 0/938] [D loss: 0.132192] [G loss: 5.287601

KeyboardInterrupt: 

In [None]:
os.makedirs('models', exist_ok=True)
torch.save(generator.state_dict(), './models/generator.pt')
torch.save(discriminator.state_dict(), './models/discriminator.pt')