## Импорты:

In [1]:
import numpy as np
import pandas as pd
import math
import os
import sys

from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

## КОНСТАНТЫ

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
METADATA_PATH = '/kaggle/input/edges2shoes-dataset/metadata.csv'
IMG_PATH = '/kaggle/input/edges2shoes-dataset/'
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_WORKERS = 2
LEARNING_RATE = 2e-4
BETA_1 = 0.5
BETA_2 = 0.999
NUM_EPOCHS = 16
BATCH_SIZE = 4
IMAGE_SIZE = 256
SAVE_MODEL = True

transform_both = A.Compose(
    [
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.HorizontalFlip(p=0.5)                
    ],
    additional_targets={'image0':'image'}
)
        
transform_input = A.Compose(
    [
        A.ColorJitter(p=0.1),
        A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), 255.0),
        ToTensorV2()
    ]
)
        
transform_target = A.Compose(
    [
        A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), 255.0),
        ToTensorV2()            
    ]
)

## БЛОКИ И МОДЕЛИ

In [3]:
class CkBLOCK(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, padding = 0, stride = 2):
        super(CkBLOCK, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding = padding,
                      padding_mode='reflect',
                      bias=False
                     ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        x = self.block(x)
        return x

In [4]:
class CkBLOCK_Transpose(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, padding = 0, stride = 2):
        super(CkBLOCK_Transpose, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding = padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.block(x)
        return x

In [5]:
class CDkBLOCK(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, padding = 0, stride = 2):
        super(CDkBLOCK, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding = padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.dropout = nn.Dropout2d(p = 0.5)
    def forward(self, x):
        x = self.block(x)
        return self.dropout(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()
        self.c64 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels*2,
                      out_channels=64,
                      kernel_size=4,
                      stride=2,
                      bias=False,
                      padding = 1,
                      padding_mode='reflect'
                     ),
            nn.LeakyReLU(0.2))
        self.c128 = CkBLOCK(64, 128, padding = 1)
        self.c256 = CkBLOCK(128, 256, padding = 1)
        self.c512 = CkBLOCK(256, 512, padding = 1, stride = 1)
        self.fin_conv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding = 1, padding_mode='reflect')
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, y):
        x = torch.cat([x,y], dim=1)
        x = self.c64(x)
        x = self.c128(x)
        x = self.c256(x)
        x = self.c512(x)
        x = self.sigmoid(self.fin_conv(x))
        return x
        

In [7]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()
        # encoder        
        self.c64_enc = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=64,
                      kernel_size=4,
                      stride=2,
                      padding = 1,
                      padding_mode='reflect',
                      bias=False
                     ),
            nn.LeakyReLU(0.2)
        ) # 256 -> 128
        self.c128_enc = CkBLOCK(64, 128, padding = 1) # 128 -> 64
        self.c256_enc = CkBLOCK(128, 256, padding = 1) # 64 -> 32
        self.c512_1_enc = CkBLOCK(256, 512, padding = 1) # 32 -> 16
        self.c512_2_enc = CkBLOCK(512, 512, padding = 1) # 16 -> 8
        self.c512_3_enc = CkBLOCK(512, 512, padding = 1) # 8 -> 4
        self.c512_4_enc = CkBLOCK(512, 512, padding = 1) # 4 -> 2
        #bottleneck
        self.bottleneck = nn.Sequential(
            CkBLOCK(512, 1024, kernel_size = 2, padding = 1),
            CDkBLOCK(1024, 512, kernel_size = 2, padding = 1)
        )
        #decoder
        self.c1024_4_dec = CDkBLOCK(1024, 512, padding = 1) # 2 -> 4
        self.c1024_3_dec = CDkBLOCK(1024, 512, padding = 1) # 4 -> 8
        self.c1024_2_dec = CkBLOCK_Transpose(1024, 512, padding = 1) # 8 -> 16
        self.c1024_1_dec = CkBLOCK_Transpose(1024, 256, padding = 1) # 16 -> 32
        self.c512_dec = CkBLOCK_Transpose(512, 128, padding = 1) # 32 -> 64
        self.c256_dec = CkBLOCK_Transpose(256, 64, padding = 1) # 64 -> 128
        self.c128_dec = CkBLOCK_Transpose(128, in_channels, padding = 1) # 128 -> 256
        self.tanh = nn.Tanh()
    def forward(self, x):
        d1 = self.c64_enc(x)
        d2 = self.c128_enc(d1)
        d3 = self.c256_enc(d2)
        d4 = self.c512_1_enc(d3)
        d5 = self.c512_2_enc(d4)
        d6 = self.c512_3_enc(d5)
        d7 = self.c512_4_enc(d6)
        bottleneck = self.bottleneck(d7)
        u2 = self.c1024_4_dec(torch.cat([bottleneck, d7], dim=1))
        u3 = self.c1024_3_dec(torch.cat([u2, d6], dim=1))
        u4 = self.c1024_2_dec(torch.cat([u3, d5], dim=1))
        u5 = self.c1024_1_dec(torch.cat([u4, d4], dim=1))
        u6 = self.c512_dec(torch.cat([u5, d3], dim=1))
        u7 = self.c256_dec(torch.cat([u6, d2], dim=1))
        u8 = self.c128_dec(torch.cat([u7, d1], dim=1))
        x = self.tanh(u8)
        return x

## ЗАГРУЗКА ДАТАСЕТА

In [8]:
class ShoesDataset(Dataset):
    def __init__(self, root_dir: str, mode: str):
        self.mode = mode
        self.root_dir = root_dir + f"{mode}/"
        self.file_list = os.listdir(self.root_dir)
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, index):
        img_file = self.file_list[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = Image.open(img_path)
        
        img_w, img_h = image.size
        
        input_image_PIL = image.crop((0, 0, img_w//2, img_h))
        target_image_PIL = image.crop((img_w//2, 0, img_w, img_h))
        
        input_image = np.array(input_image_PIL)
        target_image = np.array(target_image_PIL)
        
        augmentations = transform_both(image = input_image, image0 = target_image)
        input_image, target_image = augmentations['image'], augmentations['image0']
        
        input_image = transform_input(image = input_image)['image']
        target_image = transform_target(image = target_image)['image']
        
        return input_image, target_image
                

In [9]:
train_dataset = ShoesDataset(root_dir = IMG_PATH, mode = "train")
val_dataset = ShoesDataset(root_dir = IMG_PATH, mode = "val")

In [10]:
len(val_dataset)

200

In [11]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle = True, num_workers=NUM_WORKERS)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle = False)

In [12]:
x, y = next(iter(train_dataloader))
len(x), len(y)

(4, 4)

## ЧЕКПОИНТЫ

In [13]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"label_{epoch}.png")
    gen.train()

In [14]:
def save_checkpoint(model, opt, filename):
    print('---SAVING CHECKPOINT---')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': opt.state_dict()
    }
    torch.save(checkpoint, filename)

In [15]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('---LOADING CHECKPOINT---')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

## ОБУЧЕНИЕ

In [16]:
disc = Discriminator().to(DEVICE)
gen = Generator().to(DEVICE)
opt_disc = torch.optim.Adam(disc.parameters(), lr = LEARNING_RATE, betas=(BETA_1, BETA_2))
opt_gen = torch.optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas=(BETA_1, BETA_2))
BCE = nn.BCEWithLogitsLoss()
L1_loss = nn.L1Loss()
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [17]:
load_checkpoint("/kaggle/input/checkpoints/gen.pth.tar", gen, opt_gen, LEARNING_RATE)

---LOADING CHECKPOINT---


In [18]:
x, y = next(iter(val_dataloader))
x, y = x.to(DEVICE), y.to(DEVICE)
gen.eval()
with torch.no_grad():
    y_fake = gen(x)
    save_image(y_fake, "/kaggle/working/" + f"/y_val_normalized.png")
    y_fake1 = y_fake * 0.5 + 0.5
    save_image(y_fake1, "/kaggle/working/" + f"/y_val_denorm.png")
gen.train()

Generator(
  (c64_enc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (c128_enc): CkBLOCK(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (c256_enc): CkBLOCK(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (c512_1_enc): CkBLOCK(
    (block): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(512, eps=1e-05, mome

train func was adopted from this video https://www.youtube.com/watch?v=SuddDSqGRzg

In [17]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
    loop = tqdm(loader, leave = True)
    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)
        
        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_fake = disc(x, y_fake.detach())
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2
        
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1
            
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

In [None]:
for epoch in range(NUM_EPOCHS):
    train_fn(disc, gen, train_dataloader, opt_disc, opt_gen, L1_loss, BCE, g_scaler, d_scaler)
    if SAVE_MODEL and epoch % 5 == 0:
        save_checkpoint(gen, opt_gen, filename = 'gen.pth.tar')
        save_checkpoint(disc, opt_disc, filename = 'disc.pth.tar')
    save_some_examples(gen, val_dataloader, epoch, folder = '/kaggle/working/')

  0%|          | 0/12457 [00:00<?, ?it/s]

---SAVING CHECKPOINT---
---SAVING CHECKPOINT---


  0%|          | 0/12457 [00:00<?, ?it/s]

  0%|          | 0/12457 [00:00<?, ?it/s]

KeyboardInterrupt: 