In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import os
import glob
import datetime
import itertools

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode='train'):
        self.transform = transforms_
        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
        item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if np.random.uniform(0, 1) > 0.5:
                    i = np.random.randint(0, self.max_size)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        self.input_shape = input_shape
        self.num_residual_blocks = num_residual_blocks

        # Initial convolution block
        model = [nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=1, padding=3, bias=False),
                 nn.InstanceNorm2d(64),
                 nn.ReLU(inplace=True)]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=False),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [nn.Conv2d(64, input_shape[0], kernel_size=7, stride=1, padding=3),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1, bias=False),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1, bias=False),
                      nn.InstanceNorm2d(in_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, height, width = self.input_shape

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

class CycleGAN():
    def __init__(self, dataset_name='dataset_preprocessed', img_height=100, img_width=100, channels=3, lr=0.0002, b1=0.5, b2=0.999):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.img_height = img_height
        self.img_width = img_width
        self.channels = channels
        self.input_shape = (self.channels, self.img_height, self.img_width)

        self.dataset_name = dataset_name

        # Initialize the generator and discriminator
        self.G_AB = GeneratorResNet(self.input_shape, num_residual_blocks=9).to(self.device)
        self.G_BA = GeneratorResNet(self.input_shape, num_residual_blocks=9).to(self.device)
        self.D_A = Discriminator(self.input_shape).to(self.device)
        self.D_B = Discriminator(self.input_shape).to(self.device)

        # Initialize weights
        self.G_AB.apply(self.weights_init_normal)
        self.G_BA.apply(self.weights_init_normal)
        self.D_A.apply(self.weights_init_normal)
        self.D_B.apply(self.weights_init_normal)

        # Losses
        self.criterion_GAN = torch.nn.MSELoss().to(self.device)
        self.criterion_cycle = torch.nn.L1Loss().to(self.device)
        self.criterion_identity = torch.nn.L1Loss().to(self.device)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), lr=lr, betas=(b1, b2))
        self.optimizer_D_A = torch.optim.Adam(self.D_A.parameters(), lr=lr, betas=(b1, b2))
        self.optimizer_D_B = torch.optim.Adam(self.D_B.parameters(), lr=lr, betas=(b1, b2))

        # Buffers of previously generated samples
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Image transformations
        self.transforms_ = transforms.Compose([
            transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
            transforms.RandomCrop((self.img_height, self.img_width)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def weights_init_normal(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    def train(self, epochs, batch_size=1, sample_interval=50, saving_dir='saved_models'):
        # Load dataset
        dataloader = DataLoader(ImageDataset(f'./{self.dataset_name}', transforms_=self.transforms_),
                                batch_size=batch_size, shuffle=True, num_workers=0)

        # Test the output size of the Discriminator
        sample_input = torch.randn((batch_size, *self.input_shape)).to(self.device)
        output_size = self.D_A(sample_input).shape[2:]
        valid = torch.ones((batch_size, 1, *output_size), requires_grad=False).to(self.device)
        fake = torch.zeros((batch_size, 1, *output_size), requires_grad=False).to(self.device)

        start_time = datetime.datetime.now()

        for epoch in range(epochs):
            for i, batch in enumerate(dataloader):

                # Set model input
                real_A = batch['A'].to(self.device)
                real_B = batch['B'].to(self.device)

                # ----------------------
                #  Train Generators
                # ----------------------

                self.optimizer_G.zero_grad()

                # Identity loss
                loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
                loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)

                loss_identity = (loss_id_A + loss_id_B) / 2

                # GAN loss
                fake_B = self.G_AB(real_A)
                loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid)
                fake_A = self.G_BA(real_B)
                loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid)

                loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

                # Cycle loss
                recov_A = self.G_BA(fake_B)
                loss_cycle_A = self.criterion_cycle(recov_A, real_A)
                recov_B = self.G_AB(fake_A)
                loss_cycle_B = self.criterion_cycle(recov_B, real_B)

                loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

                # Total loss
                loss_G = loss_GAN + 10.0 * loss_cycle + 5.0 * loss_identity

                loss_G.backward()
                self.optimizer_G.step()

                # -----------------------
                #  Train Discriminator A
                # -----------------------

                self.optimizer_D_A.zero_grad()

                # Real loss
                loss_real = self.criterion_GAN(self.D_A(real_A), valid)
                # Fake loss (on batch of previously generated samples)
                fake_A_ = self.fake_A_buffer.push_and_pop(fake_A)
                loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake)

                # Total loss
                loss_D_A = (loss_real + loss_fake) / 2

                loss_D_A.backward()
                self.optimizer_D_A.step()

                # -----------------------
                #  Train Discriminator B
                # -----------------------

                self.optimizer_D_B.zero_grad()

                # Real loss
                loss_real = self.criterion_GAN(self.D_B(real_B), valid)
                # Fake loss (on batch of previously generated samples)
                fake_B_ = self.fake_B_buffer.push_and_pop(fake_B)
                loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake)

                # Total loss
                loss_D_B = (loss_real + loss_fake) / 2

                loss_D_B.backward()
                self.optimizer_D_B.step()

                loss_D = (loss_D_A + loss_D_B) / 2

                # --------------
                #  Log Progress
                # --------------

                batches_done = epoch * len(dataloader) + i
                batches_left = epochs * len(dataloader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left * (datetime.datetime.now() - start_time).seconds / max(1, batches_done))

                print(f'\r[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {loss_D.item()}] [G loss: {loss_G.item()}] '
                      f'[Adv: {loss_GAN.item()}] [Cycle: {loss_cycle.item()}] [Identity: {loss_identity.item()}] ETA: {time_left}', end='')

                # If at sample interval save image
                if batches_done % sample_interval == 0:
                    self.sample_images(batches_done, epoch)

            # Save model checkpoints
            if epoch % 2 == 0:
                os.makedirs(saving_dir, exist_ok=True)
                torch.save(self.G_AB.state_dict(), os.path.join(saving_dir, f'G_AB_{epoch}.pth'))
                torch.save(self.G_BA.state_dict(), os.path.join(saving_dir, f'G_BA_{epoch}.pth'))
                torch.save(self.D_A.state_dict(), os.path.join(saving_dir, f'D_A_{epoch}.pth'))
                torch.save(self.D_B.state_dict(), os.path.join(saving_dir, f'D_B_{epoch}.pth'))

    def sample_images(self, batches_done, epoch):
        os.makedirs('images/%s/%s' % (self.dataset_name, epoch), exist_ok=True)
        imgs = next(iter(DataLoader(ImageDataset(f'./{self.dataset_name}', transforms_=self.transforms_, mode='train'), batch_size=5, shuffle=True)))
        real_A = imgs['A'].to(self.device)
        real_B = imgs['B'].to(self.device)
        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)
        recov_A = self.G_BA(fake_B)
        recov_B = self.G_AB(fake_A)
        img_sample = torch.cat((real_A.data, fake_B.data, recov_A.data, real_B.data, fake_A.data, recov_B.data), 0)
        save_image(img_sample, 'images/%s/%s/%s.png' % (self.dataset_name, epoch, batches_done), nrow=5, normalize=True)

In [None]:
# Create CycleGAN instance and start training
cyclegan = CycleGAN()
cyclegan.train(epochs=200, batch_size=1, sample_interval=200)