In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid

plt.ion()

In [2]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.block = nn.Sequential(

            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,True),
            
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.block(x)
        #print(x.size())
        return x

In [3]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.block = nn.Sequential(
            #nn.Dropout2d(0.2),
            nn.Conv2d(3, 16, kernel_size=(3,3),padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
 
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
 
            nn.Conv2d(32, 64, kernel_size=(3,3),padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=(3,3),padding=1, bias=False),            
            nn.AvgPool2d(kernel_size=(4,4)),
            #nn.Linear(512,2)
        )
        self.out = nn.Sequential(nn.Linear(256,1),nn.Sigmoid())
    def forward(self, x):
        x1 = self.block(x)
        x1 = x1.view(x.size(0),-1)
        x = self.out(x1)
        #print(x.size())
        return x

gen_model = Generator().cuda()
dis_model = Discriminator().cuda()



In [4]:
avDev = torch.device("cuda:0")

In [5]:
def dataloaders(path=r"./images"):
    transforms_list = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor()
    ])
    
    dataset = datasets.ImageFolder(path, transforms_list)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return dataloader


def generate_images(gen_model, epoch_no, batch_size=16):
    gen_inp = torch.randn(batch_size, 100, 1, 1, device=avDev)
    gen_inp = gen_inp.cuda()
    fake_images = gen_model(gen_inp)
    fake_images = fake_images.cpu().detach()
    show(make_grid(fake_images, nrow=4))

    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def show(img):
    npimg = img.numpy()
    plt.figure(figsize = (8,8))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest', aspect='auto')

In [None]:
if __name__ == "__main__":

    image_size = 64
    batch_size = 16
    n_epochs = 250

    gen_lr = 1e-3
    dis_lr = 1e-4

    gen_model = Generator().cuda()
    dis_model = Discriminator().cuda()

    optimizer_gen = optim.Adam(gen_model.parameters(), lr=gen_lr, betas=(0.5, 0.999))
    optimizer_dis = optim.Adam(dis_model.parameters(), lr=dis_lr, betas=(0.5, 0.999))

    true_labels = 0.0   # CHANGE TO ZERO TO FLIP LABELS
    
    d_true_labels = Variable(torch.Tensor(batch_size).fill_(true_labels), requires_grad=False).cuda()
    g_labels = Variable(torch.Tensor(batch_size).fill_(1.0-true_labels), requires_grad=False).cuda()
      
    place_holder_true = np.ones([batch_size]) * true_labels
    place_holder_false = np.ones([batch_size]) * (1.0-true_labels)

    loss = nn.BCELoss().cuda()

    loss_list = []

    disc_dataloader = dataloaders()

    for epoch in range(n_epochs):

        total = len(disc_dataloader) * batch_size * 2
        correct_pos = 0.0
        correct_neg = 0.0
        gen_acc = 0.0

        disc_loss_epoch = 0
        gen_loss_epoch = 0

        #for batch_id, (true_images,_) in enumerate(disc_dataloader, 0): # For MNIST Example
        for batch_id, (true_images,_) in enumerate(disc_dataloader, 0):
            if epoch==0 and batch_id == 0:
                sample_true = true_images
            gen_model.train()
            dis_model.train()

            optimizer_dis.zero_grad()
            gen_inp_1 = torch.randn(batch_size, 100, 1, 1, device=avDev, requires_grad=False)

            # gen_inp = gen_inp.to(device)
            fake_images = gen_model(gen_inp_1)
            # TRAINING DISCRIMINATOR

            true_images = true_images.cuda()
            
            d_true_loss = loss(dis_model(true_images).view(-1), d_true_labels)
            d_fake_loss = loss(dis_model(fake_images.detach()).view(-1), g_labels)  # Have to detach from graph for some reason

            d_loss = (d_true_loss + d_fake_loss) / 2

            d_true_loss.backward()
            d_fake_loss.backward()
            optimizer_dis.step()

            # TRAINING GENERATOR
            optimizer_gen.zero_grad()

            # CREATING A BATCH OF RANDOM NOISE IN [0,1] FOR GENERATOR INPUT
            gen_loss = loss(dis_model(fake_images).view(-1), d_true_labels)
            gen_loss.backward()
            optimizer_gen.step()

            gen_loss_epoch += gen_loss.item()
            disc_loss_epoch += d_loss.item()

            with torch.no_grad():
                outputs = dis_model(true_images)
                predicted = (outputs > 0.5).to("cpu").numpy()
                batch_disc_acc = (predicted == place_holder_true).sum()/batch_size
                correct_pos += (predicted == place_holder_true).sum()

                outputs = dis_model(fake_images.detach())
                predicted = (outputs > 0.5).to("cpu").numpy()
                correct_neg += (predicted == place_holder_false).sum()

                outputs = dis_model(fake_images.detach())
                predicted = (outputs > 0.5).to("cpu").numpy()
                batch_gen_acc = (predicted == place_holder_true).sum()/batch_size
                gen_acc += (predicted == place_holder_true).sum()

        correct_pos/=total
        correct_neg/=total
        gen_acc/=(total / 2)

        
        loss_list.append((gen_loss_epoch/len(disc_dataloader),disc_loss_epoch/len(disc_dataloader)))
        
        if epoch==0:
          print("SAMPLE OF TRUE IMAGES..")
          true_images = sample_true.cpu().detach()
          show(make_grid(true_images, nrow=4))
          
        if epoch%10==0:
          print("SAMPLES OF GENERATED IMAGES..")
          print("Epoch {}".format(epoch))
          print("Generator and Discrimator average loss: {}, {}".
              format(gen_loss_epoch / len(disc_dataloader), disc_loss_epoch / len(disc_dataloader)))
          print("Current Discriminator Accuracy = {}".format(correct_pos+correct_neg))
          print("Current Discriminator Positives Accuracy = {}".format(correct_pos*2))
          print("Current Discriminator Negatives Accuracy = {}".format(correct_neg*2))
          print("Current Generator Accuracy = {}".format(gen_acc))
          generate_images(gen_model, epoch, 16)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLE OF TRUE IMAGES..
SAMPLES OF GENERATED IMAGES..
Epoch 0
Generator and Discrimator average loss: 0.8433278072625399, 0.7020684201270342
Current Discriminator Accuracy = 8.984375
Current Discriminator Positives Accuracy = 2.03125
Current Discriminator Negatives Accuracy = 15.9375
Current Generator Accuracy = 0.0625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 10
Generator and Discrimator average loss: 1.1027829237282276, 0.5546630038879812
Current Discriminator Accuracy = 13.4453125
Current Discriminator Positives Accuracy = 12.84375
Current Discriminator Negatives Accuracy = 14.046875
Current Generator Accuracy = 1.953125


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 20
Generator and Discrimator average loss: 1.237726324237883, 0.5462089609354734
Current Discriminator Accuracy = 13.5546875
Current Discriminator Positives Accuracy = 12.9375
Current Discriminator Negatives Accuracy = 14.171875
Current Generator Accuracy = 1.828125


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 30
Generator and Discrimator average loss: 1.6263795122504234, 0.4447324469219893
Current Discriminator Accuracy = 14.6015625
Current Discriminator Positives Accuracy = 14.359375
Current Discriminator Negatives Accuracy = 14.84375
Current Generator Accuracy = 1.15625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 40
Generator and Discrimator average loss: 1.5346604892984033, 0.4527559978887439
Current Discriminator Accuracy = 14.3359375
Current Discriminator Positives Accuracy = 14.328125
Current Discriminator Negatives Accuracy = 14.34375
Current Generator Accuracy = 1.65625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 50
Generator and Discrimator average loss: 1.9514185544103384, 0.35870601050555706
Current Discriminator Accuracy = 14.8984375
Current Discriminator Positives Accuracy = 14.90625
Current Discriminator Negatives Accuracy = 14.890625
Current Generator Accuracy = 1.109375


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 60
Generator and Discrimator average loss: 2.345680011436343, 0.3231601509032771
Current Discriminator Accuracy = 15.25
Current Discriminator Positives Accuracy = 15.15625
Current Discriminator Negatives Accuracy = 15.34375
Current Generator Accuracy = 0.65625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 70
Generator and Discrimator average loss: 2.0325023345649242, 0.37777010328136384
Current Discriminator Accuracy = 14.859375
Current Discriminator Positives Accuracy = 14.703125
Current Discriminator Negatives Accuracy = 15.015625
Current Generator Accuracy = 0.984375


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 80
Generator and Discrimator average loss: 2.46107833692804, 0.3464687436353415
Current Discriminator Accuracy = 14.984375
Current Discriminator Positives Accuracy = 15.03125
Current Discriminator Negatives Accuracy = 14.9375
Current Generator Accuracy = 1.0625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 90
Generator and Discrimator average loss: 2.516309032216668, 0.30132305034203455
Current Discriminator Accuracy = 15.40625
Current Discriminator Positives Accuracy = 15.21875
Current Discriminator Negatives Accuracy = 15.59375
Current Generator Accuracy = 0.40625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 100
Generator and Discrimator average loss: 2.4102360298857093, 0.30773115216288716
Current Discriminator Accuracy = 15.1796875
Current Discriminator Positives Accuracy = 15.125
Current Discriminator Negatives Accuracy = 15.234375
Current Generator Accuracy = 0.765625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 110
Generator and Discrimator average loss: 2.454275853931904, 0.33393732469994575
Current Discriminator Accuracy = 15.03125
Current Discriminator Positives Accuracy = 14.90625
Current Discriminator Negatives Accuracy = 15.15625
Current Generator Accuracy = 0.84375


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 120
Generator and Discrimator average loss: 2.804047922603786, 0.28140960415476
Current Discriminator Accuracy = 15.296875
Current Discriminator Positives Accuracy = 15.09375
Current Discriminator Negatives Accuracy = 15.5
Current Generator Accuracy = 0.5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 130
Generator and Discrimator average loss: 2.919398915488273, 0.25064292081515305
Current Discriminator Accuracy = 15.484375
Current Discriminator Positives Accuracy = 15.5625
Current Discriminator Negatives Accuracy = 15.40625
Current Generator Accuracy = 0.59375


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 140
Generator and Discrimator average loss: 2.734283212572336, 0.2856207637814805
Current Discriminator Accuracy = 15.4609375
Current Discriminator Positives Accuracy = 15.265625
Current Discriminator Negatives Accuracy = 15.65625
Current Generator Accuracy = 0.34375


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 150
Generator and Discrimator average loss: 2.989222090691328, 0.24727321192040108
Current Discriminator Accuracy = 15.4609375
Current Discriminator Positives Accuracy = 15.390625
Current Discriminator Negatives Accuracy = 15.53125
Current Generator Accuracy = 0.46875


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 160
Generator and Discrimator average loss: 3.4722580751404166, 0.2592405207105912
Current Discriminator Accuracy = 15.3671875
Current Discriminator Positives Accuracy = 15.5
Current Discriminator Negatives Accuracy = 15.234375
Current Generator Accuracy = 0.765625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 170
Generator and Discrimator average loss: 2.9757222793996334, 0.24509906495222822
Current Discriminator Accuracy = 15.546875
Current Discriminator Positives Accuracy = 15.359375
Current Discriminator Negatives Accuracy = 15.734375
Current Generator Accuracy = 0.265625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 180
Generator and Discrimator average loss: 2.94578024931252, 0.20245543465716764
Current Discriminator Accuracy = 15.6953125
Current Discriminator Positives Accuracy = 15.640625
Current Discriminator Negatives Accuracy = 15.75
Current Generator Accuracy = 0.25


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 190
Generator and Discrimator average loss: 3.588310773484409, 0.24698671052465215
Current Discriminator Accuracy = 15.4765625
Current Discriminator Positives Accuracy = 15.609375
Current Discriminator Negatives Accuracy = 15.34375
Current Generator Accuracy = 0.65625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 200
Generator and Discrimator average loss: 3.5932771526277065, 0.2598382553551346
Current Discriminator Accuracy = 15.515625
Current Discriminator Positives Accuracy = 15.546875
Current Discriminator Negatives Accuracy = 15.484375
Current Generator Accuracy = 0.515625


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


SAMPLES OF GENERATED IMAGES..
Epoch 210
Generator and Discrimator average loss: 3.542491387575865, 0.22385903703980148
Current Discriminator Accuracy = 15.5625
Current Discriminator Positives Accuracy = 15.546875
Current Discriminator Negatives Accuracy = 15.578125
Current Generator Accuracy = 0.421875
