# COMP9444 HorseToZebra

## Import Model and Inital Setup

In [37]:
from abc import ABC, abstractmethod
from collections import OrderedDict
import cv2
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import functools
import itertools
import importlib
from IPython.display import display, HTML, clear_output
import multiprocessing
from multiprocessing import freeze_support
import matplotlib.pyplot as plt
import numpy as np
import ntpath
import os
import os.path
from PIL import Image
import random
import sys
from subprocess import Popen, PIPE
from scipy.stats import gaussian_kde
import time
import torch
import torch.nn as nn
from torch.nn import init
import torch.utils.data
import torch.utils.data as data
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torchvision.transforms as transforms

opt = {
    # Basic parameters
    'dataroot': './datasets',
    'name': 'horse2zebra_attentiongan',
    'gpu_ids': [0],
    'checkpoints_dir': './checkpoints',
    
    # Model parameters
    'model' : 'attention_gan',
    'input_nc': 3,
    'output_nc': 3,
    'ngf': 64,
    'ndf': 64,
    'netD': 'basic',
    'netG': 'resnet_9blocks',
    'n_layers_D': 3,
    'norm': 'instance',
    'init_type': 'normal',
    'init_gain': 0.02,
    'no_dropout': True,
    
    # Dataset parameters
    'dataset_mode': 'unaligned',
    'direction': 'AtoB',
    'serial_batches': False,
    'num_threads': 4,
    'batch_size': 4,
    'load_size': 286,
    'crop_size': 256,
    'max_dataset_size': float("inf"),
    'preprocess': 'resize_and_crop',
    'no_flip': False,
    'display_winsize': 256,
    
    # Training parameters
    'phase': 'train',
    'niter': 60,
    'niter_decay': 0,
    'beta1': 0.5,
    'lr': 0.0002,
    'gan_mode': 'lsgan',
    'pool_size': 50,
    'lambda_A': 10.0,
    'lambda_B': 10.0,
    'lambda_identity': 0.5,
    'isTrain': True,
    'lr_policy': 'linear',
    
    # Display parameters
    'display_freq': 100,
    'display_id':1,
    'display_ncols':10,
    'no_html':True,
    'display_port':8097,
    'display_server':"http://localhost",
    'display_env':'main',
    'print_freq': 100,
    'save_latest_freq': 5000,
    'save_epoch_freq': 5,
    'save_by_iter': False,
    'continue_train': False,
    'epoch_count': 1,
    'update_html_freq':1000,
    'save_result': False,

    # additional
    'saveDisk':True,
    'verbose':True,
}

# Image handling utilities
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']

freeze_support()

# Create output directories
output_dir = os.path.join(opt['checkpoints_dir'], opt['name'])
os.makedirs(output_dir, exist_ok=True)

### Used Function for DataBase

In [38]:
class BaseDataset(data.Dataset, ABC):
    def __init__(self):
        self.opt = opt
        self.root = opt['dataroot']

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    @abstractmethod
    def __len__(self):
        return 0

    @abstractmethod
    def __getitem__(self, index):
        pass

def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt['preprocess'] == 'resize_and_crop':
        new_h = new_w = opt['load_size']
    elif opt['preprocess'] == 'scale_width_and_crop':
        new_w = opt['load_size']
        new_h = opt['load_size'] * h // w

    x = random.randint(0, np.maximum(0, new_w - opt['crop_size']))
    y = random.randint(0, np.maximum(0, new_h - opt['crop_size']))

    flip = random.random() > 0.5

    return {'crop_pos': (x, y), 'flip': flip}
    
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt['preprocess']:
        osize = [opt['load_size'], opt['load_size']]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt['preprocess']:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt['load_size'], method)))

    if 'crop' in opt['preprocess']:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt['crop_size']))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt['crop_size'])))

    if opt['preprocess'] == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt['no_flip']:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), method)

def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def __print_size_warning(ow, oh, w, h):
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images[:min(max_dataset_size, len(images))]

def default_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))
        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)

class UnalignedDataset(BaseDataset):
    def __init__(self):
        BaseDataset.__init__(self)
        self.dir_A = os.path.join(opt['dataroot'], opt['phase'] + 'A')  # create a path '/path/to/data/trainA'
        self.dir_B = os.path.join(opt['dataroot'], opt['phase'] + 'B')  # create a path '/path/to/data/trainB'

        self.A_paths = sorted(make_dataset(self.dir_A, opt['max_dataset_size']))   # load images from '/path/to/data/trainA'
        self.B_paths = sorted(make_dataset(self.dir_B, opt['max_dataset_size']))    # load images from '/path/to/data/trainB'
        self.A_size = len(self.A_paths)  # get the size of dataset A
        self.B_size = len(self.B_paths)  # get the size of dataset B
        btoA = self.opt['direction'] == 'BtoA'
        input_nc = self.opt['output_nc'] if btoA else self.opt['input_nc']       # get the number of channels of input image
        output_nc = self.opt['input_nc'] if btoA else self.opt['output_nc']      # get the number of channels of output image
        self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
        self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]  # make sure index is within then range
        if self.opt['serial_batches']:   # make sure index is within then range
            index_B = index % self.B_size
        else:   # randomize the index for domain B to avoid fixed pairs.
            index_B = random.randint(0, self.B_size - 1)
        B_path = self.B_paths[index_B]
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')
        # apply image transformation
        A = self.transform_A(A_img)
        B = self.transform_B(B_img)

        return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        return max(self.A_size, self.B_size)

class CustomDatasetDataLoader():
    def __init__(self):
        dataset_class = UnalignedDataset()
        self.dataset = dataset_class
        print("dataset [%s] was created" % type(self.dataset).__name__)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt['batch_size'],
            shuffle=not opt['serial_batches'],
            num_workers=0)

    def load_data(self):
        return self

    def __len__(self):
        return min(len(self.dataset), opt['max_dataset_size'])

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i * opt['batch_size'] >= opt['max_dataset_size']:
                break
            yield data

def create_dataset():
    data_loader = CustomDatasetDataLoader()
    dataset = data_loader.load_data()
    return dataset

## Database Analyze

In [39]:
# Initialize dataset and model
dataset = create_dataset()
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)



dataset [UnalignedDataset] was created
The number of training images = 1334


## Used Function for Network

In [40]:
class Identity(nn.Module):
    def forward(self, x):
        return x

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        norm_layer = lambda x: Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

def get_scheduler(optimizer, opt):
    if opt['lr_policy'] == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt['epoch_count'] - opt['niter']) / float(opt['niter_decay'] + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt['lr_policy'] == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt['lr_decay_iters'], gamma=0.1)
    elif opt['lr_policy'] == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt['lr_policy'] == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt['niter'], eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt['lr_policy'])
    return scheduler

def init_weights(net, init_type='normal', init_gain=0.02):
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02):
    init_weights(net, init_type, init_gain=init_gain)
    return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_9blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    elif netG == 'unet_128':
        net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    elif netG == 'unet_256':
        net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    elif netG == 'our':
        net = ResnetGenerator_our(input_nc, output_nc, ngf, n_blocks=9)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    elif netD == 'n_layers':  # more options
        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
    elif netD == 'pixel':     # classify if each pixel is real or fake
        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
    return init_net(net, init_type, init_gain)


################################################
#                  Classes                     #
################################################
class GANLoss(nn.Module):
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss


def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
    if lambda_gp > 0.0:
        if type == 'real':   # either use real images, fake images, or a linear interpolation of two.
            interpolatesv = real_data
        elif type == 'fake':
            interpolatesv = fake_data
        elif type == 'mixed':
            alpha = torch.rand(real_data.shape[0], 1, device=device)
            alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
            interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
        else:
            raise NotImplementedError('{} not implemented'.format(type))
        interpolatesv.requires_grad_(True)
        disc_interpolates = netD(interpolatesv)
        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
                                        grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                        create_graph=True, retain_graph=True, only_inputs=True)
        gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
        gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp        # added eps
        return gradient_penalty, gradients
    else:
        return 0.0, None


class ResnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):

        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

class ResnetGenerator_our(nn.Module):
    # initializers
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
        super(ResnetGenerator_our, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.nb = n_blocks
        self.conv1 = nn.Conv2d(input_nc, ngf, 7, 1, 0)
        self.conv1_norm = nn.InstanceNorm2d(ngf)
        self.conv2 = nn.Conv2d(ngf, ngf * 2, 3, 2, 1)
        self.conv2_norm = nn.InstanceNorm2d(ngf * 2)
        self.conv3 = nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)
        self.conv3_norm = nn.InstanceNorm2d(ngf * 4)

        self.resnet_blocks = []
        for i in range(n_blocks):
            self.resnet_blocks.append(resnet_block(ngf * 4, 3, 1, 1))
            self.resnet_blocks[i].weight_init(0, 0.02)

        self.resnet_blocks = nn.Sequential(*self.resnet_blocks)

        self.deconv1_content = nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1)
        self.deconv1_norm_content = nn.InstanceNorm2d(ngf * 2)
        self.deconv2_content = nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1)
        self.deconv2_norm_content = nn.InstanceNorm2d(ngf)
        self.deconv3_content = nn.Conv2d(ngf, 27, 7, 1, 0)

        self.deconv1_attention = nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1)
        self.deconv1_norm_attention = nn.InstanceNorm2d(ngf * 2)
        self.deconv2_attention = nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1)
        self.deconv2_norm_attention = nn.InstanceNorm2d(ngf)
        self.deconv3_attention = nn.Conv2d(ngf, 10, 1, 1, 0)
        
        self.tanh = torch.nn.Tanh()
    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        x = F.pad(input, (3, 3, 3, 3), 'reflect')
        x = F.relu(self.conv1_norm(self.conv1(x)))
        x = F.relu(self.conv2_norm(self.conv2(x)))
        x = F.relu(self.conv3_norm(self.conv3(x)))
        x = self.resnet_blocks(x)
        x_content = F.relu(self.deconv1_norm_content(self.deconv1_content(x)))
        x_content = F.relu(self.deconv2_norm_content(self.deconv2_content(x_content)))
        x_content = F.pad(x_content, (3, 3, 3, 3), 'reflect')
        content = self.deconv3_content(x_content)
        image = self.tanh(content)
        image1 = image[:, 0:3, :, :]
        # print(image1.size()) # [1, 3, 256, 256]
        image2 = image[:, 3:6, :, :]
        image3 = image[:, 6:9, :, :]
        image4 = image[:, 9:12, :, :]
        image5 = image[:, 12:15, :, :]
        image6 = image[:, 15:18, :, :]
        image7 = image[:, 18:21, :, :]
        image8 = image[:, 21:24, :, :]
        image9 = image[:, 24:27, :, :]

        x_attention = F.relu(self.deconv1_norm_attention(self.deconv1_attention(x)))
        x_attention = F.relu(self.deconv2_norm_attention(self.deconv2_attention(x_attention)))
        # x_attention = F.pad(x_attention, (3, 3, 3, 3), 'reflect')
        # print(x_attention.size()) [1, 64, 256, 256]
        attention = self.deconv3_attention(x_attention)

        softmax_ = torch.nn.Softmax(dim=1)
        attention = softmax_(attention)

        attention1_ = attention[:, 0:1, :, :]
        attention2_ = attention[:, 1:2, :, :]
        attention3_ = attention[:, 2:3, :, :]
        attention4_ = attention[:, 3:4, :, :]
        attention5_ = attention[:, 4:5, :, :]
        attention6_ = attention[:, 5:6, :, :]
        attention7_ = attention[:, 6:7, :, :]
        attention8_ = attention[:, 7:8, :, :]
        attention9_ = attention[:, 8:9, :, :]
        attention10_ = attention[:, 9:10, :, :]

        attention1 = attention1_.repeat(1, 3, 1, 1)
        # print(attention1.size())
        attention2 = attention2_.repeat(1, 3, 1, 1)
        attention3 = attention3_.repeat(1, 3, 1, 1)
        attention4 = attention4_.repeat(1, 3, 1, 1)
        attention5 = attention5_.repeat(1, 3, 1, 1)
        attention6 = attention6_.repeat(1, 3, 1, 1)
        attention7 = attention7_.repeat(1, 3, 1, 1)
        attention8 = attention8_.repeat(1, 3, 1, 1)
        attention9 = attention9_.repeat(1, 3, 1, 1)
        attention10 = attention10_.repeat(1, 3, 1, 1)

        output1 = image1 * attention1
        output2 = image2 * attention2
        output3 = image3 * attention3
        output4 = image4 * attention4
        output5 = image5 * attention5
        output6 = image6 * attention6
        output7 = image7 * attention7
        output8 = image8 * attention8
        output9 = image9 * attention9
        # output10 = image10 * attention10
        output10 = input * attention10

        o=output1 + output2 + output3 + output4 + output5 + output6 + output7 + output8 + output9 + output10

        return o, output1, output2, output3, output4, output5, output6, output7, output8, output9, output10, attention1,attention2,attention3, attention4, attention5, attention6, attention7, attention8,attention9,attention10, image1, image2,image3,image4,image5,image6,image7,image8,image9

# resnet block with reflect padding
class resnet_block(nn.Module):
    def __init__(self, channel, kernel, stride, padding):
        super(resnet_block, self).__init__()
        self.channel = channel
        self.kernel = kernel
        self.strdie = stride
        self.padding = padding
        self.conv1 = nn.Conv2d(channel, channel, kernel, stride, 0)
        self.conv1_norm = nn.InstanceNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, channel, kernel, stride, 0)
        self.conv2_norm = nn.InstanceNorm2d(channel)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, input):
        x = F.pad(input, (self.padding, self.padding, self.padding, self.padding), 'reflect')
        x = F.relu(self.conv1_norm(self.conv1(x)))
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), 'reflect')
        x = self.conv2_norm(self.conv2(x))

        return input + x

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):

        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):

        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out


class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):

        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class UnetSkipConnectionBlock(nn.Module):

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):

        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


class PixelDiscriminator(nn.Module):
    """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
        super(PixelDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        """Standard forward."""
        return self.net(input)

class ImagePool():

    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):

        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)   # collect all the images and return
        return return_images


def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        h, w, _ = im.shape
        if aspect_ratio > 1.0:
            im = cv2(src=im, dsize=(h, int(w * aspect_ratio)), interpolation=cv2.INTER_CUBIC)
        if aspect_ratio < 1.0:
            im = cv2(src=im, dsize=(int(h / aspect_ratio), w), interpolation=cv2.INTER_CUBIC)
        save_image(im, save_path)

        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)

def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

def diagnose_network(net, name='network'):
    """Calculate and print the mean of average absolute(gradients)

    Parameters:
        net (torch network) -- Torch network
        name (str) -- the name of the network
    """
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)


def save_image(image_numpy, image_path):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
    """Print the mean, min, max, median, std, and size of a numpy array

    Parameters:
        val (bool) -- if print the values of the numpy array
        shp (bool) -- if print the shape of the numpy array
    """
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


## Function for Model

In [41]:
class BaseModel(ABC):

    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cpu')
        self.isTrain = opt['isTrain']
        self.save_dir = os.path.join(opt['checkpoints_dir'], opt['name'])
        if opt['preprocess'] != 'scale_width':
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0

    @staticmethod
    def modify_commandline_options(parser, is_train):
        return parser

    @abstractmethod
    def set_input(self, input):
        pass

    @abstractmethod
    def forward(self):
        pass

    @abstractmethod
    def optimize_parameters(self):
        pass

    def setup(self, opt):
        """Initialize networks and optimizers"""
        if self.isTrain:
            self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt['continue_train']:
            load_suffix = 'iter_%d' % opt['load_iter'] if opt['load_iter'] > 0 else opt['epoch']
            self.load_networks(load_suffix)
        self.print_networks(opt['verbose'])

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        return self.image_paths

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            if self.opt['lr_policy'] == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        """Save model networks"""
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)
                torch.save(net.state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
               (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        """Load model networks"""
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                state_dict = torch.load(load_path, map_location=self.device)
                net.load_state_dict(state_dict)

    def print_networks(self, verbose):

        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


class AttentionGANModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        
        # Visual names setup
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'o1_b', 'o2_b', 'o3_b', 'o4_b', 'o5_b', 'o6_b', 'o7_b', 'o8_b', 'o9_b', 'o10_b',
                         'a1_b', 'a2_b', 'a3_b', 'a4_b', 'a5_b', 'a6_b', 'a7_b', 'a8_b', 'a9_b', 'a10_b', 'i1_b', 'i2_b', 'i3_b', 'i4_b', 'i5_b', 
                         'i6_b', 'i7_b', 'i8_b', 'i9_b']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'o1_a', 'o2_a', 'o3_a', 'o4_a', 'o5_a', 'o6_a', 'o7_a', 'o8_a', 'o9_a', 'o10_a', 
                         'a1_a', 'a2_a', 'a3_a', 'a4_a', 'a5_a', 'a6_a', 'a7_a', 'a8_a', 'a9_a', 'a10_a', 'i1_a', 'i2_a', 'i3_a', 'i4_a', 'i5_a', 
                         'i6_a', 'i7_a', 'i8_a', 'i9_a']

        if self.isTrain and self.opt['lambda_identity'] > 0.0:
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B if not self.opt['saveDisk'] else ['real_A', 'fake_B', 'a10_b', 'real_B', 'fake_A', 'a10_a']
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] if self.isTrain else ['G_A', 'G_B']

        # Initialize networks
        self.netG_A = define_G(opt['input_nc'], opt['output_nc'], opt['ngf'], 'our', opt['norm'],
                              not opt['no_dropout'], opt['init_type'], opt['init_gain'])
        self.netG_B = define_G(opt['output_nc'], opt['input_nc'], opt['ngf'], 'our', opt['norm'],
                              not opt['no_dropout'], opt['init_type'], opt['init_gain'])

        if self.isTrain:
            self.netD_A = define_D(opt['output_nc'], opt['ndf'], opt['netD'],
                                 opt['n_layers_D'], opt['norm'], opt['init_type'], opt['init_gain'])
            self.netD_B = define_D(opt['input_nc'], opt['ndf'], opt['netD'],
                                 opt['n_layers_D'], opt['norm'], opt['init_type'], opt['init_gain'])

            if opt['lambda_identity'] > 0.0:  # only works when input and output images have the same number of channels
                assert(opt['input_nc'] == opt['output_nc'])
            self.fake_A_pool = ImagePool(opt['pool_size'])  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt['pool_size'])  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = GANLoss(opt['gan_mode']).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt['lr'], betas=(opt['beta1'], 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt['lr'], betas=(opt['beta1'], 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt['direction'] == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B, self.o1_b, self.o2_b, self.o3_b, self.o4_b, self.o5_b, self.o6_b, self.o7_b, self.o8_b, self.o9_b, self.o10_b, \
        self.a1_b, self.a2_b, self.a3_b, self.a4_b, self.a5_b, self.a6_b, self.a7_b, self.a8_b, self.a9_b, self.a10_b, \
        self.i1_b, self.i2_b, self.i3_b, self.i4_b, self.i5_b, self.i6_b, self.i7_b, self.i8_b, self.i9_b = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A, _, _, _, _, _, _, _, _, _, _, \
        _, _, _, _, _, _, _, _, _, _, \
        _, _, _, _, _, _, _, _, _ = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A, self.o1_a, self.o2_a, self.o3_a, self.o4_a, self.o5_a, self.o6_a, self.o7_a, self.o8_a, self.o9_a, self.o10_a, \
        self.a1_a, self.a2_a, self.a3_a, self.a4_a, self.a5_a, self.a6_a, self.a7_a, self.a8_a, self.a9_a, self.a10_a, \
        self.i1_a, self.i2_a, self.i3_a, self.i4_a, self.i5_a, self.i6_a, self.i7_a, self.i8_a, self.i9_a = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B, _, _, _, _, _, _, _, _, _, _, \
        _, _, _, _, _, _, _, _, _, _, \
        _, _, _, _, _, _, _, _, _ = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator
        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator
        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt['lambda_identity']
        lambda_A = self.opt['lambda_A']
        lambda_B = self.opt['lambda_B']
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A, _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _  = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B, _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _, _, \
            _, _, _, _, _, _, _, _, _  = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights

def create_model(opt):
    instance = AttentionGANModel(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance


## Model Setup

In [42]:
model = create_model(opt)      # create a model given optmodel and other options
model.setup(opt)               # regular setup: load and print networks; create schedulers

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
model [AttentionGANModel] was created
---------- Networks initialized -------------
ResnetGenerator_our(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
  (conv1_norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2_norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3_norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (resnet_blocks): Sequential(
    (0): resnet_block(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (conv1_norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (conv2): Conv2d(

## Display Graph Function

In [43]:
def display_epoch_results(epoch, visuals, num_cols=4):
    """
    Display epoch results using basic matplotlib
    """
    print(f"\n=== Epoch {epoch} Visual Results ===")
    
    # # Main results (real/fake images)
    # main_visuals = ['real_A', 'fake_B', 'real_B', 'fake_A']
    # plt.figure(figsize=(16, 4))
    
    # for idx, name in enumerate(main_visuals):
    #     if name in visuals:
    #         plt.subplot(1, 4, idx + 1)
    #         img = visuals[name][0].cpu().detach().numpy().transpose(1, 2, 0)
    #         img = (img + 1) / 2.0  # Convert from [-1, 1] to [0, 1]
    #         plt.imshow(img)
    #         plt.axis('off')
    #         plt.title(name)
    # plt.tight_layout()
    plt.show()

    # # Attention maps visualization
    # attention_maps = {k: v for k, v in visuals.items() if k.startswith('a') and k.endswith(('_a', '_b'))}
    # if attention_maps:
    #     num_att = len(attention_maps)
    #     num_rows = (num_att + num_cols - 1) // num_cols
    #     plt.figure(figsize=(16, 3 * num_rows))
        
    #     for idx, (name, att) in enumerate(attention_maps.items()):
    #         plt.subplot(num_rows, num_cols, idx + 1)
    #         att_map = att[0][0].cpu().detach().numpy()
    #         im = plt.imshow(att_map, cmap='viridis')
    #         plt.colorbar(im)
    #         plt.title(f'Attention {name}')
    #         plt.axis('off')
    #     plt.tight_layout()
    #     plt.show()

    # # Intermediate outputs
    # outputs = {k: v for k, v in visuals.items() if k.startswith('o') and k.endswith(('_a', '_b'))}
    # if outputs:
    #     num_out = len(outputs)
    #     num_rows = (num_out + num_cols - 1) // num_cols
    #     plt.figure(figsize=(16, 3 * num_rows))
        
    #     for idx, (name, out) in enumerate(outputs.items()):
    #         plt.subplot(num_rows, num_cols, idx + 1)
    #         img = out[0].cpu().detach().numpy().transpose(1, 2, 0)
    #         img = (img + 1) / 2.0
    #         plt.imshow(img)
    #         plt.title(f'Output {name}')
    #         plt.axis('off')
    #     plt.tight_layout()
    #     plt.show()
    
    if (opt['save_result']):  # save images to an HTML file if they haven't been saved.
        opt['save_result'] = True
        # save images to the disk
        for label, image in visuals.items():
            image_numpy = tensor2im(image)
            img_path = os.path.join(opt['checkpoints_dir'], 'epoch%.3d_%s.png' % (epoch, label))
            save_image(image_numpy, img_path)

def update_and_display_losses(loss_history, epoch, current_losses, display_freq=1):
    """
    Update loss history and display dynamic loss plots
    """
    # Update history
    for k, v in current_losses.items():
        if k not in loss_history:
            loss_history[k] = []
        loss_history[k].append(v)

    if epoch % display_freq == 0:
        clear_output(wait=True)
        
        # Current loss values
        print(f"\n=== Epoch {epoch} Loss Values ===")
        print("-" * 50)
        print(f"{'Loss Type':<15} {'Current':<10} {'Mean':<10} {'Std':<10}")
        print("-" * 50)
        
        for k, v in current_losses.items():
            recent_values = loss_history[k][-10:] if len(loss_history[k]) >= 10 else loss_history[k]
            print(f"{k:<15} {v:10.4f} {np.mean(recent_values):10.4f} {np.std(recent_values):10.4f}")
        
        print("-" * 50)

        # Plot loss trends
        plt.figure(figsize=(12, 6))
        for k, v in loss_history.items():
            plt.plot(v, label=k)
        
        plt.xlabel('Epoch')
        plt.ylabel('Loss Value')
        plt.title('Training Loss Evolution')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

## Main loop for Training Steps

In [45]:
# Initialize loss history
loss_history = {}
total_iters = 0

print("Starting training...")
print("=" * 80)

# Training epochs
for epoch in range(opt['epoch_count'], opt['niter'] + opt['niter_decay'] + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0

    print(f"\nEpoch {epoch}/{opt['niter'] + opt['niter_decay']}")
    print("-" * 40)

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        
        if total_iters % opt['print_freq'] == 0:
            t_data = iter_start_time - iter_data_time
            
        total_iters += opt['batch_size']
        epoch_iter += opt['batch_size']
        
        # Training step
        model.set_input(data)
        model.optimize_parameters()

        # Display results
        if total_iters % opt['display_freq'] == 0:
            model.compute_visuals()
            display_epoch_results(epoch, model.get_current_visuals())

        # Update and display losses
        if total_iters % opt['print_freq'] == 0:
            losses = model.get_current_losses()
            t_comp = (time.time() - iter_start_time) / opt['batch_size']
            update_and_display_losses(loss_history, epoch, losses)

        if total_iters % opt['save_latest_freq'] == 0:   # cache our latest model every <save_latest_freq> iterations
            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
            save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
            model.save_networks(save_suffix)

        iter_data_time = time.time()

    if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)
            
    # Print epoch timing
    time_taken = time.time() - epoch_start_time
    print(f'\nEpoch {epoch} completed in {time_taken:.2f} seconds')
    print(f"Learning rate: {model.optimizers[0].param_groups[0]['lr']:.7f}")
    
    # Update learning rates
    model.update_learning_rate()

# End of training - display final analysis
print("\nTraining completed!")
print("=" * 80)

Starting training...

Epoch 1/60
----------------------------------------


KeyboardInterrupt: 

## Result Visualisation

In [None]:
plt.style.use('seaborn')

# Convert loss history to numpy arrays for easier analysis
loss_arrays = {k: np.array(v) for k, v in loss_history.items()}

print("\n=== Final Training Analysis ===\n")

# 1. Statistical Summary
print("Loss Statistics:")
print("-" * 100)
headers = ['Loss Type', 'Mean', 'Std Dev', 'Min', 'Max', 'Final', 'Initial', 'Improvement (%)']
print(f"{headers[0]:<15} {headers[1]:>10} {headers[2]:>10} {headers[3]:>10} {headers[4]:>10} {headers[5]:>10} {headers[6]:>10} {headers[7]:>15}")
print("-" * 100)

for loss_type, values in loss_arrays.items():
    mean = np.mean(values)
    std = np.std(values)
    min_val = np.min(values)
    max_val = np.max(values)
    final = values[-1]
    initial = values[0]
    improvement = ((initial - final) / initial * 100) if initial != 0 else 0
    
    print(f"{loss_type:<15} {mean:10.4f} {std:10.4f} {min_val:10.4f} {max_val:10.4f} "
            f"{final:10.4f} {initial:10.4f} {improvement:15.2f}")
print("-" * 100)

# Create a figure with subplots
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(3, 2)

# 2. Loss Evolution Plots
ax1 = fig.add_subplot(gs[0, :])
for loss_type, values in loss_arrays.items():
    ax1.plot(values, label=loss_type, alpha=0.8, linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss Value')
ax1.set_title('Loss Evolution Over Training', pad=20)
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 3. Log-scale Evolution Plot
ax2 = fig.add_subplot(gs[1, 0])
for loss_type, values in loss_arrays.items():
    ax2.semilogy(values, label=loss_type, alpha=0.8, linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss Value (log scale)')
ax2.set_title('Loss Evolution (Log Scale)')
ax2.grid(True)

# 4. Loss Distribution Plot
ax3 = fig.add_subplot(gs[1, 1])
colors = plt.cm.tab10(np.linspace(0, 1, len(loss_arrays)))
for (loss_type, values), color in zip(loss_arrays.items(), colors):
    ax3.hist(values, bins=50, alpha=0.5, label=loss_type, color=color, density=True)
    
    # Add KDE plot
    kde = gaussian_kde(values)
    x_range = np.linspace(min(values), max(values), 200)
    ax3.plot(x_range, kde(x_range), color=color, linewidth=2)

ax3.set_xlabel('Loss Value')
ax3.set_ylabel('Density')
ax3.set_title('Loss Value Distributions')
ax3.legend()

# 5. Correlation Matrix Heatmap
ax4 = fig.add_subplot(gs[2, 0])
loss_types = list(loss_arrays.keys())
corr_matrix = np.zeros((len(loss_types), len(loss_types)))

for i, loss1 in enumerate(loss_types):
    for j, loss2 in enumerate(loss_types):
        corr = np.corrcoef(loss_arrays[loss1], loss_arrays[loss2])[0,1]
        corr_matrix[i,j] = corr

im = ax4.imshow(corr_matrix, cmap='RdYlBu', vmin=-1, vmax=1)
plt.colorbar(im, ax=ax4)

# Add correlation values
for i in range(len(loss_types)):
    for j in range(len(loss_types)):
        text = ax4.text(j, i, f'{corr_matrix[i, j]:.2f}',
                        ha='center', va='center',
                        color='black' if abs(corr_matrix[i, j]) < 0.5 else 'white')

ax4.set_xticks(range(len(loss_types)))
ax4.set_yticks(range(len(loss_types)))
ax4.set_xticklabels(loss_types, rotation=45)
ax4.set_yticklabels(loss_types)
ax4.set_title('Loss Correlation Matrix')

# 6. Box Plot
ax5 = fig.add_subplot(gs[2, 1])
box_data = [values for values in loss_arrays.values()]
ax5.boxplot(box_data, labels=loss_types)
ax5.set_xticklabels(loss_types, rotation=45)
ax5.set_ylabel('Loss Value')
ax5.set_title('Loss Distribution Box Plots')

plt.tight_layout()
plt.show()

# 7. Print Training Duration Analysis
print("\nTraining Progress Summary:")
print("-" * 50)
epochs = len(next(iter(loss_arrays.values())))
print(f"Total Epochs: {epochs}")

for loss_type, values in loss_arrays.items():
    print(f"\n{loss_type}:")
    # Find best epoch
    best_epoch = np.argmin(values)
    best_value = values[best_epoch]
    
    # Calculate convergence (when loss stabilizes within 5% of final value)
    convergence_threshold = values[-1] * 1.05
    convergence_epoch = np.where(values <= convergence_threshold)[0][0]
    
    print(f"  Best Value: {best_value:.4f} (Epoch {best_epoch + 1})")
    print(f"  Convergence Epoch: {convergence_epoch + 1}")
    print(f"  Early Improvement (first 10%): {((values[0] - values[epochs//10]) / values[0] * 100):.1f}%")
    print(f"  Late Improvement (last 10%): {((values[-epochs//10] - values[-1]) / values[-epochs//10] * 100):.1f}%")

## Test visual function

In [None]:
def display_test_results(visuals, img_path, index):
    """
    Display test results in notebook console
    """
    print(f"\n=== Test Result {index + 1} ===")
    print(f"Image Path: {img_path}")
    print("-" * 50)
    
    # Print image statistics for each visual
    for name, img_tensor in visuals.items():
        # Convert tensor to numpy array
        img = img_tensor[0].cpu().detach().numpy()
        
        # Calculate statistics
        mean_val = np.mean(img)
        std_val = np.std(img)
        min_val = np.min(img)
        max_val = np.max(img)
        
        print(f"\n{name}:")
        print(f"  Mean: {mean_val:.4f}")
        print(f"  Std Dev: {std_val:.4f}")
        print(f"  Range: [{min_val:.4f}, {max_val:.4f}]")
        
        # Print transformation summary
        if name.startswith('fake'):
            print("  Transformation Analysis:")
            reference_img = visuals[f'real_{name[-1]}'][0].cpu().detach().numpy()
            diff = img - reference_img
            mean_diff = np.mean(np.abs(diff))
            print(f"  Mean Absolute Change: {mean_diff:.4f}")
            print(f"  Max Absolute Change: {np.max(np.abs(diff)):.4f}")
    
    print("\nAttention Analysis:")
    attention_maps = {k: v for k, v in visuals.items() if k.startswith('a') and k.endswith(('_a', '_b'))}
    if attention_maps:
        for name, att in attention_maps.items():
            att_map = att[0][0].cpu().detach().numpy()
            print(f"\n{name}:")
            print(f"  Mean Attention: {np.mean(att_map):.4f}")
            print(f"  Max Attention: {np.max(att_map):.4f}")
            print(f"  Active Regions: {np.mean(att_map > 0.5):.1%} of image")



In [None]:

# Test configuration
opt['num_threads'] = 0
opt['batch_size'] = 1
opt['serial_batches'] = True
opt['no_flip'] = True
opt['display_id'] = -1

# Create dataset and model
print("Initializing test setup...")
dataset = create_dataset()
model = create_model(opt)
model.setup(opt)

# Set to evaluation mode if specified
if opt.get('eval', False):
    print("Setting model to evaluation mode...")
    model.eval()

# Testing loop
print("\nStarting testing...")
print("=" * 80)

total_processed = 0
test_stats = {
    'total_time': 0,
    'processing_times': []
}

for i, data in enumerate(dataset):
    if i >= opt.get('num_test', float('inf')):
        break
    
    start_time = time.time()
    
    # Process image
    model.set_input(data)
    model.test()
    
    # Get results
    visuals = model.get_current_visuals()
    img_path = model.get_image_paths()
    
    # Record processing time
    processing_time = time.time() - start_time
    test_stats['processing_times'].append(processing_time)
    test_stats['total_time'] += processing_time
    
    # Display results
    display_test_results(visuals, img_path, i)
    
    total_processed += 1
    
    if (i + 1) % 5 == 0:
        print(f"\nProcessed {i + 1} images...")
        print(f"Average processing time: {np.mean(test_stats['processing_times']):.3f} seconds")

# Final Statistics
print("\n=== Testing Complete ===")
print(f"Total images processed: {total_processed}")
print(f"Total time: {test_stats['total_time']:.2f} seconds")
print(f"Average time per image: {test_stats['total_time']/total_processed:.2f} seconds")
print(f"Fastest processing: {min(test_stats['processing_times']):.2f} seconds")
print(f"Slowest processing: {max(test_stats['processing_times']):.2f} seconds")
print("=" * 80)
