In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import itertools
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image

In [None]:
# Training hyperparameters
num_epochs = 200
decay_epoch = 100
batch_size = 32
lr = 0.0002
size = 256  # Image size
input_nc = 3  # Number of input image channels
output_nc = 3  # Number of output image channels
ngf = 64  # Number of generator filters in first conv layer
ndf = 64  # Number of discriminator filters in first conv layer
lambda_cyc = 10  # Cycle consistency loss weight
lambda_id = 0.5 * lambda_cyc  # Identity loss weight
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
##############################
#        NETWORKS            #
##############################

# Define the Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

# Define the Generator
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Downsampling
        in_features = ngf
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2,
                                   padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 8, 1, kernel_size=4, padding=1)
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

##############################
#          MODELS            #
##############################

netG_A2B = Generator(input_nc, output_nc).to(device)
netG_B2A = Generator(output_nc, input_nc).to(device)
netD_A = Discriminator(input_nc).to(device)
netD_B = Discriminator(output_nc).to(device)

In [10]:
class UnalignedDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.transform = transform
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))])
        self.length_A = len(self.files_A)
        self.length_B = len(self.files_B)

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % self.length_A]).convert('RGB')
        img_B = Image.open(self.files_B[index % self.length_B]).convert('RGB')

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {'A': img_A, 'B': img_B}

    def __len__(self):
        return max(self.length_A, self.length_B)
    
transform = transforms.Compose([
    transforms.Resize(int(size * 1.12), transforms.InterpolationMode.BICUBIC),
    transforms.RandomCrop(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

data_folder = "/home/johntoro/code/GAN/data/"
path_A = data_folder + "maps/trainA"
path_B = data_folder + "maps/trainB"
dataset = UnalignedDataset(root_A=path_A, root_B=path_B, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize optimizers
optimizer_G = optim.Adam(
    itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
    lr=lr, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
##############################
#           TRAINING         #
##############################
for epoch in range(1, num_epochs + 1):
    for i, data in enumerate(dataloader):
        real_A = data['A'].to(device)
        real_B = data['B'].to(device)
        
        print(real_A.shape, real_B.shape)

        # Generators
        optimizer_G.zero_grad()

        # Identity loss
        same_B = netG_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B) * lambda_cyc * lambda_id

        same_A = netG_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A) * lambda_cyc * lambda_id

        # GAN loss
        fake_B = netG_A2B(real_A)
        pred_fake = netD_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

        # Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * lambda_cyc

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * lambda_cyc

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        # Discriminator A
        optimizer_D_A.zero_grad()

        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

        pred_fake = netD_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # Discriminator B
        optimizer_D_B.zero_grad()

        pred_real = netD_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))

        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        # Print log info
        if i % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)} '
                  f'Loss_D: {loss_D_A + loss_D_B}, Loss_G: {loss_G}')

    # Update learning rates
    if epoch > decay_epoch:
        lr -= lr / (num_epochs - decay_epoch)
        for param_group in optimizer_G.param_groups:
            param_group['lr'] = lr
        for param_group in optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in optimizer_D_B.param_groups:
            param_group['lr'] = lr


torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
Epoch [1/200] Batch 0/1096 Loss_D: 1.281421422958374, Loss_G: 72.14110565185547
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256