In [None]:
# Code inspired from https://www.youtube.com/watch?v=4LktBHGCNfw

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import requests
from io import BytesIO
from torch.nn import init
import torch.optim as optim
import tensorflow_datasets as tfds
import albumentations as A
from tqdm import tqdm
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Data Initialization

In [None]:
def preprocess(image, target_size=(256, 256)):
    image = np.array(image)
    return cv2.resize(image, target_size)

In [None]:
class PaintingPictureDataset(Dataset):
    def __init__(self, root_painting, root_picture, transform=None):
        self.root_painting = root_painting
        self.root_picture = root_picture
        self.transform = transform

        self.painting_images = os.listdir(root_painting)
        self.picture_images = os.listdir(root_picture)
        self.length_dataset = max(len(self.painting_images), len(self.picture_images))
        self.painting_len = len(self.painting_images)
        self.picture_len = len(self.picture_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        painting_img = self.painting_images[index % self.painting_len]
        picture_img = self.picture_images[index % self.picture_len]

        painting_path = os.path.join(self.root_painting, painting_img)
        picture_path = os.path.join(self.root_picture, picture_img)

        painting_img = preprocess(Image.open(painting_path).convert("RGB"))
        picture_img = preprocess(Image.open(picture_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=painting_img, image0=picture_img)
            painting_img = augmentations["image"]
            picture_img = augmentations["image0"]

        return painting_img, picture_img

In [None]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
    is_check_shapes=False
)

In [None]:
dataset = PaintingPictureDataset("/content/drive/MyDrive/munch_paintings", "/content/drive/MyDrive/photo_jpg", transform=transforms)
# val_dataset = PaintingPictureDataset(root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=transforms)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
# val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True)

Define Discreminator

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                4,
                stride,
                1,
                bias=True,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,
                1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect",
            )
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))


In [None]:
disc_Picture = Discriminator().to(device)
disc_Painting = Discriminator().to(device)

Define Generator

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                img_channels,
                num_features,
                kernel_size=7,
                stride=1,
                padding=3,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 4,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                ),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features * 4,
                    num_features * 2,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 1,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
            ]
        )

        self.last = nn.Conv2d(
            num_features * 1,
            img_channels,
            kernel_size=7,
            stride=1,
            padding=3,
            padding_mode="reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
gen_Painting = Generator(img_channels=3, num_residuals=9).to(device)
gen_Picture = Generator(img_channels=3, num_residuals=9).to(device)

Training Loop

In [None]:
save_path = "/content/drive/MyDrive/saved_models"

In [None]:
from torchvision.utils import save_image
def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (painting, picture) in enumerate(loop):
        painting = painting.to(device)
        picture = picture.to(device)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_picture = gen_H(painting)
            D_H_real = disc_H(picture)
            D_H_fake = disc_H(fake_picture.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_painting = gen_Z(picture)
            D_Z_real = disc_Z(painting)
            D_Z_fake = disc_Z(fake_painting.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_picture)
            D_Z_fake = disc_Z(fake_painting)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_painting = gen_Z(fake_picture)
            cycle_picture = gen_H(fake_painting)
            cycle_painting_loss = l1(painting, cycle_painting)
            cycle_picture_loss = l1(picture, cycle_picture)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_painting = gen_Z(painting)
            identity_picture = gen_H(picture)
            identity_painting_loss = l1(painting, identity_painting)
            identity_picture_loss = l1(picture, identity_picture)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_painting_loss * 10
                + cycle_picture_loss * 10
                + identity_picture_loss * 0
                + identity_painting_loss * 0
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_picture * 0.5 + 0.5, f"/content/saved_images/picture_{idx}.png")
            save_image(fake_painting * 0.5 + 0.5, f"/content/saved_images/painting_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))
        if idx % 10 == 0:
          torch.save(gen_Painting.state_dict(), f"{save_path}/gen_Painting.pt")
          torch.save(gen_Picture.state_dict(), f"{save_path}/gen_Picture.pt")

In [None]:
opt_disc = optim.Adam(
        list(disc_Picture.parameters()) + list(disc_Painting.parameters()),
        lr=1e-5,
        betas=(0.5, 0.999),
    )

opt_gen = optim.Adam(
    list(gen_Painting.parameters()) + list(gen_Picture.parameters()),
    lr=1e-5,
    betas=(0.5, 0.999),
)

L1 = nn.L1Loss()
mse = nn.MSELoss()

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

for epoch in range(10):
    train_fn(disc_Picture, disc_Painting, gen_Painting, gen_Picture, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,)

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
 22%|██▏       | 1528/7038 [06:58<12:37,  7.27it/s, H_fake=0.434, H_real=0.56]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7eaa983f6c20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7eaa983f6c20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/u