check paper for information<br>
paper : https://arxiv.org/pdf/1806.06397

imports

In [1]:
# review , instanse normalization ?
import torch
from torchvision import models
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision.transforms import v2 as T

import os
import matplotlib.pyplot as plt
import datetime

from casnet import CasNet
from patchGAN import PatchGAN
from loss_functions import CGANGeneratorLoss,CGANDiscreminatorLoss,PerceptualLoss,StyleTransferLoss

hyperparameters<br>
note : only some of them , there are other hyper params

In [2]:
hp = {
    "gpu":False,
    "g_lr":0.0002,
    "d_lr":0.0002,
    "batch_size":1,
    "epochs":1, # 200 in paper
    "g_iter":3, # generator iteration per epoch
    "lambda_perceptual":20,
    "lambda_style": 0.0001,
    "lambda_content": 0.0001,
    "lambda_perceptual_layer":[1,1],
    "lambda_content_layer":[1,2,3,4,5],
    "lambda_style_layer":[1,0,0,0,1],
    "save_model_every_n_epoch":50,
    "print_every_n_batch":1,
    "percent_of_data_used":1, # 1 means all
    "train_test_split":1 # between 0 and 1 , provide train size
    
}

data

In [3]:
root_dir = os.getcwd()
dataset_dir = os.path.join(root_dir,"dataset")
domain_A_data_dir = os.path.join(dataset_dir,"A")
domain_B_data_dir = os.path.join(dataset_dir,"B")
logs_dir = os.path.join(root_dir,"logs")

current_run_log_fname = str(datetime.datetime.now()).replace(":","-")[:-7]
current_logs_dir = os.path.join(logs_dir,current_run_log_fname)
os.makedirs(current_logs_dir, exist_ok=True)

In [4]:
class CustomDataset(Dataset):
    def __init__(self,transform=None):
        self.file_names = os.listdir(domain_A_data_dir)
        self.transform = transform
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        imageA = read_image(os.path.join(domain_A_data_dir,self.file_names[idx]))
        imageB = read_image(os.path.join(domain_B_data_dir,self.file_names[idx]))
        if self.transform:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)
        return imageA,imageB

In [5]:
# transforms
def get_transform():
    transforms = []
    transforms.append(T.Grayscale())
    transforms.append(T.Resize((1024,1024)))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [6]:
dataset = CustomDataset(transform=get_transform())

# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
indices = indices[:int(len(indices)*hp["percent_of_data_used"])] # number of data used
split_idx = int(hp["train_test_split"]*len(indices)) 
dataset_train = torch.utils.data.Subset(dataset, indices[:split_idx])
dataset_test = torch.utils.data.Subset(dataset, indices[split_idx:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=hp["batch_size"],
    shuffle=False,
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=hp["batch_size"],
    shuffle=False,
)

device

In [7]:
if(not hp["gpu"]):
    device = "cpu"
    
else:
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

print(f"Using {device} device")

Using cpu device


models

In [8]:
generator =  CasNet(num_ublocks=6).to(device)
discriminator = PatchGAN().to(device)

In [9]:
vgg = models.vgg19(weights='DEFAULT').features.to(device)
for param in vgg.parameters():
    param.requires_grad_(False)

training

In [10]:
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=hp["d_lr"])
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=hp["g_lr"])

In [11]:
cgan_loss_generator_fn=CGANGeneratorLoss()
cgan_loss_discriminator_fn=CGANDiscreminatorLoss()
perceptual_loss_fn = PerceptualLoss()
style_transfer_loss_fn=StyleTransferLoss(device)

In [12]:
losses = {"g_loss":[],"d_loss":[],"g_cgan":[],"g_perceptual":[],"g_style":[],"g_content":[]}
size = len(data_loader.dataset)
generator.train()
discriminator.train()

for epoch in range(hp["epochs"]):
    print(f"epoch [{epoch+1}/{hp["epochs"]}]")
    for batch, (A, B) in enumerate(data_loader):
        A, B = A.to(device), B.to(device)
        
        # generator
        for g_iter in range(hp["g_iter"]):
            generated = generator(A)
            patch_y,fmap_y = discriminator(generated)
            patch_x,fmap_x = discriminator(B)

            # loss
            g_cgan_loss = cgan_loss_generator_fn(patch_y)
            perceptual_loss = perceptual_loss_fn(fmap_x,fmap_y,hp["lambda_perceptual"],hp["lambda_perceptual_layer"])
            content_loss,style_loss = style_transfer_loss_fn(B,generated,vgg,hp["lambda_content"]\
                ,hp["lambda_style"],hp["lambda_content_layer"],hp["lambda_style_layer"])
            g_loss = g_cgan_loss+perceptual_loss+content_loss+style_loss
            
            # backpropagation
            g_loss.backward()
            generator_optimizer.step()
            generator_optimizer.zero_grad()
        
        # discriminator review
        generated = generator(A)
        patch_y,fmap_y = discriminator(generated)
        patch_x,fmap_x = discriminator(B)
        d_loss = cgan_loss_discriminator_fn(B, generated, patch_x, patch_y)

        d_loss.backward()
        discriminator_optimizer.step()
        discriminator_optimizer.zero_grad()

        losses["g_loss"].append(g_loss)
        losses["d_loss"].append(d_loss)
        losses["g_cgan"].append(g_cgan_loss)
        losses["g_perceptual"].append(perceptual_loss)
        losses["g_style"].append(style_loss)
        losses["g_content"].append(content_loss)

        if batch % hp["print_every_n_batch"] == 0:
            current = (batch + 1) * len(A)
            print(f"----g_loss: {g_loss:>7f} & d_loss: {d_loss:>7f} [{current:>5d}/{size:>5d}]")
    
    if(epoch%hp["save_model_every_n_epoch"]==0):
        torch.save(generator.state_dict(), os.path.join(current_logs_dir,"generator-"+str(epoch+1)+".pth"))
        torch.save(discriminator.state_dict(), os.path.join(current_logs_dir,"discriminator-"+str(epoch+1)+".pth"))

epoch [1/1]


In [None]:
torch.save(generator.state_dict(), os.path.join(current_logs_dir,"final_generator.pth"))
torch.save(discriminator.state_dict(), os.path.join(current_logs_dir,"final_discriminator.pth"))

visualize loss

In [None]:
for l in losses.items():
    plt.plot(l)
plt.legend(list(losses.keys()))
plt.show()