In [0]:
#Run this cell to check how much GPU has been allotted (GPU RAM Free)
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

In [0]:
#This cell mounts the Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#Importing necessary packages
import torch
import torch.nn as nn
import os

In [0]:

def conv_block(in_feat, out_feat, ksize, stride, padding, 
               activation=nn.LeakyReLU(0.2, inplace=True), use_batchnorm=True):
  #Defines convolution block with conv2d, batch normalization and LeakyReLU activation for each block
    layers = [nn.Conv2d(in_feat, out_feat, ksize, stride, padding, bias=not use_batchnorm)]
    if use_batchnorm:
        layers.append(nn.BatchNorm2d(out_feat)) 
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)

class BASIC_D(nn.Module):
  #Basic discriminator model for building architecture through recurrence. Uses the Patch GAN architecture.
    def __init__(self, nc_in, nc_out, ndf, max_layers=3):
        super(BASIC_D, self).__init__()       
        main = nn.Sequential()
        # input is nc x isize x isize
        main.add_module('initial-{0}-{1}'.format(nc_in+nc_out, ndf),
                        conv_block(nc_in+nc_out, ndf, 4, 2, 1, use_batchnorm=False))
        out_feat = ndf
        for layer in range(1, max_layers):
            in_feat = out_feat
            out_feat = ndf * min(2**layer, 8)
            main.add_module('pyramid-{0}-{1}'.format(in_feat, out_feat),
                                conv_block(in_feat, out_feat, 4, 2, 1, ))           
        in_feat = out_feat
        out_feat = ndf*min(2**max_layers, 8)
        main.add_module('last-{0}-{1}'.format(in_feat, out_feat),
                        conv_block(in_feat, out_feat, 4, 1, 1))
        
        in_feat, out_feat = out_feat, 1        
        main.add_module('output-{0}-{1}'.format(in_feat, out_feat),
                        conv_block(in_feat, out_feat, 4, 1, 1, nn.Sigmoid(), False))
        self.main = main

    def forward(self, a, b):
        x = torch.cat((a, b), 1)        
        output = self.main(x)                    
        return output

In [0]:
class UBlock(nn.Module):
  """ 
  Defines basic ublock for generator. Unet is an encoder- decoder.
  U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
  The original U-Net paper: https://arxiv.org/abs/1505.04597
  
  """
  
    def __init__(self, s, nf_in, max_nf, use_batchnorm=True, nf_out=None, nf_next=None):
        super(UBlock, self).__init__()
        assert s>=2 and s%2==0
        nf_next = nf_next if nf_next else min(nf_in*2, max_nf)
        nf_out = nf_out if nf_out else nf_in            
        self.conv = nn.Conv2d(nf_in, nf_next, 4, 2, 1, bias=not (use_batchnorm and s>2) )
        if s>2:
            next_block = [nn.BatchNorm2d(nf_next)] if use_batchnorm else []
            next_block += [nn.LeakyReLU(0.2, inplace=True), UBlock(s//2, nf_next, max_nf)]
            self.next_block = nn.Sequential(*next_block)
        else:
            self.next_block = None
        convt = [nn.ReLU(), 
                 nn.ConvTranspose2d(nf_next*2 if self.next_block else nf_next, nf_out,
                                        kernel_size=4, stride=2,padding=1, bias=not use_batchnorm)]    
        if use_batchnorm:
            convt += [nn.BatchNorm2d(nf_out)]        
        if s <= 8:
            convt += [nn.Dropout(0.5, inplace=True)]
        self.convt = nn.Sequential(*convt)  

    def forward(self, x):
        x = self.conv(x)
        if self.next_block:
            x2 = self.next_block(x)
            #Adds skip connections
            x = torch.cat((x,x2),1)
        return self.convt(x)        


def UNET_G(isize, nc_in=3, nc_out=3, ngf=64):
    return nn.Sequential(
                  UBlock(isize, nc_in, 8*ngf, False, nf_out=nc_out, nf_next=ngf),
                  nn.Tanh() )

In [0]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
"""
Parameters for our models
nc_in -> Number of channels in input (RGB is 3 channels)
nc_out -> Number of channels in output
ngf -> The number of filters in the last conv layer
ndf -> The nuber of filters in the last discriminator layer
imageSize -> Input image size (Height and width)
batchSize -> Number of images in each batch. Total number of batches per epoch is total_dataset/batch_size
lrD and lrG - Learning rates of discriminator and generator respectively


"""
nc_in = 3
nc_out = 3
ngf = 64
ndf = 64

loadSize = 286
imageSize = 256
batchSize = 32
lrD = 2e-4
lrG = 2e-4

In [0]:
netD = BASIC_D(nc_in, nc_out, ndf)

if os.path.exists("/content/drive/My Drive/checkpoints/Discy.pt"):
  netD=torch.load("/content/drive/My Drive/checkpoints/Discy.pt")
else:
  netD.apply(weights_init)

In [0]:
netG = UNET_G(imageSize, nc_in, nc_out, ngf)

if os.path.exists("/content/drive/My Drive/checkpoints/Generator.pt"):
  netG=torch.load("/content/drive/My Drive/checkpoints/Generator.pt")
else:
  netG.apply(weights_init)

In [0]:
inputA = torch.FloatTensor(batchSize, nc_in, imageSize, imageSize)
inputB = torch.FloatTensor(batchSize, nc_out, imageSize, imageSize)

In [0]:
#Remove for CPU usage 
netD.cuda()
netG.cuda()
inputA = inputA.cuda()
inputB = inputB.cuda()

In [0]:
from PIL import Image
import numpy as np
import glob
from random import randint, shuffle

#Preprocessing of images in dataset. 
def resize_image(im,bandwidth=256):
    #Resize image to imageSize and remove alpha channel (PNG images have alpha channel, our model does not require this information)
    img = im.resize((bandwidth, bandwidth),  Image.ANTIALIAS)
    img=np.array(img)/255*2-1
    img=img[:,:,:3]
    return img
def read_image(rc,direction=0):
    realimage=Image.open(rc[1])
    semantic=Image.open(rc[0])
    realimage=resize_image(realimage)
    semantic=resize_image(semantic)
    if randint(0,1):
        realimage=realimage[:,::-1]
        semantic=semantic[:,::-1]
    if channel_first:
        realimage = np.moveaxis(realimage, 2, 0)
        semantic = np.moveaxis(semantic, 2, 0)
    if direction==0:
        return semantic,realimage
    else:
        return realimage, semantic


In [0]:
anuepath="/content/drive/My Drive/anue"
TrainCol=glob.glob(anuepath+'/gtFine/train/*/*_labelColors.png')
TrainReal=glob.glob(anuepath+'/leftImg8bit/train/*/*_leftImg8bit.png')

ValCol=glob.glob(anuepath+'/gtFine/val/*/*_labelColors.png')
ValReal=glob.glob(anuepath+'/leftImg8bit/val/*/*_leftImg8bit.png')
direction = 0
trainAB=list(zip(sorted(TrainCol), sorted(TrainReal)))
valAB=list(zip(sorted(ValCol), sorted(ValReal)))

assert len(TrainCol) and len(TrainReal)

In [0]:
def minibatch(dataAB, batchsize, direction=0):
    length = len(dataAB)
    epoch = i = 0
    with open('/content/drive/My Drive/epoch.txt','r') as fp:
        epoch=int(fp.read())
    tmpsize = None    
    while True:
        size = tmpsize if tmpsize else batchsize
        if i+size > length:
            shuffle(dataAB)
            i = 0
            epoch+=1        
        dataA = []
        dataB = []
        for j in range(i,i+size):
            imgA,imgB = read_image(dataAB[j], direction)
            dataA.append(imgA)
            dataB.append(imgB)
        dataA = np.float32(dataA)
        dataB = np.float32(dataB)
        i+=size
        tmpsize = yield epoch, dataA, dataB     

In [0]:
from IPython.display import display
def showX(X,epoch,geniter, rows=1):
    assert X.shape[0]%rows == 0
    int_X = ( (X+1)/2*255).clip(0,255).astype('uint8')
    if channel_first:
        int_X = np.moveaxis(int_X.reshape(-1,3,imageSize,imageSize), 1, 3)
    else:
        int_X = int_X.reshape(-1,imageSize,imageSize, 3)
    int_X = int_X.reshape(rows, -1, imageSize, imageSize,3).swapaxes(1,2).reshape(rows*imageSize,-1, 3)
    
    j=Image.fromarray(int_X)
    j.save("/content/drive/My Drive/output/epoch"+str(epoch)+"genit"+str(geniter)+".png")
    display(Image.fromarray(int_X))

In [0]:
channel_first=True
channel_axis=1
train_batch = minibatch(trainAB, 6, direction=direction)
_, trainA, trainB = next(train_batch)

In [0]:
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable

In [0]:
optimizerD = optim.Adam(netD.parameters(), lr = lrD, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = lrG, betas=(0.5, 0.999))

In [0]:
loss = nn.BCELoss()
lossL1 = nn.L1Loss()
one = None
zero = None
def netD_train(A, B):    
    global one, zero
    netD.zero_grad()
    output_D_real = netD(A, B)
    if one is None:
        one = Variable(torch.ones(*output_D_real.size()).cuda())
    errD_real = loss(output_D_real, one)
    errD_real.backward()

    output_G = netG(A)
    output_D_fake = netD(A, output_G)
    if zero is None:
        zero = Variable(torch.zeros(*output_D_fake.size()).cuda())
    errD_fake = loss(output_D_fake, zero)
    errD_fake.backward()
    optimizerD.step()
    return (errD_fake.data+errD_real.data)/2,


def netG_train(A, B):
    global one
    netG.zero_grad()d
    output_G = netG(A)
    output_D_fake = netD(A, output_G)
    if one is None:
        one = Variable(torch.ones(*output_D_fake.size()).cuda())
    errG_fake = loss(output_D_fake, one)    
    errG_L1 = lossL1(output_G, B)
    errG = errG_fake + 100 * errG_L1
    errG.backward()
        
    optimizerG.step()
    return errG_fake.data, errG_L1.data

In [0]:
def V(x):
    return Variable(torch.from_numpy(x).cuda())

In [0]:
def netG_gen(A):
    return np.concatenate([netG(A[i:i+1]).data.cpu().numpy() for i in range(A.size()[0])], axis=0)

In [0]:
import time
import pickle
from IPython.display import clear_output
t0 = time.time()
niter = 150
gen_iterations = 0
errL1 = epoch = errG = 0
with open('/content/drive/My Drive/epoch.txt','r') as fp:
    epoch=int(fp.read())
errL1_sum = errG_sum = errD_sum = 0

display_iters = 100
val_batch = minibatch(valAB, 6, direction)
train_batch = minibatch(trainAB, batchSize, direction)
print(train_batch)
genLoss=[]
discLoss=[]
while epoch < niter: 
    epoch, trainA, trainB = next(train_batch)   
    vA, vB = V(trainA), V(trainB)
    errD,  = netD_train(vA, vB)
    errD_sum +=errD
    if gen_iterations%10==0:
      print("Epoch # ",epoch,"Minibatch# ",gen_iterations)
    
    errG, errL1 = netG_train(vA, vB)
    errG_sum += errG
    errL1_sum += errL1
    gen_iterations+=1
    if gen_iterations%display_iters==0:
        if gen_iterations%(2*display_iters)==0:
            torch.save(netG, "/content/drive/My Drive/checkpoints/Generator.pt")
            torch.save(netD, "/content/drive/My Drive/checkpoints/Discy.pt")
            with open('/content/drive/My Drive/generatorloss.pickle', 'ab') as fp:
              pickle.dump(genLoss, fp)
            with open('/content/drive/My Drive/discriminatorloss.pickle', 'ab') as fp:
              pickle.dump(discLoss, fp)
            genLoss=[]
            discLoss=[]
            with open('/content/drive/My Drive/epoch.txt','w') as fp:
              fp.write(str(epoch))
            clear_output()
        genLoss.append(errG_sum/display_iters)
        discLoss.append(errD_sum/display_iters)
        print('[%d/%d][%d] Loss_D: %f Loss_G: %f loss_L1: %f'
        % (epoch, niter, gen_iterations, errD_sum/display_iters, 
           errG_sum/display_iters, errL1_sum/display_iters), time.time()-t0)
        _, valA, valB = train_batch.send(6)
        vA, vB = V(valA),V(valB)
        fakeB = netG_gen(vA)
        showX(np.concatenate([valA, valB, fakeB], axis=0),epoch,gen_iterations,3)
        
        errL1_sum = errG_sum = errD_sum = 0
        _, valA, valB = next(val_batch)
        fakeB = netG_gen(V(valA))
        showX(np.concatenate([valA, valB, fakeB], axis=0),epoch,gen_iterations,3)
        