In [None]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, utils as vutils
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
from torchvision.models.inception import inception_v3
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
from tqdm import tqdm

In [None]:
ModelId= "<Constructed Model ID>" #Constructed Model ID, such as LB-EGAN_Tapered

In [None]:
import os
import torch
import torchvision.utils as vutils
from PIL import Image

def save_class_0_images(generator, latent_dim, save_dir=f"{ModelId}/{ModelId}_out", n_images=500):
    """
    Generate and save images for class 0 using the generator model.

    Parameters:
    - generator: The generator model.
    - latent_dim: Dimensionality of the generator's latent input.
    - save_dir: Directory where images will be saved.
    - n_images: Number of images to generate and save.
    """
    os.makedirs(save_dir, exist_ok=True)
    generator.eval()  # Set generator to evaluation mode

    # Generate images in batches
    batch_size = 16
    num_batches = (n_images + batch_size - 1) // batch_size  # Number of batches needed
    images_saved = 0
    

    with torch.no_grad():  # Disable gradient calculation for efficiency
        for batch_idx in range(num_batches):
            # Sample noise and set labels to 0 for generator input
            z = torch.randn(batch_size, latent_dim).to(torch.device('cuda'))
            #labels = torch.zeros(batch_size, dtype=torch.long).to(torch.device('cuda'))

            # Generate images
            gen_imgs = generator(z)

            # Save images in the batch
            for img_idx in range(min(batch_size, n_images - images_saved)):
                img_path = os.path.join(save_dir, f"image_{images_saved + img_idx + 1}.jpg")
                vutils.save_image(gen_imgs[img_idx], img_path, normalize=True)
                #print(f"Saved class 0 image at {img_path}")
            
            images_saved += batch_size
            if images_saved >= n_images:
                break

    print(f"Total {n_images} images saved in '{save_dir}' directory.")
    generator.train()  # Restore generator to training mode
    

def load_images_in_batches(folder, transform, batch_size=64):
    images = []
    for filename in os.listdir(folder):
        if filename.endswith(".jpg") or filename.endswith(".BMP"):
            img_path = os.path.join(folder, filename)
            try:
                img = Image.open(img_path).convert('RGB')
                img = transform(img)
                images.append(img)
            except Exception as e:
                print(f"Error loading image: {filename}, Error: {e}")
            
            if len(images) == batch_size:
                yield torch.stack(images)
                images = []  
    if len(images) > 0:  
        yield torch.stack(images)



def save_model(epoch,save_dir=f"{ModelId}/{ModelId}_out"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    generator_path = os.path.join(save_dir, f'generator_model.pth')
    torch.save(generator.state_dict(), generator_path)

In [None]:
def calculate_inception_score(fake_data_loader, splits=10):
    model = inception_v3(pretrained=True, transform_input=False).eval().to(torch.device('cuda'))
    preds = []
    
    with torch.no_grad():
        for images in tqdm(fake_data_loader, desc="Calculating Inception Score"):
            images = images.to(torch.device('cuda'))
            pred = F.softmax(model(images), dim=1).cpu().numpy()
            preds.append(pred)
    
    preds = np.concatenate(preds, axis=0)  
    
    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl_div = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
        kl_div = np.mean(np.sum(kl_div, axis=1))
        scores.append(np.exp(kl_div))
    
    return np.mean(scores), np.std(scores)

def create_fake_data_loader(generator, latent_dim, batch_size=32, num_samples=1000):
    transform = transforms.Compose([
        transforms.Resize((299, 299)),  
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
    ])
    fake_images = []
    with torch.no_grad():
        for _ in range(num_samples // batch_size):
            z = torch.randn(batch_size, latent_dim).to(torch.device('cuda'))
            #labels = torch.zeros(batch_size, dtype=torch.long).to(torch.device('cuda'))
            generated_images = generator(z).to(torch.device('cuda'))
            fake_images.append(transform(generated_images))
    fake_images = torch.cat(fake_images)
    return torch.utils.data.DataLoader(fake_images, batch_size=batch_size, shuffle=False)


In [None]:
os.makedirs(f"{ModelId}/{ModelId}", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20000, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=128, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=10, help="interval between image sampling")
opt, _ = parser.parse_known_args()
print(opt)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Ensemble_Generator(nn.Module):
    def __init__(self):
        super(Ensemble_Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


############# DRAGAN ###############

class GeneratorDG(nn.Module):
    def __init__(self):
        super(GeneratorDG, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):

        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class DiscriminatorDG(nn.Module):
    def __init__(self):
        super(DiscriminatorDG, self).__init__()


        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, images):

        output = self.model(images)
        output = output.view(output.shape[0], -1)
        output = self.adv_layer(output)
        
        return output


######## DCGAN ###########

class GeneratorDC(nn.Module):
    def __init__(self):
        super(GeneratorDC, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class DiscriminatorDC(nn.Module):
    def __init__(self):
        super(DiscriminatorDC, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

# Loss function
crit = torch.nn.BCELoss()



########### Ensemble GAN ###########

Ensemble_Generator = Ensemble_Generator()

if cuda:
    Ensemble_Generator.cuda()


########### DRAGAN Func. ###########

# Loss weight for gradient penalty
lambda_gp = 0.1

# Initialize generator and discriminator
generatorDG = GeneratorDG()
discriminatorDG = DiscriminatorDG()

if cuda:
    generatorDG.cuda()
    discriminatorDG.cuda()
    crit.cuda()

# Initialize weights
generatorDG.apply(weights_init_normal)
discriminatorDG.apply(weights_init_normal)


########## DCGAN Func. ############

# Initialize generator and discriminator
generatorDC = GeneratorDC()
discriminatorDC = DiscriminatorDC()

if cuda:
    generatorDC.cuda()
    discriminatorDC.cuda()

# Initialize weights
generatorDC.apply(weights_init_normal)
discriminatorDC.apply(weights_init_normal)


############ DATASET ###############

# Configure data loader
data_path =  "" # Should be the path of the original dataset

dataset = ImageFolder(
    root=data_path,
    transform=transforms.Compose([
        transforms.Resize((opt.img_size, opt.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5], [0.5])
    ])
)


dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)



############### DRAGAN OPT. #################

# Optimizers
optimizer_DG = torch.optim.Adam(generatorDG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_DD = torch.optim.Adam(discriminatorDG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


############### DCGAN OPT. #################

# Optimizers
optimizer_DC = torch.optim.Adam(generatorDC.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_DC = torch.optim.Adam(discriminatorDC.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


############## Other Func. #################

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

def compute_gradient_penalty(D, X):
    """Calculates the gradient penalty loss for DRAGAN"""
    
    alpha = FloatTensor(np.random.random(size=X.shape)).to(X.device)

    interpolates = alpha * X + ((1 - alpha) * (X + 0.5 * X.std() * torch.rand(X.size(), device=X.device)))
    interpolates = Variable(interpolates, requires_grad=True)

    d_interpolates = D(interpolates)

    fake = Variable(FloatTensor(X.shape[0], 1).fill_(1.0), requires_grad=False)

    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradient_penalty = lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    
    gen_imgs = Ensemble_Generator(z)
    save_image(gen_imgs.data, f"{ModelId}/{ModelId}/%d.png" % batches_done, nrow=n_row, normalize=True)



def FIDCalc(): #Calculate the FID score between the original and generated images. 
    transform = transforms.Compose([
        transforms.Resize((299,299)),  # Inception v3 
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Inception model 
        ])
    
    original_dataset_paths = ["<Original Dataset Path>"] #Should be the path of the original dataset.

    synthetic_dataset_path = f'.../{ModelId}/{ModelId}_out' #Should be the path of the generated images.
    synthetic_images = load_images_in_batches(synthetic_dataset_path, transform)

    fid = FrechetInceptionDistance(normalize=True).to(torch.device('cuda'))

    for original_dataset_path in original_dataset_paths:
        for batch in load_images_in_batches(original_dataset_path, transform):
            batch.size()
            fid.update(batch.to(torch.device('cuda')), real=True)

    for batch in load_images_in_batches(synthetic_dataset_path, transform):
        fid.update(batch.to(torch.device('cuda')), real=False)


    fid_score = fid.compute()
    
    return fid_score

def save_model(epoch,save_dir=f"{ModelId}/{ModelId}_Out"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    generator_path = os.path.join(save_dir, f'generator_model.pth')
    torch.save(Ensemble_Generator.state_dict(), generator_path)


# ----------
#  Training
# ----------
fid_p=9999

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        
        
        
        batch_size = imgs.shape[0]


        # Adversarial ground truths
        valid = Variable(FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))

        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))


        ############ DRAGAN TR ###############

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_DG.zero_grad()

        #gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

        # Generate a batch of images
        gen_imgsDG = generatorDG(z)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminatorDG(gen_imgsDG)
        dg_loss = crit(validity, valid)

        dg_loss.backward()
        optimizer_DG.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_DD.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        validity_real = discriminatorDG(real_imgs)
        d_real_loss = crit(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminatorDG(gen_imgsDG.detach())
        d_fake_loss = crit(validity_fake, fake)

        # Total discriminator loss
        dd_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminatorDG, real_imgs.data)
        dd_loss = gradient_penalty + dd_loss
        dd_loss.backward()

        optimizer_DD.step()

        ############## DCGAN TR. #################

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_DC.zero_grad()

        # Generate a batch of images
        gen_imgsDC = generatorDC(z)

        # Loss measures generator's ability to fool the discriminator
        DCg_loss = crit(discriminatorDC(gen_imgsDC), valid)

        DCg_loss.backward()
        optimizer_DC.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_DC.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        DCreal_loss = crit(discriminatorDC(real_imgs), valid)
        DCfake_loss = crit(discriminatorDC(gen_imgsDC.detach()), fake)
        DCd_loss = (DCreal_loss + DCfake_loss) / 2

        DCd_loss.backward()
        optimizer_DC.step()


        ############## Ensemble GAN ##############

        w1 = 1 / DCd_loss
        w2 = 1 / dd_loss
        total_weight = w1 + w2
        w1 /= total_weight
        w2 /= total_weight


        weighted_state_dict = {}
        for key in Ensemble_Generator.state_dict().keys():
            weighted_state_dict[key] = w1 * generatorDC.state_dict()[key] + w2 * generatorDG.state_dict()[key]



        Ensemble_Generator.load_state_dict(weighted_state_dict)

        generatorDC.load_state_dict(Ensemble_Generator.state_dict())
        generatorDG.load_state_dict(Ensemble_Generator.state_dict())

        print(
            "[Epoch %d/%d] [Batch %d/%d] [DCD loss: %f] [DD loss: %f] [DCG loss: %f] [DG loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), DCd_loss.item(), dd_loss.item(), DCg_loss.item(), dg_loss.item())
        )
        
        batches_done = epoch * len(dataloader) + i
        epoch_done = epoch
        if epoch_done % opt.sample_interval == 0 and i == 2:
            sample_image(n_row=4, batches_done=epoch_done)
            save_class_0_images(generator = Ensemble_Generator, latent_dim = opt.latent_dim)
            fid_c = FIDCalc()  
            f = open(f"{ModelId}/FID_Scores_{ModelId}.txt", "a")
            f.write("\n" + str(fid_c))
            if fid_p > fid_c:
                save_model(epoch)  
                fid_p = fid_c  
                print(fid_p)
            fake_loader = create_fake_data_loader(Ensemble_Generator, latent_dim=100)
            is_mean, is_std = calculate_inception_score(fake_loader)
            print(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}")
            I = open(f"{ModelId}/IS_{ModelId}.txt", "a")
            I.write("\n" + str(is_mean))