In [1]:
!git clone https://github.com/Federico6419/Mask-CycleGAN          #It clones our github repository
%cd Mask-CycleGAN

from google.colab import drive
drive.mount('/content/drive')

#IMPORTS
from dataset import Dataset
import config
from discriminator import Discriminator
from generator import Generator
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch import Tensor
from tqdm import tqdm
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.tensorboard import SummaryWriter
import random
from PIL import Image


######### FUNCTIONS FOR SAVE AND LOAD MODELS #########
#This function saves the weights of the model in a file
def save_model(model, optimizer, epoch, filename="my_checkpoint.pth.tar"):
    print("Saving model for epoch : "+ str(epoch))

    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, filename)


#This function loads the precomputed weights of the model from a file
def load_model(file, model, optimizer, lr):
    print("Loading model: ")
    model_check = torch.load(file, map_location=config.DEVICE)
    model.load_state_dict(model_check["state_dict"])
    optimizer.load_state_dict(model_check["optimizer"])

    #epoch =model_check["epoch"]

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

######### END FUNCTIONS FOR SAVE AND LOAD MODELS ##########


######### FUNCTIONS THAT GENERATES A RANDOM RECTANGULAR MASK ##########
def mask_generator(image_size):
    maxNum = 5                        #Maximum number of rectangles to draw
    minNum = int(random.uniform(1, 5))       #Minimum number of rectangles to draw

    minArea = 0.15                    #Minimal accumulative area relative to the whole image area

    imageSize = 256                   #Size of the image
    minRectSize = imageSize/10        #Minimum size of the rectangles
    maxRectSize= imageSize            #Maximum size of the rectangles

    numRects = 0                      #Initialize the number of rectangles to 0
    sumRelArea = 0.0
    mask = torch.zeros(3, imageSize, imageSize, requires_grad=False)

    while((numRects < minNum) or (sumRelArea < minArea)):
        #Randomly generate the top left corner of the rectangle
        i0 = int(random.uniform(0, imageSize - minRectSize))
        j0 = int(random.uniform(0, imageSize - minRectSize))

        #Randomly generate the bottom right corner of the rectangle.
        i1 = int(random.uniform(i0 + minRectSize, min(i0 + maxRectSize, imageSize)))
        j1 = int(random.uniform(j0 + minRectSize, min(j0 + maxRectSize, imageSize)))

        #Draw rectangle on the mask
        mask[:, i0:i1, j0:j1] = 1.0
        numRects += 1
        sumRelArea += ((i1 - i0) * (j1 - j0)) / (imageSize * imageSize)

    return mask




########################### TRAIN FUNCTION #########################
def train_fn(disc_A, disc_B, disc_AM, disc_BM, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler,LAMBDA_IDENTITY, LAMBDA_CYCLE, LAMBDA_MASK, LAMBDA_CYCLE_MASK):

    loop = tqdm(loader, leave=True)           #leave=True to avoid print newline

    for idx, (domainB, domainA) in enumerate(loop):                             #It loops over the images from domain A and domain B
        domainA = domainA.to(config.DEVICE)                                     #Its puts the images from the two domains one the device
        domainB = domainB.to(config.DEVICE)

        #Label printed every epoch to see the prediction of the discriminators
        A_is_real = 0
        A_is_fake = 0
        AM_is_real = 0
        AM_is_fake = 0
        B_is_real = 0
        B_is_fake = 0
        BM_is_real = 0
        BM_is_fake = 0


        #Create a random Mask
        mask = mask_generator(domainA.size())
        mask = mask.to(config.DEVICE)


        with torch.cuda.amp.autocast():

            ############## TRAIN DISCRIMINATOR DOMAIN B #############
            fake_B = gen_B(domainA, mask)              #Generate with Generator a fake image from domain B starting from an image from domain A

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            D_B_real = disc_B(domainB)
            D_B_fake = disc_B(fake_B.detach())

            #Used to print the percentage that the given image is predicted real or fake !!!!
            B_is_real += D_B_real.mean().item()
            B_is_fake += D_B_fake.mean().item()

            #Compute the Mean Squared Error
            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            D_B_loss = D_B_real_loss + D_B_fake_loss                    #Sum the real image loss and the fake image loss



            ############## TRAIN MASK DISCRIMINATOR DOMAIN B #############
            #fake_B = gen_B(domainA)              #Generate with Generator a fake image from domain B starting from an image from domain A

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            DM_B_real = disc_BM(domainB*mask)
            DM_B_fake = disc_BM((fake_B*mask).detach())

            #Used to print the percentage that the given image is predicted real or fake !!!!
            BM_is_real += DM_B_real.mean().item()
            BM_is_fake += DM_B_fake.mean().item()

            #Compute the Mean Squared Error
            DM_B_real_loss = mse(DM_B_real, torch.ones_like(DM_B_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            DM_B_fake_loss = mse(DM_B_fake, torch.zeros_like(DM_B_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            DM_B_loss = DM_B_real_loss + DM_B_fake_loss                    #Sum the real image loss and the fake image loss


            ########### SUM TWO DISCRIMINATORS FROM DOMAIN B ##############
            D_B_total_loss = LAMBDA_MASK * DM_B_loss + (1 - LAMBDA_MASK) * D_B_loss



            ########### TRAIN DISCRIMINATOR OF THE DOMAIN A ##############
            fake_A = gen_A(domainB, mask)             #Generate with Generator a fake image from domain A starting from an image from domain B

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            D_A_real = disc_A(domainA)
            D_A_fake = disc_A(fake_A.detach())

            #Used print the percentage that the given image is predicted real or fake !!!!
            A_is_real += D_A_real.mean().item()
            A_is_fake += D_A_fake.mean().item()

            #Compute the Mean Squared Error
            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            D_A_loss = D_A_real_loss + D_A_fake_loss                    #Sum the real image loss and the fake image loss


            ############## TRAIN MASK DISCRIMINATOR DOMAIN A #############
            #fake_A = gen_A(domainB)              #Generate with Generator a fake image from domain B starting from an image from domain A

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            DM_A_real = disc_AM(domainA*mask)
            DM_A_fake = disc_AM((fake_A*mask).detach())

            #Used to print the percentage that the given image is predicted real or fake !!!!
            AM_is_real += DM_A_real.mean().item()
            AM_is_fake += DM_A_fake.mean().item()

            #Compute the Mean Squared Error
            DM_A_real_loss = mse(DM_A_real, torch.ones_like(DM_A_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            DM_A_fake_loss = mse(DM_A_fake, torch.zeros_like(DM_A_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            DM_A_loss = DM_A_real_loss + DM_A_fake_loss                    #Sum the real image loss and the fake image loss


            ########### SUM TWO DISCRIMINATORS FROM DOMAIN A ##############
            D_A_total_loss = LAMBDA_MASK * DM_A_loss + (1 - LAMBDA_MASK) * D_A_loss




            #Put together the loss of the two discriminators
            D_loss = (D_A_total_loss + D_B_total_loss)/2              #Lasciamo il /2???


        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward(retain_graph=True)
        d_scaler.step(opt_disc)
        d_scaler.update()



        ########################## TRAIN GENERATORS #########################
        with torch.cuda.amp.autocast():

            #Compute the Discriminator predictions
            D_A_fake = disc_A(fake_A)
            D_B_fake = disc_B(fake_B)
            D_A_real = disc_A(domainA)
            D_B_real = disc_B(domainB)

            #Compute the GAN losses
            loss_G_A = 0
            loss_G_B = 0

            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain A and a tensor with all ones
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain B and a tensor with all ones


            #Compute the Mask Discriminator predictions
            DM_A_fake = disc_AM(fake_A * mask)
            DM_B_fake = disc_BM(fake_B * mask)
            DM_A_real = disc_AM(domainA * mask)
            DM_B_real = disc_BM(domainB * mask)

            #Compute the GAN losses
            loss_GM_A = 0
            loss_GM_B = 0

            loss_GM_A = mse(DM_A_fake, torch.ones_like(DM_A_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain A and a tensor with all ones
            loss_GM_B = mse(DM_B_fake, torch.ones_like(DM_B_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain B and a tensor with all ones


            #CYCLE LOSS
            cycle_A = gen_A(fake_B, mask)
            cycle_B = gen_B(fake_A, mask)
            cycle_A_loss = l1(domainA, cycle_A)
            cycle_B_loss = l1(domainB, cycle_B)


            #CYCLE MASK LOSS
            fake_AM = gen_A(domainB, (1 - mask))
            fake_BM = gen_B(domainA, (1 -mask))
            cycle_A = gen_A(fake_BM, (1 - mask))
            cycle_B = gen_B(fake_AM, (1 - mask))
            cycle_AM_loss = l1(domainA * (1 - mask), cycle_A)
            cycle_BM_loss = l1(domainB * (1 - mask), cycle_B)

            #SUM THE CYCLE LOSSES
            cycle_total_loss_A = LAMBDA_CYCLE_MASK * cycle_A_loss + (1 - LAMBDA_CYCLE_MASK) * cycle_AM_loss
            cycle_total_loss_B = LAMBDA_CYCLE_MASK * cycle_B_loss + (1 - LAMBDA_CYCLE_MASK) * cycle_BM_loss


            #IDENTITY LOSS
            identity_A = gen_A(domainA, mask)
            identity_B = gen_B(domainB, mask)
            identity_loss_A = l1(domainA, identity_A)
            identity_loss_B = l1(domainB, identity_B)

            #Add all losses together, multiplied by their relative parameter
            G_loss = (
                loss_G_B * (1 - LAMBDA_MASK)
                + loss_G_A * (1 - LAMBDA_MASK)
                + loss_GM_B * LAMBDA_MASK
                + loss_GM_A * LAMBDA_MASK
                + cycle_total_loss_B * LAMBDA_CYCLE
                + cycle_total_loss_A * LAMBDA_CYCLE
                + identity_loss_A * LAMBDA_IDENTITY
                + identity_loss_B * LAMBDA_IDENTITY
            )


        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward(retain_graph=True)
        g_scaler.step(opt_gen)
        g_scaler.update()

        ##########################  END TRAIN GENERATORS #########################


        #Save tensors into images every 150 to see in real time the progress of the net
        if idx % 150 == 0:
            save_image(mask, f"Saved_Images/Mask_{idx}.png")
            save_image(fake_B*0.5+0.5, f"Saved_Images/domainB_{idx}.png")
            save_image(fake_A*0.5+0.5, f"Saved_Images/domainA_{idx}.png")

        #Set postfixes to the progess bar of tqdm
        #loop.set_postfix(A_real=A_is_real/(idx+1), A_fake=A_is_fake/(idx+1),B_real=B_is_real/(idx+1), B_fake=B_is_fake/(idx+1))
        loop.set_postfix(G_loss=G_loss.item(), D_loss=D_loss.item(), cycle_A_loss=cycle_A_loss.item(), cycle_B_loss=cycle_B_loss.item())

########################### END TRAIN FUNCTION ######################



#TEST FUNCTIONS
#Test function for Domain A
def test_fn(gen_B, gen_A, test_loader, mask_type):

    loop = tqdm(test_loader, leave=True)

    """
    for idx, (domainB, domainA) in enumerate(loop):
        if(mask_type == "Random"):
            mask = mask_generator(domainA.size())

        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)
        fake_B = gen_B(domainA, mask)
        fake_A = gen_A(fake_B, mask)

        save_image(mask, f"Test/test_images_A/testmask_{idx}.png")
        save_image(domainA*0.5+0.5, f"Test/test_images_A/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"Test/test_images_A/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"Test/test_images_A/testdomainA_{idx}.png")

        fake_A = gen_A(domainB, mask)
        fake_B = gen_B(fake_A, mask)

        save_image(mask, f"Test/test_images_B/testmask_{idx}.png")
        save_image(domainB*0.5+0.5, f"Test/test_images_B/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"Test/test_images_B/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"Test/test_images_B/testdomainA_{idx}.png")
    """


    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)

        save_image(domainA*0.5+0.5, f"Test/test_images_A/testoriginal_{idx}.png")
        save_image(domainB*0.5+0.5, f"Test/test_images_B/testoriginal_{idx}.png")

        for mask_type in range(7):
            if (mask_type == 1 or mask_type > 4):
                mask = np.array(Image.open(f"Test/Test_Masks/Mask_{mask_type}.png").convert("RGB"))

                fake_B = gen_B(domainA, mask)
                fake_A = gen_A(fake_B, mask)

                save_image(fake_B*0.5+0.5, f"Test/test_images_A/testdomainB_{idx}.png")
                save_image(fake_A*0.5+0.5, f"Test/test_images_A/testdomainA_{idx}.png")

                fake_A = gen_A(domainB, mask)
                fake_B = gen_B(fake_A, mask)

                save_image(fake_B*0.5+0.5, f"Test/test_images_B/testdomainB_{idx}.png")
                save_image(fake_A*0.5+0.5, f"Test/test_images_B/testdomainA_{idx}.png")



###################### MAIN FUNCTION #######################
def main():
    #Initialize Discriminators and Generators
    disc_A = Discriminator(in_channels=3).to(config.DEVICE)
    disc_B = Discriminator(in_channels=3).to(config.DEVICE)
    disc_AM = Discriminator(in_channels=3).to(config.DEVICE)
    disc_BM = Discriminator(in_channels=3).to(config.DEVICE)
    gen_A = Generator(img_channels=3).to(config.DEVICE)
    gen_B = Generator(img_channels=3).to(config.DEVICE)

    #Adam for Discriminators
    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()) + list(disc_AM.parameters()) + list(disc_BM.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    #Adam for Generators
    opt_gen = optim.Adam(
        list(gen_B.parameters()) + list(gen_A.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    #Define L1 and Mean Squared Error loss
    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    #Load pretrained model
    if config.LOAD_MODEL:
        load_model(
            config.CHECKPOINT_GEN_A, gen_A, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_GEN_B, gen_B, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_A, disc_A, opt_disc, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_B, disc_B, opt_disc, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_AM, disc_AM, opt_disc, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_BM, disc_BM, opt_disc, config.LEARNING_RATE,
        )


    ############## CHOICE OF THE DATASET ###############
    if(config.TRANSFORMATION == "WinterToSummer"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainWinter", domainB_dir=config.TRAIN_DIR+"/trainSummer", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testWinter", domainB_dir=config.TEST_DIR+"/testSummer", transform=config.transforms
        )
    elif(config.TRANSFORMATION == "HorseToZebra"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainHorse", domainB_dir=config.TRAIN_DIR+"/trainZebra", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testHorse", domainB_dir=config.TEST_DIR+"/testZebra", transform=config.transforms
        )
    elif(config.TRANSFORMATION == "MonetToPhoto"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainMonet", domainB_dir=config.TRAIN_DIR+"/trainPhotoMonet", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testMonet", domainB_dir=config.TEST_DIR+"/testPhotoMonet", transform=config.transforms
        )
    elif(config.TRANSFORMATION == "AppleToOrange"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainApple", domainB_dir=config.TRAIN_DIR+"/trainOrange", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testApple", domainB_dir=config.TEST_DIR+"/testOrange", transform=config.transforms
        )


    ############# DATALOADER #############
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True  #for faster training(non-paged cpu memory)
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )


    #Define the scalers
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    #Train the model
    if(config.TRAIN_MODEL):

        for epoch in range(config.NUM_EPOCHS):

            #Set the models in training mode
            disc_A.train()
            disc_B.train()
            disc_AM.train()
            disc_BM.train()
            gen_A.train()
            gen_B.train()
            train_fn(disc_A, disc_B, disc_AM, disc_BM, gen_B, gen_A, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler, config.LAMBDA_IDENTITY, config.LAMBDA_CYCLE, config.LAMBDA_MASK, config.LAMBDA_CYCLE_MASK)

            #If SAVE_MODEL is set to True save the current model
            if config.SAVE_MODEL:
                save_model(gen_A, opt_gen, epoch ,filename=config.NEW_CHECKPOINT_GEN_A)
                save_model(gen_B, opt_gen, epoch , filename=config.NEW_CHECKPOINT_GEN_B)
                save_model(disc_A, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_A)
                save_model(disc_B, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_B)
                save_model(disc_AM, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_AM)
                save_model(disc_BM, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_BM)

    #Test the model
    else:
        #Set the models in evaluation mode
        disc_A.eval()
        disc_B.eval()
        disc_AM.eval()
        disc_BM.eval()
        gen_A.eval()
        gen_B.eval()

        test_fn(gen_B, gen_A, test_loader, config.TEST_MASK)        #Start Test

if __name__ == "__main__":
    main()



Cloning into 'Mask-CycleGAN'...
remote: Enumerating objects: 16073, done.[K
remote: Counting objects: 100% (124/124), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 16073 (delta 69), reused 3 (delta 0), pack-reused 15949[K
Receiving objects: 100% (16073/16073), 590.21 MiB | 16.96 MiB/s, done.
Resolving deltas: 100% (73/73), done.
Updating files: 100% (16172/16172), done.
/content/Mask-CycleGAN
Mounted at /content/drive
Loading model: 
Loading model: 
Loading model: 
Loading model: 
Loading model: 
Loading model: 


100%|██████████| 1019/1019 [07:29<00:00,  2.27it/s, D_loss=1.26, G_loss=3.07, cycle_A_loss=0.052, cycle_B_loss=0.0281]


Saving model for epoch : 0
Saving model for epoch : 0
Saving model for epoch : 0
Saving model for epoch : 0
Saving model for epoch : 0
Saving model for epoch : 0


100%|██████████| 1019/1019 [07:22<00:00,  2.30it/s, D_loss=0.869, G_loss=3.29, cycle_A_loss=0.039, cycle_B_loss=0.0756]


Saving model for epoch : 1
Saving model for epoch : 1
Saving model for epoch : 1
Saving model for epoch : 1
Saving model for epoch : 1
Saving model for epoch : 1


100%|██████████| 1019/1019 [07:22<00:00,  2.30it/s, D_loss=0.874, G_loss=4.68, cycle_A_loss=0.104, cycle_B_loss=0.0503]


Saving model for epoch : 2
Saving model for epoch : 2
Saving model for epoch : 2
Saving model for epoch : 2
Saving model for epoch : 2
Saving model for epoch : 2


100%|██████████| 1019/1019 [07:21<00:00,  2.31it/s, D_loss=0.992, G_loss=5.26, cycle_A_loss=0.0908, cycle_B_loss=0.105]


Saving model for epoch : 3
Saving model for epoch : 3
Saving model for epoch : 3
Saving model for epoch : 3
Saving model for epoch : 3
Saving model for epoch : 3


100%|██████████| 1019/1019 [07:21<00:00,  2.31it/s, D_loss=0.856, G_loss=2.9, cycle_A_loss=0.0373, cycle_B_loss=0.0334]


Saving model for epoch : 4
Saving model for epoch : 4
Saving model for epoch : 4
Saving model for epoch : 4
Saving model for epoch : 4
Saving model for epoch : 4


100%|██████████| 1019/1019 [07:22<00:00,  2.30it/s, D_loss=0.706, G_loss=3.27, cycle_A_loss=0.0583, cycle_B_loss=0.0437]


Saving model for epoch : 5
Saving model for epoch : 5
Saving model for epoch : 5
Saving model for epoch : 5
Saving model for epoch : 5
Saving model for epoch : 5


100%|██████████| 1019/1019 [07:21<00:00,  2.31it/s, D_loss=0.74, G_loss=2.92, cycle_A_loss=0.0377, cycle_B_loss=0.0288]


Saving model for epoch : 6
Saving model for epoch : 6
Saving model for epoch : 6
Saving model for epoch : 6
Saving model for epoch : 6
Saving model for epoch : 6


100%|██████████| 1019/1019 [07:21<00:00,  2.31it/s, D_loss=0.664, G_loss=2.96, cycle_A_loss=0.0581, cycle_B_loss=0.0497]


Saving model for epoch : 7
Saving model for epoch : 7
Saving model for epoch : 7
Saving model for epoch : 7
Saving model for epoch : 7
Saving model for epoch : 7


100%|██████████| 1019/1019 [07:20<00:00,  2.31it/s, D_loss=1.02, G_loss=5.04, cycle_A_loss=0.152, cycle_B_loss=0.107]


Saving model for epoch : 8
Saving model for epoch : 8
Saving model for epoch : 8
Saving model for epoch : 8
Saving model for epoch : 8
Saving model for epoch : 8


100%|██████████| 1019/1019 [07:21<00:00,  2.31it/s, D_loss=0.742, G_loss=2.97, cycle_A_loss=0.0319, cycle_B_loss=0.0381]


Saving model for epoch : 9
Saving model for epoch : 9
Saving model for epoch : 9
Saving model for epoch : 9
Saving model for epoch : 9
Saving model for epoch : 9
