# Setup

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torchvision.transforms.functional import InterpolationMode

import matplotlib.pyplot as plt
from PIL import Image

import shutil

In [None]:
seed = 142

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data preparation

In [None]:
class ImageDataset(Dataset):
    def __init__(self, data_dir, transforms=None):
        monet_dir = os.path.join(data_dir, 'monet_jpg')
        photo_dir = os.path.join(data_dir, 'photo_jpg')
        
        self.files_monet = [os.path.join(monet_dir, name) for name in sorted(os.listdir(monet_dir))]
        self.files_photo = [os.path.join(photo_dir, name) for name in sorted(os.listdir(photo_dir))]
        
        self.transforms = transforms
        
    def __len__(self):
        # we know that len(files_monet) = 300 < 7038 = len(files_photo)
        return len(self.files_monet)
    
    def __getitem__(self, index):
        # we will use only 300 (=len(files_monet)) photos during training
        # randomly picking them from the first 300 photos
        random_index = np.random.randint(0, len(self.files_monet))
        file_monet = self.files_monet[index]
        file_photo = self.files_photo[random_index]
        
        image_monet = Image.open(file_monet)
        image_photo = Image.open(file_photo)
        
        if self.transforms is not None:
            image_monet = self.transforms(image_monet)
            image_photo = self.transforms(image_photo)
        
        return image_monet, image_photo

In [None]:
data_dir = '/kaggle/input/gan-getting-started'
batch_size = 5

In [None]:
transforms_ = transforms.Compose([
    #transforms.Resize((256, 256)), # photos already have the same size
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

dataloader = DataLoader(
    ImageDataset(data_dir, transforms=transforms_),
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(s)
        
    return img

# Build ASA Style transfer

## Auxiliary blocks

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
                 transpose=False, use_leaky=True, use_dropout=False, normalize=True):
        
        super(ConvBlock, self).__init__()
        self.block = []
        
        if transpose:
            self.block += [nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
                                              stride, padding, output_padding=1)]
        else:
            self.block += [nn.Conv2d(in_channels, out_channels, kernel_size,
                                     stride, padding, bias=True)]
            
        if normalize:
            self.block += [nn.InstanceNorm2d(out_channels,affine=True)]
            
        if use_dropout:
            self.block += [nn.Dropout(0.5)]
            
        if use_leaky:
            self.block += [nn.LeakyReLU(negative_slope=0.2, inplace=True)]
        else:
            self.block += [nn.ReLU(inplace=True)]
            
        self.block = nn.Sequential(*self.block)
    
    
    def forward(self, x):
        return self.block(x)
    
    
class ResidualBlock(nn.Module):
    def __init__(self, channels,kernel_size=3):
        super(ResidualBlock, self).__init__()
        
        #doesn't change shape of input
        self.block = nn.Sequential(
            nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
            ConvBlock(in_channels=channels, out_channels=channels,padding=0,
                      kernel_size=kernel_size, use_leaky=False, use_dropout=True),
            nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size,padding=0),
            nn.InstanceNorm2d(channels,affine=True)
        )
    
    
    def forward(self, x):
        return x + self.block(x)
    
    
class transformer_(nn.Module):
    
    def __init__(self,transformer_kernel_size=9):
        super(transformer_, self).__init__()
        #self.model=nn.AvgPool2d(transformer_kernel_size, stride=1,padding=transformer_kernel_size//2)
        self.model=nn.Conv2d(in_channels=3, out_channels=3, kernel_size=transformer_kernel_size, stride=1,padding=transformer_kernel_size//2)

    def forward(self,x):
        return self.model(x)
    
    
    

## Define Discriminator

In [None]:
class discriminator_(nn.Module):
    def __init__(self, in_channels=3,pad_type='valid',n_filter_discriminator=64):#ASA code use "valid" padding
        super(discriminator_, self).__init__()
        
        #3*256*256 - 128*128*128 -> 1*128*128  
        self.conv0=ConvBlock(in_channels=in_channels, out_channels=n_filter_discriminator*2, kernel_size=5,
                      stride=2, padding=2, normalize=True)
        self.conv0_pred=nn.Conv2d(in_channels=n_filter_discriminator*2, out_channels=1, kernel_size=5,
                                     stride=1, padding=2, bias=True)
        
        #128*128*128 - 128*64*64 -> 1*65*65
        self.conv1=ConvBlock(in_channels=n_filter_discriminator*2, out_channels=n_filter_discriminator*2, kernel_size=5,
                      stride=2, padding=2, normalize=True)
        self.conv1_pred=nn.Conv2d(in_channels=n_filter_discriminator*2, out_channels=1, kernel_size=10,
                                     stride=1, padding=5, bias=True)
        
        #128*64*64 - 256*32*32
        self.conv2=ConvBlock(in_channels=n_filter_discriminator*2, out_channels=n_filter_discriminator*4, kernel_size=5,
                      stride=2, padding=2, normalize=True)
        
        #256*32*32 - 512*16*16 -> 1*17*17
        self.conv3=ConvBlock(in_channels=n_filter_discriminator*4, out_channels=n_filter_discriminator*8, kernel_size=5,
                      stride=2, padding=2, normalize=True)
        self.conv3_pred=nn.Conv2d(in_channels=n_filter_discriminator*8, out_channels=1, kernel_size=10,
                                     stride=1, padding=5, bias=True)
        
        #512*16*16 - 512*16*16 
        self.conv4=ConvBlock(in_channels=n_filter_discriminator*8, out_channels=n_filter_discriminator*8, kernel_size=5,
                      stride=1, padding=2, normalize=True)
        self.conv4_pred=ConvBlock(in_channels=n_filter_discriminator*8, out_channels=1, kernel_size=5,
                      stride=1, padding=2, normalize=True)
        
        #512*4*4 - 1024*4*4 -> 1*5*5
        self.conv5=ConvBlock(in_channels=n_filter_discriminator*8, out_channels=n_filter_discriminator*16, kernel_size=5,
                      stride=1, padding=2, normalize=True)
        self.conv5_pred=nn.Conv2d(in_channels=n_filter_discriminator*16, out_channels=1, kernel_size=6,
                                     stride=1, padding=3, bias=True)
        
        #1024*1*1 - 1024*2*2 -> 1*2*2
        #self.conv6=ConvBlock(in_channels=n_filter_discriminator*16, out_channels=n_filter_discriminator*16, kernel_size=1,
        #              stride=2, padding=0, normalize=True)
        #self.conv6_pred=nn.Conv2d(in_channels=n_filter_discriminator*16, out_channels=1, kernel_size=3,
        #                             stride=1, padding=1, bias=True)
        
        
        
    def forward(self, x):  
        
        h0=self.conv0(x)
        h0_pred=self.conv0_pred(h0)
        
        h1=self.conv1(h0)
        h1_pred=self.conv1_pred(h1)
        
        h2=self.conv2(h1)
        
        h3=self.conv3(h2)
        h3_pred=self.conv3_pred(h3)
        
        h4=self.conv4(h3)
        h4_pred=self.conv4_pred(h4)

        h5=self.conv5(h4)
        h5_pred=self.conv5_pred(h5)
        
        #h6=self.conv6(h5)
        #h6_pred=self.conv6_pred(h6)
        
        #print((h0.size(),h1.size(),h2.size(),h3.size(),h4.size(),h5.size(),h6.size()))
        #print((h0_pred.size(),h1_pred.size(),h3_pred.size(),h5_pred.size(),h6_pred.size()))
        #return [h0_pred, h1_pred, h3_pred, h5_pred]#, h6_pred]
        return [h4_pred]

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # 3*256*256 -> 64*128*128 
            ConvBlock(in_channels=in_channels, out_channels=64, kernel_size=4,
                      stride=2, padding=1, normalize=False),
            
            # 64*128*128 -> 128*64*64
            ConvBlock(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            
            # 128*64*64 -> 256*32*32
            ConvBlock(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            
            # 256*32*32 -> 512*31*31
            ConvBlock(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1),
            
            # 512*31*31 -> 1*30*30
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1),
        )
        
        
    def forward(self, x):    
        return self.model(x)

## Define Generator

In [None]:
class encoder_(nn.Module):
    
    def __init__(self,in_channels=3,n_filter_generator=32,pad_type=0):
        super(encoder_, self).__init__()
        
        self.model=nn.Sequential(
            
            nn.InstanceNorm2d(in_channels,affine=True),
            
            #3*256*256 - 3*262*262
            nn.ReflectionPad2d(3),
            
            #3*262*262 - 64*259*259
            ConvBlock(in_channels=in_channels, out_channels=n_filter_generator*2, kernel_size=4,
                      stride=1, padding=pad_type, normalize=True,use_leaky=False),
            #64*259*259 - 128*129*129
            ConvBlock(in_channels=n_filter_generator*2, out_channels=4*n_filter_generator, kernel_size=4,
                      stride=2, padding=1, normalize=True,use_leaky=False),
            #128*129*129 - 256*64*64
            ConvBlock(in_channels=4*n_filter_generator, out_channels=n_filter_generator*8, kernel_size=4,
                      stride=2, padding=1, normalize=True,use_leaky=False),
            #128*70*70 - 256*64*64
            #ConvBlock(in_channels=n_filter_generator*4, out_channels=n_filter_generator*8, kernel_size=5,
            #          stride=1, padding=pad_type, normalize=True,use_leaky=False),
            #256*66*34 - 256*66*66
            #ConvBlock(in_channels=n_filter_generator*8, out_channels=n_filter_generator*8, kernel_size=5,
            #          stride=1, padding=pad_type, normalize=True,use_leaky=False)
        )
        
    def forward(self, x):
        return self.model(x)
    
class decoder_(nn.Module):
    
    """decoder model following https://arxiv.org/pdf/1807.10201.pdf
    Returns: decoder model
    """
    def __init__(self, input_shape,n_filter_generator=32):
        super(decoder_, self).__init__()
        
        num_kernels = n_filter_generator * 8
        self.model=[]
        for i in range(9):
            self.model.append(ResidualBlock(num_kernels))
        
        
        #i=0: 256*64*64 - 256*64*64 - 128*64*64
        #i=1: 128*64*64 - 128*128*128 - 64*128*128
        #i=2: 64*128*128 - 64*256*256 - 32*256*256
        ####i=3: 64*128*128 - 64*256*256 - 32*256*256
        in_channels=[num_kernels,4*n_filter_generator,2*n_filter_generator,n_filter_generator]
        for i in range(3):
            #self.model.append(nn.ConvTranspose2d(in_channels=in_channels[i], out_channels=n_filter_generator * 2 ** (2 - i), kernel_size=5,
            #                                  stride=2, padding=2, output_padding=1))
            self.model.append(transforms.Resize(size=(32*2**(i+1),32*2**(i+1)), interpolation=InterpolationMode.NEAREST, max_size=None, antialias=None))
            self.model.append(nn.Conv2d(in_channels=in_channels[i], out_channels=n_filter_generator * 2 ** (2 - i), kernel_size=3,
                                     stride=1, padding=1, bias=True))
            self.model.append(nn.InstanceNorm2d(n_filter_generator * 2 ** (2 - i),affine=True))
            self.model.append(nn.LeakyReLU(inplace=True))
            
        self.model.append(nn.ReflectionPad2d(3))
        
        #32*256*256 - 3*256*256
        self.model.append(nn.Conv2d(in_channels=n_filter_generator, out_channels=3, kernel_size=7,
                                     stride=1, padding=0, bias=True))
        self.model.append(nn.Tanh())
        self.model = nn.Sequential(*self.model)
        #x = CenterLayer()(x)

    def forward(self,x):
        return self.model(x)#*2-1

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=9):
        super(Generator, self).__init__()
        
        ''' Encoder '''
        # Inital layer:  3*256*256 -> 64*256*256
        self.initial = [
            nn.ReflectionPad2d(in_channels),
            ConvBlock(in_channels=in_channels, out_channels=64,
                      kernel_size=2*in_channels+1, use_leaky=False),
        ]
        self.initial = nn.Sequential(*self.initial)
        
        # Downsampling:  64*256*256 -> 128*128*128 -> 256*64*64
        self.down = [
            ConvBlock(in_channels=64, out_channels=128, kernel_size=3,
                      stride=2, padding=1, use_leaky=False),
            ConvBlock(in_channels=128, out_channels=256, kernel_size=3,
                      stride=2, padding=1, use_leaky=False),
        ]
        self.down = nn.Sequential(*self.down)
        
        
        """ Transformer """
        # ResNet:  256*64*64 -> 256*64*64
        self.transform = [ResidualBlock(256) for _ in range(num_residual_blocks)]
        self.transform = nn.Sequential(*self.transform)
        
        
        """ Decoder """
        # Upsampling:  256*64*64 -> 128*128*128 -> 64*256*256
        self.up = [
            ConvBlock(in_channels=256, out_channels=128, kernel_size=3, stride=2,
                      padding=1, transpose=True, use_leaky=False),
            ConvBlock(in_channels=128, out_channels=64, kernel_size=3, stride=2,
                      padding=1, transpose=True, use_leaky=False),
        ]
        self.up = nn.Sequential(*self.up)
        
        # Out layer:  64*256*256 -> 3*256*256
        self.out = nn.Sequential(
            nn.ReflectionPad2d(out_channels),
            nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=2*out_channels+1),
            nn.Tanh()
        )
    
    
    def forward(self, x):
        x = self.down(self.initial(x))
        x = self.transform(x)
        x = self.out(self.up(x))
        return x

## Define Losses

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
def _discriminator_loss(logits, labels):
    #return torch.mean(F.binary_cross_entropy_with_logits(logits, labels, reduction='none'))
    return criterion_GAN(logits, labels)
def discriminator_loss(discriminate_encoded_picture=None, discriminate_picture=None, discriminate_art=None):
    """
    When training the discriminator, it needs to lean to classify the original art as 1 and the rest as 0
    Args:
        discriminate_encoded_picture:
        discriminate_picture:
        discriminate_art:
    Returns:
    """
    art_loss = 0
    picture_loss = 0
    encoded_picture_loss = 0
    if discriminate_art!=None:
        for pred in discriminate_art:
            art_loss += _discriminator_loss(pred, torch.ones_like(pred).to(device))
            #print(('art',_discriminator_loss(pred, torch.ones_like(pred))))
    if discriminate_picture!=None:
        for pred in discriminate_picture:
            picture_loss += _discriminator_loss(pred, torch.zeros_like(pred).to(device))

    if discriminate_encoded_picture!=None:
        for pred in discriminate_encoded_picture:
            encoded_picture_loss += _discriminator_loss(pred, torch.zeros_like(pred).to(device))
          
    global_loss = (art_loss + picture_loss + encoded_picture_loss)

    return global_loss

def discriminator_acc(discriminate_encoded_picture, discriminate_picture, discriminate_art):
    art = []
    picture = []
    encoded = []
    for pred in discriminate_art:
        art.append(torch.mean((pred > torch.zeros_like(pred)).float()))
    for pred in discriminate_picture:
        picture.append(torch.mean((pred < torch.zeros_like(pred)).float()))
    for pred in discriminate_encoded_picture:
        encoded.append(torch.mean((pred < torch.zeros_like(pred)).float()))
    global_accuracy = torch.stack(art + picture + encoded, dim=0).sum(dim=0)

    return global_accuracy

In [None]:
def abs_criterion(logits, target):
    """absolute criterion or L1 norm
    """
    return torch.mean(torch.abs(logits - target))


def mse_criterion(logits, target):
    return torch.mean((logits - target) ** 2)

def generator_loss(disc_output=None, transformed_input_image=None, input_features=None, transformed_output_image=None, output_features=None, img_loss_weight=1., feature_loss_weight=1.,
                   generator_weight=1.):
    losses = []
    if disc_output!=None:
        for pred in disc_output:
            losses.append(_discriminator_loss(pred, torch.ones_like(pred).to(device)))#should be zero like?

    generator_global_loss = torch.stack(losses, dim=0).sum(dim=0) * generator_weight

    # Image loss.
    img_loss=0
    if transformed_output_image!=None:
        img_loss = mse_criterion(transformed_output_image, transformed_input_image) * img_loss_weight

    # Features loss.
    feature_loss=0
    if input_features!=None:
        feature_loss = abs_criterion(output_features, input_features) * feature_loss_weight

    global_loss = (img_loss + feature_loss + generator_global_loss)
    #global_loss = generator_global_loss# (feature_loss)# + generator_global_loss)

    return global_loss

def generator_acc(disc_output):
    accuracies = []
    for pred in disc_output:
        accuracies.append(torch.mean((pred > torch.zeros_like(pred)).float()))

    generator_global_accuracy = torch.stack(accuracies, dim=0).sum(dim=0)

    return generator_global_accuracy

# Model Initialization

In [None]:
#encoder=encoder_().to(device)
#decoder=decoder_(input_shape=(batch_size,256,16,16)).to(device)
#discriminator = discriminator_().to(device)
#transformer = transformer_().to(device)

In [None]:
'''
generator_monet2photo = nn.Sequential(encoder_(),decoder_(input_shape=(batch_size,256,64,64))).to(device)
generator_photo2monet = nn.Sequential(encoder_(),decoder_(input_shape=(batch_size,256,64,64))).to(device)

discriminator_monet = Discriminator(in_channels=3).to(device)
discriminator_photo = Discriminator(in_channels=3).to(device)
'''

In [None]:
encoder_m2p=encoder_().to(device)
encoder_p2m=encoder_().to(device)
decoder_m2p=decoder_(input_shape=(batch_size,256,64,64)).to(device)
decoder_p2m=decoder_(input_shape=(batch_size,256,64,64)).to(device)

#generator_monet2photo = nn.Sequential(encoder1,decoder1).to(device)
#generator_photo2monet = nn.Sequential(encoder2,decoder2).to(device)

discriminator_monet = Discriminator(in_channels=3).to(device)
discriminator_photo = Discriminator(in_channels=3).to(device)

## Optimization Setup

In [None]:
'''
lr1=2e-5
lr2 = 2e-5
b1 = 0.9
b2 = 0.999

optim_generators = torch.optim.Adam(
    list(encoder.parameters()) + list(transformer.parameters()) + list(decoder.parameters()),
    lr=lr1, betas=(b1, b2)
)

optim_discriminators = torch.optim.Adam(
    list(discriminator.parameters()),
    lr=lr2, betas=(b1, b2)
)
'''


In [None]:
'''
num_epochs = 80
decay_epoch = 25

lr_sched_step = lambda epoch: 1 - max(0, epoch - decay_epoch) / (num_epochs - decay_epoch)

lr_sched_generators = torch.optim.lr_scheduler.LambdaLR(optim_generators, lr_lambda=lr_sched_step)
lr_sched_discriminators = torch.optim.lr_scheduler.LambdaLR(optim_discriminators, lr_lambda=lr_sched_step)
'''


In [None]:
'''
lr = 2e-4
b1 = 0.5
b2 = 0.999

optim_generators = torch.optim.Adam(
    list(generator_monet2photo.parameters()) + list(generator_photo2monet.parameters()),
    lr=lr, betas=(b1, b2)
)

optim_discriminators = torch.optim.Adam(
    list(discriminator_monet.parameters()) + list(discriminator_photo.parameters()),
    lr=lr, betas=(b1, b2)
)
'''


In [None]:
'''
num_epochs = 80
decay_epoch = 25

lr_sched_step = lambda epoch: 1 - max(0, epoch - decay_epoch) / (num_epochs - decay_epoch)

lr_sched_generators = torch.optim.lr_scheduler.LambdaLR(optim_generators, lr_lambda=lr_sched_step)
lr_sched_discriminators = torch.optim.lr_scheduler.LambdaLR(optim_discriminators, lr_lambda=lr_sched_step)
'''



In [None]:
lr = 2e-4
b1 = 0.5
b2 = 0.999

optim_generators = torch.optim.Adam(
    list(encoder_m2p.parameters()) + list(encoder_p2m.parameters())
     + list(decoder_m2p.parameters()) + list(decoder_p2m.parameters()),
    lr=lr, betas=(b1, b2)
)

optim_discriminators = torch.optim.Adam(
    list(discriminator_monet.parameters()) + list(discriminator_photo.parameters()),
    lr=lr, betas=(b1, b2)
)

In [None]:
num_epochs = 60
decay_epoch = 25

lr_sched_step = lambda epoch: 1 - max(0, epoch - decay_epoch) / (num_epochs - decay_epoch)

lr_sched_generators = torch.optim.lr_scheduler.LambdaLR(optim_generators, lr_lambda=lr_sched_step)
lr_sched_discriminators = torch.optim.lr_scheduler.LambdaLR(optim_discriminators, lr_lambda=lr_sched_step)

## Define Learning rate schedulers, auxiliary tools, 

In [None]:
class History:
    def __init__(self):
        self.generators_loss = []
        self.discriminators_loss = []
    
    def update(self, gen_loss, discr_loss):
        self.generators_loss.append(gen_loss)
        self.discriminators_loss.append(discr_loss)
        
    def show(self, title='Losses'):
        fig = plt.figure(figsize=(20, 8))
        plt.title(title)
        plt.plot(self.generators_loss, 'o-', color='r',
                 linewidth=2, markersize=3, label='Generators Loss')
        plt.plot(self.discriminators_loss, 'o-', color='b',
                 linewidth=2, markersize=3, label='Discriminators Loss')
        plt.legend(loc='best')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.show()
        
class Buffer:
    def __init__(self, max_images=50):
        self.max_images = max_images
        self.images = []
        
    def update(self, images):
        images = images.detach().cpu().data.numpy()
        for image in images:
            if len(self.images) < self.max_images:
                self.images.append(image)
            else:
                if np.random.rand() > 0.5:
                    index = np.random.randint(0, self.max_images)
                    self.images[index] = image

    def sample(self, num_images):
        samples = np.array([self.images[np.random.randint(0, len(self.images))]
                            for _ in range(num_images)])
        return torch.tensor(samples)
def update_req_grad(models, requires_grad=True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

# Training

In [None]:

history = History()
buffer_monet = Buffer()
buffer_photo = Buffer()

for epoch in range(num_epochs):
    avg_generators_loss = 0
    avg_discriminators_loss = 0
    
    for i, (real_monet, real_photo) in enumerate(tqdm(dataloader, leave=False, total=len(dataloader))):
        real_monet, real_photo = real_monet.to(device), real_photo.to(device)
                
        """ Train Generators """
        # switching models parameters so that only generators are trained
        #update_req_grad([generator_monet2photo, generator_photo2monet], True)
        update_req_grad([encoder_m2p, encoder_p2m, decoder_m2p, decoder_p2m], True)
        update_req_grad([discriminator_monet, discriminator_photo], False)
        
        # zero the parameters gradients
        optim_generators.zero_grad()
        
        # forward-pass
        encoded_monet=encoder_m2p(real_monet)
        encoded_photo=encoder_p2m(real_photo)

        fake_photo = decoder_m2p(encoded_monet)
        fake_monet = decoder_p2m(encoded_photo)
        
        latent_photo=encoder_p2m(fake_photo)
        latent_monet=encoder_m2p(fake_monet)

        cycle_photo = decoder_m2p(encoder_m2p(fake_monet))
        cycle_monet = decoder_p2m(encoder_p2m(fake_photo))

        identity_photo = decoder_m2p(encoder_m2p(real_photo))
        identity_monet = decoder_p2m(encoder_p2m(real_monet))
        
        #cycle_photo = generator_monet2photo(fake_monet)
        #cycle_monet = generator_photo2monet(fake_photo)
        
        #identity_photo = generator_monet2photo(real_photo)
        #identity_monet = generator_photo2monet(real_monet)
        
        # update photos that are used to feed up discriminators
        buffer_photo.update(fake_photo)
        buffer_monet.update(fake_monet)
        
        # discriminators outputs that are used in adversarial loss
        discriminator_outputs_photo = discriminator_photo(fake_photo)
        discriminator_outputs_monet = discriminator_monet(fake_monet)
        
        # labels that are used as ground truth
        labels_real = torch.ones(discriminator_outputs_monet.size()).to(device)
        labels_fake = torch.zeros(discriminator_outputs_monet.size()).to(device)
        
        # adversarial loss - enforces that the generated output be of the appropriate domain
        loss_GAN_monet2photo = criterion_GAN(discriminator_outputs_photo, labels_real)
        loss_GAN_photo2monet = criterion_GAN(discriminator_outputs_monet, labels_real)
        loss_GAN = (loss_GAN_monet2photo + loss_GAN_photo2monet) / 2
        #loss_GAN=generator_loss(disc_output=discriminator_outputs_photo)+generator_loss(disc_output=discriminator_outputs_monet)
        
        # cycle consistency loss - enforces that the input and output are recognizably the same
        loss_cycle_photo = criterion_cycle(cycle_photo, real_photo)
        loss_cycle_monet = criterion_cycle(cycle_monet, real_monet)
        loss_cycle = (loss_cycle_photo + loss_cycle_monet) / 2
        
        # identity mapping loss - helps preserve the color of the input images
        loss_identity_photo = criterion_identity(identity_photo, real_photo)
        loss_identity_monet = criterion_identity(identity_monet, real_monet)
        loss_identity = (loss_identity_photo + loss_identity_monet) / 2
        
        loss_latent_photo=criterion_cycle(latent_photo, encoded_monet)
        loss_latent_monet=criterion_cycle(latent_monet, encoded_photo)

        # total loss
        #print(0.1*loss_latent_photo)
        #print(10*loss_cycle)
        #print(5* loss_identity)
        loss_generators_total = loss_GAN + 0.2*loss_latent_photo + 0.2*loss_latent_monet + 10 * loss_cycle + 5 * loss_identity
        #loss_generators_total = loss_GAN + 10 * loss_cycle + 5 * loss_identity
        
        # backward-pass
        loss_generators_total.backward()
        optim_generators.step()
        
        # limiting gradient norms - if they exceed 100, something went wrong
        #clip_grad_norm_(generator_photo2monet.parameters(), 100)
        #clip_grad_norm_(generator_monet2photo.parameters(), 100)
        clip_grad_norm_(encoder_m2p.parameters(), 100)
        clip_grad_norm_(encoder_p2m.parameters(), 100)
        clip_grad_norm_(decoder_m2p.parameters(), 100)
        clip_grad_norm_(decoder_p2m.parameters(), 100)
        
        
        """ Train Discriminators """
        # switching models parameters so that only discriminators alu trained
        update_req_grad([discriminator_monet, discriminator_photo], True)
        #update_req_grad([generator_monet2photo, generator_photo2monet], False)
        update_req_grad([encoder_m2p, encoder_p2m, decoder_m2p, decoder_p2m], False)

        # zero the parameters gradients
        optim_discriminators.zero_grad()
        
        # sample images from 50 stored
        fake_photo = buffer_photo.sample(num_images=batch_size).to(device)
        fake_monet = buffer_monet.sample(num_images=batch_size).to(device)
        
        # making labels noisy for discriminators so that they don't prevail over generators
        threshold = min(1, 0.85 + (1 - 0.85) * epoch / (num_epochs // 2))
        noisy_labels_real = (torch.rand(discriminator_outputs_monet.size()) < threshold).float().to(device)
        
        # forward-pass + losses
        loss_real_photo = criterion_GAN(discriminator_photo(real_photo), noisy_labels_real)
        loss_fake_photo = criterion_GAN(discriminator_photo(fake_photo.detach()), labels_fake)
        loss_photo = (loss_real_photo + loss_fake_photo) / 2
        #loss_photo = discriminator_loss(discriminate_encoded_picture=discriminator_photo(fake_photo.detach()), discriminate_art=discriminator_photo(real_photo))

        loss_real_monet = criterion_GAN(discriminator_monet(real_monet), noisy_labels_real)
        loss_fake_monet = criterion_GAN(discriminator_monet(fake_monet.detach()), labels_fake)
        loss_monet = (loss_real_monet + loss_fake_monet) / 2
        #loss_monet = discriminator_loss(discriminate_encoded_picture=discriminator_monet(fake_monet.detach()), discriminate_art=discriminator_monet(real_monet))

        loss_discriminators_total = loss_monet + loss_photo
        
        # backward-pass
        loss_discriminators_total.backward()
        optim_discriminators.step()
        
        # clipping gradients to avoid gradients explosion
        clip_grad_norm_(discriminator_monet.parameters(), 100)
        clip_grad_norm_(discriminator_photo.parameters(), 100)
        
        # updating intermediate results
        avg_generators_loss += loss_generators_total.item()
        avg_discriminators_loss += loss_discriminators_total.item()
        
    # saving intermediate results
    avg_generators_loss /= len(dataloader)
    avg_discriminators_loss /= len(dataloader)
    history.update(avg_generators_loss, avg_discriminators_loss)
    
    # showing intermediate results
    print("Epoch: %d/%d | Generators Loss: %.4f | Discriminators Loss: %.4f"
              % (epoch+1, num_epochs, avg_generators_loss, avg_discriminators_loss))
    
    # showing generated images
    if (epoch + 1) % 5 == 0:
        _, sample_real_photo = next(iter(dataloader))
        
        sample_fake_monet = decoder_p2m(encoder_p2m(sample_real_photo.to(device))).detach().cpu()
        
        num_photos = min(batch_size, 5)
        plt.figure(figsize=(20, 8))
        for k in range(num_photos):
            plt.subplot(2, num_photos, k + 1)
            plt.imshow(unnorm(sample_real_photo[k]).permute(1, 2, 0))
            plt.title('Input photo')
            plt.axis('off')

            plt.subplot(2, num_photos, k + num_photos + 1)
            plt.imshow(unnorm(sample_fake_monet[k]).permute(1, 2, 0))
            plt.title('Output image')
            plt.axis('off')
        plt.show()
    
    lr_sched_generators.step()
    lr_sched_discriminators.step()


In [None]:
'''
history = History()
buffer_monet = Buffer()
buffer_photo = Buffer()

for epoch in range(num_epochs):
    avg_generators_loss = 0
    avg_discriminators_loss = 0
    
    for i, (real_monet, real_photo) in enumerate(tqdm(dataloader, leave=False, total=len(dataloader))):
        real_monet, real_photo = real_monet.to(device), real_photo.to(device)
                
        """ Train Generators """
        # switching models parameters so that only generators are trained
        update_req_grad([generator_monet2photo, generator_photo2monet], True)
        update_req_grad([discriminator_monet, discriminator_photo], False)
        
        # zero the parameters gradients
        optim_generators.zero_grad()
        
        # forward-pass
        fake_photo = generator_monet2photo(real_monet)
        fake_monet = generator_photo2monet(real_photo)
        
        cycle_photo = generator_monet2photo(fake_monet)
        cycle_monet = generator_photo2monet(fake_photo)
        
        identity_photo = generator_monet2photo(real_photo)
        identity_monet = generator_photo2monet(real_monet)
        
        # update photos that are used to feed up discriminators
        buffer_photo.update(fake_photo)
        buffer_monet.update(fake_monet)
        
        # discriminators outputs that are used in adversarial loss
        discriminator_outputs_photo = discriminator_photo(fake_photo)
        discriminator_outputs_monet = discriminator_monet(fake_monet)
        
        # labels that are used as ground truth
        labels_real = torch.ones(discriminator_outputs_monet.size()).to(device)
        labels_fake = torch.zeros(discriminator_outputs_monet.size()).to(device)
        
        # adversarial loss - enforces that the generated output be of the appropriate domain
        loss_GAN_monet2photo = criterion_GAN(discriminator_outputs_photo, labels_real)
        loss_GAN_photo2monet = criterion_GAN(discriminator_outputs_monet, labels_real)
        loss_GAN = (loss_GAN_monet2photo + loss_GAN_photo2monet) / 2
        #loss_GAN=generator_loss(disc_output=discriminator_outputs_photo)+generator_loss(disc_output=discriminator_outputs_monet)
        
        # cycle consistency loss - enforces that the input and output are recognizably the same
        loss_cycle_photo = criterion_cycle(cycle_photo, real_photo)
        loss_cycle_monet = criterion_cycle(cycle_monet, real_monet)
        loss_cycle = (loss_cycle_photo + loss_cycle_monet) / 2
        
        # identity mapping loss - helps preserve the color of the input images
        loss_identity_photo = criterion_identity(identity_photo, real_photo)
        loss_identity_monet = criterion_identity(identity_monet, real_monet)
        loss_identity = (loss_identity_photo + loss_identity_monet) / 2
        
        # total loss
        loss_generators_total = loss_GAN + 10 * loss_cycle + 5 * loss_identity
        
        # backward-pass
        loss_generators_total.backward()
        optim_generators.step()
        
        # limiting gradient norms - if they exceed 100, something went wrong
        clip_grad_norm_(generator_photo2monet.parameters(), 100)
        clip_grad_norm_(generator_monet2photo.parameters(), 100)
        
        
        """ Train Discriminators """
        # switching models parameters so that only discriminators are trained
        update_req_grad([discriminator_monet, discriminator_photo], True)
        update_req_grad([generator_monet2photo, generator_photo2monet], False)
        
        # zero the parameters gradients
        optim_discriminators.zero_grad()
        
        # sample images from 50 stored
        fake_photo = buffer_photo.sample(num_images=batch_size).to(device)
        fake_monet = buffer_monet.sample(num_images=batch_size).to(device)
        
        # making labels noisy for discriminators so that they don't prevail over generators
        threshold = min(1, 0.85 + (1 - 0.85) * epoch / (num_epochs // 2))
        noisy_labels_real = (torch.rand(discriminator_outputs_monet.size()) < threshold).float().to(device)
        
        # forward-pass + losses
        loss_real_photo = criterion_GAN(discriminator_photo(real_photo), noisy_labels_real)
        loss_fake_photo = criterion_GAN(discriminator_photo(fake_photo.detach()), labels_fake)
        loss_photo = (loss_real_photo + loss_fake_photo) / 2
        #loss_photo = discriminator_loss(discriminate_encoded_picture=discriminator_photo(fake_photo.detach()), discriminate_art=discriminator_photo(real_photo))

        loss_real_monet = criterion_GAN(discriminator_monet(real_monet), noisy_labels_real)
        loss_fake_monet = criterion_GAN(discriminator_monet(fake_monet.detach()), labels_fake)
        loss_monet = (loss_real_monet + loss_fake_monet) / 2
        #loss_monet = discriminator_loss(discriminate_encoded_picture=discriminator_monet(fake_monet.detach()), discriminate_art=discriminator_monet(real_monet))

        loss_discriminators_total = loss_monet + loss_photo
        
        # backward-pass
        loss_discriminators_total.backward()
        optim_discriminators.step()
        
        # clipping gradients to avoid gradients explosion
        clip_grad_norm_(discriminator_monet.parameters(), 100)
        clip_grad_norm_(discriminator_photo.parameters(), 100)
        
        # updating intermediate results
        avg_generators_loss += loss_generators_total.item()
        avg_discriminators_loss += loss_discriminators_total.item()
        
    # saving intermediate results
    avg_generators_loss /= len(dataloader)
    avg_discriminators_loss /= len(dataloader)
    history.update(avg_generators_loss, avg_discriminators_loss)
    
    # showing intermediate results
    print("Epoch: %d/%d | Generators Loss: %.4f | Discriminators Loss: %.4f"
              % (epoch+1, num_epochs, avg_generators_loss, avg_discriminators_loss))
    
    # showing generated images
    if (epoch + 1) % 10 == 0:
        _, sample_real_photo = next(iter(dataloader))
        
        sample_fake_monet = generator_photo2monet(sample_real_photo.to(device)).detach().cpu()
        
        num_photos = min(batch_size, 5)
        plt.figure(figsize=(20, 8))
        for k in range(num_photos):
            plt.subplot(2, num_photos, k + 1)
            plt.imshow(unnorm(sample_real_photo[k]).permute(1, 2, 0))
            plt.title('Input photo')
            plt.axis('off')

            plt.subplot(2, num_photos, k + num_photos + 1)
            plt.imshow(unnorm(sample_fake_monet[k]).permute(1, 2, 0))
            plt.title('Output image')
            plt.axis('off')
        plt.show()
    
    lr_sched_generators.step()
    lr_sched_discriminators.step()
    
'''

In [None]:
'''
history = History()
discr_success_rate=0.8

for epoch in range(num_epochs):
    avg_generators_loss = 0
    avg_discriminators_loss = 0
    
    discr_success = 0
    alpha = 0.05
    train_generator_at_previous_step = True
    for i, (real_monet, real_photo) in enumerate(tqdm(dataloader, leave=False, total=len(dataloader))):
        real_monet, real_photo = real_monet.to(device), real_photo.to(device)
        #print(real_monet)
        #---------------------train generator-------------------------------------------
        #if discr_success >= discr_success_rate:
        #if train_generator_at_previous_step==False:
        if True:
            #print('1')
            if not train_generator_at_previous_step:
                train_generator_at_previous_step = True
            
            update_req_grad([discriminator], False)
            update_req_grad([encoder, decoder, transformer], True)

            # zero the parameters gradients
            optim_generators.zero_grad()
            
            encoded_picture = encoder(real_photo)
            decoded_picture = decoder(encoded_picture)
            encoded_decoded_picture = encoder(decoded_picture)
            discriminate_encoded_picture = discriminator(decoded_picture)
            transformed_picture = transformer(real_photo)
            transformed_decoded_picture = transformer(decoded_picture)
            #print(('generator',encoded_picture,decoded_picture,encoded_decoded_picture,
            #      discriminate_encoded_picture, transformed_picture, transformed_decoded_picture))
            accuracy = generator_acc(discriminate_encoded_picture)
            loss_generators_total = generator_loss(discriminate_encoded_picture, transformed_picture, encoded_picture, transformed_decoded_picture, encoded_decoded_picture)
            print(loss_generators_total)
            loss_generators_total.backward()
            optim_generators.step()
        
            # limiting gradient norms - if they exceed 100, something went wrong
            clip_grad_norm_(encoder.parameters(), 100)
            clip_grad_norm_(decoder.parameters(), 100)
            clip_grad_norm_(transformer.parameters(), 100)
            
            discr_success = discr_success * (1. - alpha) + alpha * (1 - accuracy)
            lr_sched_generators.step()
            # updating intermediate results
            avg_generators_loss += loss_generators_total.item()
        #----------------train discriminator--------------------------------------
        else:
            #print('2')
            if train_generator_at_previous_step:
                train_generator_at_previous_step = False
                
            update_req_grad([discriminator], True)
            update_req_grad([encoder, decoder, transformer], False)

            # zero the parameters gradients
            optim_discriminators.zero_grad()
            
            encoded_picture = encoder(real_photo)
            decoded_picture = decoder(encoded_picture)
            
            discriminate_encoded_picture = discriminator(decoded_picture.detach())
            discriminate_picture = discriminator(real_photo.detach())
            discriminate_art = discriminator(real_monet.detach())
            
            loss_discriminators_total = discriminator_loss(discriminate_encoded_picture, discriminate_picture, discriminate_art)
            accuracy = discriminator_acc(discriminate_encoded_picture, discriminate_picture, discriminate_art)
            
            loss_discriminators_total.backward()
            optim_discriminators.step()

            # limiting gradient norms - if they exceed 100, something went wrong
            clip_grad_norm_(discriminator.parameters(), 100)
            discr_success = discr_success * (1. - alpha) + alpha * accuracy
            #print(loss_discriminators_total.item())
            lr_sched_discriminators.step()
            # updating intermediate results
            avg_discriminators_loss += loss_discriminators_total.item()

    # saving intermediate results
    avg_generators_loss /= len(dataloader)
    avg_discriminators_loss /= len(dataloader)
    history.update(avg_generators_loss, avg_discriminators_loss)
    
    # showing intermediate results
    print("Epoch: %d/%d | Generators Loss: %.4f | Discriminators Loss: %.4f"
              % (epoch+1, num_epochs, avg_generators_loss, avg_discriminators_loss))
    
    # showing generated images
    if (epoch + 1) % 10 == 0:
        _, sample_real_photo = next(iter(dataloader))
        
        sample_fake_monet = decoder(encoder(sample_real_photo.to(device))).detach().cpu()
        
        num_photos = min(batch_size, 5)
        plt.figure(figsize=(20, 8))
        for k in range(num_photos):
            plt.subplot(2, num_photos, k + 1)
            plt.imshow(unnorm(sample_real_photo[k]).permute(1, 2, 0))
            plt.title('Input photo')
            plt.axis('off')

            plt.subplot(2, num_photos, k + num_photos + 1)
            plt.imshow(unnorm(sample_fake_monet[k]).permute(1, 2, 0))
            plt.title('Output image')
            plt.axis('off')
        plt.show()
    
''' 
    

In [None]:
'''
history = History()
buffer_monet = Buffer()
buffer_photo = Buffer()

transformer = transformer_().to(device)

generator_monet2photo = nn.Sequential(encoder_(),decoder_(input_shape=(batch_size,256,16,16))).to(device)
generator_photo2monet = nn.Sequential(encoder_(),decoder_(input_shape=(batch_size,256,16,16))).to(device)

discriminator_monet = Discriminator(in_channels=3).to(device)
discriminator_photo = Discriminator(in_channels=3).to(device)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

lr = 2e-4
b1 = 0.5
b2 = 0.999

optim_generators = torch.optim.Adam(
    list(generator_monet2photo.parameters()) + list(generator_photo2monet.parameters()),
    lr=lr, betas=(b1, b2)
)

optim_discriminators = torch.optim.Adam(
    list(discriminator_monet.parameters()) + list(discriminator_photo.parameters()),
    lr=lr, betas=(b1, b2)
)

num_epochs = 150
decay_epoch = 25

lr_sched_step = lambda epoch: 1 - max(0, epoch - decay_epoch) / (num_epochs - decay_epoch)

lr_sched_generators = torch.optim.lr_scheduler.LambdaLR(optim_generators, lr_lambda=lr_sched_step)
lr_sched_discriminators = torch.optim.lr_scheduler.LambdaLR(optim_discriminators, lr_lambda=lr_sched_step)
for epoch in range(num_epochs):
    avg_generators_loss = 0
    avg_discriminators_loss = 0
    
    for i, (real_monet, real_photo) in enumerate(tqdm(dataloader, leave=False, total=len(dataloader))):
        real_monet, real_photo = real_monet.to(device), real_photo.to(device)
                
        """ Train Generators """
        # switching models parameters so that only generators are trained
        update_req_grad([generator_monet2photo, generator_photo2monet], True)
        update_req_grad([discriminator_monet, discriminator_photo], False)
        
        # zero the parameters gradients
        optim_generators.zero_grad()
        
        # forward-pass
        fake_photo = generator_monet2photo(real_monet)
        fake_monet = generator_photo2monet(real_photo)
        
        cycle_photo = generator_monet2photo(fake_monet)
        cycle_monet = generator_photo2monet(fake_photo)
        
        identity_photo = generator_monet2photo(real_photo)
        identity_monet = generator_photo2monet(real_monet)
        
        # update photos that are used to feed up discriminators
        buffer_photo.update(fake_photo)
        buffer_monet.update(fake_monet)
        
        # discriminators outputs that are used in adversarial loss
        discriminator_outputs_photo = discriminator_photo(fake_photo)
        discriminator_outputs_monet = discriminator_monet(fake_monet)
        
        # labels that are used as ground truth
        labels_real = torch.ones(discriminator_outputs_monet.size()).to(device)
        labels_fake = torch.zeros(discriminator_outputs_monet.size()).to(device)
        
        # adversarial loss - enforces that the generated output be of the appropriate domain
        loss_GAN_monet2photo = criterion_GAN(discriminator_outputs_photo, labels_real)
        loss_GAN_photo2monet = criterion_GAN(discriminator_outputs_monet, labels_real)
        loss_GAN = (loss_GAN_monet2photo + loss_GAN_photo2monet) / 2
        
        # cycle consistency loss - enforces that the input and output are recognizably the same
        loss_cycle_photo = criterion_cycle(cycle_photo, real_photo)
        loss_cycle_monet = criterion_cycle(cycle_monet, real_monet)
        loss_cycle = (loss_cycle_photo + loss_cycle_monet) / 2
        
        # identity mapping loss - helps preserve the color of the input images
        loss_identity_photo = criterion_identity(identity_photo, real_photo)
        loss_identity_monet = criterion_identity(identity_monet, real_monet)
        loss_identity = (loss_identity_photo + loss_identity_monet) / 2
        
        # total loss
        loss_generators_total = loss_GAN + 10 * loss_cycle + 5 * loss_identity
        
        # backward-pass
        loss_generators_total.backward()
        optim_generators.step()
        
        # limiting gradient norms - if they exceed 100, something went wrong
        clip_grad_norm_(generator_photo2monet.parameters(), 100)
        clip_grad_norm_(generator_monet2photo.parameters(), 100)
        
        
        """ Train Discriminators """
        # switching models parameters so that only discriminators are trained
        update_req_grad([discriminator_monet, discriminator_photo], True)
        update_req_grad([generator_monet2photo, generator_photo2monet], False)
        
        # zero the parameters gradients
        optim_discriminators.zero_grad()
        
        # sample images from 50 stored
        fake_photo = buffer_photo.sample(num_images=batch_size).to(device)
        fake_monet = buffer_monet.sample(num_images=batch_size).to(device)
        
        # making labels noisy for discriminators so that they don't prevail over generators
        threshold = min(1, 0.85 + (1 - 0.85) * epoch / (num_epochs // 2))
        noisy_labels_real = (torch.rand(discriminator_outputs_monet.size()) < threshold).float().to(device)
        
        # forward-pass + losses
        loss_real_photo = criterion_GAN(discriminator_photo(real_photo), noisy_labels_real)
        loss_fake_photo = criterion_GAN(discriminator_photo(fake_photo.detach()), labels_fake)
        loss_photo = (loss_real_photo + loss_fake_photo) / 2
        
        loss_real_monet = criterion_GAN(discriminator_monet(real_monet), noisy_labels_real)
        loss_fake_monet = criterion_GAN(discriminator_monet(fake_monet.detach()), labels_fake)
        loss_monet = (loss_real_monet + loss_fake_monet) / 2
        
        loss_discriminators_total = loss_monet + loss_photo
        
        # backward-pass
        loss_discriminators_total.backward()
        optim_discriminators.step()
        
        # clipping gradients to avoid gradients explosion
        clip_grad_norm_(discriminator_monet.parameters(), 100)
        clip_grad_norm_(discriminator_photo.parameters(), 100)
        
        # updating intermediate results
        avg_generators_loss += loss_generators_total.item()
        avg_discriminators_loss += loss_discriminators_total.item()
        
    # saving intermediate results
    avg_generators_loss /= len(dataloader)
    avg_discriminators_loss /= len(dataloader)
    history.update(avg_generators_loss, avg_discriminators_loss)
    
    # showing intermediate results
    print("Epoch: %d/%d | Generators Loss: %.4f | Discriminators Loss: %.4f"
              % (epoch+1, num_epochs, avg_generators_loss, avg_discriminators_loss))
    
    # showing generated images
    if (epoch + 1) % 10 == 0:
        _, sample_real_photo = next(iter(dataloader))
        
        sample_fake_monet = generator_photo2monet(sample_real_photo.to(device)).detach().cpu()
        
        num_photos = min(batch_size, 5)
        plt.figure(figsize=(20, 8))
        for k in range(num_photos):
            plt.subplot(2, num_photos, k + 1)
            plt.imshow(unnorm(sample_real_photo[k]).permute(1, 2, 0))
            plt.title('Input photo')
            plt.axis('off')

            plt.subplot(2, num_photos, k + num_photos + 1)
            plt.imshow(unnorm(sample_fake_monet[k]).permute(1, 2, 0))
            plt.title('Output image')
            plt.axis('off')
        plt.show()
    
    lr_sched_generators.step()
    lr_sched_discriminators.step()
'''

In [None]:
history.show()

# Create submission file

In [None]:
photo_dir = os.path.join(data_dir, 'photo_jpg')
files = [os.path.join(photo_dir, name) for name in os.listdir(photo_dir)]
len(files)

In [None]:
save_dir = '../images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
# NB: Here we use generator training mode to provide noise in the form of dropout

generate_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

for i in range(0, len(files), batch_size):
    images = []
    for j in range(i, min(len(files), i + batch_size)):
        image = Image.open(files[j])
        image = generate_transforms(image)
        images.append(image)
    real_photo = torch.stack(images, 0)
    
    fake_images = decoder_p2m(encoder_p2m(real_photo.to(device))).detach().cpu()
    #fake_images = real_photo.to(device).detach().cpu()
    for j in range(fake_images.size(0)):
        img = unnorm(fake_images[j])
        img = transforms.ToPILImage()(img).convert("RGB")
        img.save(os.path.join(save_dir, str(i + j + 1) + ".jpg"))

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")