In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from typing import Tuple, List
import matplotlib.pyplot as plt
import matplotlib
import os

class Data:
    def __init__(self, source: str, target: str, flip: bool, bs: int, im_size: Tuple[int], make_valid: bool = True):
        if not flip:
            self.source=source
            self.target=target
        else:
            self.source=target
            self.target=source
        self.flip=flip
        self.bs=bs
        self.im_size = im_size
        self.dlSourceTrain=None
        self.dlTargetTrain=None
        self.dlSourceTest=None
        self.dlTargetTest=None
        # seems like it is not necessary to create validation set
        # self.dlSourceValid=None
        # self.dlTargetValid=None
        # self.make_valid = make_valid

    def get_loaders(self, data_dir_name: str) -> List[DataLoader]:
        data_dir_path = os.path.join('drive', 'MyDrive', 'data', data_dir_name)
        for dir in os.listdir(data_dir_path):
            # print(dir)
            loader = create_dataloader(os.path.join(data_dir_path, dir), self.im_size, self.bs)
            # set the data class with corresponding loader
            if dir == f'{self.source}_test':
                self.dlSourceTest=loader
            elif dir == f'{self.target}_test':
                self.dlTargetTest=loader
            elif dir == f'{self.source}_train':
                self.dlSourceTrain=loader
            elif dir == f'{self.target}_train':
                self.dlTargetTrain=loader


def create_dataloader(root: str, im_size: Tuple, bs: int) -> DataLoader:
    dataset = datasets.ImageFolder(
        root=root,
        transform=transforms.Compose([
            # get the images in desired dimensions
            transforms.Resize(im_size),
            transforms.CenterCrop(im_size),
            # turn images to tensors
            transforms.ToTensor(),
            # apply normalization to mean of 0 and standard deviation of 1
            transforms.Normalize((0, 0, 0), (1, 1, 1))
        ])
    )
    dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
    return dataloader

def save_images(img_tensors: torch.Tensor, file_name: str, folder: str='output'):
    fig, axs = plt.subplots(img_tensors.shape[0], figsize=(5, img_tensors.shape[0]))
    img_idx = 0

    if img_tensors.shape[0] == 1:
        axs.imshow(img_tensors[img_idx].permute(1, 2, 0).detach().cpu())
        axs.set_xticks([])
        axs.set_yticks([])
    else:
        for ax in axs:
            # plt.axis('off')
            ax.imshow(img_tensors[img_idx].permute(1, 2, 0).detach().cpu())
            ax.set_xticks([])
            ax.set_yticks([])
            img_idx += 1
    
    # save the figure results
    os.makedirs(folder, exist_ok=True)
    plt.savefig(os.path.join(folder, file_name))

def show_batch(data: Data, n: int, folder: str='output') -> matplotlib.image.AxesImage:
    if n <= 0:
        raise ValueError(f"Expected the number of samples to be >= 1 but got {n} instead")
    elif n > data.bs:
        print(f"Will only show max of {data.bs} samples")
    
    sqrt_num_res = n ** (1/2)
    if sqrt_num_res - int(sqrt_num_res) != 0:
        raise Exception('Num res should be a perfect square (eg. 1, 4, 9, 16, 25, 36...)')
    sqrt_num_res = int(sqrt_num_res)
    os.makedirs(folder, exist_ok=True)

    for i, loader in enumerate([data.dlSourceTrain, data.dlTargetTrain]):
        for x, _ in loader:
            fig, axs = plt.subplots(sqrt_num_res, sqrt_num_res, figsize=(5, 5))
            img_idx = 0

            for row in axs:
                for ax in row:
                    # plt.axis('off')
                    ax.imshow(x[img_idx].squeeze(0).permute(1, 2, 0).detach())
                    ax.set_xticks([])
                    ax.set_yticks([])
                    img_idx += 1
            
            # save the figure results
            if i == 0:
                file = data.source+'.png'
            else:
                file = data.target+'.png'
            plt.savefig(os.path.join(folder, file))
            break
    
def set_requires_grad(net, requires_grad):
    """
    Freezes or unfreezes a network (for the weight to change or not during training)
    Note that this is pass by reference so returns None
    """
    for param in net.parameters():
        param.requires_grad = requires_grad


def report_losses(losses, loss_names, epoch, time, scale=1):
    """
    Formats losses in a list
    """
    print(f'\nEpoch: {epoch}', end=" | ")
    for loss, name in zip(losses, loss_names):
        print(f"{name}: {round(loss/scale, 2)}", end=" | ")
    print(f"Time: {round(time, 2)}s\n")

In [21]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

# get an instance of a generator to use for numpy random calls
rng = np.random.default_rng()

class Normalize(nn.Module):
    """
    Normalization layer taken from https://github.com/taesungp/contrastive-unpaired-translation/blob/57430e99df041515c57a7ffd18bb7cbc3c1af0a9/models/networks.py#L449

    The CUT GAN paper states "We normalize vectors onto a unit sphere to prevent the space from collapsing or expanding"
    This layer is used to normalize vectors onto a unit sphere by using l2 norm
    """
    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        # compute the l2 vector norm
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        # scale the input x by the norm
        out = x.div(norm + 1e-7)
        return out

def _single_conv(ch_in, ch_out, ks, stride=1, act=True, gammaZero=False, norm='instance', transpose=False, leaky=False):
    """
    Layer to perform a single convolution, normalization (batch or instance) and activation function (relu and leaky relu)
    """
    # do not reduce size due to ks mismatch
    padding = ks//2
    if not transpose:
        layers = [nn.Conv2d(ch_in, ch_out, ks, stride=stride, padding=padding)]
    else:
        layers = [nn.ConvTranspose2d(ch_in, ch_out, ks, stride=stride, padding=padding)]

    # add norm layer to prevent activations from getting too high
    if norm=='instance':
        norm_layer = nn.InstanceNorm2d(ch_out, affine=False, track_running_stats=False)
    elif norm=='batch':
        norm_layer = nn.BatchNorm2d(ch_out, affine=True, track_running_stats=True)
    else:
        raise Exception(f'Norm should be either "instance" or "batch" but {norm} was passed')

    if gammaZero and norm_layer=='batch':
        # init batch norm gamma param to zero to speed up training 
        nn.init.zeros_(norm_layer.weight.data)

    layers.append(norm_layer)
    # check if this layer should have an activation - yes unless the final layer
    if act and not leaky:
        layers.append(nn.ReLU(inplace=True))
    elif act and leaky:
        layers.append(nn.LeakyReLU(0.2, inplace=True))
    
    layers = nn.Sequential(*layers)
    return layers

class ResBlock(nn.Module):
    """
    Residual blocks to be used by both the generator and discriminator
    """
    def __init__(self, ch_in, ch_out, stride=1, leaky=False):
        super().__init__()
        self.conv = self._resblock_conv(ch_in, ch_out, leaky, stride=stride)
        self.pool = self._return if stride==1 else nn.AvgPool2d(stride, ceil_mode=True)
        self.id_conv = self._return if ch_in == ch_out else _single_conv(ch_in, ch_out, 1, stride=1, act=False)
        if leaky:
            self.relu = nn.LeakyReLU(0.2, inplace=True)
        else:
            self.relu = nn.ReLU(inplace=True)
    
    def _return(self, x):
        return x

    def _resblock_conv(self, ch_in, ch_out, leaky, stride=1, ks=3):
        # create the convolutional path of the resnet block following the bottleneck apporach
        conv_block = nn.Sequential(
            _single_conv(ch_in, ch_out//4, 1, leaky=leaky),
            _single_conv(ch_out//4, ch_out//4, ks, stride=stride, leaky=leaky), 
            _single_conv(ch_out//4, ch_out, 1, act=False, gammaZero=True)
        )
        return conv_block
    
    def forward(self, x):
        # apply a skip connection for resnet in the forward call
        return self.relu(self.conv(x) + self.id_conv(self.pool(x)))

class Generator(nn.Module):
    """
    Create a generator model (both encoder and decoder) which uses a resnet architecture as the encoder and transpoed 2d convolutions for decoder

    Note that sometimes only certain layers will be taken from the encoder
    """
    def __init__(self, in_ch, out_ch, nce_layers, base_ch=64, n_blocks=6, n_downsamples=2):
        """
        Args:
            in_ch: Number of input channels in input tensor image
            out_ch: Number of output channels for output tensor image
            nce_layers: A list of layers that the patchNCE loss function will be using from the encoder
            base_ch: Base number of channels throughout network
            n_blocks: Number of residual blocks to use
            n_downsamples: Number of downsamples to apply in the network stem
        """
        # network attributes
        super().__init__()
        self.n_downsamples = n_downsamples
        self.n_blocks = n_blocks
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.base_ch = base_ch

        # create resnet encoder stem to downsample data n_downsample times and reach initial filter param
        # below should be: [in_ch, base_ch//2, base_ch//2, base_ch, base_ch*2, base_ch*4]
        stem_sizes = self._create_stem_sizes()
        # print(stem_sizes)
        self.stem = self._create_stem(stem_sizes)

        # Add residual layers for resnet encoder
        self.res_layers = self._create_res_layers()

        # create the encoder and decoder for the resnet
        self.encoder = nn.Sequential(*self.stem+self.res_layers)
        # print(self.encoder)
        # create upsampling layers for the decoder
        anti_stem_sizes = stem_sizes[::-1]
        self.decoder = self._create_decoder(anti_stem_sizes)

        # determine the layer channels for the Hl (feature extractor) network
        self.nce_layers = nce_layers
        self._determine_layer_channels(nce_layers)
    
    def _create_stem_sizes(self):
        sizes = [self.in_ch, self.base_ch//2, self.base_ch//2, self.base_ch]
        # add the extra sizes as done in the following repo: https://github.com/taesungp/contrastive-unpaired-translation/blob/57430e99df041515c57a7ffd18bb7cbc3c1af0a9/models/networks.py#L914
        for i in range(self.n_downsamples):
            mult = 2 ** i
            sizes.append(self.base_ch*mult*2)
        return sizes
    
    
    def _create_stem(self, sizes):
        # apply n_downsample stride 2 convolutions 
        stem = [
            _single_conv(sizes[i], sizes[i+1], 3, stride = 2 if i < self.n_downsamples else 1) 
            for i in range(len(sizes)-1)
        ]
        return stem
    
    def _create_res_layers(self):
        # create multiplication factor
        mult = 2**self.n_downsamples
        layers = [ResBlock(self.base_ch*mult, self.base_ch*mult) for i in range(self.n_blocks)]
        return layers
    
    def _create_decoder(self, sizes):
        # print(sizes)
        lt_downsamples = lambda i: i < self.n_downsamples
        # create the decoder with transposed convolutions and a final Tanh layer
        decoder = [
            *[_single_conv(sizes[i], sizes[i+1], 3 if lt_downsamples(i) else 4, stride = 2 if lt_downsamples(i) else 1, transpose = True if lt_downsamples(i) else False)
            for i in range(len(sizes)-1)],
            nn.Tanh()
        ]
        return nn.Sequential(*decoder)
    
    def _determine_layer_channels(self, layers):
        """
        Determines the channels for the output layers that generator encoder will produce
        """
        # create a list of channels that will be used by the feature extractor
        self.feature_extractor_channels = []
        for layer_id, layer in enumerate(self.encoder):
            # print(layer, hasattr(layer, 'conv'))
            # only add the channels to the feature extractor channels if it is in the list of layers
            if layer_id in layers:
                if hasattr(layer, 'conv'):
                    try:
                        conv_out = layer.conv[2][0]
                    except:
                        print("The resblock has a different configuration as expected")
                else:
                    try:
                        if type(layer[0]) == nn.Conv2d:
                            conv_out = layer[0]
                    except:
                        print("The conv layer has a different configuration as expected")
                self.feature_extractor_channels.append(conv_out.out_channels)


    def forward(self, x, layers=[], encode_only=False):
        """
        Generator forward pass; only forward the specific layers if they were passed (if no layers, return entire encoder + decoder)
        """
        # TODO: Consider removing the layers input var
        layers = self.nce_layers
        # only output specific generator encoder layers
        if len(layers) > 0 and encode_only:
            encoder_layer_outs = [] # list of activation maps when one index corresponds to a layer of the encoder
            for layer_id, layer in enumerate(self.encoder):
                # compute the output for the layer
                x = layer(x)
                # print('gen shapes', x.shape, layer_id)
                # only add the layer activation map to the output if in the list of layers (this will be used as one of the layers by PatchNCELoss)
                if layer_id in layers:
                    encoder_layer_outs.append(x)
                    # print('gen shapes', x.shape[2:])
            # raise Exception('done gen fwd test')
            return encoder_layer_outs

        # first apply encoder - only reaches this part if layers is empty
        enc_x = self.encoder(x)
        # print(enc_x.shape)
        # return only encoder results for patchNCELoss
        if encode_only:
            return enc_x
        
        # second apply decoder
        dec_x = self.decoder(enc_x)
        return dec_x


class Disciminator(nn.Module):
    """
    Create a discriminator model to tell the difference between real and fake images (assumes 128*128 input images)
    """
    def __init__(self, ch_in, base_ch=64, n_layers=3, n_downsamples=3):
        super().__init__()
        self.ch_in = ch_in
        self.base_ch = base_ch
        self.n_layers = n_layers
        self.n_downsamples = n_downsamples
        self.convs = self._create_conv_discriminator()
    
    def _create_conv_discriminator(self):
        # start with a 1x1 conv assuming  a 128*128 input image; convert to base_ch channels
        convs = [_single_conv(self.ch_in, self.base_ch, 1, stride=2, leaky=True)]

        # add multiple res blocks to reduce the 128*128 input
        ch_mult_prev = 1
        ch_mult = 1
        # first layer was alreday applied above; apply all others
        for i in range(1, self.n_layers):
            ch_mult_prev = ch_mult
            # set the multiplier to a max of 8 or 2**current layer
            ch_mult = min(2**i, 8)
            convs += [
                ResBlock(self.base_ch * ch_mult_prev, self.base_ch * ch_mult, stride=2 if i < self.n_downsamples else 1, leaky=True)
            ]
        
        curr_ch = self.base_ch * ch_mult
        for j in range(self.n_layers):
          convs += [
                ResBlock(curr_ch, curr_ch//2, leaky=True)
            ]
          curr_ch = curr_ch//2
        # output a single channel feature map of activations from the discriminator (from the Patch GAN paper)
        convs += [_single_conv(curr_ch, 1, 3, leaky=True)]
        return nn.Sequential(*convs)
    
    def forward(self, x):
        # apply the convolutional discriminator
        out = self.convs(x)
        return out



class EncoderFeatureExtractor(nn.Module):
    """
    Create a MLP (multilayer perceptron) to transform the patch features from output and input into shared feature space
    Approach is taken from SimCLR: https://arxiv.org/pdf/2002.05709.pdf
    """
    def __init__(self, nce_layer_channels, n_features=256):
        """
        Creates the MLP network (H sub l in the paper) that will transform input encoder patches to a shared embedding space

        Args: 
            nce_layer_channels: A list containing the size of the channels for each of the layers that will be used by nce loss
            n_features: An integer which is the number of features that the transformed space will have
        """
        super().__init__()
        self.norm = Normalize(2)
        # self.gpu = gpu
        self.nce_layer_channels = nce_layer_channels
        self.n_features = n_features
        # create the mlp for each layer that will be used for nce (based on the channels)
        self.mlps = nn.ModuleList(self._create_mlp()) # define this as a module list for pytorch to treat each module in the list on it's own
    
    def _create_mlp(self):
        """
        Create a two layer fully connected network for each of the layers with the corresponding channel sizes
        """
        mlps = []
        for ch_in in self.nce_layer_channels:
            mlps.append(nn.Sequential(*[
                nn.Linear(ch_in, self.n_features), 
                nn.ReLU(inplace=True), 
                nn.Linear(self.n_features, self.n_features)
                ]))
        return mlps


    def forward(self, feats, num_patches, patch_ids=None):
        """
        Performs a forward pass for an EncoderFeatureExtractor (called Hl in this paper: https://arxiv.org/pdf/2007.15651.pdf)

        Args: 
            feats: A list containing tensor features passed from specific layer of the generator encoder (assume tensors of size bs*channels*H*W)
            num_patches: The number of patches to sample from feats
            patch_ids: The indexes of patches to select from each layer of feats (this is != None when a forward call has been used on the source images and we want to take the same patches from the target images)
        
        Returns: 
            A tuple of lists, the first list being 
        """
        return_ids = []
        return_feats = []

        # go through each of the feature layers while grabbing corresponding mlp
        for feat_id, (mlp, feat) in enumerate(zip(self.mlps, feats)):
            # reshape the feature tensor to be (bs*img_locs*channels)
            feat_reshaped = feat.flatten(2, 3).permute(0, 2, 1)
            # print('feats', feat.shape, feat_reshaped.shape)

            if num_patches > 0:
                # get the patch id for the current layer if the patch_ids exist
                if patch_ids is not None:
                    patch_id = patch_ids[feat_id]
                # create the patches if the patch ids DNE
                else:
                    # get random permutation of all indices from 0 to max img loc (axis=1)
                    patch_id = rng.permutation(feat_reshaped.shape[1])
                    # print(num_patches, len(patch_id))
                    # index the patch_id to extract first num_patch locations 
                    if num_patches < len(patch_id):
                        patch_id = patch_id[:num_patches]
                # index only the patches that will be used from feats_reshaped (axis 1 is the img_loc axis)
                # note that in practice, bs=1 so that we only compare batches in the same image (internal patches outperforms external patches)
                # TODO: Consider removing the .view() at the end of the next line
                feat_patch_sampled = feat_reshaped[:, patch_id, :].view(-1, feat_reshaped.shape[2]) # flatten the tensor to be of shape (patch_loc, channels). The PatchNCE loss will be done regardless of batch (this makes it by patches and not by image)
            else:
                # the number of patches is zero or negative; take all patches 
                feat_patch_sampled = feat_reshaped
                patch_id = []
             
            # apply the mlp (H sub l) to select patches
            feat_sampled_emb_space = mlp(feat_patch_sampled)
            # "We normalize vectors onto a unit sphere to prevent the space from collapsing or expanding"
            feat_norm = self.norm(feat_sampled_emb_space)
            return_ids.append(patch_id)

            # if zero patches, return in as similar a shape as possible (but in dimensions of shared embedding space after passing through mlp)
            if num_patches == 0:
                batch, _, h, w = feat.shape
                n_emb_feat = feat_norm.shape[-1]
                # shape is (bs*img_loc*emb_features) -> (bs*emb_features*img_loc)
                feat_norm = feat_norm.permute(0, 2, 1).view(batch, n_emb_feat, h, w)
            return_feats.append(feat_norm)
        
        return return_feats, return_ids

        
    

In [15]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class DGANLoss(nn.Module):
    """
    Defines the GAN loss function for the discriminator's predictions of real or fake on data
    """
    def __init__(self, mode, device):
        super().__init__()
        self.mode = mode
        self.device=device

        # define the loss to be used when receiving a grid of activaitons (predictions) for an image
        if self.mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.mode == 'non-saturating':
            self.loss = None
        else:
            raise NotImplementedError(f"The mode {mode} for DGANLoss is not implemented")
    
    def create_targ_tensor(self, inp, is_real):
        if is_real:
            # apply some label smoothing for better generalization
            targ_tensor = torch.Tensor([0.9]).to(self.device)
        else:
            targ_tensor = torch.Tensor([0.1]).to(self.device)
        # returns the target tensor in the same shape as the input (since it will be a grid of activations from the discriminator)
        return targ_tensor.expand_as(inp)
        
    
    def forward(self, x, is_real):
        if self.mode in ['lsgan', 'vanilla']:
            # create an equal shaped target tensor and compute the loss
            targ_tens = self.create_targ_tensor(x, is_real)
            loss = self.loss(x, targ_tens)
        # non-saturating loss is being used
        else:
            if is_real:
                # minimize the loss by passing softplus(-x) = ln(1+e**-x) as x -> +inf (real prediction get predicted more real) => e**-x -> 0 => softplus(-x) -> 0+
                loss = F.softplus(-x)
            else:
                # minimize the loss by passing softplus(x) = ln(1+e**x) as x -> -inf (fake prediction get predicted more fake) => e**x -> 0 => softplus(x) -> 0+
                loss = F.softplus(x)

            # since the discriminator is giving a grid of activations, group the loss by batch and take the mean along the activation dimension
            loss = loss.view(x.shape[0], -1).mean(1)
        return loss

# TODO: Ask prof for help understanding math behind this loss function
class PatchNCELoss(nn.Module):
    """
    The patch NCE loss to associate similar sections in source and target images
    """
    def __init__(self, tau, bs):
        super().__init__()
        self.loss = nn.CrossEntropyLoss(reduction='none')
        # assume bs=1 by default
        self.bs = bs
        # division factor for scaling outputs
        self.tau = tau
    
    def forward(self, real_feats, fake_feats):
        n_patches = real_feats.shape[0]
        n_transformed_space = real_feats.shape[1]
        # detach generator for the real features
        fake_feats = fake_feats.detach()

        # create the positive feature results by doing (1, bs*n_transformed_space) * (bs*n_transformed_space, 1) = (1, 1) for each patch in the group
        l_pos = torch.bmm(real_feats.view(n_patches, 1, -1), fake_feats.view(n_patches, -1, 1)).view(n_patches, 1)

        real_feats = real_feats.view(self.bs, -1, n_transformed_space)
        fake_feats = fake_feats.view(self.bs, -1, n_transformed_space)
        # the new number of patches is the number of patches by image
        n_patches = real_feats.shape[1]
        # create the negative feature results by doing (n_patches, n_transformed_space) * (n_transformed_space, n_patches) = (n_patches, n_patches)
        l_neg_batch = torch.bmm(real_feats, fake_feats.transpose(2, 1))
        # remove meaningless diagonal entries by masking the l_neg_batch with an identity matrix
        diag = torch.eye(n_patches, device=real_feats.device).bool()
        l_neg_batch.masked_fill_(diag, -10.0)
        l_neg = l_neg_batch.view(n_patches, -1) # NOTICE: this line is different from the paper

        # concat over patch dimension (1 pos and n_patch neg patches)
        out = torch.cat([l_pos, l_neg], dim=1) / self.tau
        # the target feature is alway the l_pos feature of out which is at index 0
        targs = torch.zeros(out.shape[0], device=real_feats.device).long()
        loss = self.loss(out, targs)
        return loss




In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
# from create_gen_discr import Generator, Disciminator, EncoderFeatureExtractor
# from losses import DGANLoss, PatchNCELoss
# from data_utility import set_requires_grad
import os

class CUT_gan(nn.Module):
    def __init__(self, lambda_gan, lambda_nce, nce_layers, device, lr, nce_idt=True, encoder_net_features=256, nce_tau=0.07, num_patches=256, train=True, gan_l_type='non-saturating', bs=1):
        """
        Creates a CUT model which is a type of GAN for image to image translation

        Args: 
            lambda_gan: The weight for the GAN loss for the generator (since generator loss depends on gan and nce loss)
            lambda_nce: The weight for the NCE loss for the generator (since generator loss depends on gan and nce loss)
            nce_layers: A list of layers that the generator encoder will return for the nce_loss from convolutional layer (can also be residual blocks) activations
            device: torch.Device('cpu') if a cpu is used to train on otherwise torch.Device('cuda:0') if gpu is avaiable
            lr: The learning rate for stepping the weights of all of the optimizers
            nce_idt: True if the loss consists of the identity loss NCE(Y, X-tilde), False otherwise
            encoder_net_features: The number of features that the EncoderFeatureExtractor will produce in it's new space
            nce_tau: A constant that the nce loss will use to scale the matrices by in the loss
            num_patches: The number of pathces that will be used by nce to compute the loss
            train: True if training False if evaluating/inferencing
            gan_l_type: The type of dgan loss to be used (either 'non-saturating', 'vanilla', or 'lsgan')
            bs: The batch size that is going to be used in training
        """
        super().__init__()
        # keep relevant attirbutes for the training loop
        self.device = device
        self.lambda_gan = lambda_gan
        self.lambda_nce = lambda_nce
        self.nce_idt = nce_idt
        self.num_patches = num_patches
        self.nce_layers = nce_layers
        # definte the generator for the CUT model to go from rgb -> rgb image
        self.gen = Generator(3, 3, nce_layers).to(self.device)
 
        if train:
            # define a discriminator to take 3 input channels with 4 residual blocks
            self.disc = Disciminator(3, n_layers=4).to(self.device)
            # define a feature extractor network H sub l to transform generator encoder features to a new embedding space for nce loss
            self.feat_net = EncoderFeatureExtractor(self.gen.feature_extractor_channels, n_features=encoder_net_features).to(self.device)

            # define loss functions
            self.dgan_loss = DGANLoss(gan_l_type, self.device).to(self.device)
            self.nce_losses = []
            for _ in nce_layers:
                self.nce_losses.append(PatchNCELoss(nce_tau, bs).to(self.device))

            # create adam optimizers
            self.gen_optim = optim.Adam(self.gen.parameters(), lr=lr)
            self.disc_optim = optim.Adam(self.disc.parameters(), lr=lr)
            self.feat_net_optim = optim.Adam(self.feat_net.parameters(), lr=lr)
    

    def train(self):
        """
        Set all 3 networks to training mode
        """
        self.gen.train()
        self.disc.train()
        self.feat_net.train()
    

    def eval(self):
        """
        Depending on the mode, set the networks to eval mode
        """
        if self.train:
            self.gen.eval()
            self.disc.eval()
            self.feat_net.eval()
        else:
            self.gen.eval()
    
    def save_nets(self, epoch, folder='models'):
        """
        Save the network params on the cpu for all 3 networks
        """
        gen_checkpoint = {'state_dict': self.gen.cpu().state_dict()}
        disc_checkpoint = {'state_dict': self.disc.cpu().state_dict()}
        feat_checkpoint = {'state_dict': self.feat_net.cpu().state_dict()}
        
        # move the nets back to the current device
        self.gen.to(self.device)
        self.disc.to(self.device)
        self.feat_net.to(self.device)

        os.makedirs(folder, exist_ok=True)

        get_path = lambda model_name: os.path.join(folder, f"{epoch}_{model_name}.pth")
        torch.save(gen_checkpoint, get_path('gen'))
        torch.save(disc_checkpoint, get_path('disc'))
        torch.save(feat_checkpoint, get_path('feat_net'))
    
    def load_gen(self, epoch, folder='models'):
        checkpoint = torch.load(os.path.join(folder, f"{epoch}_gen.pth"), map_location=torch.device('cpu'))
        gen = Generator(3, 3, self.nce_layers)
        gen.load_state_dict(checkpoint['state_dict'])
        # print(next(self.gen.parameters()))
        return gen


    def forward(self, real_src, real_targ=None):
        """
        Does a forward pass of the generator for training and inference.

        Saves the real source, real target, and fake target images.
        Also, if nce_idt and train are True, saves the fake source images. 
        """
        # save the current real src and targ images to pass to other functions
        self.real_src = real_src
        self.real_targ = real_targ
        if self.train and self.nce_idt and real_targ != None:
            # put the real source and target images if in training and using identity nce loss
            real = torch.cat((real_src, real_targ), dim=0)
        else:
            real = real_src
        
        # use the generator on real images
        fake = self.gen(real)
        # get fake target images (y hat)
        self.fake_targ = fake[:real_src.shape[0]]
        # if possible, get fake source images for identity loss (x tilde)
        if self.train and self.nce_idt and real_targ != None:
            self.fake_src = fake[real_src.shape[0]:]


    def optimize_params(self, real_src, real_targ, discriminator_train=1):
        """
        Forward pass, loss, back propagate, and step for all 3 networks to optmize all the params
        """
        # forward pass
        self.forward(real_src, real_targ)

        # discriminator param update
        for _ in range(discriminator_train):
            set_requires_grad(self.disc, True)
            self.disc_optim.zero_grad()
            self.loss_d = self.calc_d_loss()
            self.loss_d.backward()
            self.disc_optim.step()

        # generator and encoder feature extractor param update
        set_requires_grad(self.disc, False)
        self.gen_optim.zero_grad()
        self.feat_net_optim.zero_grad()
        self.loss_g = self.calc_g_loss()
        self.loss_g.backward()
        self.gen_optim.step()
        self.feat_net_optim.step()


    def calc_d_loss(self):
        """
        Calculates discriminator loss
        """
        # prevent generator from updating
        fake_targ = self.fake_targ.detach()
        # fake target loss
        fake_pred = self.disc(fake_targ)
        self.fake_d_loss = self.dgan_loss(fake_pred, False)
        # real target loss
        real_pred = self.disc(self.real_targ)
        self.real_d_loss = self.dgan_loss(real_pred, True)
        # combine both fake and real target loss
        return (self.fake_d_loss + self.real_d_loss) * 0.5


    def calc_g_loss(self):
        """
        Calculates generator loss
        """
        # check normal GAN loss on discriminator with fake generator images
        pred_fake = self.disc(self.fake_targ)
        self.gan_g_loss = self.dgan_loss(pred_fake, False) * self.lambda_gan

        # use patch NCE loss for src -> fake targ
        self.nce_loss = self.calc_nce_loss(self.real_src, self.fake_targ)
        # use patch NCE loss for targ -> fake source (identity loss)
        self.nce_identity_loss = self.calc_nce_loss(self.real_targ, self.fake_src)
        # get total nce loss
        nce_loss_total = (self.nce_loss + self.nce_identity_loss) * 0.5
        # get total loss (Lgan + NCE loss + identity NCE loss)
        loss_total = nce_loss_total + self.gan_g_loss
        return loss_total


    def calc_nce_loss(self, src, targ):
        """
        Calculates the NCE loss using patches to associate similar locations and dissociate different locations
        """
        # get the pathces for source after doing H sub l(G enc(x))
        src_feats = self.gen(src, encode_only=True)
        transformed_src_feats, patch_ids = self.feat_net(src_feats, self.num_patches)

        # get the patches for target after doing H sub l(G enc(G(x))) 
        targ_feats = self.gen(targ, encode_only=True)
        transformed_targ_feats, _ = self.feat_net(targ_feats, self.num_patches, patch_ids=patch_ids)

        total_loss = 0
        # calculate the loss for each layer in the transformed returned features
        for src_feat, targ_feat, nce_loss in zip(transformed_src_feats, transformed_targ_feats, self.nce_losses):
            # TODO: Consider switching src_feats and targ_feats if training is not working well
            total_loss += (nce_loss(targ_feat, src_feat) * self.lambda_nce).mean()
        
        return total_loss/len(self.gen.nce_layers)


In [None]:
# from data_utility import Data, save_images, report_losses
# from cut_model import CUT_gan
import torch
import time
from tqdm import tqdm

# define the params to pass to the cut gan model
lambda_gan = 1
lambda_nce = 1
nce_layers = [0, 1, 2, 3, 4, 6, 8, 10]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
enc_net_feats = 16
num_patches = 512
print(f'Device: {device}')
lr = 2e-3 # use the lr as recommended by the paper
gan_l_type='lsgan' # consider switching to lsgan as used in the paper
bs = 1

save_every = 100 # save generator image every x images in a batch
# define the number of epochs for training
epochs = 30
loss_names = ['DLoss', 'FakeDLoss', 'RealDLoss', 'GLoss', 'GANGLoss', 'NCELoss', 'NCEIdentityLoss']

def main():
    data = Data('apples', 'oranges', False, bs, (128, 128))
    data.get_loaders('apples_and_oranges')
    # show_batch(data, 9)

    # define the CUT gan model which has all 3 nets and training loop for the 3 nets
    cut_model = CUT_gan(lambda_gan, lambda_nce, nce_layers, device, lr, gan_l_type=gan_l_type, bs=bs, num_patches=num_patches, encoder_net_features=enc_net_feats)
    # print(cut_model.disc)
    # print(next(cut_model.gen.parameters()).device)
    # print(cut_model.gen.feature_extractor_channels)
    # raise Exception('done')

    for epoch in range(epochs):
        # discriminator losses
        loss_d = 0
        fake_d_loss = 0
        real_d_loss = 0
        # generator losses
        loss_g = 0
        gan_g_loss = 0
        nce_loss = 0
        nce_identity_loss = 0
        start_ep = time.time()
        x_check = []
        with tqdm(total=len(data.dlSourceTrain)) as pbar:
            for i, ((x, _), (y, _)) in enumerate(zip(data.dlSourceTrain, data.dlTargetTrain)):
                # set model to train
                cut_model.train()
                # move the image tensors onto the correct device
                x = x.to(device)
                y = y.to(device)
                # train model
                cut_model.optimize_params(x, y)
                # update losses
                # with torch.no_grad():
                loss_d += cut_model.loss_d.item()
                fake_d_loss += cut_model.fake_d_loss.item()
                real_d_loss += cut_model.real_d_loss.item()
                loss_g += cut_model.loss_g.item()
                gan_g_loss += cut_model.gan_g_loss.item()
                nce_loss += cut_model.nce_loss.item()
                nce_identity_loss += cut_model.nce_identity_loss.item()
                # save image to show in results at end of epoch
                if i % save_every == 0:
                    x_check.append(x)
                # update progress
                pbar.update(bs)
                # if i == 20:
                #     break
    

        # print(torch.cuda.memory_summary(device=None, abbreviated=False))
        ep_time = time.time() - start_ep
        cut_model.save_nets(epoch) # save all 3 networks
        # output training loss and time
        loss_list = [loss_d, fake_d_loss, real_d_loss, loss_g, gan_g_loss, nce_loss, nce_identity_loss]
        report_losses(loss_list, loss_names, epoch, ep_time, i+1)
        # output visuals
        cut_model.eval()
        x_check = torch.cat(x_check, dim=0)
        # print(x_check.device)
        # print(next(cut_model.gen.parameters()).device)
        # raise Exception('done')
        cut_model(x_check)
        x_fake = cut_model.fake_targ
        # put the images together to save them
        ims = torch.cat((x_check, x_fake), dim=3)
        save_images(ims, f'ep{epoch}.png')

main()


Device: cuda


 26%|██▌       | 258/995 [01:07<03:08,  3.91it/s]