In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader


class ConvBlock(nn.Module):
    """two convolution layers with batch norm and leaky relu"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(ConvBlock, self).__init__()
        self.conv_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_p),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.conv_conv(x)


class DownBlock(nn.Module):
    """Downsampling followed by ConvBlock"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(DownBlock, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, dropout_p)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class UpBlock(nn.Module):
    """Upssampling followed by ConvBlock"""
    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p):
        super(UpBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

    def forward(self, x1, x2):
        x1 = self.conv1x1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5)
        self.in_conv = ConvBlock(
            self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(
            self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(
            self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(
            self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        self.down4 = DownBlock(
            self.ft_chns[3], self.ft_chns[4], self.dropout[4])

    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        return x4, [x0, x1, x2, x3, x4]

class Decoder(nn.Module):
    def __init__(self, params):
        super(Decoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.in_chns, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x_last = self.up4(x, x0)
        output = self.out_conv(x_last)
        return output, x_last
    

class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()

        params = {'in_chns': in_channels,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        dim_in = 16
        feat_dim = 32
        

    def forward(self, x):
        _, feature = self.encoder(x)
        output, features = self.decoder(feature)
        return torch.sigmoid(output)
    

class Block(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride,
                      padding,
                      padding_mode='reflect',
                      bias=True),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2))

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):

    def __init__(self, in_channels=3, features=(64, 128, 256, 512)):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=features[0],
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels=in_channels,
                                out_channels=feature,
                                kernel_size=4,
                                stride= 1 if feature == features[-1] else 2,
                                padding=1,
            ))
            in_channels = feature

        layers.append(nn.Conv2d(in_channels,
                                1, 4, 1, 1, padding_mode='reflect'))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial_layer(x)
        return torch.sigmoid(self.model(x))

In [24]:
####-----------Define the dataloaders and dataset class-------------####
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np


class monet2photo(Dataset):

    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.root_monet = os.path.join(root_dir, 'trainA')
        self.root_photos = os.path.join(root_dir, 'trainB')

        self.monet_images = os.listdir(self.root_monet)
        self.photo_images = os.listdir(self.root_photos)

        self.length = max(len(self.monet_images), len(self.photo_images))
        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)

        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):

        photo_img = Image.open(os.path.join(self.root_photos, self.photo_images[index % self.photo_len])).convert('RGB')
        monet_img = Image.open(os.path.join(self.root_monet, self.monet_images[index % self.monet_len])).convert('RGB')

        photo_img = np.array(photo_img) / 255.
        monet_img = np.array(monet_img) / 255.
        photo_img, monet_img = photo_img.astype(np.float32), monet_img.astype(np.float32)


        if self.transform:
            photo_img = self.transform(photo_img)
            monet_img = self.transform(monet_img)

        return monet_img, photo_img





In [25]:
# Define helper functions
import random, torch, os, numpy as np


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [26]:
# train
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.transforms import Compose, Resize, RandomHorizontalFlip, Normalize, ToTensor

# Hyperparameters and configs
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available else 'cpu'
root_dir = 'monet2photo'
batch_size = 8
lr_rate = 1e-4
lambda_identity = 5
lambda_cycle = 10
num_epochs = 50
load_model = False
save_model = True
monet_generator = 'monet_gen.pth.tar'
photo_generator = 'photo_gen.pth.tar'
monet_discriminator = 'monet_dis.pth.tar'
photo_discriminator = 'photo_dis.pth.tar'
num_workers = 4
transforms = Compose(
    [
        ToTensor(),
        Resize(size=(256, 256)),
        RandomHorizontalFlip(p=0.5),
    ]
)

In [27]:
## The training loop
def train_step(disc_y, disc_x, gen_ytox, gen_xtoy, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler):
    loop = tqdm(loader)
    y_reals = 0
    y_fakes = 0
    gen_loss = 0.0
    disc_loss = 0.0
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(device)
        y = y.to(device)

        # First train the discriminators
        with torch.cuda.amp.autocast():
            fake_y = gen_xtoy(x)
            dy_fake = disc_y(fake_y.detach())
            dy_real = disc_y(y)
            y_reals += dy_real.mean().item()
            y_fakes += dy_fake.mean().item()
            discy_real_loss = mse(dy_real, torch.ones_like(dy_real))
            discy_fake_loss = mse(dy_fake, torch.zeros_like(dy_fake))
            discy_loss = (discy_real_loss + discy_fake_loss) / 2

            fake_x = gen_ytox(y)
            dx_fake = disc_x(fake_x.detach())
            dx_real = disc_x(x)
            discx_real_loss = mse(dx_real, torch.ones_like(dx_real))
            discx_fake_loss = mse(dx_fake, torch.zeros_like(dx_fake))
            discx_loss = (discx_real_loss + discx_fake_loss) / 2

            D_loss = discy_loss + discx_loss
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()


        # Train generators
        with torch.cuda.amp.autocast():
            # 1. Adversarial Loss
            discx_fake = disc_x(fake_x)
            discy_fake = disc_y(fake_y)
            loss_g_xtoy = mse(discy_fake, torch.ones_like(discy_fake))
            loss_g_ytox = mse(discx_fake, torch.ones_like(discx_fake))
            adv_G_loss = loss_g_ytox + loss_g_xtoy

            # 2. Cycle-consistency loss
            cycle_xtoytox = gen_ytox(fake_y)
            cycle_ytoxtoy = gen_xtoy(fake_x)
            cycle_x_loss = L1(cycle_xtoytox, x)
            cycle_y_loss = L1(cycle_ytoxtoy, y)
            cycle_G_loss = cycle_x_loss + cycle_y_loss

            # 3. Identity loss
            identity_x = gen_ytox(x)
            identity_y = gen_xtoy(y)
            identity_x_loss = L1(identity_x, x)
            identity_y_loss = L1(identity_y, y)
            identity_G_loss = identity_x_loss + identity_y_loss

            # add all togethor
            G_loss = adv_G_loss + lambda_cycle * cycle_G_loss + lambda_identity * identity_G_loss

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if batch_idx % 200 == 0:
            if not os.path.exists(f"saved_images/{batch_idx}"):
              os.mkdir(f"saved_images/{batch_idx}")
            save_image(fake_y, f"saved_images/{batch_idx}/fake_photo.png")
            save_image(fake_x, f"saved_images/{batch_idx}/fake_monet.png")
            save_image(x, f"saved_images/{batch_idx}/real_monet.png")
            save_image(y, f"saved_images/{batch_idx}/real_photo.png")

        gen_loss += G_loss.item()
        disc_loss += D_loss.item()
        loop.set_postfix(y_real=y_reals / (batch_idx + 1), y_fake=y_fakes / (batch_idx + 1), gen_loss=gen_loss / (batch_idx+1), disc_loss = disc_loss/ (batch_idx+1))


In [30]:
disc_x = Discriminator(in_channels=3).to(device)
disc_y = Discriminator(in_channels=3).to(device)
gen_xtoy = Generator(in_channels=3).to(device)
gen_ytox = Generator(in_channels=3).to(device)

opt_disc = optim.Adam(
    params = list(disc_x.parameters()) + list(disc_y.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)
opt_gen = optim.Adam(
    params=list(gen_xtoy.parameters()) + list(gen_ytox.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)

L1 = nn.L1Loss()
mse = nn.MSELoss()

if load_model:
    load_checkpoint(
        photo_generator,
        gen_xtoy,
        opt_gen,
        lr_rate,
    )
    load_checkpoint(
        monet_generator,
        gen_ytox,
        opt_gen,
        lr_rate,
    )
    load_checkpoint(
        monet_discriminator,
        disc_y,
        opt_disc,
        lr_rate,
    )
    load_checkpoint(
        photo_discriminator,
        disc_x,
        opt_disc,
        lr_rate,
    )

dataset = monet2photo(
    root_dir=root_dir,
    transform=transforms,
)
# val_dataset = HorseZebraDataset(
#     root_horse=config.VAL_DIR + "/horses",
#     root_zebra=config.VAL_DIR + "/zebras",
#     transform=config.transforms,
# )
# val_loader = DataLoader(
#     val_dataset,
#     batch_size=1,
#     shuffle=False,
#     pin_memory=True,
# )
loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

if not os.path.exists('saved_images'):
  os.mkdir('saved_images')

print(device)

for epoch in range(num_epochs):
    train_step(
        disc_y,
        disc_x,
        gen_ytox,
        gen_xtoy,
        loader,
        opt_disc,
        opt_gen,
        L1,
        mse,
        d_scaler,
        g_scaler,
    )

    if save_model:
        save_checkpoint(gen_xtoy, opt_gen, filename=photo_generator)
        save_checkpoint(gen_ytox, opt_gen, filename=monet_generator)
        save_checkpoint(disc_y, opt_disc, filename=monet_discriminator)
        save_checkpoint(disc_x, opt_disc, filename=photo_discriminator)

mps


100%|██████████| 786/786 [12:11<00:00,  1.08it/s, disc_loss=0.429, gen_loss=3.03, y_fake=0.426, y_real=0.568]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 786/786 [12:24<00:00,  1.06it/s, disc_loss=0.432, gen_loss=2.46, y_fake=0.417, y_real=0.573]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 786/786 [12:22<00:00,  1.06it/s, disc_loss=0.411, gen_loss=2.42, y_fake=0.396, y_real=0.593]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 786/786 [3:40:19<00:00, 16.82s/it, disc_loss=0.388, gen_loss=2.42, y_fake=0.381, y_real=0.611]    


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


  2%|▏         | 14/786 [00:13<11:55,  1.08it/s, disc_loss=0.36, gen_loss=2.6, y_fake=0.375, y_real=0.63]  