In [None]:
import torch
import torch.nn as nn

In [None]:
class BASIC_D(nn.Module):
    def __init__(self, nc_in, nc_out, ndf, max_layers=3, ngpu=1):
        super(BASIC_D, self).__init__()
        self.ngpu = ngpu
        self.nc_in = nc_in
        self.nc_out = nc_out        
        main = nn.Sequential()
        # input is nc x isize x isize
        main.add_module('initial.conv.{0}-{1}'.format(nc_in+nc_out, ndf),
                        nn.Conv2d(nc_in+nc_out, ndf, 4, 2, 1))
        main.add_module('initial.relu.{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        out_feat = ndf
        for layer in range(1, max_layers):
            in_feat = out_feat
            out_feat = ndf * min(2**layer, 8)
            main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))            
            main.add_module('pyramid.{0}.batchnorm'.format(out_feat),
                            nn.BatchNorm2d(out_feat))
            main.add_module('pyramid.{0}.relu'.format(out_feat),
                            nn.LeakyReLU(0.2, inplace=True))
        in_feat = out_feat
        out_feat = ndf*min(2**max_layers, 8)
        main.add_module('last.{0}-{1}.conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 1, 1, bias=False))
        main.add_module('last.{0}.batchnorm'.format(out_feat), nn.BatchNorm2d(out_feat))
        main.add_module('last.{0}.relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True))
        
        in_feat, out_feat = out_feat, 1        
        main.add_module('final.{0}-{1}.conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 1, 1, bias=True))
        main.add_module('final.{0}.sigmoid'.format(out_feat),  nn.nn.Sigmoid())        
        self.main = main

    def forward(self, a, b):
        x = torch.cat((a, b), 1)        
        output = self.main(x)                    
        return output
    

In [None]:
class UNET_G(nn.Module):
    def __init__(self, isize, nc_in=3, nc_out=3, ngf=64):
        super(UNET_G, self).__init__()       
        assert isize % 16 == 0, "isize has to be a multiple of 16"
        conv_layers = []
        convt_layers = []
        # Down sampling
        tsize = isize
        out_feat = nc_in
        layers = []
        while True:
            assert tsize>=2 and tsize%2==0
            out_feat, in_feat = ngf * min(2**len(conv_layers), 8), out_feat
            use_batchnorm = len(conv_layers)>1 and tsize>2
            layers.append(nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=not use_batchnorm))
            if tsize==2:
                break 
            if use_batchnorm:
                layers.append(nn.BatchNorm2d(out_feat))
            conv_layers.append(nn.Sequential(*layers))
            layers = [nn.LeakyReLU(0.2, inplace=True)]            
            tsize = tsize // 2
        
        # Up sampling
        use_batchnorm = True
        while tsize<isize:
            layers.append(nn.ReLU)
            out_feat, in_feat = ngf * min(2**(len(conv_layers)-1), 8), out_feat
            layers.append(nn.ConvTranspose2d(in_feat, out_feat,
                                        kernel_size=4, stride=2,padding=1, bias=False))
            layers.append(nn.BatchNorm2d(out_feat))
            if tsize <=8:
                layers.append(DropoutLayer(0.5))
            convt_layers.append(layers)
            layers = []
            out_feat = out_feat*2
        self.output = nn.Sequential(nn.ReLU, 
                      nn.ConvTranspose2d(out_feat, 1, kernel_size=4, stride=2,padding=1, bias=True),
                      nn.Tanh()
                      )
        self.conv_layers = nn.ModuleList((conv_layers)
        self.convt_layers = nn.ModuleList(convt_layers)
 

    def forward(self, x):
        outputs = []
        for l in self.conv_layers:
            x = l(x)
            outputs.append(x)
        for l in self.convt_layers:
            x = l(x)            
            x = torch.cat((x, outputs.pop()), 1)
        return x

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
nc_in = 3
nc_out = 3
ngf = 64
ndf = 64
λ = 10

loadSize = 286
imageSize = 256
batchSize = 1
lrD = 2e-4
lrG = 2e-4

In [None]:
netD = BASIC_D(nc_in, nc_out, ndf)
netD.apply(weights_init)

In [None]:
netG = UNET_G(imageSize, nc_in, nc_out, ngf)
netG.apply(weights_init)

In [None]:
netG

In [None]:
netD

In [None]:
input = torch.FloatTensor(batchSize, 3, imageSize, imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
one = torch.FloatTensor([1])
mone = one * -1

In [None]:
netD.cuda()
netG.cuda()
input = input.cuda()
one, mone = one.cuda(), mone.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

In [None]:
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable

In [None]:
optimizerD = optim.RMSprop(netD.parameters(), lr = lrD)
optimizerG = optim.RMSprop(netG.parameters(), lr = lrG)

In [None]:
import time
t0 = time.time()
niter = 1000
gen_iterations = 0
for epoch in range(niter):
    i = 0
    batches = train_X.shape[0]//batchSize
    while i < batches:
        for p in netD.parameters(): # reset requires_grad
            p.requires_grad = True # they are set to False below in netG update
        if gen_iterations < 25 or gen_iterations %500 == 0:
            _Diters = 100
        else:
            _Diters = Diters
        j = 0
        while j < _Diters and i < batches:
            j+=1
            
            # clamp parameters to a cube
            for p in netD.parameters():
                p.data.clamp_(clamp_lower, clamp_upper)
                
            real_data = torch.from_numpy(
                np.moveaxis(train_X[i*batchSize:(i+1)*batchSize], 3,1)
            ).cuda()
            i+=1
            
            netD.zero_grad()
            input.resize_as_(real_data).copy_(real_data)
            inputv = Variable(input)
            
            errD_real = netD(inputv)
            errD_real.backward(one)
            
            # train with fake
            noise.resize_(batchSize, nz, 1, 1).normal_(0, 1)
            noisev = Variable(noise, volatile = True) # totally freeze netG
            fake = Variable(netG(noisev).data)
            inputv = fake
            errD_fake = netD(inputv)
            errD_fake.backward(mone)
            errD = errD_real - errD_fake
            optimizerD.step()

        for p in netD.parameters():
            p.requires_grad = False # to avoid computation
        netG.zero_grad()
        noise.resize_(batchSize, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = netG(noisev)
        errG = netD(fake)
        errG.backward(one)
        optimizerG.step()
        gen_iterations += 1
        if gen_iterations%500 ==0:
            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
            % (epoch, niter, i, batches, gen_iterations,
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]), time.time()-t0)
        if gen_iterations%500 == 0:            
            fake = netG(Variable(fixed_noise, volatile=True))            
            showX(np.moveaxis(fake.data.cpu().numpy(),1,3), 4)
        

-------------------------