In [25]:
!pip install torch tqdm albumentations numpy torchvision tensorflow matplotlib



In [27]:
import torch


class ConvBlock(torch.nn.Module):
    def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, activation='relu', batch_norm=True):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.lrelu = torch.nn.LeakyReLU(0.2, True)
        self.tanh = torch.nn.Tanh()

    def forward(self, x):
        if self.batch_norm:
            out = self.bn(self.conv(x))
        else:
            out = self.conv(x)

        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out


class DeconvBlock(torch.nn.Module):
    def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, output_padding=1, activation='relu', batch_norm=True):
        super(DeconvBlock, self).__init__()
        self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)

    def forward(self, x):
        if self.batch_norm:
            out = self.bn(self.deconv(x))
        else:
            out = self.deconv(x)

        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out


class ResnetBlock(torch.nn.Module):
    def __init__(self, num_filter, kernel_size=3, stride=1, padding=0):
        super(ResnetBlock, self).__init__()
        conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding)
        conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding)
        bn = torch.nn.InstanceNorm2d(num_filter)
        relu = torch.nn.ReLU(True)
        pad = torch.nn.ReflectionPad2d(1)

        self.resnet_block = torch.nn.Sequential(
            pad,
            conv1,
            bn,
            relu,
            pad,
            conv2,
            bn
        )

    def forward(self, x):
        out = self.resnet_block(x)
        return out


class Generator(torch.nn.Module):
    def __init__(self, input_dim, num_filter, output_dim, num_resnet):
        super(Generator, self).__init__()

        # Reflection padding
        self.pad = torch.nn.ReflectionPad2d(3)
        # Encoder
        self.conv1 = ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0)
        self.conv2 = ConvBlock(num_filter, num_filter * 2)
        self.conv3 = ConvBlock(num_filter * 2, num_filter * 4)
        # Resnet blocks
        self.resnet_blocks = []
        for i in range(num_resnet):
            self.resnet_blocks.append(ResnetBlock(num_filter * 4))
        self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
        # Decoder
        self.deconv1 = DeconvBlock(num_filter * 4, num_filter * 2)
        self.deconv2 = DeconvBlock(num_filter * 2, num_filter)
        self.deconv3 = ConvBlock(num_filter, output_dim,
                                 kernel_size=7, stride=1, padding=0, activation='tanh', batch_norm=False)

    def forward(self, x):
        # Encoder
        enc1 = self.conv1(self.pad(x))
        enc2 = self.conv2(enc1)
        enc3 = self.conv3(enc2)
        # Resnet blocks
        res = self.resnet_blocks(enc3)
        # Decoder
        dec1 = self.deconv1(res)
        dec2 = self.deconv2(dec1)
        out = self.deconv3(self.pad(dec2))
        return out

    def normal_weight_init(self, mean=0.0, std=0.02):
        for m in self.children():
            if isinstance(m, ConvBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)
            if isinstance(m, DeconvBlock):
                torch.nn.init.normal(m.deconv.weight, mean, std)
            if isinstance(m, ResnetBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)
                torch.nn.init.constant(m.conv.bias, 0)


class Discriminator(torch.nn.Module):
    def __init__(self, input_dim, num_filter, output_dim):
        super(Discriminator, self).__init__()

        conv1 = ConvBlock(input_dim, num_filter, kernel_size=4, stride=2, padding=1, activation='lrelu', batch_norm=False)
        conv2 = ConvBlock(num_filter, num_filter * 2, kernel_size=4, stride=2, padding=1, activation='lrelu')
        conv3 = ConvBlock(num_filter * 2, num_filter * 4, kernel_size=4, stride=2, padding=1, activation='lrelu')
        conv4 = ConvBlock(num_filter * 4, num_filter * 8, kernel_size=4, stride=1, padding=1, activation='lrelu')
        conv5 = ConvBlock(num_filter * 8, output_dim, kernel_size=4, stride=1, padding=1, activation='no_act', batch_norm=False)

        self.conv_blocks = torch.nn.Sequential(
            conv1,
            conv2,
            conv3,
            conv4,
            conv5
        )

    def forward(self, x):
        out = self.conv_blocks(x)
        return out

    def normal_weight_init(self, mean=0.0, std=0.02):
        for m in self.children():
            if isinstance(m, ConvBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)

In [45]:
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
import random


# For logger
def to_np(x):
    return x.data.cpu().numpy()


def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)


# De-normalization
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


# Plot losses
def plot_loss(avg_losses, num_epochs, save=False, save_dir='results/', show=False):
    fig, ax = plt.subplots()
    ax.set_xlim(0, num_epochs)
    temp = 0.0
    for i in range(len(avg_losses)):
        temp = max(np.max(avg_losses[i]), temp)
    ax.set_ylim(0, temp*1.1)
    plt.xlabel('# of Epochs')
    plt.ylabel('Loss values')

    plt.plot(avg_losses[0], label='D_A')
    plt.plot(avg_losses[1], label='D_B')
    plt.plot(avg_losses[2], label='G_A')
    plt.plot(avg_losses[3], label='G_B')
    plt.plot(avg_losses[4], label='cycle_A')
    plt.plot(avg_losses[5], label='cycle_B')
    plt.legend()

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


def plot_train_result(real_image, gen_image, recon_image, epoch, save=False, save_dir='results/', show=False, fig_size=(5, 5)):
    import matplotlib.pyplot as plt
    import os
    import numpy as np

    fig, axes = plt.subplots(2, 3, figsize=fig_size)

    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        ax.set_adjustable('box')  # Corrected here
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)  # Changed to makedirs for nested directory creation

        save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()



def plot_test_result(real_image, gen_image, recon_image, index, save=False, save_dir='results/', show=False):
    import matplotlib.pyplot as plt
    import os
    import numpy as np

    fig_size = (real_image.size(2) * 3 / 100, real_image.size(3) / 100)
    fig, axes = plt.subplots(1, 3, figsize=fig_size)

    imgs = [to_np(real_image), to_np(gen_image), to_np(recon_image)]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        ax.set_adjustable('box')  # Corrected here
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)  # Changed to makedirs for nested directory creation

        save_fn = save_dir + 'Test_result_{:d}'.format(index + 1) + '.png'
        fig.subplots_adjust(bottom=0)
        fig.subplots_adjust(top=1)
        fig.subplots_adjust(right=1)
        fig.subplots_adjust(left=0)
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


# Make gif
def make_gif(dataset, num_epochs, save_dir='results/'):
    gen_image_plots = []
    for epoch in range(num_epochs):
        # plot for generating gif
        save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png'
        gen_image_plots.append(imageio.imread(save_fn))

    imageio.mimsave(save_dir + dataset + '_CycleGAN_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5)


class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [46]:
from torchvision import transforms
import albumentations as A

dataset = 'summer2winter'
batch_size = 1
ngf = 32
ndf = 64
num_resnet = 6
input_size = 256
resize_scale = 286
crop_size = 256
fliplr = True
num_epochs = 10
decay_epoch = 100
lrG = 0.0002  # Learning rate for generator
lrD = 0.0002  # Learning rate for discriminator
lambdaA = 10  # Weight for cycle loss A to B
lambdaB = 10  # Weight for cycle loss B to A
beta1 = 0.5   # Beta1 hyperparameter for the Adam optimizer
beta2 = 0.999 # Beta2 hyperparameter for the Adam optimizer

# Directories for loading data and saving results
data_dir = '/kaggle/input/summer2winter-yosemite/'
save_dir = dataset + '_results/'
model_dir = dataset + '_model/'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

# Data pre-processing
transform = transforms.Compose([transforms.Resize(input_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

In [31]:
from PIL import Image
import torch.utils.data as data
import os
import random


class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr

    def __getitem__(self, index):
        # Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn).convert('RGB')

        # preprocessing
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x + self.crop_size, y + self.crop_size))
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.image_filenames)

In [41]:
import tensorflow as tf
import numpy as np
import scipy.misc

try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x


class Logger(object):
    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.create_file_writer(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        with self.writer.as_default():
            tf.summary.scalar(name=tag, data=value, step=step)
            self.writer.flush()

    def image_summary(self, tag, images, step):
        """Log a list of images."""
        with self.writer.as_default():
            for i, img in enumerate(images):
                # Convert to a format suitable for logging
                if img.dtype != np.uint8:
                    img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)
                
                img = np.expand_dims(img, axis=0)  # Add batch dimension if necessary
                tf.summary.image(name=f"{tag}/{i}", data=img, step=step)
            self.writer.flush()


    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values ** 2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()

In [47]:
import torch
from torchvision import transforms
from torch.autograd import Variable

import argparse
import os, itertools

import numpy as np


# Train data
train_data_A = DatasetFromFolder(data_dir, subfolder='trainA', transform=transform,
                                 resize_scale=resize_scale, crop_size=crop_size, fliplr=fliplr)
train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A,
                                                  batch_size=batch_size,
                                                  shuffle=True)
train_data_B = DatasetFromFolder(data_dir, subfolder='trainB', transform=transform,
                                 resize_scale=resize_scale, crop_size=crop_size, fliplr=fliplr)
train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B,
                                                  batch_size=batch_size,
                                                  shuffle=True)

# Test data
test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)
test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A,
                                                 batch_size=batch_size,
                                                 shuffle=False)
test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)
test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B,
                                                 batch_size=batch_size,
                                                 shuffle=False)

# Get specific test images
test_real_A_data = test_data_A.__getitem__(11).unsqueeze(0)  # Convert to 4d tensor (BxNxHxW)
test_real_B_data = test_data_B.__getitem__(91).unsqueeze(0)

# Models
G_A = Generator(3, ngf, 3, num_resnet)
G_B = Generator(3, ngf, 3, num_resnet)
D_A = Discriminator(3, ndf, 1)
D_B = Discriminator(3, ndf, 1)
G_A.normal_weight_init(mean=0.0, std=0.02)
G_B.normal_weight_init(mean=0.0, std=0.02)
D_A.normal_weight_init(mean=0.0, std=0.02)
D_B.normal_weight_init(mean=0.0, std=0.02)
G_A.cuda()
G_B.cuda()
D_A.cuda()
D_B.cuda()


# Set the logger
D_A_log_dir = save_dir + 'D_A_logs'
D_B_log_dir = save_dir + 'D_B_logs'
if not os.path.exists(D_A_log_dir):
    os.mkdir(D_A_log_dir)
D_A_logger = Logger(D_A_log_dir)
if not os.path.exists(D_B_log_dir):
    os.mkdir(D_B_log_dir)
D_B_logger = Logger(D_B_log_dir)

G_A_log_dir = save_dir + 'G_A_logs'
G_B_log_dir = save_dir + 'G_B_logs'
if not os.path.exists(G_A_log_dir):
    os.mkdir(G_A_log_dir)
G_A_logger = Logger(G_A_log_dir)
if not os.path.exists(G_B_log_dir):
    os.mkdir(G_B_log_dir)
G_B_logger = Logger(G_B_log_dir)

cycle_A_log_dir = save_dir + 'cycle_A_logs'
cycle_B_log_dir = save_dir + 'cycle_B_logs'
if not os.path.exists(cycle_A_log_dir):
    os.mkdir(cycle_A_log_dir)
cycle_A_logger = Logger(cycle_A_log_dir)
if not os.path.exists(cycle_B_log_dir):
    os.mkdir(cycle_B_log_dir)
cycle_B_logger = Logger(cycle_B_log_dir)

img_log_dir = save_dir + 'img_logs'
if not os.path.exists(img_log_dir):
    os.mkdir(img_log_dir)
img_logger = Logger(img_log_dir)


# Loss function
MSE_loss = torch.nn.MSELoss().cuda()
L1_loss = torch.nn.L1Loss().cuda()

# optimizers
G_optimizer = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lrG, betas=(beta1, beta2))
D_A_optimizer = torch.optim.Adam(D_A.parameters(), lr=lrD, betas=(beta1, beta2))
D_B_optimizer = torch.optim.Adam(D_B.parameters(), lr=lrD, betas=(beta1, beta2))

# Training GAN
D_A_avg_losses = []
D_B_avg_losses = []
G_A_avg_losses = []
G_B_avg_losses = []
cycle_A_avg_losses = []
cycle_B_avg_losses = []

# Generated image pool
num_pool = 50
fake_A_pool = ImagePool(num_pool)
fake_B_pool = ImagePool(num_pool)

step = 0
for epoch in range(num_epochs):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    cycle_A_losses = []
    cycle_B_losses = []

    # learning rate decay
    if (epoch + 1) > decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= lrD / (num_epochs - decay_epoch)
        D_B_optimizer.param_groups[0]['lr'] -= lrD / (num_epochs - decay_epoch)
        G_optimizer.param_groups[0]['lr'] -= lrG / (num_epochs - decay_epoch)

    # training
    for i, (real_A, real_B) in enumerate(zip(train_data_loader_A, train_data_loader_B)):

        # input image data
        real_A = Variable(real_A.cuda())
        real_B = Variable(real_B.cuda())

        # Train generator G
        # A -> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = MSE_loss(D_B_fake_decision, Variable(torch.ones(D_B_fake_decision.size()).cuda()))

        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_loss(recon_A, real_A) * lambdaA

        # B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = MSE_loss(D_A_fake_decision, Variable(torch.ones(D_A_fake_decision.size()).cuda()))

        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_loss(recon_B, real_B) * lambdaB

        # Back propagation
        G_loss = G_A_loss + G_B_loss + cycle_A_loss + cycle_B_loss
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # Train discriminator D_A
        D_A_real_decision = D_A(real_A)
        D_A_real_loss = MSE_loss(D_A_real_decision, Variable(torch.ones(D_A_real_decision.size()).cuda()))
        fake_A = fake_A_pool.query(fake_A)
        D_A_fake_decision = D_A(fake_A)
        D_A_fake_loss = MSE_loss(D_A_fake_decision, Variable(torch.zeros(D_A_fake_decision.size()).cuda()))

        # Back propagation
        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_optimizer.zero_grad()
        D_A_loss.backward()
        D_A_optimizer.step()

        # Train discriminator D_B
        D_B_real_decision = D_B(real_B)
        D_B_real_loss = MSE_loss(D_B_real_decision, Variable(torch.ones(D_B_real_decision.size()).cuda()))
        fake_B = fake_B_pool.query(fake_B)
        D_B_fake_decision = D_B(fake_B)
        D_B_fake_loss = MSE_loss(D_B_fake_decision, Variable(torch.zeros(D_B_fake_decision.size()).cuda()))

        # Back propagation
        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_optimizer.zero_grad()
        D_B_loss.backward()
        D_B_optimizer.step()

        # loss values
        D_A_losses.append(D_A_loss.item())  # Corrected
        D_B_losses.append(D_B_loss.item())  # Corrected
        G_A_losses.append(G_A_loss.item())  # Corrected
        G_B_losses.append(G_B_loss.item())  # Corrected
        cycle_A_losses.append(cycle_A_loss.item())  # Corrected
        cycle_B_losses.append(cycle_B_loss.item())  # Corrected

        print('Epoch [%d/%d], Step [%d/%d], D_A_loss: %.4f, D_B_loss: %.4f, G_A_loss: %.4f, G_B_loss: %.4f'
              % (epoch + 1, num_epochs, i + 1, len(train_data_loader_A), D_A_loss.item(), D_B_loss.item(), G_A_loss.item(), G_B_loss.item()))

        # ============ TensorBoard logging ============#
        D_A_logger.scalar_summary('losses', D_A_loss.item(), step + 1)
        D_B_logger.scalar_summary('losses', D_B_loss.item(), step + 1)
        G_A_logger.scalar_summary('losses', G_A_loss.item(), step + 1)
        G_B_logger.scalar_summary('losses', G_B_loss.item(), step + 1)
        cycle_A_logger.scalar_summary('losses', cycle_A_loss.item(), step + 1)
        cycle_B_logger.scalar_summary('losses', cycle_B_loss.item(), step + 1)
        step += 1

    D_A_avg_loss = torch.mean(torch.FloatTensor(D_A_losses))
    D_B_avg_loss = torch.mean(torch.FloatTensor(D_B_losses))
    G_A_avg_loss = torch.mean(torch.FloatTensor(G_A_losses))
    G_B_avg_loss = torch.mean(torch.FloatTensor(G_B_losses))
    cycle_A_avg_loss = torch.mean(torch.FloatTensor(cycle_A_losses))
    cycle_B_avg_loss = torch.mean(torch.FloatTensor(cycle_B_losses))

    # avg loss values for plot
    D_A_avg_losses.append(D_A_avg_loss)
    D_B_avg_losses.append(D_B_avg_loss)
    G_A_avg_losses.append(G_A_avg_loss)
    G_B_avg_losses.append(G_B_avg_loss)
    cycle_A_avg_losses.append(cycle_A_avg_loss)
    cycle_B_avg_losses.append(cycle_B_avg_loss)

    # Show result for test image
    test_real_A = Variable(test_real_A_data.cuda())
    test_fake_B = G_A(test_real_A)
    test_recon_A = G_B(test_fake_B)

    test_real_B = Variable(test_real_B_data.cuda())
    test_fake_A = G_B(test_real_B)
    test_recon_B = G_A(test_fake_A)

    plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B],
                            epoch, save=True, save_dir=save_dir)

    # log the images
    result_AtoB = np.concatenate((to_np(test_real_A), to_np(test_fake_B), to_np(test_recon_A)), axis=3)
    result_BtoA = np.concatenate((to_np(test_real_B), to_np(test_fake_A), to_np(test_recon_B)), axis=3)

    info = {
        'result_AtoB': result_AtoB.transpose(0, 2, 3, 1),  # convert to BxHxWxC
        'result_BtoA': result_BtoA.transpose(0, 2, 3, 1)
    }

    for tag, images in info.items():
        img_logger.image_summary(tag, images, epoch + 1)


# Plot average losses
avg_losses = []
avg_losses.append(D_A_avg_losses)
avg_losses.append(D_B_avg_losses)
avg_losses.append(G_A_avg_losses)
avg_losses.append(G_B_avg_losses)
avg_losses.append(cycle_A_avg_losses)
avg_losses.append(cycle_B_avg_losses)
plot_loss(avg_losses, num_epochs, save=True, save_dir=save_dir)

# Make gif
make_gif(dataset, num_epochs, save_dir=save_dir)
# Save trained parameters of model
torch.save(G_A.state_dict(), model_dir + 'generator_A_param.pkl')
torch.save(G_B.state_dict(), model_dir + 'generator_B_param.pkl')
torch.save(D_A.state_dict(), model_dir + 'discriminator_A_param.pkl')
torch.save(D_B.state_dict(), model_dir + 'discriminator_B_param.pkl')


  torch.nn.init.normal(m.conv.weight, mean, std)
  torch.nn.init.normal(m.deconv.weight, mean, std)


Epoch [1/10], Step [1/1231], D_A_loss: 1.0925, D_B_loss: 0.5151, G_A_loss: 0.8837, G_B_loss: 2.0017
Epoch [1/10], Step [2/1231], D_A_loss: 1.0271, D_B_loss: 1.7351, G_A_loss: 0.6730, G_B_loss: 1.0747
Epoch [1/10], Step [3/1231], D_A_loss: 1.5336, D_B_loss: 1.2694, G_A_loss: 0.3942, G_B_loss: 0.6698
Epoch [1/10], Step [4/1231], D_A_loss: 0.7682, D_B_loss: 0.5051, G_A_loss: 0.2315, G_B_loss: 0.3622
Epoch [1/10], Step [5/1231], D_A_loss: 0.4940, D_B_loss: 0.6609, G_A_loss: 0.3766, G_B_loss: 0.4724
Epoch [1/10], Step [6/1231], D_A_loss: 1.1676, D_B_loss: 0.7108, G_A_loss: 0.7239, G_B_loss: 0.5016
Epoch [1/10], Step [7/1231], D_A_loss: 0.6887, D_B_loss: 0.4278, G_A_loss: 0.6237, G_B_loss: 0.5183
Epoch [1/10], Step [8/1231], D_A_loss: 0.3674, D_B_loss: 0.3847, G_A_loss: 0.4840, G_B_loss: 0.5221
Epoch [1/10], Step [9/1231], D_A_loss: 0.3989, D_B_loss: 0.3313, G_A_loss: 0.4088, G_B_loss: 0.4489
Epoch [1/10], Step [10/1231], D_A_loss: 0.3836, D_B_loss: 0.3505, G_A_loss: 0.3604, G_B_loss: 0.3905

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 15.89 GiB of which 38.12 MiB is free. Process 4601 has 15.86 GiB memory in use. Of the allocated memory 1.87 GiB is allocated by PyTorch, and 187.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF