In [1]:
import torch
import itertools
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets import ImageDataset

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        #make residual block with nn.ReflectionPad2d, nn.Conv2d, nn.InstanceNorm2d, nn.ReLU
        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]
       
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block 
        # Conv nn.Reflectionpad2d , nn.Conv2d, nn.InstanceNorm2d , nn.ReLU
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        
        # make 3 downsampling with for
        # make Conv2d with stride 2 / InstanceNorm2d , nn.ReLU
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Make Residual blocks with class ResidualBlock
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        # make ConvTranspose2d with stride 2 / InstanceNorm2d , nn.ReLU
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        # make Reflectionpad2d, Conv2d, Tanh
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        
        # conv2d kernel_size = 4 stride=2 padding =1 , LeakyReLU 
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        # conv2d kernel_size = 4 stride=2 padding =1 , InstanceNorm2d, LeakyRelu
        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # conv2d kernel_size = 4 stride=2 padding =1 , InstanceNorm2d, LeakyRelu
        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]
        # conv2d kernel_size = 4 padding=1 , InstanceNorm2d, LeakyRelu
        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [5]:
###### Definition of variables ######
input_nc = 3
output_nc = 3
size = 256
device = torch.device("cuda")
batch_size = 4
lr = 0.001

# Networks 2 Generator / 2 Discriminator
netG_A2B = Generator(input_nc, output_nc)
netG_B2A = Generator(output_nc, input_nc)
netD_A = Discriminator(input_nc)
netD_B = Discriminator(output_nc)

In [6]:
# to device
netG_A2B.to(device)
netG_B2A.to(device)
netD_A.to(device)
netD_B.to(device)

# criterion MSE and L1Loss
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()

# Optimizers of 4 models
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

# intermidiate Tensor
Tensor = torch.cuda.FloatTensor
input_A = Tensor(batch_size,input_nc,size,size)
input_B = Tensor(batch_size,input_nc,size,size)
target_real = Tensor(batch_size,1).fill_(1.0)
target_fake = Tensor(batch_size,1).fill_(0.0)

In [7]:
# transform and Dataset loader
transforms_ = [ transforms.Resize(int(size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset('datasets/apple2orange', transforms_=transforms_, unaligned=False), 
                        batch_size=batch_size, shuffle=True)

In [8]:
for epoch in range(1, 200):
    for i, batch in enumerate(dataloader):
        # Set model input
        real_A = (input_A.copy_(batch['A']))
        real_B = (input_B.copy_(batch['B']))

        ###### Generators A2B and B2A ######
        #zero grad
        optimizer_G.zero_grad()

        # GAN loss
        #make fake_B and Discriminate
        fake_B = netG_A2B(real_A)
        pred_fake = netD_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        #make fake A and Discriminate
        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Calculate Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0

        # Total loss
        loss_G =  loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        print (loss_G)
        optimizer_G.step()
        ###################################

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        pred_fake = netD_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()
        optimizer_D_A.step()
        ###################################

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_B)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        ###################################

tensor(15.2123, device='cuda:0')
tensor(0.4930, device='cuda:0')
tensor(0.4927, device='cuda:0')
tensor(18.9966, device='cuda:0')
tensor(6.3532, device='cuda:0')
tensor(5.4807, device='cuda:0')
tensor(13.1298, device='cuda:0')
tensor(0.2031, device='cuda:0')
tensor(0.2293, device='cuda:0')
tensor(15.6228, device='cuda:0')
tensor(1.4290, device='cuda:0')
tensor(1.0899, device='cuda:0')
tensor(10.9677, device='cuda:0')
tensor(0.2917, device='cuda:0')
tensor(0.2828, device='cuda:0')
tensor(10.2554, device='cuda:0')
tensor(0.5093, device='cuda:0')
tensor(0.4046, device='cuda:0')
tensor(9.9535, device='cuda:0')
tensor(0.1355, device='cuda:0')
tensor(0.1926, device='cuda:0')
tensor(8.9080, device='cuda:0')
tensor(0.1413, device='cuda:0')
tensor(0.2267, device='cuda:0')
tensor(9.3137, device='cuda:0')
tensor(0.1574, device='cuda:0')
tensor(0.1691, device='cuda:0')
tensor(9.0695, device='cuda:0')
tensor(0.1463, device='cuda:0')
tensor(0.1901, device='cuda:0')


KeyboardInterrupt: 