In [None]:
!git clone https://github.com/Federico6419/Mask-CycleGAN          #It clones our github repository
%cd Mask-CycleGAN

from google.colab import drive
drive.mount('/content/drive')

#IMPORTS
from dataset import Dataset
import config
from discriminator import Discriminator
from generator import Generator
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch import Tensor
from tqdm import tqdm
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.tensorboard import SummaryWriter


######### FUNCTIONS FOR SAVE AND LOAD MODELS #########
#This function saves the weights of the model in a file
def save_model(model, optimizer, epoch, filename="my_checkpoint.pth.tar"):
    print("Saving model for epoch : "+ str(epoch))

    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, filename)


#This function loads the precomputed weights of the model from a file
def load_model(file, model, optimizer, lr):
    print("Loading model: ")
    model_check = torch.load(file, map_location=config.DEVICE)
    model.load_state_dict(model_check["state_dict"])
    optimizer.load_state_dict(model_check["optimizer"])

    #epoch =model_check["epoch"]

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

######### END FUNCTIONS FOR SAVE AND LOAD MODELS ##########



########################### TRAIN FUNCTION #########################
def train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler,LAMBDA_IDENTITY, LAMBDA_CYCLE,GAMMA_CYCLE):

    loop = tqdm(loader, leave=True)           #leave=True to avoid print newline

    for idx, (domainB, domainA) in enumerate(loop):                             #It loops over the images from domain A and domain B
        domainA = domainA.to(config.DEVICE)                                     #Its puts the images from the two domains one the device
        domainB = domainB.to(config.DEVICE)

        #Label printed every epoch to see the prediction of the discriminators
        A_is_real = 0
        A_is_fake = 0
        B_is_real = 0
        B_is_fake = 0


        with torch.cuda.amp.autocast():

            ############## TRAIN DISCRIMINATOR DOMAIN B #############
            fake_B = gen_B(domainA)              #Generate with Generator a fake image from domain B starting from an image from domain A

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            D_B_real = disc_B(domainB)
            D_B_fake = disc_B(fake_B.detach())

            #Used to print the percentage that the given image is predicted real or fake !!!!
            B_is_real += D_B_real.mean().item()
            B_is_fake += D_B_fake.mean().item()

            #Compute the Mean Squared Error
            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            D_B_loss = D_B_real_loss + D_B_fake_loss                    #Sum the real image loss and the fake image loss



            ########### TRAIN DISCRIMINATOR OF THE DOMAIN A ##############
            fake_A = gen_A(domainB)             #Generate with Generator a fake image from domain A starting from an image from domain B

            #Compute probability of the real image and of the fake image to be a real image from domain B using the Discriminator
            D_A_real = disc_A(domainA)
            D_A_fake = disc_A(fake_A.detach())

            #Used print the percentage that the given image is predicted real or fake !!!!
            A_is_real += D_A_real.mean().item()
            A_is_fake += D_A_fake.mean().item()

            #Compute the Mean Squared Error
            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))    #MSE computed between the prediction of the real image made by Discriminator and a Tensor composed by all ones
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))   #MSE computed between the prediction of the fake image made by Discriminator and a Tensor composed by all zeros
            D_A_loss = D_A_real_loss + D_A_fake_loss                    #Sum the real image loss and the fake image loss



            #Put together the loss of the two discriminators
            D_loss = (D_A_loss + D_B_loss)/2


        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward(retain_graph=True)
        d_scaler.step(opt_disc)
        d_scaler.update()



        ########################## TRAIN GENERATORS #########################
        with torch.cuda.amp.autocast():

            #Compute the Discriminator predictions
            D_A_fake = disc_A(fake_A)
            D_B_fake = disc_B(fake_B)
            D_A_real = disc_A(domainA)
            D_B_real = disc_B(domainB)

            #Compute the GAN losses
            loss_G_A = 0
            loss_G_B = 0

            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain A and a tensor with all ones
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))  #Compute the MSE between the prediction of the Discriminator of the fake image from domain B and a tensor with all ones

            """
            #BETTER CYCLE CONSISTENCY
            if(config.BETTER):
                ################ BETTER CYCLE CONSISTENCY FOLLOWING THE REPORT TIPS #################
                cycle_summer = gen_S(fake_winter)
                x = disc_S(summer,feature_extract = True)
                Fx = disc_S(cycle_summer,feature_extract = True)
                norma_summer=l1(x,Fx)
                cycle_summer_loss = l1(summer, cycle_summer)

                cycle_winter = gen_W(fake_summer)
                y = disc_W(winter,feature_extract = True)
                Fy = disc_W(cycle_winter,feature_extract = True)
                norma_winter=l1(y,Fy)
                cycle_winter_loss = l1(winter, cycle_winter)

                G_loss = (
                loss_G_S
                + loss_G_W
                + torch.mean(disc_W(winter))*(GAMMA_CYCLE * norma_winter + (1-GAMMA_CYCLE) * cycle_winter_loss) * LAMBDA_CYCLE
                + torch.mean(disc_S(summer))*(GAMMA_CYCLE * norma_summer+ (1-GAMMA_CYCLE) * cycle_summer_loss) * LAMBDA_CYCLE
                )
                ################ BETTER CYCLE CONSISTENCY FOLLOWING THE REPORT TIPS #################
            else:
              """

            #CYCLE LOSS
            cycle_A = gen_A(fake_A)
            cycle_B = gen_B(fake_B)
            cycle_A_loss = l1(domainA, cycle_A)
            cycle_B_loss = l1(domainB, cycle_B)

            #IDENTITY LOSS
            identity_A = gen_A(domainA)
            identity_B = gen_B(domainB)
            identity_loss_A = l1(domainA, identity_A)
            identity_loss_B = l1(domainB, identity_B)

            #Add all losses together, multiplied by their relative parameter
            G_loss = (
                loss_G_B
                + loss_G_A
                + cycle_B_loss * LAMBDA_CYCLE
                + cycle_A_loss * LAMBDA_CYCLE
                + identity_loss_A * LAMBDA_IDENTITY
                + identity_loss_B * LAMBDA_IDENTITY
            )


        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward(retain_graph=True)
        g_scaler.step(opt_gen)
        g_scaler.update()

        ##########################  END TRAIN GENERATORS #########################


        #Save tensors into images every 150 to see in real time the progress of the net
        if idx % 150 == 0:
            save_image(fake_B*0.5+0.5, f"Saved_Images/domainB_{idx}.png")
            save_image(fake_A*0.5+0.5, f"Saved_Images/domainA_{idx}.png")

        #Set postfixes to the progess bar of tqdm
        loop.set_postfix(A_real=A_is_real/(idx+1), A_fake=A_is_fake/(idx+1),B_real=B_is_real/(idx+1), B_fake=B_is_fake/(idx+1))

########################### END TRAIN FUNCTION ######################



#TEST FUNCTIONS
#Test function for Domain A
def test_fn_A(gen_B, gen_A, test_loader):

    loop = tqdm(test_loader, leave=True)

    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)
        fake_B = gen_B(domainA)
        fake_A = gen_A(fake_B)

        save_image(domainA*0.5+0.5, f"test_images/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"test_images/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"test_images/testdomainA_{idx}.png")

#Test function for Domain B
def test_fn_B(gen_B,gen_A,test_loader):

    loop = tqdm(test_loader, leave=True)

    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)
        fake_A = gen_A(domainB)
        fake_B = gen_B(fake_A)

        save_image(domainB*0.5+0.5, f"test_images/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"test_images/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"test_images/testdomainA_{idx}.png")


###################### MAIN FUNCTION #######################
def main():
    disc_A = Discriminator(in_channels=3).to(config.DEVICE)
    disc_B = Discriminator(in_channels=3).to(config.DEVICE)
    gen_A = Generator(img_channels=3).to(config.DEVICE)
    gen_B = Generator(img_channels=3).to(config.DEVICE)

    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_B.parameters()) + list(gen_A.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    #GAMMA_CYCLE = config.GAMMA_CYCLE # ratio between discriminator CNN feature level and pixel level loss   !!!!!

    if config.LOAD_MODEL:
        load_model(
            config.CHECKPOINT_GEN_A, gen_A, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_GEN_B, gen_B, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_A, disc_A, opt_disc, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_B, disc_B, opt_disc, config.LEARNING_RATE,
        )


    ############## CHOICE OF THE DATASET ###############  !!!!
    if(config.TRANSFORMATION == "WinterToSummer"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainWinter", domainB_dir=config.TRAIN_DIR+"/trainSummer", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testWinter", domainB_dir=config.TEST_DIR+"/testSummer", transform=config.transforms
        )
    elif(config.TRANSFORMATION == "HorseToZebra"):
        dataset = Dataset(
            domainA_dir=config.TRAIN_DIR+"/trainHorse", domainB_dir=config.TRAIN_DIR+"/trainZebra", transform=config.transforms
        )
        test_dataset = Dataset(
            domainA_dir=config.TEST_DIR+"/testHorse", domainB_dir=config.TEST_DIR+"/testZebra", transform=config.transforms
        )

     ############# CHOICE OF THE DATASET ##############

    ############# DATALOADER #############
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True  #for faster training(non-paged cpu memory)
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    ############# DATALOADER ###############

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    if(config.TRAIN_MODEL):

        for epoch in range(config.NUM_EPOCHS):
            #if(config.BETTER): !!!!
                #train_fn(disc_W, disc_S, gen_S, gen_W, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,config.LAMBDA_IDENTITY, config.LAMBDA_CYCLE-epoch*0.15,GAMMA_CYCLE=GAMMA_CYCLE+0.015)
            #else:
            train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,config.LAMBDA_IDENTITY, config.LAMBDA_CYCLE, 0)


            if config.SAVE_MODEL: #if save_Model is set to true save model on the specific path
                save_model(gen_A, opt_gen, epoch ,filename=config.NEW_CHECKPOINT_GEN_A)
                save_model(gen_B, opt_gen, epoch , filename=config.NEW_CHECKPOINT_GEN_B)
                save_model(disc_A, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_A)
                save_model(disc_B, opt_disc, epoch , filename=config.NEW_CHECKPOINT_DISC_B)
    else:

        test_fn_A(gen_B,gen_A,test_loader)
        #test_fn_B(gen_B,gen_A,test_loader)

if __name__ == "__main__":
    main()



Cloning into 'Mask-CycleGAN'...
remote: Enumerating objects: 5410, done.[K
remote: Counting objects: 100% (94/94), done.[K
remote: Compressing objects: 100% (92/92), done.[K
