# Обучения CycleGAN

## Используемые библиотеки

In [1]:
import os
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as tt

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision.utils import save_image
from tqdm import notebook
from PIL import Image

## Константы

In [2]:
SEED = 21

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

SELFIE_PATH = 'путь до папки с тренировочным набором фотографий'
ANIME_PATH = 'путь до папки с тренировочным набором аниме изображений'
SELFIE_TEST_PATH = 'путь до папки с тестовым набором фотографий'
ANIME_TEST_PATH = 'путь до папки с тестовым набором аниме изображений'

# Общие константы:
IMAGE_SIZE = 128
BATCH_SIZE = 25
EPOCHS = 200

# Константы для оптимизатора:
LEARNING_RATE = 0.0003
BETA_ONE = 0.5
BETA_TWO = 0.999

# Константы для шедулера:
STEP = 101
GAMMA = 0.8

# Данные константы понадобятся, если Вы решите дообучать свою модель
# с другими параметрами или датасетами:
D_SELFIE = 'путь до папки с весами для фото-дискриминатора'
D_ANIME = 'путь до папки с весами для аниме-дискриминатора'
G_SELFIE_TO_ANIME = 'путь до папки с весами для генератора аниме из фотографий'
G_ANIME_TO_SELFIE = 'путь до папки с весами для генератора фотографий из аниме'

## Формирование датасета

In [3]:
class SelfieToAnimeDataset(Dataset):
    
    def __init__(self, directory: str) -> Dataset:
        
        path_list = os.listdir(directory)
        abspath = os.path.abspath(directory)
        
        self.directory = directory
        self.image_list = [os.path.join(abspath, path) for path in path_list]
        
        # Аугментации для преобразования полученных от пользователя изображений:
        self.transform = tt.Compose([
            tt.Resize(IMAGE_SIZE),
            tt.CenterCrop(IMAGE_SIZE),
            tt.RandomHorizontalFlip(p=0.5),
            tt.ToTensor(),
            tt.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    
    def __len__(self):
        
        return len(self.image_list)
    
    def __getitem__(self, index):
        
        path = self.image_list[index]
        image = Image.open(path).convert('RGB')
        
        return self.transform(image)

## Дискриминатор

In [4]:
class Discriminator(nn.Module):
    
    def __init__(self, features: int = 64) -> torch.Tensor:
        super().__init__()
        
        self.model = nn.Sequential(
            
            nn.Conv2d(3, features, kernel_size=4, 
                      stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Conv2d(features, features*2, kernel_size=4, 
                      stride=2, padding=1),
            nn.InstanceNorm2d(features*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(features*2, features*4, kernel_size=4, 
                      stride=2, padding=1),
            nn.InstanceNorm2d(features*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(features*4, features*8, kernel_size=4, 
                      stride=1, padding=1),
            nn.InstanceNorm2d(features*8),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Conv2d(features*8, 1, kernel_size=4, stride=1, padding=1)
        )
            
    def forward(self, x: torch.Tensor):
        x = self.model(x)
        x = F.avg_pool2d(x, x.size()[2:])
        x = torch.flatten(x, 1)
        return x

## Генератор

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, features: int) -> torch.Tensor:
        super().__init__()
        
        self.block = nn.Sequential(
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, kernel_size=3),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, kernel_size=3),
            nn.InstanceNorm2d(features)
        )
        
    def forward(self, x: torch.Tensor):
        return x + self.block(x)

In [6]:
class Generator(nn.Module):
    def __init__(self, in_features: int = 3, 
                 features: int = 64) -> torch.Tensor:
        super().__init__()
        
        self.model = nn.Sequential(
            
            nn.ReflectionPad2d(in_features),
            nn.Conv2d(in_features, features,
                      kernel_size=7, stride=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(features, features*2,
                      kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(features*2, features*4,
                      kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*4),
            nn.ReLU(inplace=True),

            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
            ResidualBlock(features*4),
        
            nn.ConvTranspose2d(features*4, features*2, kernel_size=3, stride=2, 
                               padding=1, output_padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(features*2, features, kernel_size=3, stride=2, 
                               padding=1, output_padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            
            nn.ReflectionPad2d(in_features),
            nn.Conv2d(features, in_features, kernel_size=7, stride=1),
            nn.Tanh()
        )
        
    def forward(self, x: torch.Tensor):
        return self.model(x)

## CycleGAN

In [7]:
class CycleGAN:
    def __init__(self):
        
        # Инициализация CUDA-ядер, при наличии:
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        
        self.G_selfie_to_anime = Generator().to(self.device)
        self.G_anime_to_selfie = Generator().to(self.device)
        
        self.D_selfie = Discriminator().to(self.device)
        self.D_anime = Discriminator().to(self.device)
    
    def load_weights(self, G_selfie_to_anime: str, G_anime_to_selfie: str,
                     D_selfie: str, D_anime: str):
        """
        Функция загружает необходимые веса для работы модели.

        Параметры:
        G_selfie_to_anime - путь до директории с весами
                            для генератора "из фотографии в аниме";
        G_anime_to_selfie - путь до директории с весами
                            для генератора "из аниме в фотографию";
        D_selfie - путь до директории с весами
                   для дискриминатора для "фотографий";
        D_anime - путь до директории с весами
                  для дискриминатора для "аниме".
        """
        self.G_selfie_to_anime.load_state_dict(torch.load(G_SELFIE_TO_ANIME))
        self.G_anime_to_selfie.load_state_dict(torch.load(G_ANIME_TO_SELFIE))
        self.D_selfie.load_state_dict(torch.load(D_SELFIE))
        self.D_anime.load_state_dict(torch.load(D_ANIME))
    
    def real_mse_loss(self, D_out: torch.Tensor):
        return torch.mean((D_out-1)**2)
    
    def fake_mse_loss(self, D_out: torch.Tensor):
        return torch.mean(D_out**2)
    
    def cycle_consistency_loss(self, real_img: torch.Tensor, 
                               reconstructed_img: torch.Tensor, lambda_w: int | float):
        return (torch.mean(torch.abs(real_img - reconstructed_img))) * lambda_w
    
    def train_generator(self, optimizer, selfie: torch.Tensor, anime: torch.Tensor):
        
        optimizer['generator'].zero_grad()
        
        with torch.cuda.amp.autocast():
            
            # Обучаем генератор "из аниме в фотографию":
            fake_selfie = self.G_anime_to_selfie(anime)
            real_selfie = self.D_selfie(fake_selfie)
            G_anime_to_selfie_loss = self.real_mse_loss(real_selfie)
            reconstructed_anime = self.G_selfie_to_anime(fake_selfie)
            reconstructed_anime_loss = self.cycle_consistency_loss(anime, reconstructed_anime, 10)
            
            # Обучаем генератор "из фотографии в аниме":
            fake_anime = self.G_selfie_to_anime(selfie)
            real_anime = self.D_anime(fake_anime)
            G_selfie_to_anime_loss = self.real_mse_loss(real_anime)
            reconstructed_selfie = self.G_anime_to_selfie(fake_anime)
            reconstructed_selfie_loss = self.cycle_consistency_loss(selfie, reconstructed_selfie, 10)
        
            G_loss = (G_anime_to_selfie_loss + 
                      reconstructed_anime_loss +
                      G_selfie_to_anime_loss +
                      reconstructed_selfie_loss)
        
        G_loss.backward()
        optimizer['generator'].step()
        
        return G_loss.item()
    
    def train_discriminator(self, optimizer: torch.optim.AdamW, 
                            selfie: torch.Tensor, anime: torch.Tensor,
                            real_noise: torch.Tensor,
                            fake_noise: torch.Tensor):
        
        optimizer['selfie_discriminator'].zero_grad()
        
        with torch.cuda.amp.autocast():
            
            # Обучаем фото-дискриминатор с добавлением шума:
            real_selfie = self.D_selfie(selfie)
            real_selfie_loss = self.real_mse_loss(real_selfie - real_noise)
            fake_selfie_images = self.G_anime_to_selfie(anime)
            fake_selfie = self.D_selfie(fake_selfie_images)
            fake_selfie_loss = self.fake_mse_loss(fake_selfie + fake_noise)

            D_selfie_loss = real_selfie_loss + fake_selfie_loss
        
        D_selfie_loss.backward()
        optimizer['selfie_discriminator'].step()
        
        optimizer['anime_discriminator'].zero_grad()
        
        with torch.cuda.amp.autocast():
            
            # Обучаем аниме-дискриминатор с добавлением шума:
            real_anime = self.D_anime(anime)
            real_anime_loss = self.real_mse_loss(real_anime - real_noise)
            fake_anime_image = self.G_selfie_to_anime(selfie)
            fake_anime = self.D_anime(fake_anime_image)
            fake_anime_loss = self.fake_mse_loss(fake_anime + fake_noise)

            D_anime_loss = real_anime_loss + fake_anime_loss
        
        D_anime_loss.backward()
        optimizer['anime_discriminator'].step()
        
        return D_selfie_loss.item(), D_anime_loss.item()
    
    def train(self, optimizer: torch.optim.AdamW, sheduler: StepLR, selfie_loader: DataLoader, 
              anime_loader: DataLoader, test_selfie_loader: DataLoader) -> tuple:
        
        losses = []   
        
        for epoch in notebook.tqdm(range(EPOCHS)):
            for (selfie_image, anime_image) in notebook.tqdm(zip(selfie_loader, anime_loader)):
                
                # Переносим батчи на СUDA-ядра:
                selfie = selfie_image.to(self.device)
                anime = anime_image.to(self.device)
                
                # генерируем шум:
                real_noise = 0.05 * torch.rand(selfie.size(0), 1, device=self.device)
                fake_noise = 0.05 * torch.rand(selfie.size(0), 1, device=self.device)
                
                # обучаем CycleGAN:
                G_loss = self.train_generator(optimizer, selfie, anime)
                D_selfie_loss, D_anime_loss = self.train_discriminator(
                    optimizer, selfie, anime, real_noise, fake_noise)
                
                losses.append((D_selfie_loss, D_anime_loss, G_loss))
            
            # Делаем шаг шедулером:
            sheduler['generator'].step()
            sheduler['selfie_discriminator'].step()
            sheduler['anime_discriminator'].step()
            
            # При необходимости выводим Learning Rate:
            print(sheduler['anime_discriminator'].get_last_lr())
            
            # Выводим общую информацию:
            print('Epoch [{}/{}] | D_selfie_loss: {:.4f} | D_anime_loss: {:.4f} | G_loss: {:.4f}'.format(
                epoch+1, EPOCHS, D_selfie_loss, D_anime_loss, G_loss))
            
            # Сохраняем веса каждого шага:
            torch.save(self.G_selfie_to_anime.state_dict(), str(epoch + 1) + 'G_selfie_to_anime')
            torch.save(self.G_anime_to_selfie.state_dict(), str(epoch + 1) + 'G_anime_to_selfie')
            torch.save(self.D_selfie.state_dict(), str(epoch + 1) + 'D_selfie')
            torch.save(self.D_anime.state_dict(), str(epoch + 1) + 'D_anime')
            
            # Визуализируем результат обучения для каждой эпохи:
            samples = []
            with torch.no_grad():
                for i in range(2):
                    fixed_selfie = next(iter(test_selfie_loader))[i].to(self.device)
                    fake_anime = self.G_selfie_to_anime(fixed_selfie)
                    samples.append(fixed_selfie)
                    samples.append(fake_anime)
            
            plt.figure(figsize=(10, 10))
            title = ['Selfie', 'Anime', 'Selfie', 'Anime']
                
            for i in range(4):
                plt.subplot(1, 4, i+1)
                plt.axis('off')
                plt.title(title[i])
                plt.imshow((samples[i] * 0.5 + 0.5).cpu().detach().permute(1, 2, 0))
            plt.show();
            
        return losses

## Подготовка к обучению

In [8]:
# Создаём датасеты:
selfie_dataset = SelfieToAnimeDataset(SELFIE_PATH)
anime_dataset = SelfieToAnimeDataset(ANIME_PATH)
test_selfie_dataset = SelfieToAnimeDataset(SELFIE_TEST_PATH)
test_anime_dataset = SelfieToAnimeDataset(ANIME_TEST_PATH)

# Создаём объекты класса DataLoader:
selfie_loader = DataLoader(selfie_dataset, BATCH_SIZE, shuffle=True)
anime_loader = DataLoader(anime_dataset, BATCH_SIZE, shuffle=True)
test_selfie_loader = DataLoader(test_selfie_dataset, BATCH_SIZE, shuffle=True)
test_anime_loader = DataLoader(test_anime_dataset, BATCH_SIZE, shuffle=True)

# Инициализируем веса:
model = CycleGAN()

# Если планируем продолжить обучение нашей модели с определённого момента,
# раскомментирем строчку ниже и инициализием веса модели:
# --------------------------------------------------------------------------
#model.load_weights(G_SELFIE_TO_ANIME, G_ANIME_TO_SELFIE, D_SELFIE, D_ANIME)
# --------------------------------------------------------------------------

# Объединим параметры генератора
generator_params: list = (list(model.G_selfie_to_anime.parameters()) +
                          list(model.G_anime_to_selfie.parameters()))

# Зададим параметры для оптимизатора:
optimizer = {
    'generator': torch.optim.AdamW(
        generator_params, LEARNING_RATE, [BETA_ONE, BETA_TWO]),
    'selfie_discriminator': torch.optim.AdamW(
        model.D_selfie.parameters(), LEARNING_RATE, [BETA_ONE, BETA_TWO]),
    'anime_discriminator': torch.optim.AdamW(
        model.D_anime.parameters(), LEARNING_RATE, [BETA_ONE, BETA_TWO])
}

# Зададим параметры для шедулера:
sheduler = {
    'generator': StepLR(optimizer['generator'], step_size=STEP, gamma=GAMMA),
    'selfie_discriminator': StepLR(optimizer['selfie_discriminator'], step_size=STEP, gamma=GAMMA),
    'anime_discriminator': StepLR(optimizer['anime_discriminator'], step_size=STEP, gamma=GAMMA)
}

## Обучение CycleGAN

In [None]:
losses = model.train(
    optimizer,
    sheduler,
    selfie_loader,
    anime_loader,
    test_selfie_loader)