In [None]:
#Library Import
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import albumentations as A
import torch.optim as optim
from albumentations.pytorch import ToTensorV2

from PIL import Image
import os
from tqdm import tqdm
import numpy as np
import random, torch, os, numpy as np
import copy
import sys

In [None]:
device = "cuda"

#directory paths
training_directory = "/content/drive/MyDrive/GANDataset/"
validation_directory = "/content/drive/MyDrive/GANDataset/"

In [None]:
#hyper parameters
batch_size = 1
learning_rate = 1e-5

lambda_cycle = 10
lambda_identity = 0.0


num_of_workers = 4
epochs = 20

load_model = True
save_model = True

In [None]:
#loading checkpoints

CHECKPOINT_GEN_ART = "/content/drive/MyDrive/GAN_Weights/genh.pth.tar"
CHECKPOINT_GEN_ARCHITECTURE = "/content/drive/MyDrive/GAN_Weights/genz.pth.tar"
CHECKPOINT_CRITIC_ART = "/content/drive/MyDrive/GAN_Weights/critich.pth.tar"
CHECKPOINT_CRITIC_ARCHITECTURE = "/content/drive/MyDrive/GAN_Weights/criticz.pth.tar"


In [None]:
#data augmentation
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)


In [None]:
class ArtAndArchitecureDataset(Dataset):

    def __init__(self, root_architecture, root_arts, transform= None):

        self.root_architecture = root_architecture
        self.root_arts = root_arts
        self.transform = transform
        self.architecture_images = os.listdir(root_architecture)
        self.art_images = os.listdir(root_arts)
        self.length_dataset = max(len(self.architecture_images), len(self.art_images)) # 1000, 1500
        self.architecture_len = len(self.architecture_images)
        self.art_len = len(self.art_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        archi_image = self.architecture_images[index % self.architecture_len]
        art_image = self.art_images[index % self.art_len]

        archi_path = os.path.join(self.root_architecture, archi_image)
        art_path = os.path.join(self.root_arts, art_image)

        archi_image = np.array(Image.open(archi_path).convert("RGB"))
        art_image = np.array(Image.open(art_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=archi_image, image0=art_image)
            archi_image = augmentations["image"]
            art_image = augmentations["image0"]

        return archi_image, art_image

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)


if __name__ == "__main__":
    test()

torch.Size([5, 1, 30, 30])


In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

if __name__ == "__main__":
    test()

torch.Size([2, 3, 256, 256])


In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
import torch

import sys

from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from torchvision.utils import save_image

#horse - > art
#zebra -> architecture

def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    Art_reals = 0
    Art_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (architecture, art) in enumerate(loop):
        architecture = architecture.to(device)
        art = art.to(device)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_art = gen_H(architecture)
            D_H_real = disc_H(art)
            D_H_fake = disc_H(fake_art.detach())
            Art_reals += D_H_real.mean().item()
            Art_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_architecture = gen_Z(art)
            D_Z_real = disc_Z(architecture)
            D_Z_fake = disc_Z(fake_architecture.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_art)
            D_Z_fake = disc_Z(fake_architecture)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_zebra = gen_Z(fake_art)
            cycle_horse = gen_H(fake_architecture)
            cycle_zebra_loss = l1(architecture, cycle_zebra)
            cycle_horse_loss = l1(art, cycle_horse)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_zebra = gen_Z(architecture)
            identity_horse = gen_H(art)
            identity_zebra_loss = l1(architecture, identity_zebra)
            identity_horse_loss = l1(art, identity_horse)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_zebra_loss * lambda_cycle
                + cycle_horse_loss * lambda_cycle
                + identity_horse_loss * lambda_identity
                + identity_zebra_loss * lambda_identity
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_art*0.5+0.5, f"/content/drive/MyDrive/GANSaveImages/art_{idx}.png")
            save_image(fake_architecture*0.5+0.5, f"/content/drive/MyDrive/GANSaveImages/architecture_{idx}.png")

        loop.set_postfix(H_real=Art_reals/(idx+1), H_fake=Art_fakes/(idx+1))



def main():
    disc_H = Discriminator(in_channels=3).to(device)
    disc_Z = Discriminator(in_channels=3).to(device)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(device)
    gen_H = Generator(img_channels=3, num_residuals=9).to(device)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=learning_rate,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=learning_rate,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if load_model:
        load_checkpoint(
            CHECKPOINT_GEN_ART, gen_H, opt_gen, learning_rate,
        )
        load_checkpoint(
            CHECKPOINT_GEN_ARCHITECTURE, gen_Z, opt_gen, learning_rate,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_ART, disc_H, opt_disc, learning_rate,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_ARCHITECTURE, disc_Z, opt_disc, learning_rate,
        )

    # dataset = HorseZebraDataset(
    #     root_horse=config.TRAIN_DIR+"/trainA", root_zebra=config.TRAIN_DIR+"/trainB", transform=config.transforms
    # )
    # val_dataset = HorseZebraDataset(
    #    root_horse="horse2zebra/testA", root_zebra="horse2zebra/testB", transform=config.transforms
    # )
    dataset = ArtAndArchitecureDataset(root_architecture=training_directory+"/TrainA", root_arts=training_directory+"/TrainB", transform=transforms)
    val_dataset = ArtAndArchitecureDataset(root_architecture=training_directory+"/TestA", root_arts=training_directory+"/TestB", transform=transforms)


    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_of_workers,
        pin_memory=True
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

        if save_model:
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_ART)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_ARCHITECTURE)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_ART)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_ARCHITECTURE)

if __name__ == "__main__":
    main()

=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint


  cpuset_checked))
100%|██████████| 6738/6738 [27:45<00:00,  4.05it/s, H_fake=0.376, H_real=0.617]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:32<00:00,  4.08it/s, H_fake=0.374, H_real=0.62]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:33<00:00,  4.07it/s, H_fake=0.373, H_real=0.62]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:31<00:00,  4.08it/s, H_fake=0.372, H_real=0.623]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:32<00:00,  4.08it/s, H_fake=0.37, H_real=0.624]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:31<00:00,  4.08it/s, H_fake=0.368, H_real=0.627]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 6738/6738 [27:27<00:00,  4.09it/s, H_fake=0.366, H_real=0.63]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


 59%|█████▊    | 3952/6738 [16:04<11:15,  4.13it/s, H_fake=0.365, H_real=0.631]