In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader,Subset

import numpy as np
import matplotlib.pyplot as plt

from torchvision.utils import save_image
from torchvision.utils import make_grid
import sys
import os

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
imageAugment = T.Compose([
    T.Resize((64,64)),
    T.ToTensor(),
    T.Normalize([0.5,0.5,0.5],
                [0.5,0.5,0.5])])

In [7]:
batchsize = 200
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

In [8]:
dataPath = './data'

In [9]:
data =  DataLoader(torchvision.datasets.ImageFolder(dataPath,transform= imageAugment),batch_size=batchsize,shuffle=True)

In [10]:
class discriminatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    # convolution layers
    self.conv1 = nn.Conv2d(  3, 64, 4, 2, 1, bias=False)
    self.conv2 = nn.Conv2d( 64,128, 4, 2, 1, bias=False)
    self.conv3 = nn.Conv2d(128,256, 4, 2, 1, bias=False)
    self.conv4 = nn.Conv2d(256,512, 4, 2, 1, bias=False)
    self.conv5 = nn.Conv2d(512,  1, 4, 1, 0, bias=False)

    # batchnorm
    self.bn2 = nn.BatchNorm2d(128)
    self.bn3 = nn.BatchNorm2d(256)
    self.bn4 = nn.BatchNorm2d(512)
    
  def forward(self,x):
    x = F.leaky_relu( self.conv1(x) ,.2)
    x = F.leaky_relu( self.conv2(x) ,.2)
    x = self.bn2(x)
    x = F.leaky_relu( self.conv3(x) ,.2)
    x = self.bn3(x)
    x = F.leaky_relu( self.conv4(x) ,.2)
    x = self.bn4(x)
    return torch.sigmoid(self.conv5(x)).view(-1,1)

In [11]:
class generatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    # convolution layers
    self.conv1 = nn.ConvTranspose2d(200,512, 4, 1, 0, bias=False)
    self.conv2 = nn.ConvTranspose2d(512,256, 4, 2, 1, bias=False)
    self.conv3 = nn.ConvTranspose2d(256,128, 4, 2, 1, bias=False)
    self.conv4 = nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False)
    self.conv5 = nn.ConvTranspose2d(64,   3, 4, 2, 1, bias=False)

    # batchnorm
    self.bn1 = nn.BatchNorm2d(512)
    self.bn2 = nn.BatchNorm2d(256)
    self.bn3 = nn.BatchNorm2d(128)
    self.bn4 = nn.BatchNorm2d( 64)


  def forward(self,x):
    x = F.relu( self.bn1(self.conv1(x)) )
    x = F.relu( self.bn2(self.conv2(x)) )
    x = F.relu( self.bn3(self.conv3(x)) )
    x = F.relu( self.bn4(self.conv4(x)) )
    x = torch.tanh( self.conv5(x) )
    return x

In [12]:
discriminator = discriminatorNet().to(device=device)
generator = generatorNet().to(device=device)

In [13]:
optimizer_d = torch.optim.Adam(discriminator.parameters(),lr=0.001)
optimizer_g = torch.optim.Adam(generator.parameters(),lr = 0.001)

In [14]:
def trainDISCRIMINATOR(realImages,optimizer):
    optimizer.zero_grad()

    #discriminator with real images
    
    yHat_real = discriminator(realImages.to(device)).to(device)
    real_labels = torch.ones(realImages.size(0),1,device=device)
    #print(yHat_real.size())
    real_loss = F.binary_cross_entropy(yHat_real,real_labels)
   


    #generating fake data

    imageParameter = torch.randn(batchsize, 200, 1, 1)
    fake_images = generator(imageParameter.to(device))

    #discriminator with fake images

    fake_labels = torch.zeros(fake_images.size(0), 1, device=device)
    yHat_fake = discriminator(fake_images.to(device))
    fake_loss = F.binary_cross_entropy(yHat_fake,fake_labels)


    loss = real_loss + fake_loss
    loss.backward()
    optimizer.step()
    return loss.item()

In [15]:
def trainGenerator(optimizer):
    optimizer.zero_grad()

    #generating fake images
    imageParameter = torch.randn(batchsize, 200, 1, 1)
    fakeImages = generator(imageParameter.to(device))

    #fooling discriminator
    prediction = discriminator(fakeImages.to(device))
    labels = torch.ones(200,1,device=device)
    loss = F.binary_cross_entropy(prediction,labels)

    loss.backward()
    optimizer.step()

    return loss.item()

In [16]:
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

In [17]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

In [18]:
def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [19]:
def trainModel(epochs):

    generator_Losses = []
    discriminator_Losses = []
    
    for epoch in range(epochs):
        for realImages, _ in  data:
            
            lossD = trainDISCRIMINATOR(realImages,optimizer_d)

            lossG = trainGenerator(optimizer_g)

        generator_Losses.append(lossG)
        discriminator_Losses.append(lossD)

        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}".format(
        epoch+1, epochs, lossG, lossD))

        fixed_latent = torch.randn(64, 200, 1, 1, device=device)

        save_samples(epoch+1,fixed_latent,show=False)

    return generator_Losses,discriminator_Losses


In [20]:
epochs = 2000

In [21]:
history = trainModel(epochs)

Epoch [1/2000], loss_g: 10.9178, loss_d: 0.0188
Saving generated-images-0001.png
Epoch [2/2000], loss_g: 8.5219, loss_d: 0.0510
Saving generated-images-0002.png
Epoch [3/2000], loss_g: 13.1953, loss_d: 0.0687
Saving generated-images-0003.png
Epoch [4/2000], loss_g: 10.3835, loss_d: 0.0530
Saving generated-images-0004.png
Epoch [5/2000], loss_g: 11.8330, loss_d: 0.0143
Saving generated-images-0005.png
Epoch [6/2000], loss_g: 16.4586, loss_d: 0.0582
Saving generated-images-0006.png
Epoch [7/2000], loss_g: 11.8064, loss_d: 0.0099
Saving generated-images-0007.png
Epoch [8/2000], loss_g: 7.7453, loss_d: 0.0341
Saving generated-images-0008.png
Epoch [9/2000], loss_g: 11.2003, loss_d: 0.0292
Saving generated-images-0009.png
Epoch [10/2000], loss_g: 14.0679, loss_d: 0.0785
Saving generated-images-0010.png
Epoch [11/2000], loss_g: 11.7605, loss_d: 0.0199
Saving generated-images-0011.png
Epoch [12/2000], loss_g: 12.3835, loss_d: 0.0213
Saving generated-images-0012.png
Epoch [13/2000], loss_g: 13

KeyboardInterrupt: 