In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

from PIL import Image
import matplotlib.pyplot as plt

import copy

from pytorch_datasetloader import *

# The Discriminator

In [2]:
class Dis(nn.Module):
    
    
    def __init__(self):
        
        super().__init__()
        
        #block 1 input = 3*128*128 
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=15, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        #block 2 input = 15*64*64
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=15,out_channels=50, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        #block 3 input = 50*32*32
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=50,out_channels=200, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
     
        #self.drop_out = nn.Dropout()
        self.fc1 = torch.nn.Linear(200*16*16, 120)
        self.fc2 = torch.nn.Linear(120, 60)
        
        self.outlayer = nn.Sequential(
            nn.Linear(60, 1),
            nn.Sigmoid()
        )
        
        
        
    def forward(self, x):
        

        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.reshape(out.size(0), -1)
        #out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.outlayer(out)
        
        
        return(out)
    


# The Generator

In [3]:
#pooling for rescaling was avoided in favor of strided convolutions
#based on following paper: https://arxiv.org/pdf/1606.03498.pdf
class Gen(nn.Module):
    

    
    def __init__(self):
        
        super().__init__()
        

        self.encode = nn.Sequential(
            
            #encoding 1
            nn.Conv2d(in_channels=3,out_channels=15, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(15),
            
            #encoding 2
            nn.Conv2d(in_channels=15,out_channels=50, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(50),
            
            #encoding 3
            nn.Conv2d(in_channels=50,out_channels=200, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(200)
            
            
        )
        
        
        self.decode = nn.Sequential(
            
            #decoding 1
            nn.ConvTranspose2d(in_channels=200,out_channels=50, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(50),
            
            #decoding 2
            nn.ConvTranspose2d(in_channels=50,out_channels=15, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(15),
            
            #decoding 3
            nn.ConvTranspose2d(in_channels=15,out_channels=3, kernel_size=4, stride=2, padding=1)
        )
        
        
        
        
    def forward(self, x):
        

        out = self.encode(x)
        out = self.decode(out)
  
        return(out)




In [4]:
#using cuda if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

D = Dis()
D = D.to(device)

G = Gen()
G = G.to(device)


cuda:0


# Loading Dataset

In [7]:
batch_size_Realpix = 32
batch_size_mixed = 64
dataset_len =  100

path = 'RealPix_224'
trainloader_nonpixel, validloader_nonpixel = get_loaders(path,batch_size=batch_size_Realpix)

path = 'mixed'
trainloader_pixel, validloader_pixel = get_loaders(path,batch_size=batch_size_mixed)


trainiter_nonpixel = iter(trainloader_nonpixel)
trainIter_pixel = iter(trainloader_pixel)




224 images from the dataset
448 images from the dataset


# Selecting Optimizer 

In [8]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

# Training the DCGAN

In [11]:
# Train the model
total_step = len(trainloader_pixel)
d_loss_list = []
g_loss_list = []
acc_list = []
num_epochs = 1000
sample_imflag = 0

for epoch in range(num_epochs):
    
    trainiter_nonpixel = iter(trainloader_nonpixel)
    trainIter_pixel = iter(trainloader_pixel)
    
    for i in range(2):
        
        images, labels = trainIter_pixel.next()
        inputImages, input_lbls = trainiter_nonpixel.next()
        
        
        images = images.to(device)
        inputImages = inputImages.to(device)
        
        labels = torch.tensor(labels, dtype=torch.float, device=device)
        
        # Run the forward pass
        real_classification = D(images) 
        d_loss_real = criterion(real_classification, labels)
        fake_images = G(inputImages)
        
        fake_classification = D(fake_images)
        
        fake_lbls = torch.zeros(len(input_lbls))
        fake_lbls = fake_lbls.to(device)
        
        d_loss_fake = criterion(fake_classification, fake_lbls)
        d_loss = d_loss_real + d_loss_fake
        
        
        

        # Backprop and perform Adam optimisation for Discriminator
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        
        
        #Discriminator
        fake_lbls_duped = torch.ones(len(input_lbls))
        fake_lbls_duped = fake_lbls_duped.to(device)
    
        fake_images = G(inputImages)
        fake_classification = D(fake_images)
    
        #this loss denotes how well the generator duped/tricked the discriminator
        g_loss = criterion(fake_classification, fake_lbls_duped)
        
        
        
        # Backprop and perform Adam optimisation for generator
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
      
    if (epoch+1) % 10 == 0:
        print('Epoch No [{}/{}] Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch+1,num_epochs,d_loss.item(),g_loss.item()))
        #appending loss every 10 epochs
        d_loss_list.append(d_loss.item())
        g_loss_list.append(d_loss.item())
    
        # Save fake images
        fake_images = fake_images.view(fake_images.size(0), 3, 128, 128)
        save_image(fake_images.data, 'images/fakes/fake_images-%d.png' %(epoch+1))
        
        
        # Save input images
        if sample_imflag == 0:
            inputImages = inputImages.view(inputImages.size(0), 3, 128, 128)
            save_image(inputImages.data, 'images/input_images/inputImages-%d.png' %(epoch+1))
            sample_imflag = 1
    
#GANs are saved
torch.save(G.state_dict(), 'GAN_OUTs/DCGAN2_generator.pkl')
torch.save(D.state_dict(), 'GAN_OUTs/DCGAN2_discriminator.pkl')
    


Epoch No [10/1000] Discriminator Loss: 0.0037, Generator Loss: 10.0998
Epoch No [20/1000] Discriminator Loss: 0.0034, Generator Loss: 10.0202
Epoch No [30/1000] Discriminator Loss: 0.0027, Generator Loss: 10.7909
Epoch No [40/1000] Discriminator Loss: 0.0021, Generator Loss: 10.8625
Epoch No [50/1000] Discriminator Loss: 0.0025, Generator Loss: 11.2064
Epoch No [60/1000] Discriminator Loss: 0.0022, Generator Loss: 11.1205
Epoch No [70/1000] Discriminator Loss: 0.0024, Generator Loss: 10.8766
Epoch No [80/1000] Discriminator Loss: 0.0021, Generator Loss: 11.1700
Epoch No [90/1000] Discriminator Loss: 0.0023, Generator Loss: 11.1395
Epoch No [100/1000] Discriminator Loss: 0.0025, Generator Loss: 11.0610
Epoch No [110/1000] Discriminator Loss: 0.0023, Generator Loss: 11.3941
Epoch No [120/1000] Discriminator Loss: 0.0020, Generator Loss: 11.7314
Epoch No [130/1000] Discriminator Loss: 0.0020, Generator Loss: 11.5247
Epoch No [140/1000] Discriminator Loss: 0.0023, Generator Loss: 11.4969
E