In [0]:
# ! git clone https://github.com/brownvc/ganimorph.git

In [3]:
import argparse
import os
import numpy as np
import math
import itertools
import datetime
import time
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

# from pytorch_msssim.pytorch_msssim import MSSSIM, msssim

Model uses an msssim loss. Most of the code below is with reference to pytorch_msssim

In [0]:
# from pytorch_msssim.pytorch_msssim with some modification if necessary
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.03 * L) ** 2
    C2 = (0.05 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
    device = img1.device
    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
    levels = weights.size()[0]
    mssim = []
    mcs = []
    for _ in range(levels):
        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
        mssim.append(sim)
        mcs.append(cs)

        img1 = F.avg_pool2d(img1, (2, 2))
        img2 = F.avg_pool2d(img2, (2, 2))

    mssim = torch.stack(mssim)
    mcs = torch.stack(mcs)

    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
    if normalize:
        mssim = (mssim + 1) / 2
        mcs = (mcs + 1) / 2

    pow1 = mcs ** weights
    pow2 = mssim ** weights
    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
    output = torch.prod(pow1[:-1] * pow2[-1])
    return output


# Classes to re-use window
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)

class MSSSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, normalize = False, channel=3):
        super(MSSSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = channel
        self.normalize = normalize

    def forward(self, img1, img2):
        # TODO: store window between calls if possible
        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize = self.normalize)


In [0]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
#     elif classname.find("Disc_Block") != -1:
#         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
#         torch.nn.init.constant_(m.bias.data, 0.0)
#     elif classname.find("Dilated_Block") != -1:
#         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
#         torch.nn.init.constant_(m.bias.data, 0.0)
#     elif classname.find("InstanceNorm2d") != -1:
#         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
#         torch.nn.init.constant_(m.bias.data, 0.0)

Utilities for training

In [0]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


The structure of the model is written here

In [None]:


class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_features, out_features, 3, 1, 1),           
            nn.Conv2d(out_features, out_features, 3, 1, 1),
            )
        
        self.block2 = nn.Sequential(
            nn.Conv2d(in_features+out_features, out_features, 3, 1, 1)
            )

    def forward(self, x):
        res =  torch.cat((self.block1(x), x), 1)
        out = self.block2(res)
        return out
        

In [0]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        def residual(in_feat, depth = 3):
            layers = []
            for i in range(depth):
            layers.append(ResidualBlock(in_feat, in_feat))
            return layers

        nf = 64
        d = 3

        self.block1 = nn.Sequential(
                    nn.Conv2d(channels, nf, 4, 2, 1),
                    nn.LeakyReLU(),
                    *downsample(nf, 2*nf),
                    *residual(2*nf, d))

        self.block2 = nn.Sequential(
                    *downsample(2*nf, 4*nf),
                    *residual(4*nf, d)
                    )



        self.block3 = nn.Sequential(
                    *downsample(4*nf, 8*nf),
                    *residual(8*nf, d),
                    *upsample(8*nf, 4*nf),
                    )

        self.block4 = nn.Sequential(
                    ResidualBlock(8*nf, 4*nf),
                    *residual(4*nf, d-1),
                    *upsample(4*nf, 2*nf)
                    )

        self.block5 = nn.Sequential(
                    ResidualBlock(4*nf, 2*nf),
                    *residual(2*nf, d-1),
                    *upsample(2*nf, nf),
                    nn.ConvTranspose2d(nf, 3, 4, 2, 1),
        #                     nn.Sigmoid()
                    nn.Tanh()
        )
      
    def forward(self, x):
        y1 = self.block1(x)
        y2 = self.block2(y1)
        y3 = self.block3(y2)
        y4 = torch.cat((y3, y2), 1)
        y5 = self.block4(y4)
        y6 = torch.cat((y5, y1), 1)
        out = self.block5(y6)
        return out

In [0]:
class Disc_Block(nn.Module):
    def __init__(self, in_features, out_features, normalize = True, kernel_size = 4, stride = 2, padding = 1):
        super(Disc_Block, self).__init__()

        layers = [nn.Conv2d(in_features, out_features, kernel_size, stride, padding)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_features))
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class Dilated_Block(nn.Module):
    def __init__(self, in_features, out_features, dilation = 2, normalize = True, kernel_size = 3, stride = 1):
        super(Dilated_Block, self).__init__()

        layers = [nn.Conv2d(in_features, out_features, kernel_size, stride, padding = dilation, dilation = dilation)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_features))
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

In [0]:
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        nf = 64

        self.layer1 = Disc_Block(3, 2*nf, normalize = False)
        #self.layer1 = nn.Conv2d(3, 2*nf, 4, 2, 1)
        self.layer2 = Disc_Block(2*nf, 4*nf)
        self.layer3 = Disc_Block(4*nf, 8*nf)
        self.layer4 = Disc_Block(8*nf, 8*nf, True, 3, 1, 1)
        self.layer5 = Dilated_Block(8*nf, 8*nf, 2)
        self.layer6 = Dilated_Block(8*nf, 8*nf, 4)
        self.layer7 = Dilated_Block(8*nf, 8*nf, 8)
        self.layer9 = Disc_Block(16*nf, 8*nf, True, 3, 1, 1)
        self.layer10 = nn.Sequential(
                      nn.Conv2d(8*nf, 1, 3, 1, 1),
                      nn.Identity()
        )

    def forward(self, x):
        l1 = self.layer1(x)
        l2 = self.layer2(l1)
        l3 = self.layer3(l2)
        l4 = self.layer4(l3)
        l5 = self.layer5(l4)
        l6 = self.layer6(l5)
        l7 = self.layer7(l6)
        l8 = torch.cat((l4, l7), 1)
        l9 = self.layer9(l8)
        l10 = self.layer10(l9)

        return l10, (l2, l3, l4, l5, l6, l7, l9)

Losses need to be normalised to stabilise training

In [0]:
def normalise_loss(loss, update_condition, epsilon = 1e-10):
    loss_value = 1
    loss_value_smooth = 1

    # hard-coded the implementation of tf.python.training.moving_averages(variable, value, decay)
    # variable -= (1 - decay) * (variable - value)
    decay = 0.9999
    ma_loss_value = loss_value_smooth - (1 - decay)*(loss_value_smooth - loss)

    loss_value_updated = loss_value

    if update_condition:
        # loss_value = ma_loss_value
        loss_value_updated = ma_loss_value

    loss_normalised = loss/(loss_value_updated + epsilon)

    return loss_normalised

Dataloader to load data in the manner required by the model. One image of type A and another type B

In [0]:
import glob
import random

from torch.utils.data import Dataset
from PIL import Image


def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image


class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


To train the model on cat_dog_face dataset. Modify this part to call from command line and train with other datasets

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=1, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="cat_dog_face", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=5, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=128, help="size of image height")
parser.add_argument("--img_width", type=int, default=128, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=10, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
# parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
# parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args(["--dataset_name", "cat_dog_face"])
print(opt)

# Create sample and checkpoint directories
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

patch = (1, opt.img_height // 2**3, opt.img_width // 2**3)

# Losses
msssim_loss = MSSSIM(window_size = 8, normalize = True)
l1_loss = nn.L1Loss()
gan_loss = nn.BCEWithLogitsLoss()

cuda = torch.cuda.is_available()

input_shape = (opt.channels, opt.img_height, opt.img_width)

# Initialize generator and discriminator
G_AB = Generator()
G_BA = Generator()
D_A = Discriminator()
D_B = Discriminator()

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    msssim_loss.cuda()
    l1_loss.cuda()
    gan_loss.cuda()

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Image transformations
transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop((opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

# Training data loader
dataloader = DataLoader(
    ImageDataset("%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)


def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)


Namespace(b1=0.5, b2=0.999, batch_size=32, channels=3, checkpoint_interval=10, dataset_name='cat_dog_face', decay_epoch=5, epoch=0, img_height=128, img_width=128, lr=0.0002, n_cpu=8, n_epochs=10, n_residual_blocks=9, sample_interval=100)


NameError: name 'MSSSIM' is not defined

In [0]:
# Training

prev_time = time.time()


# norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
#         valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
#         fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        batches_done = epoch * len(dataloader) + i
        condition = (batches_done == 36) or (batches_done % 90 == 0)

        # ------------------
        #  Train Generators
        # ------------------
        if i%3 == 0:
            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Images
            fake_A = G_BA(real_B)        
            fake_B = G_AB(real_A)

            recov_A = G_BA(fake_B)
            recov_B = G_AB(fake_A)

            real_A_dis, real_A_feat = D_A(real_A)
            fake_A_dis, fake_A_feat = D_A(fake_A)

            real_B_dis, real_B_feat = D_B(real_B)
            fake_B_dis, fake_B_feat = D_B(fake_B)

            # Lambdas for final arithmetic, taken from the paper
            lambda_gan = 0.49
            lambda_fm = 0.21
            lambda_cyc = 0.3
            lambda_ss = 0.7
            lambda_l1 = 0.3

            # Reconstruction loss
            recon_loss_A = 1 - msssim_loss(real_A, recov_A)
            #         print(recon_loss_A)
            recon_loss_A_l = l1_loss(real_A, recov_A)
            #         print(recon_loss_A_l)

            recon_loss_B = 1 - msssim_loss(real_B, recov_B)
            recon_loss_B_l = l1_loss(real_B, recov_B)

            total_recon_loss = recon_loss_A + recon_loss_B
            total_recon_loss_l = recon_loss_A_l + recon_loss_B_l

            total_recon = lambda_ss*total_recon_loss + lambda_l1*total_recon_loss_l
            sln_recon_loss = normalise_loss(total_recon, condition)


            # GAN loss
            gan_loss_A = gan_loss(fake_A_dis, valid)
            #           print(gan_loss_A)
            #           print(valid.shape)

            gan_loss_B = gan_loss(fake_B_dis, valid)

            total_gan_loss = gan_loss_A + gan_loss_B
            sln_gan_loss = normalise_loss(total_gan_loss, condition)

            # Feature Matching loss
            fm_loss_A = feature_match_loss(real_A_feat, fake_A_feat)

            fm_loss_B = feature_match_loss(real_B_feat, fake_B_feat)

            total_fm_loss = fm_loss_A + fm_loss_B
            sln_fm_loss = normalise_loss(total_fm_loss, condition)

            # Total loss
            loss_G = lambda_gan*sln_gan_loss + lambda_fm*sln_fm_loss + lambda_cyc*sln_recon_loss

            loss_G.backward()
            optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        D_A.train()

        optimizer_D_A.zero_grad()

        # Real loss
#         loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
#         fake_A_ = fake_A_buffer.push_and_pop(fake_A)
#         loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
#         loss_D_A = (loss_real + loss_fake) / 2

        # Adding gaussian noise to inputs of discriminator
        real_A = real_A + Variable(torch.randn(real_A.shape).type(Tensor))
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        fake_A_ = fake_A_ + Variable(torch.randn(fake_A_.shape).type(Tensor))
  
        # Real loss
 
        pos_loss_A = gan_loss(D_A(real_A)[0], valid.new_full((real_A.size(0), *patch), np.random.uniform(0.8, 1, 1).item()))
        # Fake loss
        neg_loss_A = gan_loss(D_A(fake_A_.detach())[0], fake.new_full((fake_A_.shape[0], *patch), np.random.uniform(0, 0.2, 1).item()))
        # Total loss
        loss_D_A = (pos_loss_A + neg_loss_A) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        D_B.train()

        optimizer_D_B.zero_grad()

        # Real loss
#         loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
#         fake_B_ = fake_B_buffer.push_and_pop(fake_B)
#         loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
#         loss_D_B = (loss_real + loss_fake) / 2
        
        # Adding gaussian noise to inputs of discriminator
        real_B = real_B + Variable(torch.randn(real_B.shape).type(Tensor))
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        fake_B_ = fake_B_ + Variable(torch.randn(fake_B_.shape).type(Tensor))
  
        # Real loss
        pos_loss_B = gan_loss(D_B(real_B)[0], valid.new_full((real_B.size(0), *patch), np.random.uniform(0.8, 1, 1).item()))
        # Fake loss
        neg_loss_B = gan_loss(D_B(fake_B_.detach())[0], fake.new_full((fake_B_.shape[0], *patch), np.random.uniform(0, 0.2, 1).item()))
        # Total loss
        loss_D_B = (pos_loss_B + neg_loss_B) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
#         batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, feature matching: %f] ETA: %s"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D,
                loss_G,
                total_gan_loss,
                total_recon,
                total_fm_loss,
                time_left
            )
        )

        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))

[Epoch 4/10] [Batch 1166/1498] [D loss: 0.327157] [G loss: 2.714543, adv: 4.572423, cycle: 0.297775, feature matching: 1.832016] ETA: 4:25:42.991716

KeyboardInterrupt: ignored

Save the model after training

In [0]:
torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))