In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/horse2zebra-dataset/metadata.csv
/kaggle/input/horse2zebra-dataset/testB/n02391049_80.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_400.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_3770.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_8830.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_8340.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_6690.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_10980.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_8140.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_3060.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_1220.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_1880.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_2220.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_10160.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_2510.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_890.jpg
/kaggle/input/horse2zebra-dataset/testB/n02391049_2810.jpg
/kaggle/inp

In [2]:
import torch
import torch.nn as nn

In [3]:
from torch.utils.data import Dataset
from PIL import Image

In [4]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image

In [6]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random

In [18]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, strides):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, strides, 1, bias=True, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        return self.conv(x)


In [19]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, strides=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )
    
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [20]:
def test():
    x= torch.randn((5, 3, 256, 256))
    model=Discriminator()
    preds=  model(x)
    print(preds.shape)
    
test()

torch.Size([5, 1, 30, 30])


In [21]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
    
    def forward(self, x):
        return self.conv(x)

    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1, stride=1),
            ConvBlock(channels, channels, kernel_size=3, padding=1, use_act=False)
        )
    
    def forward(self, x):
        return x + self.block(x)

In [22]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_residuals=9, num_features=64):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True)
        )
        
        self.down_blocks = nn.ModuleList([
            ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1)
        ])
        
        self.residual_blocks = nn.Sequential(*[
            ResidualBlock(num_features * 4) for _ in range(num_residuals)
        ])
        
        self.up_blocks = nn.ModuleList([
            ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ConvBlock(num_features * 2, num_features * 1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ])
        
        self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [23]:
def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

test()

torch.Size([2, 3, 256, 256])


In [24]:
class HorseZebraDataset(Dataset):
    def __init__(self, zebra, horse, transform=None):
        self.horse = horse
        self.zebra = zebra
        self.transform = transform
        
        self.horse_images = os.listdir(horse)
        self.zebra_images = os.listdir(zebra)
        self.dataset_length = max(len(self.horse_images), len(self.zebra_images))
        self.horse_length = len(self.horse_images)
        self.zebra_length = len(self.zebra_images)
        
    def __len__(self):
        return self.dataset_length

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_length]
        horse_img = self.horse_images[index % self.horse_length]
        
        horse_path = os.path.join(self.horse, horse_img)
        zebra_path = os.path.join(self.zebra, zebra_img)
        
        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image=horse_img, image0=zebra_img)
            horse_img = augmentations["image"]
            zebra_img = augmentations["image0"]
        
        return horse_img, zebra_img

In [25]:
# Configs

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 45
LOAD_MODEL = False
SAVE_MODEL = True
HORSE_GEN_CHECKPOINT = "hgen.pth.tar"
ZEBRA_GEN_CHECKPOINT = "zgen.pth.tar"
HORSE_DISC_CHECKPOINT = "hdisc.pth.tar"
ZEBRA_DISC_CHECKPOINT = "zdisc.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [26]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("-> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("-> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        

def seed_everything(seed=42):
    os.environ["PYTHONASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = Trie
    torch.backends.cudnn.benchmark = False

In [27]:
save_count = 0

def train_fn(h_disc, z_disc, h_gen, z_gen, loader, disc_opt, gen_opt, l1,
            mse, d_scaler, g_scaler):
    global save_count
    loop = tqdm(loader, leave=True)
    
    fake_horses = 0
    real_horses = 0

    for idx, (horse, zebra) in enumerate(loop):
        horse = horse.to(DEVICE)
        zebra = zebra.to(DEVICE)

        # train descriminators
        with torch.cuda.amp.autocast():
            fake_horse = h_gen(zebra)
            real_horse_disc = h_disc(horse)
            fake_horse_disc = h_disc(fake_horse.detach())
            real_horses += real_horse_disc.mean().item()
            fake_horses += fake_horse_disc.mean().item()
            real_horse_disc_loss = mse(real_horse_disc, torch.ones_like(real_horse_disc))
            fake_horse_disc_loss = mse(fake_horse_disc, torch.zeros_like(fake_horse_disc))
            horse_disc_loss = real_horse_disc_loss + fake_horse_disc_loss


            fake_zebra = z_gen(horse)
            real_zebra_disc = z_disc(zebra)
            fake_zebra_disc = z_disc(fake_zebra.detach())

            real_zebra_disc_loss = mse(real_zebra_disc, torch.ones_like(real_zebra_disc))
            fake_zebra_disc_loss = mse(fake_zebra_disc, torch.zeros_like(fake_zebra_disc))
            zebra_disc_loss = real_zebra_disc_loss + fake_zebra_disc_loss

            disc_loss = (zebra_disc_loss + horse_disc_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(disc_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            # adversarial loss
            fake_horse_disc = h_disc(fake_horse)
            fake_zebra_disc = z_disc(fake_zebra)
            horse_gen_loss = mse(fake_horse_disc, torch.ones_like(fake_horse_disc))
            zebra_gen_loss = mse(fake_zebra_disc, torch.ones_like(fake_zebra_disc))

            # cycle loss
            cycled_zebra = z_gen(fake_horse) # fake horse is a horse generated from a zebra
            cycled_horse = h_gen(fake_zebra) # we generate a horse from a zebra image that is generated from a horse image. Should be the same
            cycled_zebra_loss = l1(zebra, cycled_zebra) # the diference between the original zebra image and the cyceled_zebra.
            cycled_horse_loss = l1(horse, cycled_horse)

            # identity loss
            zebra_identity = z_gen(zebra)
            horse_identity = h_gen(horse)
            zebra_identity_loss = l1(zebra, zebra_identity)
            horse_identity_loss = l1(horse, horse_identity)

            gen_loss = (
                horse_gen_loss + zebra_gen_loss
                + (cycled_zebra_loss * LAMBDA_CYCLE)
                + (cycled_horse_loss * LAMBDA_CYCLE)
                + (horse_identity_loss * LAMBDA_IDENTITY)
                + (zebra_identity_loss * LAMBDA_IDENTITY)
            )

        opt_gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        if idx % 200 == 0:
            save_image(fake_horse * 0.5 + 0.5, f"horse_{idx}_{save_count}.png")
            save_image(fake_zebra * 0.5 + 0.5, f"zebra_{idx}_{save_count}.png")
            save_count = save_count + 1

        loop.set_postfix(real_horse=real_horses / (idx + 1), fake_horse=fake_horses / (idx + 1))

In [None]:
h_disc = Discriminator(in_channels=3).to(DEVICE)
z_disc = Discriminator(in_channels=3).to(DEVICE)
h_gen = Generator(img_channels=3, num_residuals=9).to(DEVICE)
z_gen = Generator(img_channels=3, num_residuals=9).to(DEVICE)
opt_disc = optim.Adam(
    list(h_disc.parameters()) + list(z_disc.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(z_gen.parameters()) + list(h_gen.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)


L1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
    load_checkpoint(
        HORSE_GEN_CHECKPOINT,
        h_gen,
        opt_gen,
        LEARNING_RATE,
    )
    load_checkpoint(
        ZEBRA_GEN_CHECKPOINT,
        z_gen,
        opt_gen,
        LEARNING_RATE,
    )
    load_checkpoint(
        HORSE_DISC_CHECKPOINT,
        h_disc,
        opt_disc,
        LEARNING_RATE,
    )
    load_checkpoint(
        ZEBRA_DISC_CHECKPOINT,
        z_disc,
        opt_disc,
        LEARNING_RATE,
    )
    
    
dataset = HorseZebraDataset(
    horse="/kaggle/input/horse2zebra-dataset/trainA",
    zebra="/kaggle/input/horse2zebra-dataset/trainB",
    transform=transforms,
)
val_dataset = HorseZebraDataset(
    horse="/kaggle/input/horse2zebra-dataset/testA",
    zebra="/kaggle/input/horse2zebra-dataset/testB",
    transform=transforms,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(h_disc, z_disc, z_gen, h_gen, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,)

    if SAVE_MODEL:
        save_checkpoint(h_gen, opt_gen, filename=HORSE_GEN_CHECKPOINT)
        save_checkpoint(z_gen, opt_gen, filename=ZEBRA_GEN_CHECKPOINT)
        save_checkpoint(h_disc, opt_disc, filename=HORSE_DISC_CHECKPOINT)
        save_checkpoint(h_disc, opt_disc, filename=ZEBRA_DISC_CHECKPOINT)

 65%|██████▍   | 865/1334 [4:56:53<2:40:55, 20.59s/it, fake_horse=0.43, real_horse=0.562] 