In [1]:
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor


def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor:
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution
    Returns:
        torch.Tensor: 1D kernel (1 x 1 x size)
    """
    coords = torch.arange(size, dtype=torch.float)
    coords -= size // 2

    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)


def gaussian_filter(input: Tensor, win: Tensor) -> Tensor:
    r""" Blur input with 1-D kernel
    Args:
        input (torch.Tensor): a batch of tensors to be blurred
        window (torch.Tensor): 1-D gauss kernel
    Returns:
        torch.Tensor: blurred tensors
    """
    assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
    if len(input.shape) == 4:
        conv = F.conv2d
    elif len(input.shape) == 5:
        conv = F.conv3d
    else:
        raise NotImplementedError(input.shape)

    C = input.shape[1]
    out = input
    for i, s in enumerate(input.shape[2:]):
        if s >= win.shape[-1]:
            out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
        else:
            warnings.warn(
                f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
            )

    return out


def _ssim(
    X: Tensor,
    Y: Tensor,
    data_range: float,
    win: Tensor,
    size_average: bool = True,
    K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tuple[Tensor, Tensor]:
    r""" Calculate ssim index for X and Y

    Args:
        X (torch.Tensor): images
        Y (torch.Tensor): images
        data_range (float or int): value range of input images. (usually 1.0 or 255)
        win (torch.Tensor): 1-D gauss kernel
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: ssim results.
    """
    K1, K2 = K
    # batch, channel, [depth,] height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range) ** 2
    C2 = (K2 * data_range) ** 2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
    sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
    sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)  # set alpha=beta=gamma=1
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
    cs = torch.flatten(cs_map, 2).mean(-1)
    return ssim_per_channel, cs


def ssim(
    X: Tensor,
    Y: Tensor,
    data_range: float = 255,
    size_average: bool = True,
    win_size: int = 11,
    win_sigma: float = 1.5,
    win: Optional[Tensor] = None,
    K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
    nonnegative_ssim: bool = False,
) -> Tensor:
    r""" interface of ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu

    Returns:
        torch.Tensor: ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if len(X.shape) not in (4, 5):
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    #if not X.type() == Y.type():
    #    raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
    if nonnegative_ssim:
        ssim_per_channel = torch.relu(ssim_per_channel)

    if size_average:
        return ssim_per_channel.mean()
    else:
        return ssim_per_channel.mean(1)


def ms_ssim(
    X: Tensor,
    Y: Tensor,
    data_range: float = 255,
    size_average: bool = True,
    win_size: int = 11,
    win_sigma: float = 1.5,
    win: Optional[Tensor] = None,
    weights: Optional[List[float]] = None,
    K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tensor:
    r""" interface of ms-ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        weights (list, optional): weights for different levels
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
    Returns:
        torch.Tensor: ms-ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    #if not X.type() == Y.type():
    #    raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if len(X.shape) == 4:
        avg_pool = F.avg_pool2d
    elif len(X.shape) == 5:
        avg_pool = F.avg_pool3d
    else:
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    smaller_side = min(X.shape[-2:])
    assert smaller_side > (win_size - 1) * (
        2 ** 4
    ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))

    if weights is None:
        weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
    weights_tensor = X.new_tensor(weights)

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    levels = weights_tensor.shape[0]
    mcs = []
    for i in range(levels):
        ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)

        if i < levels - 1:
            mcs.append(torch.relu(cs))
            padding = [s % 2 for s in X.shape[2:]]
            X = avg_pool(X, kernel_size=2, padding=padding)
            Y = avg_pool(Y, kernel_size=2, padding=padding)

    ssim_per_channel = torch.relu(ssim_per_channel)  # type: ignore  # (batch, channel)
    mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0)  # (level, batch, channel)
    ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0)

    if size_average:
        return ms_ssim_val.mean()
    else:
        return ms_ssim_val.mean(1)


class SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range: float = 255,
        size_average: bool = True,
        win_size: int = 11,
        win_sigma: float = 1.5,
        channel: int = 3,
        spatial_dims: int = 2,
        K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
        nonnegative_ssim: bool = False,
    ) -> None:
        r""" class for ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
            nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
        """

        super(SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.K = K
        self.nonnegative_ssim = nonnegative_ssim

    def forward(self, X: Tensor, Y: Tensor) -> Tensor:
        return ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            K=self.K,
            nonnegative_ssim=self.nonnegative_ssim,
        )


class MS_SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range: float = 255,
        size_average: bool = True,
        win_size: int = 11,
        win_sigma: float = 1.5,
        channel: int = 3,
        spatial_dims: int = 2,
        weights: Optional[List[float]] = None,
        K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
    ) -> None:
        r""" class for ms-ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            weights (list, optional): weights for different levels
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        """

        super(MS_SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights
        self.K = K

    def forward(self, X: Tensor, Y: Tensor) -> Tensor:
        return ms_ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            weights=self.weights,
            K=self.K,
        )

### Defining the model architecture.

As stated in the Readme.md, I am using CycleGAN (Zhu et. al, 2017) for transferring the Vangogh artistic style to landscape photographs. The model consists of Generator and a Discriminator

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

ssim_module = SSIM(data_range=1.0, size_average=True, channel=3) # channel=1 for grayscale images
ms_ssim_module = MS_SSIM(data_range=1.0, size_average=True, channel=3)

# Conv Block with two conv layers, batch norm and leaky-relu layer
class ConvBlock(nn.Module):
    """two convolution layers with batch norm and leaky relu"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(ConvBlock, self).__init__()
        self.conv_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_p),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.conv_conv(x)

# A downsampling module consisting of a MaxPool2d for downsampling and then a ConvBlock
class DownBlock(nn.Module):
    """Downsampling followed by ConvBlock"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(DownBlock, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, dropout_p)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

# Fractionally-strided convolutions for upsampling followed by a conv block
class UpBlock(nn.Module):
    """Upssampling followed by ConvBlock"""
    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p):
        super(UpBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

    def forward(self, x1, x2):
        x1 = self.conv1x1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

# Defining the UNet encoder
class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5)
        self.in_conv = ConvBlock(
            self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(
            self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(
            self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(
            self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        self.down4 = DownBlock(
            self.ft_chns[3], self.ft_chns[4], self.dropout[4])

    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        return x4, [x0, x1, x2, x3, x4]

# The Unet Decoder
class Decoder(nn.Module):
    def __init__(self, params):
        super(Decoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.in_chns, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x_last = self.up4(x, x0)
        output = self.out_conv(x_last)
        return output, x_last


# Combining the Encoder and the decoder to form the Generator
class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()

        params = {'in_chns': in_channels,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)


    
    def forward(self, x):
        _, feature = self.encoder(x)
        output, features = self.decoder(feature)
        return torch.sigmoid(output)

# Conv Block for Discriminator
class Block(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride,
                      padding,
                      padding_mode='reflect',
                      bias=True),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2))

    def forward(self, x):
        return self.conv(x)

# The dsicriminator class
class Discriminator(nn.Module):

    def __init__(self, in_channels=3, features=(64, 128, 256, 512)):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=features[0],
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels=in_channels,
                                out_channels=feature,
                                kernel_size=4,
                                stride= 1 if feature == features[-1] else 2,
                                padding=1,
            ))
            in_channels = feature

        layers.append(nn.Conv2d(in_channels,
                                1, 4, 1, 1, padding_mode='reflect'))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial_layer(x)
        return torch.sigmoid(self.model(x))

The following cell contains the Dataloader for my custom Vangogh2Photo dataset. The getitem method returns a pair of images, one from each domain
The constructor takes as input, the root directory where the dataset is located.

In [7]:
####-----------Define the dataloaders and dataset class-------------####
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np


class vangogh2photo(Dataset):

    def __init__(self, root_dir, split='train', transform=None):
        super().__init__()
        self.root_vangogh = os.path.join(root_dir, f'{split}A')
        self.root_photos = os.path.join(root_dir, f'{split}B')

        self.vangogh_images = os.listdir(self.root_vangogh)
        self.photo_images = os.listdir(self.root_photos)

        self.length = max(len(self.vangogh_images), len(self.photo_images))
        self.vangogh_len = len(self.vangogh_images)
        self.photo_len = len(self.photo_images)

        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):

        photo_img = Image.open(os.path.join(self.root_photos, self.photo_images[index % self.photo_len])).convert('RGB')
        vangogh_img = Image.open(os.path.join(self.root_vangogh, self.vangogh_images[index % self.vangogh_len])).convert('RGB')

        photo_img = np.array(photo_img) / 255.
        vangogh_img = np.array(vangogh_img) / 255.
        photo_img, vangogh_img = photo_img.astype(np.float32), vangogh_img.astype(np.float32)


        if self.transform:
            photo_img = self.transform(photo_img)
            vangogh_img = self.transform(vangogh_img)

        return vangogh_img, photo_img





In [15]:
# Define helper functions
import random, torch, os, numpy as np


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
# train
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.transforms import Compose, Resize, RandomHorizontalFlip, Normalize, ToTensor

# Hyperparameters and configs
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available else 'cpu'
root_dir = '.'
batch_size = 4
lr_rate = 1e-3
lambda_identity = 5
lambda_cycle = 10
num_epochs = 50
load_model = False
save_model = True
vangogh_generator = root_dir + '/checkpoints/vangogh_gen.pth.tar'
photo_generator = root_dir + '/checkpoints/photo_gen.pth.tar'
vangogh_discriminator = root_dir + '/checkpoints/vangogh_dis.pth.tar'
photo_discriminator = '/checkpoints/photo_dis.pth.tar'
num_workers = 2
transforms = Compose(
    [
        ToTensor(),
        Resize(size=(256, 256)),
        RandomHorizontalFlip(p=0.5),
    ]
)

val_transforms = Compose(
    [
        ToTensor(),
        Resize(size=(256, 256)),
    ]
)

The following cell defines the training step for one epoch.

In [21]:
## The training loop
def train_step(disc_y, disc_x, gen_ytox, gen_xtoy, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler):
    loop = tqdm(loader)
    y_reals = 0
    y_fakes = 0
    gen_loss = 0.0
    disc_loss = 0.0
    ssim_score, ssim_score_style = 0, 0
    for batch_idx, (x, y) in enumerate(loop):
        # place the two inputs on the cuda device
        x = x.to(device)
        y = y.to(device)

        # First train the discriminators
        with torch.cuda.amp.autocast():
            # generate fake y from x
            fake_y = gen_xtoy(x)
            dy_fake = disc_y(fake_y.detach())
            dy_real = disc_y(y)
            y_reals += dy_real.mean().item()
            y_fakes += dy_fake.mean().item()
            # Compute the dsicrimiator_y's loss
            discy_real_loss = mse(dy_real, torch.ones_like(dy_real))
            discy_fake_loss = mse(dy_fake, torch.zeros_like(dy_fake))
            discy_loss = (discy_real_loss + discy_fake_loss) / 2

            # generate fake x from y
            fake_x = gen_ytox(y)
            dx_fake = disc_x(fake_x.detach())
            dx_real = disc_x(x)
            # Compute the dsicrimiator_x's loss
            discx_real_loss = mse(dx_real, torch.ones_like(dx_real))
            discx_fake_loss = mse(dx_fake, torch.zeros_like(dx_fake))
            discx_loss = (discx_real_loss + discx_fake_loss) / 2

            D_loss = discy_loss + discx_loss
        # Update the dsicriminator weights
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generators
        with torch.cuda.amp.autocast():
            # 1. Adversarial Loss
            discx_fake = disc_x(fake_x)
            discy_fake = disc_y(fake_y)
            loss_g_xtoy = mse(discy_fake, torch.ones_like(discy_fake))
            loss_g_ytox = mse(discx_fake, torch.ones_like(discx_fake))
            adv_G_loss = loss_g_ytox + loss_g_xtoy

            # 2. Cycle-consistency loss
            cycle_xtoytox = gen_ytox(fake_y)
            cycle_ytoxtoy = gen_xtoy(fake_x)
            cycle_x_loss = L1(cycle_xtoytox, x)
            cycle_y_loss = L1(cycle_ytoxtoy, y)
            cycle_G_loss = cycle_x_loss + cycle_y_loss

            # 3. Identity loss
            identity_x = gen_ytox(x)
            identity_y = gen_xtoy(y)
            identity_x_loss = L1(identity_x, x)
            identity_y_loss = L1(identity_y, y)
            identity_G_loss = identity_x_loss + identity_y_loss

            # add all togethor
            G_loss = adv_G_loss + lambda_cycle * cycle_G_loss + lambda_identity * identity_G_loss

        # update generator weights
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        # save intermediate results for visualisation
        if batch_idx % 200 == 0:
            if not os.path.exists(root_dir + f"/saved_images/{batch_idx}"):
              os.mkdir(root_dir + f"/saved_images/{batch_idx}")
            save_image(fake_y, root_dir + f"/saved_images/{batch_idx}/fake_photo.png")
            save_image(fake_x, root_dir + f"/saved_images/{batch_idx}/fake_vangogh.png")
            save_image(x, root_dir + f"/saved_images/{batch_idx}/real_vangogh.png")
            save_image(y, root_dir + f"/saved_images/{batch_idx}/real_photo.png")

        gen_loss += G_loss.item()
        disc_loss += D_loss.item()
        ssim_score += 1 - ssim_module(fake_x.detach().float(), y.float()).item()
        ssim_score_style += 1 - ssim_module(fake_x.detach().float(), x.float()).item()

        # Display information on the tqdm bar
        loop.set_postfix(y_real=y_reals / (batch_idx + 1), y_fake=y_fakes / (batch_idx + 1), gen_loss=gen_loss / (batch_idx+1), disc_loss = disc_loss/ (batch_idx+1), ssim_score=ssim_score/(batch_idx+1), style_score=ssim_score_style/(batch_idx+1))


In [8]:
# Define the generators and the discriminators
disc_x = Discriminator(in_channels=3).to(device)
disc_y = Discriminator(in_channels=3).to(device)
gen_xtoy = Generator(in_channels=3).to(device)
gen_ytox = Generator(in_channels=3).to(device)

# Define the optimizers
opt_disc = optim.Adam(
    params = list(disc_x.parameters()) + list(disc_y.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)
opt_gen = optim.Adam(
    params=list(gen_xtoy.parameters()) + list(gen_ytox.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)

# Define the loss criterion
L1 = nn.L1Loss()
mse = nn.MSELoss()

# If training a pretrained model, set load_model as True and the models will be loaded
if load_model:
    load_checkpoint(
        photo_generator,
        gen_xtoy,
        opt_gen,
        lr_rate,
    )
    load_checkpoint(
        vangogh_generator,
        gen_ytox,
        opt_gen,
        lr_rate,
    )
    load_checkpoint(
        vangogh_discriminator,
        disc_y,
        opt_disc,
        lr_rate,
    )
    load_checkpoint(
        photo_discriminator,
        disc_x,
        opt_disc,
        lr_rate,
    )

# Create a dataset for training from the custom dataset class defined previously
dataset = vangogh2photo(
    root_dir=root_dir+'vangogh2photo',
    transform=transforms,
)
val_dataset = vangogh2photo(
    root_dir=root_dir+'vangogh2photo',
    split='test',
    transform=val_transforms,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

# # For mixed-precision training
# g_scaler = torch.cuda.amp.GradScaler()
# d_scaler = torch.cuda.amp.GradScaler()

# if not os.path.exists(root_dir + f"/saved_images"):
#   os.mkdir(root_dir + f"/saved_images")

# if not os.path.exists(root_dir + f"/checkpoints"):
#   os.mkdir(root_dir + f"/checkpoints")

# # The training loop
# for epoch in range(num_epochs):
#     train_step(
#         disc_y,
#         disc_x,
#         gen_ytox,
#         gen_xtoy,
#         loader,
#         opt_disc,
#         opt_gen,
#         L1,
#         mse,
#         d_scaler,
#         g_scaler,
#     )

#     if save_model:
#         save_checkpoint(gen_xtoy, opt_gen, filename=photo_generator)
#         save_checkpoint(gen_ytox, opt_gen, filename=vangogh_generator)
#         save_checkpoint(disc_y, opt_disc, filename=vangogh_discriminator)
#         save_checkpoint(disc_x, opt_disc, filename=photo_discriminator)




FileNotFoundError: [Errno 2] No such file or directory: './trainA'

Here I define the inference script

In [None]:
# get the test loaders and load the models
disc_x = Discriminator(in_channels=3).to(device)
disc_y = Discriminator(in_channels=3).to(device)
gen_xtoy = Generator(in_channels=3).to(device)
gen_ytox = Generator(in_channels=3).to(device)

# Define the optimizers
opt_disc = optim.Adam(
    params = list(disc_x.parameters()) + list(disc_y.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)
opt_gen = optim.Adam(
    params=list(gen_xtoy.parameters()) + list(gen_ytox.parameters()),
    lr=lr_rate,
    betas=(0.5, 0.999)
)

vangogh_discriminator = './checkpoints/vangogh_dis.pth.tar'
vangogh_generator = './checkpoints/vangogh_gen.pth.tar'
photo_discriminator = './checkpoints/photo_gen.pth.tar'
photo_discriminator = './checkpoints/photo_gen.pth.tar'

load_checkpoint(
        photo_generator,
        gen_xtoy,
        opt_gen,
        lr_rate,
)
load_checkpoint(
    vangogh_generator,
    gen_ytox,
    opt_gen,
    lr_rate,
)
load_checkpoint(
    vangogh_discriminator,
    disc_y,
    opt_disc,
    lr_rate,
)
load_checkpoint(
    photo_discriminator,
    disc_x,
    opt_disc,
    lr_rate,
)


# Initialise the ssim module for evaluating content preservation
ssim_module = SSIM(data_range=1.0, channel=3)

loop = tqdm(val_loader)
y_reals = y_fake = 0
for batch_idx, (x, y) in enumerate(loop):
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        fake_x = gen_ytox(y)
        dy_fake = disc_y(fake_x)
        dy_real = disc_y(y)
        print(dy_fake.shape)
        print(dy_real.shape)
        