In [None]:
import torch
import torch.nn as nn

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel: int, out_channel=256):
        super().__init__()
        self.block = nn.Sequential(
            ConvolutionalBlock(in_channel, out_channel, is_activation=True, kernel_size=3, padding=1),
            ConvolutionalBlock(in_channel, out_channel, is_activation=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

In [3]:
class ConvolutionalBlock(nn.Module):
    def __init__(
            self,
            in_channel: int,
            out_channel: int,
            kernel_size: int,
            stride=1,
            padding=0,
            is_downsample=True,
            is_activation=True,
            out_padding=1,
            **kwargs
    ):
        super().__init__()
        if is_downsample:
            self.main = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs),
                nn.InstanceNorm2d(out_channel),
            )
            if is_activation:
                self.main.append(nn.ReLU(inplace=True))
            else:
                self.main.append(nn.Identity())
        else:
            self.main = nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding,
                                   output_padding=out_padding,
                                   **kwargs),
                nn.InstanceNorm2d(out_channel)
            )
            if is_activation:
                self.main.append(nn.ReLU(inplace=True))
            else:
                self.main.append(nn.Identity())

    def forward(self, x):
        return self.main(x)

In [4]:
class Generator(nn.Module):
    def __init__(
            self,
            in_channel=3,
    ):
        super().__init__()
        channel = [64, 128, 256, 128, 64, 3]
        self.layers_1 = nn.Sequential(
            nn.Conv2d(in_channel, channel[0], kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(channel[0]),
            nn.ReLU(inplace=True)
        )
        self.layers_2 = nn.ModuleList(
            [ConvolutionalBlock(channel[0], channel[1], kernel_size=3, stride=2, padding=1, is_downsample=True,
                                is_activation=True),
             ConvolutionalBlock(channel[1], channel[2], kernel_size=3, stride=2, padding=1, is_downsample=True,
                                is_activation=True)]
        )
        self.layers_3 = nn.Sequential(
            *[ResidualBlock(channel[2]) for _ in range(9)]
        )
        self.layers_4 = nn.ModuleList(
            [ConvolutionalBlock(channel[2], channel[3], kernel_size=3, stride=2, padding=1, is_downsample=False,
                                is_activation=True, out_padding=1),
             ConvolutionalBlock(channel[3], channel[4], kernel_size=3, stride=2, padding=1, is_downsample=False,
                                is_activation=True, out_padding=1)]
        )
        self.layers_5 = nn.Sequential(
            nn.Conv2d(channel[4], channel[5], kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        )

    def forward(self, x):
        x = self.layers_1(x)
        for layer in self.layers_2:
            x = layer(x)

        x = self.layers_3(x)

        for layer in self.layers_4:
            x = layer(x)
        return torch.tanh(self.layers_5(x))

In [5]:
class Discriminator(nn.Module):

    def __init__(self, in_channel=3):
        super().__init__()

        channels = [64, 128, 256, 512]

        def ConvInstanceNormLeakyReLUBlock(
                in_channel,
                out_channel,
                normalize=True,
                kernel_size=4,
                stride=2,
                padding=1,
                activation=None
        ):
            layers = nn.ModuleList(
                [nn.Conv2d(
                    in_channel,
                    out_channel,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False if normalize else True)]
            )

            if normalize:
                layers.append(nn.BatchNorm2d(out_channel))

            layers.append(nn.LeakyReLU(0.2, inplace=True) if activation is None else activation)

            return layers

        self.main = nn.Sequential(
            *ConvInstanceNormLeakyReLUBlock(in_channel, channels[0], normalize=False),
            *ConvInstanceNormLeakyReLUBlock(channels[0], channels[1]),
            *ConvInstanceNormLeakyReLUBlock(channels[1], channels[2]),
            *ConvInstanceNormLeakyReLUBlock(channels[2], channels[3], stride=1),
            *ConvInstanceNormLeakyReLUBlock(channels[3], 1, normalize=False, stride=1, activation=nn.Sigmoid())
        )

    def forward(self, x):
        return self.main(x)


In [6]:
import os
import numpy as np

from PIL import Image
from torch.utils.data import Dataset

In [7]:
class MyDataset(Dataset):
    def __init__(self, monet_dir, photo_dir, size=(256, 256), normalize=True):
        super().__init__()

        def get_img_list(path):
            is_image_file = lambda x: any(x.endswith(extension) for extension in (['.jpg']))
            return [x for x in os.listdir(path) if is_image_file(x)]

        self.monet_dir = monet_dir
        self.photo_dir = photo_dir
        self.monet_idx = dict()
        self.photo_idx = dict()
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        for i, fl in enumerate(get_img_list(self.monet_dir)):
            self.monet_idx[i] = fl
        for i, fl in enumerate(get_img_list(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        rand_idx = int(np.random.uniform(0, len(self.monet_idx.keys())))
        photo_path = os.path.join(self.photo_dir, self.photo_idx[rand_idx])
        monet_path = os.path.join(self.monet_dir, self.monet_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        monet_img = Image.open(monet_path)
        monet_img = self.transform(monet_img)
        return photo_img, monet_img

    def __len__(self):
        return min(len(self.monet_idx.keys()), len(self.photo_idx.keys()))

In [8]:
from torch.utils.data import DataLoader
from torchvision import transforms

from tqdm import tqdm
import torch.optim as optim

In [9]:
import matplotlib.pyplot as plt

In [10]:
class config():
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_EPOCHS = 50
    BATCH_SIZE = 5
    NUM_WORKERS = 8
    LEARNING_RATE = 2e-3
    LMBDA = 10
    ROOT_MONET = "E:\github\CycleGAN\data\monet_jpg"
    ROOT_PHOTO = "E:\github\CycleGAN\data\photo_jpg"
    COEF = 0.5

In [11]:
train_data = MyDataset(config.ROOT_MONET, config.ROOT_PHOTO)
loader = DataLoader(train_data, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, drop_last=True,
                    pin_memory=True)

In [12]:
def save_checkpoint(state, save_path):
    torch.save(state, save_path)

In [13]:
class AvgStats(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.losses = []

    def append(self, loss):
        self.losses.append(loss)

In [14]:
class CycleGAN():
    def __init__(self):
        self.m_reals = 0
        self.m_fakes = 0
        self.avg_g_loss = 0
        self.avg_d_loss = 0

        self.d_p = Discriminator()
        self.d_m = Discriminator()
        self.g_ptm = Generator()
        self.g_mtp = Generator()

        self.gen_stats = AvgStats()
        self.desc_stats = AvgStats()

        self.init_models()

        self.opt_disc = optim.Adam(
            list(self.d_m.parameters()) + list(self.d_p.parameters()),
            lr=config.LEARNING_RATE,
            betas=(0.5, 0.999),
        )
        self.opt_gen = optim.Adam(
            list(self.g_ptm.parameters()) + list(self.g_mtp.parameters()),
            lr=config.LEARNING_RATE,
            betas=(0.5, 0.999),
        )

        self.l1 = nn.L1Loss()
        self.mse = nn.MSELoss()

        self.g_scaler = torch.cuda.amp.GradScaler()
        self.d_scaler = torch.cuda.amp.GradScaler()

    def init_models(self):
        self.d_p = self.d_p.to(config.DEVICE)
        self.d_m = self.d_m.to(config.DEVICE)
        self.g_ptm = self.g_ptm.to(config.DEVICE)
        self.g_mtp = self.g_mtp.to(config.DEVICE)

    def train(self, dataset):

        for epoch in range(config.NUM_EPOCHS):

            total_g_loss = 0
            total_d_loss = 0

            loop = tqdm(dataset, leave=True)

            for idx, (monet, photo) in enumerate(loop):
                monet = monet.to(config.DEVICE)
                photo = photo.to(config.DEVICE)
                
                with torch.cuda.amp.autocast():
                    self.opt_gen.zero_grad()
                    
                    id_monet = self.g_ptm(monet)
                    id_photo = self.g_mtp(photo)
                    id_monet_loss = self.l1(id_monet, monet) * config.LMBDA * config.COEF
                    id_photo_loss = self.l1(id_photo, photo) * config.LMBDA * config.COEF
                    id_loss = id_photo_loss + id_monet_loss
                    
                    fake_monet = self.g_ptm(photo)
                    d_m_fake = self.d_m(fake_monet)
                    loss_g_m = self.mse(d_m_fake, torch.ones_like(d_m_fake))
                    
                    fake_photo = self.g_mtp(monet)
                    d_p_fake = self.d_p(fake_photo)                 
                    loss_g_p = self.mse(d_p_fake, torch.ones_like(d_p_fake))
                    
                    loss_g = loss_g_m+loss_g_p

                    cycle_monet = self.g_ptm(fake_photo)
                    cycle_photo = self.g_mtp(fake_monet)
                    cycle_photo_loss = self.l1(cycle_photo, photo)
                    cycle_monet_loss = self.l1(cycle_monet, monet)

                    g_loss = (
                            loss_g
                            + cycle_photo_loss * config.LMBDA
                            + cycle_monet_loss * config.LMBDA
                            + id_loss
                    )

                    total_g_loss += g_loss.item()

                self.g_scaler.scale(g_loss).backward()
                self.g_scaler.step(self.opt_gen)
                self.g_scaler.update()

                with torch.cuda.amp.autocast():
                    self.opt_disc.zero_grad()
                    
                    d_m_real = self.d_m(monet)
                    d_m_fake = self.d_m(fake_monet.detach())
                    self.m_reals += d_m_real.mean().item()
                    self.m_fakes += d_m_fake.mean().item()
                    d_m_real_loss = self.mse(d_m_real, torch.ones_like(d_m_real))
                    d_m_fake_loss = self.mse(d_m_fake, torch.zeros_like(d_m_fake))
                    d_m_loss = d_m_real_loss + d_m_fake_loss
                    
                    d_p_real = self.d_p(photo)
                    d_p_fake = self.d_p(fake_photo.detach())
                    d_p_real_loss = self.mse(d_p_real, torch.ones_like(d_p_real))
                    d_p_fake_loss = self.mse(d_p_real, torch.zeros_like(d_p_fake))
                    d_p_loss = d_p_real_loss + d_p_fake_loss

                    d_loss = (d_m_loss + d_p_loss) / 2

                    total_d_loss += d_loss.item()

                self.d_scaler.scale(d_loss).backward()
                self.d_scaler.step(self.opt_disc)
                self.d_scaler.update()
              
                self.avg_d_loss = total_d_loss / dataset.__len__()
                self.avg_g_loss = total_g_loss / dataset.__len__()

                loop.set_postfix(m_real=self.m_reals / (idx + 1), m_fake=self.m_fakes / (idx + 1))

            save_dict = {
                'epoch': epoch + 1,
                'g_mtp': self.g_mtp.state_dict(),
                'g_ptm': self.g_ptm.state_dict(),
                'd_m': self.d_m.state_dict(),
                'd_p': self.d_p.state_dict(),
                'optimizer_gen': self.opt_gen.state_dict(),
                'optimizer_desc': self.opt_disc.state_dict()
            }
            save_checkpoint(save_dict, 'current.ckpt')

            print("Epoch: (%d) | Generator Loss:%f | Discriminator Loss:%f" % (epoch +
                                                                               1, self.avg_g_loss, self.avg_d_loss))

            self.gen_stats.append(self.avg_g_loss)
            self.desc_stats.append(self.avg_d_loss)

In [15]:
gan = CycleGAN()

In [16]:
save_dict = {
    'epoch': 0,
    'g_mtp': gan.g_mtp.state_dict(),
    'g_ptm': gan.g_ptm.state_dict(),
    'd_m': gan.d_m.state_dict(),
    'd_p': gan.d_p.state_dict(),
    'optimizer_gen': gan.opt_gen.state_dict(),
    'optimizer_desc': gan.opt_disc.state_dict()
}

In [None]:
gan.train(loader)

  0%|                                                                                                                                                                                        | 0/60 [00:00<?, ?it/s]

In [None]:
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.plot(gan.gen_stats.losses, 'r', label='Generator Loss')
plt.plot(gan.desc_stats.losses, 'b', label='Descriminator Loss')
plt.legend()
plt.show()

In [None]:
class PhotoDataset(Dataset):
    def __init__(self, photo_dir, size=(256, 256), normalize=True):
        super().__init__()

        def get_img_list(path):
            is_image_file = lambda x: any(x.endswith(extension) for extension in (['.jpg']))
            return [x for x in os.listdir(path) if is_image_file(x)]

        self.photo_dir = photo_dir
        self.photo_idx = dict()
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        for i, fl in enumerate(get_img_list(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        photo_path = os.path.join(self.photo_dir, self.photo_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        return photo_img

    def __len__(self):
        return len(self.photo_idx.keys())

In [None]:
ph_ds = PhotoDataset("./kaggle/input/gan-getting-started/photo_jpg")

In [None]:
ph_dl = DataLoader(ph_ds, batch_size=1, pin_memory=True)

In [None]:
!mkdir ./images

In [None]:
trans = transforms.ToPILImage()

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(m)

    return img

In [None]:
t = tqdm(ph_dl, leave=False, total=ph_dl.__len__())
for i, photo in enumerate(t):
    if i == 10:
        break
    with torch.no_grad():
        pred_monet = gan.g_ptm(photo.to("cuda")).cpu().detach()
    pred_monet = unnorm(pred_monet)
    img = trans(pred_monet[0]).convert("RGB")
    img.save("./images/" + str(i + 1) + ".jpg")