In [None]:
!git clone https://github.com/Federico6419/Mask-CycleGAN
%cd Mask-CycleGAN

from dataset import Dataset
import config
from discriminator import Discriminator
from generator import Generator

import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch import Tensor
from tqdm import tqdm
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.tensorboard import SummaryWriter

######### FUNCTION FOR SAVE AND LOAD MODELS #########

def save_model(model, optimizer, epoch, filename="my_checkpoint.pth.tar"):
    print("Saving model for epoch : "+ str(epoch))

    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, filename)


def load_model(file, model, optimizer, lr):
    print("Loading model: ")
    model_check = torch.load(file, map_location=config.DEVICE)
    model.load_state_dict(model_check["state_dict"])
    optimizer.load_state_dict(model_check["optimizer"])

    #epoch =model_check["epoch"]

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


######### END FUNCTION FOR MODELS ##########

def gradient_penalty(model, real_images, fake_images, device):
    # Random weight term for interpolation between real and fake data
    alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device)
    print(alpha)
    # Get random interpolation between real and fake data
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
    return gradient_penalty



                ########################### TRAIN FUNCTION #########################
def train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, l1, mse, BCE, d_scaler, g_scaler,LAMBDA_IDENTITY, LAMBDA_CYCLE,GAMMA_CYCLE):

    loop = tqdm(loader, leave=True)           #leave=True to avoid print newline

    # Loss weight for gradient penalty
    LAMBDA_GP = 10

    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)



        # Ground truths used in the adversarial loss
        """
        validA = Variable(Tensor(domainA.shape[0], 1,30,30).fill_(1.0), requires_grad=False)
        validB = Variable(Tensor(domainB.shape[0], 1,30,30).fill_(1.0), requires_grad=False)

        fakeA = Variable(Tensor(domainA.shape[0], 1,30,30).fill_(0.0), requires_grad=False)
        fakeB = Variable(Tensor(domainB.shape[0], 1,30,30).fill_(0.0), requires_grad=False)
        """

        # Label printed every epoch to see the prediction of the discriminators
        A_is_real = 0
        A_is_fake = 0
        B_is_real = 0
        B_is_fake = 0


        with torch.cuda.amp.autocast():
            """
            validA = validA.to(config.DEVICE)
            fakeA = fakeA.to(config.DEVICE)
            validB = validB.to(config.DEVICE)
            fakeB = fakeB.to(config.DEVICE)
            """


            ############## TRAIN DISCRIMINATOR DOMAIN B #############
            fake_B = gen_B(domainA)              #Generate a fake image from domain B starting from an image from domain A

            #Compute probability to be a real image from domain B using the discriminator
            D_B_real = disc_B(domainB)
            D_B_fake = disc_B(fake_B.detach())

            #Used to print the percentage that the given image is predicted real or fake !!!!!!!!!!!
            B_is_real += D_B_real.mean().item()
            B_is_fake += D_B_fake.mean().item()


            D_B_real_loss = mse(D_B_real, torch.ones_like(D_B_real))
            D_B_fake_loss = mse(D_B_fake, torch.zeros_like(D_B_fake))
            D_B_loss = D_B_real_loss + D_B_fake_loss



            ########### TRAIN DISCRIMINATOR OF THE DOMAIN B ##############
            fake_A = gen_A(domainB)
            D_A_real = disc_A(domainA)
            D_A_fake = disc_A(fake_A.detach())
            #used print the percentage that the given image is predicted real or fake
            A_is_real += D_A_real.mean().item()
            A_is_fake += D_A_fake.mean().item()

            D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))
            D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake))
            D_A_loss = D_A_real_loss + D_A_fake_loss



            # put togheter the loss of the two discriminators
            D_loss = (D_A_loss + D_B_loss)/2


        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward(retain_graph=True)
        d_scaler.step(opt_disc)
        d_scaler.update()



        ########################## TRAIN GENERATORS #########################
        with torch.cuda.amp.autocast():

            D_A_fake = disc_A(fake_A)
            D_B_fake = disc_B(fake_B)
            D_A_real = disc_A(domainA)
            D_B_real = disc_B(domainB)

            loss_G_A = 0
            loss_G_B = 0

            loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))
            loss_G_B = mse(D_B_fake, torch.ones_like(D_B_fake))

            """
            if(config.BETTER):
                ################ BETTER CYCLE CONSISTENCY FOLLOWING THE REPORT TIPS #################
                cycle_summer = gen_S(fake_winter)
                x = disc_S(summer,feature_extract = True)
                Fx = disc_S(cycle_summer,feature_extract = True)
                norma_summer=l1(x,Fx)
                cycle_summer_loss = l1(summer, cycle_summer)

                cycle_winter = gen_W(fake_summer)
                y = disc_W(winter,feature_extract = True)
                Fy = disc_W(cycle_winter,feature_extract = True)
                norma_winter=l1(y,Fy)
                cycle_winter_loss = l1(winter, cycle_winter)

                G_loss = (
                loss_G_S
                + loss_G_W
                + torch.mean(disc_W(winter))*(GAMMA_CYCLE * norma_winter + (1-GAMMA_CYCLE) * cycle_winter_loss) * LAMBDA_CYCLE
                + torch.mean(disc_S(summer))*(GAMMA_CYCLE * norma_summer+ (1-GAMMA_CYCLE) * cycle_summer_loss) * LAMBDA_CYCLE
                )
                ################ BETTER CYCLE CONSISTENCY FOLLOWING THE REPORT TIPS #################
            else:
              """

            #CYCLE LOSS
            cycle_A = gen_A(fake_A)
            cycle_B = gen_B(fake_B)
            cycle_A_loss = l1(domainA, cycle_A)
            cycle_B_loss = l1(domainB, cycle_B)

            #IDENTITY LOSS
            identity_A = gen_A(domainA)
            identity_B = gen_B(domainB)
            identity_loss_A = l1(domainA, identity_A)
            identity_loss_B = l1(domainB, identity_B)

            #Add all losses together
            G_loss = (
                loss_G_B
                + loss_G_A
                + cycle_B_loss * LAMBDA_CYCLE
                + cycle_A_loss * LAMBDA_CYCLE
                + identity_loss_A * LAMBDA_IDENTITY
                + identity_loss_B * LAMBDA_IDENTITY
            )


        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward(retain_graph=True)
        g_scaler.step(opt_gen)
        g_scaler.update()

        ##########################  END TRAIN GENERATORS #########################


        if idx % 150 == 0:    #save tensor into images every 150 to see in real time the progress of the net
            save_image(fake_B*0.5+0.5, f"Saved_Images/domainB_{idx}.png")
            save_image(fake_A*0.5+0.5, f"Saved_Images/domainA_{idx}.png")

        #set postfixes to the progess bar of tqdm
        loop.set_postfix(A_real=A_is_real/(idx+1), A_fake=A_is_fake/(idx+1),B_real=B_is_real/(idx+1), B_fake=B_is_fake/(idx+1))

                ########################### END TRAIN FUNCTION ######################



#TEST FUNCTIONS
def test_fn_A(gen_B,gen_A,test_loader):

    loop = tqdm(test_loader, leave=True)

    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)
        fake_B = gen_B(domainA)
        fake_A = gen_A(fake_B)

        save_image(domainA*0.5+0.5, f"test_images/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"test_images/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"test_images/testdomainA_{idx}.png")

def test_fn_B(gen_B,gen_A,test_loader):

    loop = tqdm(test_loader, leave=True)

    for idx, (domainB, domainA) in enumerate(loop):
        domainA = domainA.to(config.DEVICE)
        domainB = domainB.to(config.DEVICE)
        fake_A = gen_A(domainB)
        fake_B = gen_B(fake_A)

        save_image(domainB*0.5+0.5, f"test_images/testoriginal_{idx}.png")
        save_image(fake_B*0.5+0.5, f"test_images/testdomainB_{idx}.png")
        save_image(fake_A*0.5+0.5, f"test_images/testdomainA_{idx}.png")

                        ###################### MAIN FUNCTION #######################
def main():
    disc_A = Discriminator(in_channels=3).to(config.DEVICE)
    disc_B = Discriminator(in_channels=3).to(config.DEVICE)
    gen_A = Generator(img_channels=3).to(config.DEVICE)
    gen_B = Generator(img_channels=3).to(config.DEVICE)

    opt_disc = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_B.parameters()) + list(gen_A.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    BCE = torch.nn.BCEWithLogitsLoss()

    #GAMMA_CYCLE = config.GAMMA_CYCLE # ratio between discriminator CNN feature level and pixel level loss   !!!!!

    if config.LOAD_MODEL:
        """load_model(      !!!!
            config.CHECKPOINT_GEN_A, gen_A, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_GEN_B, gen_B, opt_gen, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_A, disc_A, opt_disc, config.LEARNING_RATE,
        )
        load_model(
            config.CHECKPOINT_DISC_B, disc_B, opt_disc, config.LEARNING_RATE,
        )"""


    ############## CHOICE OF THE DATASET ###############  !!!!
    if(config.TRANSFORMATION == "WinterToSummer"):
        dataset = Dataset(
            winter_dir=config.TRAIN_DIR+"/trainWinter", summer_dir=config.TRAIN_DIR+"/trainSummer", transform=config.transforms
        )
        test_dataset = Dataset(
            winter_dir=config.TEST_DIR+"/testWinter", summer_dir=config.TEST_DIR+"/testSummer", transform=config.transforms
        )
    elif(config.TRANSFORMATION == "HorseToZebra"):
        dataset = Dataset(
            horse_dir=config.TRAIN_DIR+"/trainHorse", zebra_dir=config.TRAIN_DIR+"/trainZebra", transform=config.transforms
        )
        test_dataset = Dataset(
            horse_dir=config.TEST_DIR+"/testHorse", zebra_dir=config.TEST_DIR+"/testZebra", transform=config.transforms
        )

     ############# CHOICE OF THE DATASET ##############

    ############# DATALOADER #############
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True  #for faster training(non-paged cpu memory)
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    ############# DATALOADER ###############

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    if(config.TRAIN_MODEL):

        for epoch in range(config.NUM_EPOCHS):
            #if(config.BETTER): !!!!
                #train_fn(disc_W, disc_S, gen_S, gen_W, loader, opt_disc, opt_gen, L1, mse, BCE, d_scaler, g_scaler,config.LAMBDA_IDENTITY, config.LAMBDA_CYCLE-epoch*0.15,GAMMA_CYCLE=GAMMA_CYCLE+0.015)
            #else:
            train_fn(disc_A, disc_B, gen_B, gen_A, loader, opt_disc, opt_gen, L1, mse, BCE, d_scaler, g_scaler,config.LAMBDA_IDENTITY, config.LAMBDA_CYCLE, 0)


            if config.SAVE_MODEL: #if save_Model is set to true save model on the specific path
                save_model(gen_A, opt_gen, epoch ,filename=config.CHECKPOINT_GEN_A)
                save_model(gen_B, opt_gen, epoch , filename=config.CHECKPOINT_GEN_B)
                save_model(disc_A, opt_disc, epoch , filename=config.CHECKPOINT_DISC_A)
                save_model(disc_B, opt_disc, epoch , filename=config.CHECKPOINT_DISC_B)
    else:

        test_fn_A(gen_B,gen_A,test_loader)
        #test_fn_B(gen_B,gen_A,test_loader)

if __name__ == "__main__":
    main()



Cloning into 'Mask-CycleGAN'...
remote: Enumerating objects: 2671, done.[K
remote: Counting objects: 100% (104/104), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 2671 (delta 50), reused 52 (delta 20), pack-reused 2567[K
Receiving objects: 100% (2671/2671), 117.69 MiB | 22.55 MiB/s, done.
Resolving deltas: 100% (51/51), done.
/content/Mask-CycleGAN


 41%|████      | 499/1231 [02:09<03:10,  3.85it/s, A_fake=0.000929, A_real=0.00119, B_fake=0.000876, B_real=0.00117]


KeyboardInterrupt: ignored