In [None]:
%cd C:/Users/alapa/funiegan/funiegan_1.0/

In [2]:
import os
import sys
import yaml
import argparse
import numpy as np
from PIL import Image
# pytorch libs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.transforms as transforms
# local libs
from PyTorch.nets.commons import Weights_Normal, VGG19_PercepLoss
from PyTorch.nets.SelfAttentionfuniegan import GeneratorFunieGAN, DiscriminatorFunieGAN
from PyTorch.utils.data_utils import GetTrainingPairs, GetValImage
import math
from skimage.metrics import structural_similarity as ssim
from torchvision import models
from torchvision.models import vgg19, VGG19_Weights
from torch_optimizer import Lookahead,NovoGrad
from adabelief_pytorch import AdaBelief


In [3]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--cfg_file", type=str, default=r"PyTorch/configs/train_euvp.yaml")
parser.add_argument("--epoch", type=int, default=72, help="which epoch to start from")
parser.add_argument("--num_epochs", type=int, default=400, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of 1st order momentum")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of 2nd order momentum")
args, unknown = parser.parse_known_args()



In [4]:
epoch = args.epoch
num_epochs = args.num_epochs
batch_size =  args.batch_size
lr_rate, lr_b1, lr_b2 = args.lr, args.b1, args.b2
# load the data config file
with open(args.cfg_file) as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
# get info from config file
dataset_name = cfg["dataset_name"]
dataset_path = cfg["dataset_path"]
channels = cfg["chans"]
img_width = cfg["im_width"]
img_height = cfg["im_height"]
val_interval = cfg["val_interval"]
ckpt_interval = cfg["ckpt_interval"]

In [None]:

print (dataset_name)
print (dataset_path)
print (channels)


In [6]:
## create dir for model and validation data
samples_dir = os.path.join("samples/FunieGAN/", dataset_name)
checkpoint_dir = os.path.join("checkpoints/FunieGAN(NLSelf)/", dataset_name)
os.makedirs(samples_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
Adv_cGAN = torch.nn.MSELoss()
L1_G  = torch.nn.L1Loss() # similarity loss (l1)
L_vgg = VGG19_PercepLoss() # content loss (vgg)
lambda_1, lambda_con = 7, 3 # 7:3 (as in paper)
patch = (1, img_height//16, img_width//16) # 16x16 for 256x256

In [8]:

def compute_gradient_penalty(D, real_samples, fake_samples, distorted_samples, Tensor):
    """
    Calculates the gradient penalty loss for WGAN-GP
    """
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates, distorted_samples)
    fake = Variable(Tensor(real_samples.shape[0], *patch).fill_(1.0), requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


In [None]:
generator = GeneratorFunieGAN()
discriminator = DiscriminatorFunieGAN()

# see if cuda is available
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    Adv_cGAN.cuda()
    L1_G = L1_G.cuda()
    L_vgg = L_vgg.cuda()
    Tensor = torch.cuda.FloatTensor
else:
    Tensor = torch.FloatTensor

# Initialize weights or load pretrained models
if args.epoch == 0:
    generator.apply(Weights_Normal)
    discriminator.apply(Weights_Normal)
else:
    generator.load_state_dict(torch.load("checkpoints/FunieGAN(NLSelf)/%s/generator_%d.pth" % (dataset_name, args.epoch)))
    discriminator.load_state_dict(torch.load("checkpoints/FunieGAN(NLSelf)/%s/discriminator_%d.pth" % (dataset_name, epoch)))
    print ("Loaded model from epoch %d" %(epoch))

# Optimizers
"""
base_optim_G = torch.optim.SGD(generator.parameters(), lr=3*lr_rate, momentum=0.9)
optimizer_G = Lookahead(base_optim_G, k=5, alpha=0.5)
optimizer_D = torch.optim.RAdam(discriminator.parameters(), lr=lr_rate, betas=(lr_b1, lr_b2))
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=15, gamma=0.85)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=20, gamma=0.8)
"""
base_optim_G = AdaBelief(generator.parameters(), lr=3*lr_rate, betas=(lr_b1, lr_b2), eps=1e-16, weight_decay=1e-4, rectify=True)
optimizer_G = Lookahead(base_optim_G, k=5, alpha=0.5)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.5*lr_rate, betas=(lr_b1, lr_b2))
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=25, gamma=0.85)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=25, gamma=0.8)



In [10]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    GetTrainingPairs(dataset_path, dataset_name, transforms_=transforms_),
    batch_size = batch_size,
    shuffle = True,
    num_workers = 8,
)

val_dataloader = DataLoader(
    GetValImage(dataset_path, dataset_name, transforms_=transforms_, sub_dir='validation'),
    batch_size=4,
    shuffle=True,
    num_workers=4,
)

In [None]:
print(f"Number of samples in dataloader: {len(dataloader.dataset)}")

In [12]:

def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:  # MSE is zero means no noise
        return 100
    PIXEL_MAX = 1.0  # Images are normalized between 0-1
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def validate_metrics(generator, val_dataloader, Tensor):
    generator.eval()
    psnr_values, ssim_values = [], []
    for i, batch in enumerate(val_dataloader):
        imgs_val = Variable(batch["val"].type(Tensor))
        imgs_gt = Variable(batch["gt"].type(Tensor))
        imgs_gen = generator(imgs_val).detach().cpu().numpy()
        imgs_gt = imgs_gt.cpu().numpy()
        for gt, gen in zip(imgs_gt, imgs_gen):
            psnr_values.append(calculate_psnr(gt, gen))
            ssim_values.append(ssim(gt.transpose(1, 2, 0), gen.transpose(1, 2, 0), multichannel=True))
    generator.train()
    return np.mean(psnr_values), np.mean(ssim_values)


In [13]:

# Define the VGG Feature Extractor
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        #vgg19 = models.vgg19(pretrained=True)
        vgg19 = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:16])  # Use up to conv4_1
        for param in self.feature_extractor.parameters():
            param.requires_grad = False  # Freeze VGG parameters

    def forward(self, img):
        return self.feature_extractor(img)

# Initialize the VGG Feature Extractor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg_extractor = VGGFeatureExtractor().to(device)

# VGG Method
def VGG_loss(imgs_fake, imgs_good_gt):
    """
    Calculates the VGG-based perceptual loss.
    """
    features_fake = vgg_extractor(imgs_fake)
    features_real = vgg_extractor(imgs_good_gt)
    loss_content = F.mse_loss(features_fake, features_real)
    return loss_content


In [14]:
# Pixel Loss Method
def pixel_loss(imgs_fake, imgs_good_gt):
    """
    Calculates the L1 pixel-wise loss between the fake and ground truth images.
    """
    loss_pixel = F.l1_loss(imgs_fake, imgs_good_gt)
    return loss_pixel
    

In [None]:

for epoch in range(epoch+1, num_epochs):
    for i, batch in enumerate(dataloader):
        # Model inputs
        imgs_distorted = Variable(batch["A"].type(Tensor))
        imgs_good_gt = Variable(batch["B"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_distorted.size(0), *patch))), requires_grad=False)
        
        fake = Variable(Tensor(np.zeros((imgs_distorted.size(0), *patch))), requires_grad=False)

        # Train Discriminator
        optimizer_D.zero_grad()
        imgs_fake = generator(imgs_distorted).detach()  # Ensure no gradients for the generator
        pred_real = discriminator(imgs_good_gt, imgs_distorted)
        loss_real = Adv_cGAN(pred_real, valid)
        pred_fake = discriminator(imgs_fake, imgs_distorted)
        loss_fake = Adv_cGAN(pred_fake, fake)
        # Gradient Penalty
        gradient_penalty = compute_gradient_penalty(discriminator, imgs_good_gt, imgs_fake, imgs_distorted, Tensor)
        # Total Discriminator Loss
        loss_D = 0.5 * (loss_real + loss_fake) + 5.0 * gradient_penalty
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        #for _ in range(2):
        optimizer_G.zero_grad()
        imgs_fake = generator(imgs_distorted)
        pred_fake = discriminator(imgs_fake, imgs_distorted)
        loss_GAN = Adv_cGAN(pred_fake, valid)
        loss_pixel = pixel_loss(imgs_fake, imgs_good_gt)
        loss_content = VGG_loss(imgs_fake, imgs_good_gt)
        loss_G = 0.1 * loss_GAN + 8.0 * loss_pixel + 0.5 * loss_content
        loss_G.backward()
        optimizer_G.step()

        ## Print log
        if not i%1:
            sys.stdout.write("\r[Epoch %d/%d: batch %d/%d] [DLoss: %.3f, GLoss: %.3f, AdvLoss: %.3f]"
                              %(
                                epoch, num_epochs, i, len(dataloader),
                                loss_D.item(), loss_G.item(), loss_GAN.item(),
                               )
            )
        ## If at sample interval save image
        batches_done = epoch * len(dataloader) + i
        if batches_done % val_interval == 0:
            imgs = next(iter(val_dataloader))
            imgs_val = Variable(imgs["val"].type(Tensor))
            imgs_gen = generator(imgs_val)
            img_sample = torch.cat((imgs_val.data, imgs_gen.data), -2)
            save_image(img_sample, "samples/FunieGAN/%s/%s.png" % (dataset_name, batches_done), nrow=5, normalize=True)
    scheduler_G.step()
    scheduler_D.step()
    if batches_done % val_interval == 0:
        psnr, ssim_val = validate_metrics(generator, val_dataloader, Tensor)
        print(f"Validation Metrics - PSNR: {psnr:.2f}, SSIM: {ssim_val:.2f}")

    ## Save model checkpoints
    torch.save(generator.state_dict(), "checkpoints/FunieGAN(NLSelf)/%s/generator_%d.pth" % (dataset_name, epoch))
    torch.save(discriminator.state_dict(), "checkpoints/FunieGAN(NLSelf)/%s/discriminator_%d.pth" % (dataset_name, epoch))
