<div>
  <hr>
  <h1 align="middle">Проект по реализации <b>генеративной модели (GAN)</b>.</h1>
  <ul>
    <li><h4>Выполнил: Мартынов Владислав</h4></li>
    <li><h4>Поток: Продвинутый</h4></li>
    <li><h4>StepikID: 596247708</h4></li>
    <li><h4>GitHub: <a href=""https://github.com/VladMartinov>VladMartinov</a></h4></li>
  </ul>
  <hr>
  <h2>План реализации проекта:</h2>
  <ol>
    <li><h4>Определение задачи и выбор архитектуры;</h4></li>
    <li><h4>Тестирование <b>своей</b> модели на уже решенной задаче;</h4></li>
    <li><h4>Поиск dataset'а для своей задачи;</h4></li>
    <li><h4>Решение поставленной задачи (своей) при помощи своей модели и своего dataset'а;</h4></li>
    <li><h4>Реализация удобного интерфейса для генерации.</h4></li>
  </ol>
  <hr>
  <h2>План реализации модели:</h2>
  <ol>
    <li><h4>Изучение структур различных генеративных моделей типа CycleGan;</h4></li>
    <li><h4>Реализация моделей и их тестирование;</h4></li>
  </ol>
  <hr>
</div>
<img src="https://files.realpython.com/media/An-Introduction-to-Generative-Adversarial-Networks-GANs_Watermarked.6b71bfd66fda.jpg" alt="GANS img" align="middle" />

<hr>
<h2>Инициализируем все <b>необходимые библиотеки</b> для данного проекта:</h2>

<p>Ниже будут инициализированы весы модели (в случае если они у нас есть), подключены все необходимые импорты, а так же будут инициализированы различные параметры обучения и вспомогательные функции.</p>

In [None]:
!mkdir "saved_imgs"

In [None]:
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules.linear import Identity
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

# Transform
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Other
from google.colab import drive
from tqdm import tqdm
from PIL import Image
import os
import numpy as np
import random
import copy

In [None]:
# Configuration to model training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Path to dataset
TRAIN_DIR = "/content/monet2photo/train"
VAL_DIR = "/content/monet2photo/val"

BATCH_SIZE = 1

LEARNING_RATE = 1e-5

LAMBDA_IDENTITY = 5
LAMBDA_CYCLE = 10

NUM_WORKERS = 2
NUM_EPOCHS = 20

LOAD_MODEL= True
SAVE_MODEL = True

# If we working with Google Drive, Colab
LOAD_FROM_DRIVE = True
SAVE_TO_DRIVE = True

CHECKPOINT_GEN_M = "genm.pth.tar"
CHECKPOINT_GEN_P = "genp.pth.tar"
CHECKPOINT_DISC_M = "discm.pth.tar"
CHECKPOINT_DISC_P = "discp.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [None]:
# Utility functions
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
if LOAD_FROM_DRIVE:
    drive.mount('/content/gdrive')

    # Models weights in Drive
    !cp /content/gdrive/MyDrive/CycleGAN/discm.pth.tar /content
    !cp /content/gdrive/MyDrive/CycleGAN/discp.pth.tar /content
    !cp /content/gdrive/MyDrive/CycleGAN/genm.pth.tar /content
    !cp /content/gdrive/MyDrive/CycleGAN/genp.pth.tar /content

<hr>
<h2>Загрузка <b>dataset's</b> для тренировки моделей</h2>
<h3>Всего будет загружено 2-а dataset'а:</h3>
<ul>
  <li><h4>Monet2Photo;</h4></li>
  <li><h4>Anime2Photo (Coming soon);</h4></li>
</ul>

<p>p.s.: В моем случае датасет находился на Google Drive в виде zip архива, я его распаковал так, чтобы он совпадал с ранее указанным путем.</p>

In [None]:
# !unzip "/content/gdrive/MyDrive/CycleGAN/datasets/monet2photo.zip" -d "/content/"

In [None]:
class MonetPhotoSet(Dataset):
  def __init__(self, root_monet, root_photo, transform=None):
    self.root_monet = root_monet
    self.root_photo = root_photo

    self.transform = transform

    self.monet_imgs = os.listdir(self.root_monet)
    self.photo_imgs = os.listdir(self.root_photo)

    # Max length of this dataset's (they are not equals)
    self.length_dataset = max(len(self.monet_imgs), len(self.photo_imgs))

    self.monet_dataset_len = len(self.monet_imgs)
    self.photo_dataset_len = len(self.photo_imgs)

  def __len__(self):
    return self.length_dataset

  def __getitem__(self, index):
    monet_img = self.monet_imgs[index % self.monet_dataset_len]
    photo_img = self.photo_imgs[index % self.photo_dataset_len]

    monet_path = os.path.join(self.root_monet, monet_img)
    photo_path = os.path.join(self.root_photo, photo_img)

    monet_img = np.array(Image.open(monet_path).convert("RGB"))
    photo_img = np.array(Image.open(photo_path).convert("RGB"))

    if self.transform:
      aug = self.transform(image=monet_img, image0=photo_img)

      monet_img = aug['image']
      photo_img = aug['image0']

    return monet_img, photo_img


<hr>
<h2><b>Реализация архитектуры модели CycleGan</b></h2>

<hr>
<h2>Для <b>generato'а</b> была реализована следующая структура:</h2>

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, input_chanels, out_chanels, is_down=True, is_act=True, **kwargs):
    super().__init__()

    self.conv_block = nn.Sequential(
        nn.Conv2d(input_chanels, out_chanels, padding_mode="reflect", **kwargs)
        if is_down
        else nn.ConvTranspose2d(input_chanels, out_chanels, **kwargs),
        nn.ReLU(inplace=True) if is_act else nn.Identity(),
    )

  def forward(self, x):
    return self.conv_block(x)

class ResidualBlock(nn.Module):
    def __init__(self, chanels):
      super().__init__()

      self.res_block = nn.Sequential(
          ConvBlock(chanels, chanels, kernel_size=3, padding=1),
          ConvBlock(chanels, chanels, is_act=False, kernel_size=3, padding=1),
      )

    def forward(self, x):
      return x + self.res_block(x)


# GAN Generator #
class Generator(nn.Module):
    def __init__(self, input_channels=3, num_features=[64, 128, 256, 512], num_residual_blocks=9):
        super().__init__()
        #---------------------------------------------------
        # We create the generator with the next arhitecture:
        # Conv2d -> ReLU
        # 2 * ConvBlock(down)
        # num_residual_blocks * ResidualBlock
        # 2 * ConvBlock(up)
        # Conv2d -> tanh
        #---------------------------------------------------

        # Convolution Layears
        self.initial_blocks = nn.Sequential(
            nn.Conv2d(input_channels, num_features[0], kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True),
        )

        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features[0], num_features[1], kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features[1], num_features[2], kernel_size=3, stride=2, padding=1),
            ]
        )

        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features[2]) for _ in range(num_residual_blocks)]
        )

        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features[2], num_features[1], is_down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features[1], num_features[0], is_down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last_layer = nn.Conv2d(num_features[0], input_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial_blocks(x)

        for layer in self.down_blocks:
          x = layer(x)

        for layer in self.res_blocks:
          x = layer(x)

        for layer in self.up_blocks:
          x = layer(x)

        x = self.last_layer(x)

        return torch.tanh(x)

In [None]:
# Test the Generator
def test_G():
  img_chanel = 3
  img_size = 256

  x = torch.randn(2, img_chanel, img_size, img_size)
  model = Generator()
  preds = model(x)

  print(preds.shape)

test_G()

<hr>
<h2>Для <b>discriminator'а</b> была реализована следующая структура:</h2>

In [None]:
# GAN Discriminator #
class InstanceBlock(nn.Module):
    def __init__(self, input_channels, out_channels, stride):
      super().__init__()

      self.conv = nn.Sequential(
          nn.Conv2d(input_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
          nn.InstanceNorm2d(out_channels),
          nn.LeakyReLU(0.2),
      )

    def forward(self, x):
      return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, input_channels=3, num_features=[64, 128, 256, 512]):
      super().__init__()
      # ------------------------------------------------------
      # We create the discriminator with the next arhitecture:
      # Conv2d -> LeakyRely
      # n_layers * (Conv2d -> InstanceNorm2d -> LeakyRelu)
      # Conv2d -> sigmoid
      # ------------------------------------------------------

      self.initial_layer = nn.Sequential(
          nn.Conv2d(input_channels, num_features[0], 4, 2, 1, padding_mode="reflect"),
          nn.LeakyReLU(0.2),
      )

      # layers array
      layers = []

      input_channels = num_features[0]
      for feature in num_features[1:]:
        layers.append(InstanceBlock(input_channels, feature, stride=1 if feature==num_features[-1] else 2))
        input_channels = feature

      layers.append(nn.Conv2d(input_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
      self.model = nn.Sequential(*layers)

    def forward(self, x):
      x = self.initial_layer(x)
      return torch.sigmoid(self.model(x))

In [None]:
# Test the Discriminator
def test_D():
  x = torch.randn(5, 3, 256, 256)

  model = Discriminator(input_channels=3)
  preds = model(x)

  print(preds.shape)

test_D()

<hr>
<h2><strong>Дополнительный</strong> функционал:</h2>
<h3>- Буфер</h3>

In [None]:
class CycleBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Max_size is wrong. Be careful!"
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, img):
        img_to_return = img

        if len(self.data) < self.max_size:
            self.data.append(img)
        else:
            if random.uniform(0, 1) > 0.5:
                i = random.randint(0, self.max_size - 1)
                img_to_return = self.data[i].clone()
                self.data[i] = img

        return img_to_return

<hr>
<h2>Обучение наших моделей на разных задачах</h2>

In [None]:
disc_M = Discriminator(input_channels=3).to(DEVICE)
disc_P = Discriminator(input_channels=3).to(DEVICE)

gen_M = Generator(input_channels=3, num_residual_blocks=9).to(DEVICE)
gen_P = Generator(input_channels=3, num_residual_blocks=9).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_M.parameters()) + list(disc_P.parameters()),
    lr = LEARNING_RATE,
    betas=(0.5, 0.999),
)
opt_gen = optim.Adam(
    list(gen_M.parameters()) + list(gen_P.parameters()),
    lr = LEARNING_RATE,
    betas=(0.5, 0.999),
)

l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()

if LOAD_MODEL:
  load_checkpoint(
      CHECKPOINT_GEN_M, gen_M, opt_gen, LEARNING_RATE,
  )
  load_checkpoint(
      CHECKPOINT_GEN_P, gen_P, opt_gen, LEARNING_RATE,
  )

  load_checkpoint(
      CHECKPOINT_DISC_M, disc_M, opt_disc, LEARNING_RATE,
  )
  load_checkpoint(
      CHECKPOINT_DISC_P, disc_P, opt_disc, LEARNING_RATE,
  )

dataset = MonetPhotoSet(
    root_photo=TRAIN_DIR+"/trainB", root_monet=TRAIN_DIR+"/trainA", transform=transforms
)

loader = DataLoader(
    dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS,
    pin_memory = True,
)

g_scaler=torch.cuda.amp.GradScaler()
d_scaler=torch.cuda.amp.GradScaler()

buffer_AB = CycleBuffer()
buffer_BA = CycleBuffer()

In [None]:
def train(disc_M, disc_P, gen_M, gen_P, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, buf_AB, buf_BA):
  for epoch in range(NUM_EPOCHS):
    P_reals = 0
    P_fakes = 0

    loop = tqdm(loader, leave=True)

    for idx, (monet, photo) in enumerate(loop):
        monet = monet.to(DEVICE)
        photo = photo.to(DEVICE)

        # Train discriminators P and M
        with torch.cuda.amp.autocast():
            # Train discriminators P
            fake_photo = gen_P(monet)
            fake_photo_B = buf_BA.push_and_pop(fake_photo)

            D_P_real = disc_P(photo)
            D_P_fake = disc_P(fake_photo_B.detach())

            P_reals += D_P_real.mean().item()
            P_fakes += D_P_fake.mean().item()

            D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
            D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))

            D_P_loss = D_P_real_loss + D_P_fake_loss

            # Train discriminators M
            fake_monet = gen_M(photo)
            fake_monet_B = buf_AB.push_and_pop(fake_monet)

            D_M_real = disc_M(monet)
            D_M_fake = disc_M(fake_monet_B.detach())

            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))

            D_M_loss = D_M_real_loss + D_M_fake_loss

            # Add together
            D_loss = (D_P_loss + D_M_loss) / 2

        opt_disc.zero_grad()

        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generators P and M
        with torch.cuda.amp.autocast():
            # Adversarial loss
            D_P_fake = disc_P(fake_photo)
            D_M_fake = disc_M(fake_monet)

            G_P_loss = mse(D_P_fake, torch.ones_like(D_P_fake))
            G_M_loss = mse(D_M_fake, torch.ones_like(D_M_fake))

            # Cycle loss
            cycle_monet = gen_M(fake_monet)
            cycle_photo = gen_P(fake_photo)
            cycle_monet_loss = l1(monet, cycle_monet)
            cycle_photo_loss = l1(photo, cycle_photo)

            # Identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_monet = gen_M(monet)
            identity_photo = gen_P(photo)
            identity_monet_loss = l1(monet, identity_monet)
            identity_photo_loss = l1(photo, identity_photo)

            # Add togethor
            G_loss = (
                G_M_loss +
                G_P_loss +
                cycle_monet_loss * LAMBDA_CYCLE +
                cycle_photo_loss * LAMBDA_CYCLE +
                identity_photo_loss * LAMBDA_IDENTITY +
                identity_monet_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()

        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_photo * 0.5 + 0.5, f"/content/saved_imgs/photo_{idx}.png")
            save_image(fake_monet * 0.5 + 0.5, f"/content/saved_imgs/monet_{idx}.png")

        loop.set_postfix(P_real=P_reals / (idx + 1), P_fake=P_fakes / (idx + 1))

    if SAVE_MODEL:
      # Save to local #
      save_checkpoint(gen_M, opt_gen, filename=CHECKPOINT_GEN_M)
      save_checkpoint(gen_P, opt_gen, filename=CHECKPOINT_GEN_P)
      save_checkpoint(disc_M, opt_disc, filename=CHECKPOINT_DISC_M)
      save_checkpoint(disc_P, opt_disc, filename=CHECKPOINT_DISC_P)

      !zip -r /content/saved_imgs_{epoch}.zip /content/saved_imgs

      if SAVE_TO_DRIVE:
          # Save to google drive #
          !cp "/content/discm.pth.tar" "/content/gdrive/MyDrive/CycleGAN"
          !cp "/content/discp.pth.tar" "/content/gdrive/MyDrive/CycleGAN"
          !cp "/content/genm.pth.tar" "/content/gdrive/MyDrive/CycleGAN"
          !cp "/content/genp.pth.tar" "/content/gdrive/MyDrive/CycleGAN"

          !cp /content/saved_imgs_{epoch}.zip /content/gdrive/MyDrive/CycleGAN

In [None]:
train(disc_M, disc_P, gen_M, gen_P, loader, opt_disc, opt_gen, l1_loss, mse_loss, d_scaler, g_scaler, buffer_AB, buffer_BA)

<hr>
<h2>Используемая литература:</h2>
<ol>
  <li><h4><a href="https://arxiv.org/pdf/1703.10593.pdf">Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.</a></h4></li>
  <li><h4><a href="https://www.youtube.com/watch?v=5jziBapziYE&list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&index=8">CycleGAN Paper Walkthrough.</a></h4></li>
  <li><h4><a href="https://hannibunny.github.io/mlbook/gan/GAN.html">HOCHSCHULE DER MEDIEN. Generative Adversarial Nets (GAN)</a></h4></li>
  <li><h4><a href="https://nn.labml.ai/gan/cycle_gan/index.html">labml.ai. Cycle GAN</a></h4></li>
  <li><h4><a href="https://blog.paperspace.com/unpaired-image-to-image-translations-with-cycle-gans/">Unpaired Image to Image Translations with Cycle GANs</a></h4></li>
</ol>
<hr>