In [1]:
import os
import sys
import yaml
import argparse
import numpy as np
from PIL import Image
# pytorch libs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.transforms as transforms
# local libs
from nets.commons import Weights_Normal, VGG19_PercepLoss
from nets.funiegan import GeneratorFunieGAN, DiscriminatorFunieGAN
from utils.data_utils import GetTrainingPairs, GetValImage

In [2]:
class Args:
    cfg_file = "configs/train_euvp.yaml"
    epoch = 1
    num_epochs = 60
    batch_size = 8
    lr = 0.0003
    b1 = 0.5
    b2 = 0.99

args = Args()

In [3]:
import yaml

epoch = args.epoch
num_epochs = args.num_epochs
batch_size =  args.batch_size
lr_rate, lr_b1, lr_b2 = args.lr, args.b1, args.b2 
# load the data config file
with open(args.cfg_file) as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
# get info from config file
dataset_name = cfg["dataset_name"] 
dataset_path = cfg["dataset_path"]
channels = cfg["chans"]
img_width = cfg["im_width"]
img_height = cfg["im_height"] 
val_interval = cfg["val_interval"]
ckpt_interval = cfg["ckpt_interval"]

In [4]:
import os

samples_dir = os.path.join("samples/FunieGAN/", dataset_name)
checkpoint_dir = os.path.join("checkpoints/FunieGAN/", dataset_name)
os.makedirs(samples_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
Adv_cGAN = torch.nn.MSELoss()
L1_G  = torch.nn.L1Loss() # similarity loss (l1)
L_vgg = VGG19_PercepLoss() # content loss (vgg)
lambda_1, lambda_con = 6, 2 # 7:3 (as in paper)
patch = (1, img_height//16, img_width//16) # 16x16 for 256x256

# Initialize generator and discriminator
generator = GeneratorFunieGAN()
discriminator = DiscriminatorFunieGAN()



In [6]:
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    Adv_cGAN.cuda()
    L1_G = L1_G.cuda()
    L_vgg = L_vgg.cuda()
    Tensor = torch.cuda.FloatTensor
else:
    Tensor = torch.FloatTensor

In [7]:
# if args.epoch == 0:
#     generator.apply(Weights_Normal)
#     discriminator.apply(Weights_Normal)
# else:
#     generator.load_state_dict(torch.load("checkpoints/FunieGAN/%s/generator_%d.pth" % (dataset_name, args.epoch)))
#     discriminator.load_state_dict(torch.load("checkpoints/FunieGAN/%s/discriminator_%d.pth" % (dataset_name, epoch)))
#     print ("Loaded model from epoch %d" %(epoch))

In [8]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_rate, betas=(lr_b1, lr_b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_rate, betas=(lr_b1, lr_b2))

In [9]:
print("Dataset Path:", dataset_path)
print("Dataset Name:", dataset_name)

Dataset Path: output
Dataset Name: paired


In [10]:
import glob
filesA = sorted(glob.glob("output/paired/input/*.*"))
filesB = sorted(glob.glob("output/paired/target/*.*"))
print("FilesA:", len(filesA))
print("FilesB:", len(filesB))

FilesA: 905
FilesB: 905


In [11]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataset_path = "output/"
dataset_name = "paired"

dataloader = DataLoader(
    GetTrainingPairs(dataset_path, dataset_name, transforms_=transforms_),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

val_dataloader = DataLoader(
    GetValImage(dataset_path, dataset_name, transforms_=transforms_),
    batch_size=4,
    shuffle=True,
    num_workers=0,
)



Loaded 905 input images and 905 target images.
Loaded 905 input images.


In [12]:
dataset = GetValImage(dataset_path, dataset_name, transforms_=transforms_)
print("Length from self.len:", dataset.len)


Loaded 905 input images.
Length from self.len: 905


In [13]:
# import timeit

# start_time = timeit.default_timer()
# for epoch in range(epoch, num_epochs):
#     for i, batch in enumerate(dataloader):
#         # Model inputs
#         imgs_distorted = Variable(batch["A"].type(Tensor))
#         imgs_good_gt = Variable(batch["B"].type(Tensor))
#         # Adversarial ground truths
#         valid = Variable(Tensor(np.ones((imgs_distorted.size(0), *patch))), requires_grad=False)
#         fake = Variable(Tensor(np.zeros((imgs_distorted.size(0), *patch))), requires_grad=False)

#         ## Train Discriminator
#         optimizer_D.zero_grad()
#         imgs_fake = generator(imgs_distorted)
#         pred_real = discriminator(imgs_good_gt, imgs_distorted)
#         loss_real = Adv_cGAN(pred_real, valid)
#         pred_fake = discriminator(imgs_fake, imgs_distorted)
#         loss_fake = Adv_cGAN(pred_fake, fake)
#         # Total loss: real + fake (standard PatchGAN)
#         loss_D = 0.5 * (loss_real + loss_fake) * 10.0 # 10x scaled for stability
#         loss_D.backward()
#         optimizer_D.step()

#         ## Train Generator
#         optimizer_G.zero_grad()
#         imgs_fake = generator(imgs_distorted)
#         pred_fake = discriminator(imgs_fake, imgs_distorted)
#         loss_GAN =  Adv_cGAN(pred_fake, valid) # GAN loss
#         loss_1 = L1_G(imgs_fake, imgs_good_gt) # similarity loss
#         loss_con = L_vgg(imgs_fake, imgs_good_gt)# content loss
#         # Total loss (Section 3.2.1 in the paper)
#         loss_G = loss_GAN + lambda_1 * loss_1  + lambda_con * loss_con 
#         loss_G.backward()
#         optimizer_G.step()

#         ## Print log
#         if not i%50:
#             sys.stdout.write("\r[Epoch %d/%d: batch %d/%d] [DLoss: %.3f, GLoss: %.3f, AdvLoss: %.3f]"
#                               %(
#                                 epoch, num_epochs, i, len(dataloader),
#                                 loss_D.item(), loss_G.item(), loss_GAN.item(),
#                                )
#             )
#         ## If at sample interval save image
#         batches_done = epoch * len(dataloader) + i
#         if batches_done % val_interval == 0:
#             imgs = next(iter(val_dataloader))
#             imgs_val = Variable(imgs["val"].type(Tensor))
#             imgs_gen = generator(imgs_val)
#             img_sample = torch.cat((imgs_val.data, imgs_gen.data), -2)
#             save_image(img_sample, "samples/FunieGAN/%s/%s.png" % (dataset_name, batches_done), nrow=5, normalize=True)

#     ## Save model checkpoints
#     if (epoch % ckpt_interval == 0):
#         torch.save(generator.state_dict(), "checkpoints/FunieGAN/%s/generator_%d.pth" % (dataset_name, epoch))
#         torch.save(discriminator.state_dict(), "checkpoints/FunieGAN/%s/discriminator_%d.pth" % (dataset_name, epoch))

#     end_time = timeit.default_timer()
#     elapsed_time = end_time - start_time
#     print(f"Elapsed time: {elapsed_time} seconds")

In [None]:
from pytorch_msssim import ssim
import timeit

start_time = timeit.default_timer()
for epoch in range(epoch, num_epochs):
    for i, batch in enumerate(dataloader):
        # === Model inputs ===
        imgs_distorted = Variable(batch["A"].type(Tensor))
        imgs_good_gt = Variable(batch["B"].type(Tensor))

        # === Ground truths for PatchGAN ===
        valid = Variable(Tensor(np.ones((imgs_distorted.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_distorted.size(0), *patch))), requires_grad=False)

        # === Train Discriminator ===
        optimizer_D.zero_grad()
        imgs_fake = generator(imgs_distorted).detach()
        pred_real = discriminator(imgs_good_gt, imgs_distorted)
        loss_real = Adv_cGAN(pred_real, valid)
        pred_fake = discriminator(imgs_fake, imgs_distorted)
        loss_fake = Adv_cGAN(pred_fake, fake)
        loss_D = 0.5 * (loss_real + loss_fake) * 10.0
        loss_D.backward()
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        imgs_fake = generator(imgs_distorted)
        pred_fake = discriminator(imgs_fake, imgs_distorted)
        loss_GAN = Adv_cGAN(pred_fake, valid)
        loss_1 = L1_G(imgs_fake, imgs_good_gt)
        loss_con = L_vgg(imgs_fake, imgs_good_gt)
        loss_ssim = 1 - ssim(imgs_fake, imgs_good_gt, data_range=1.0, size_average=True)

        # === Total Generator Loss ===
        lambda_ssim = 2  # you can tune this
        loss_G = loss_GAN + lambda_1 * loss_1 + lambda_con * loss_con + lambda_ssim * loss_ssim
        loss_G.backward()
        optimizer_G.step()

        # === Logging ===
        if not i % 50:
            sys.stdout.write(
                "\r[Epoch %d/%d: batch %d/%d] [DLoss: %.3f, GLoss: %.3f, AdvLoss: %.3f, SSIMLoss: %.3f]"
                % (
                    epoch, num_epochs, i, len(dataloader),
                    loss_D.item(), loss_G.item(), loss_GAN.item(), loss_ssim.item()
                )
            )

        # === Validation ===
        batches_done = epoch * len(dataloader) + i
        if batches_done % val_interval == 0:
            imgs = next(iter(val_dataloader))
            imgs_val = Variable(imgs["val"].type(Tensor))
            imgs_gen = generator(imgs_val)
            img_sample = torch.cat((imgs_val.data, imgs_gen.data), -2)
            save_image(img_sample, f"samples/FunieGAN/{dataset_name}/{batches_done}.png", nrow=5, normalize=True)

    # === Checkpointing ===
    if (epoch % ckpt_interval == 0):
        torch.save(generator.state_dict(), f"checkpoints/FunieGAN/{dataset_name}/generator_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"checkpoints/FunieGAN/{dataset_name}/discriminator_{epoch}.pth")

    end_time = timeit.default_timer()
    elapsed_time = end_time - start_time
    print(f"\nElapsed time: {elapsed_time:.2f} seconds")


  valid = Variable(Tensor(np.ones((imgs_distorted.size(0), *patch))), requires_grad=False)


[Epoch 1/60: batch 100/114] [DLoss: 2.556, GLoss: 1.355, AdvLoss: 0.278, SSIMLoss: 0.202]
Elapsed time: 96.39 seconds
[Epoch 2/60: batch 100/114] [DLoss: 2.484, GLoss: 1.055, AdvLoss: 0.339, SSIMLoss: 0.139]
Elapsed time: 213.40 seconds
[Epoch 3/60: batch 100/114] [DLoss: 2.594, GLoss: 0.924, AdvLoss: 0.282, SSIMLoss: 0.119]
Elapsed time: 323.25 seconds
[Epoch 4/60: batch 100/114] [DLoss: 2.451, GLoss: 0.806, AdvLoss: 0.247, SSIMLoss: 0.122]
Elapsed time: 432.25 seconds
[Epoch 5/60: batch 100/114] [DLoss: 2.533, GLoss: 0.979, AdvLoss: 0.360, SSIMLoss: 0.110]
Elapsed time: 529.91 seconds
[Epoch 6/60: batch 100/114] [DLoss: 2.275, GLoss: 0.940, AdvLoss: 0.365, SSIMLoss: 0.105]
Elapsed time: 630.13 seconds
[Epoch 7/60: batch 100/114] [DLoss: 2.422, GLoss: 0.709, AdvLoss: 0.265, SSIMLoss: 0.086]
Elapsed time: 719.87 seconds
[Epoch 8/60: batch 100/114] [DLoss: 2.558, GLoss: 0.647, AdvLoss: 0.253, SSIMLoss: 0.072]
Elapsed time: 793.93 seconds
[Epoch 9/60: batch 100/114] [DLoss: 2.570, GLoss:

In [16]:
torch.save(generator.state_dict(), "checkpoints/FunieGAN/paired/generator_60.pth")
torch.save(discriminator.state_dict(), "checkpoints/FunieGAN/paired/discriminator_60.pth")