In [None]:
"""
KL-GAN with Multiple Adversarial Objectives
--------------------------------------------
This script showcases training KL-GAN on CelebA with options for:
    KL-GAN, LS-GAN, WGAN-GP, Hinge-GAN, R1-GAN.

We run 5 different seeds for each method.
"""
import math
import os
import gc
import cv2
import random
import wandb
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import albumentations as A
import albumentations.pytorch as AP

from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar, RichProgressBarTheme
from torchvision import transforms, models
from torch.distributions import Normal, kl_divergence, Independent


# -----------------------
# 1. Feature extraction functions for FID-like calculations
# -----------------------
inception_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
inception_model.cuda().eval()
inception_model.fc = nn.Identity()
for param in inception_model.parameters():
    param.requires_grad = False

def get_features(real_image, fake_image, eps=1e-3):
    """
    Extracts features from real and fake images using Inception v3 and returns
    their mean and covariance.
    """
    real_features = inception_model(real_image)
    fake_features = inception_model(fake_image)
    mu_real, cov_real = real_features.mean(0), torch.cov(real_features.permute(1,0))
    mu_fake, cov_fake = fake_features.mean(0), torch.cov(fake_features.permute(1,0))

    cov_real += eps * torch.eye(cov_real.size(0)).to(cov_real.device)
    cov_fake += eps * torch.eye(cov_fake.size(0)).to(cov_fake.device)

    return mu_real.float(), cov_real.float(), mu_fake.float(), cov_fake.float()


RegNet_model = models.regnet_x_3_2gf(weights="RegNet_X_3_2GF_Weights.DEFAULT")
RegNet_model.cuda().eval()
RegNet_model.fc = nn.Identity()
for param in RegNet_model.parameters():
    param.requires_grad = False

def get_features_RegNet(real_image, fake_image, eps=1e-3):
    """
    Extracts features from real and fake images using RegNet_x_3_2gf and returns
    their mean and covariance.
    """
    real_features = RegNet_model(real_image)
    fake_features = RegNet_model(fake_image)

    mu_real, cov_real = real_features.mean(0), torch.cov(real_features.permute(1,0))
    mu_fake, cov_fake = fake_features.mean(0), torch.cov(fake_features.permute(1,0))

    cov_real += eps * torch.eye(cov_real.size(0)).to(cov_real.device)
    cov_fake += eps * torch.eye(cov_fake.size(0)).to(cov_fake.device)

    return mu_real.float(), cov_real.float(), mu_fake.float(), cov_fake.float()


def calculate_kl_divergence(mu_real, cov_real, mu_fake, cov_fake):
    """
    Symmetric KL divergence between two multivariate Gaussians.
    """
    true_dist = torch.distributions.MultivariateNormal(mu_real.to(torch.float32), cov_real.to(torch.float32))
    fake_dist = torch.distributions.MultivariateNormal(mu_fake.to(torch.float32), cov_fake.to(torch.float32))
    return 0.5 * (
        torch.distributions.kl_divergence(fake_dist, true_dist)
        + torch.distributions.kl_divergence(true_dist, fake_dist)
    )

def calculate_fid(mu_real, cov_real, mu_fake, cov_fake):
    """
    Calculate the FID score between two distributions parameterized by
    (mu_real, cov_real) and (mu_fake, cov_fake).
    """
    mu_real = mu_real.to(torch.float64)
    cov_real = cov_real.to(torch.float64)
    mu_fake = mu_fake.to(torch.float64)
    cov_fake = cov_fake.to(torch.float64)
    
    # (mu1 - mu2)^2
    a = (mu_real - mu_fake).square().sum(dim=-1)
    
    # trace(cov1 + cov2 - 2*sqrt(cov1 cov2))
    b = cov_real.trace() + cov_fake.trace()
    product = cov_real @ cov_fake
    c = 2 * torch.real(torch.linalg.eigvals(product).sqrt().sum())

    return float(a + b - c)


# -----------------------
# 2. Data loading: CelebA
# -----------------------
class CelebADataset(Dataset):
    """
    PyTorch dataset wrapper for preloaded GPU tensors of images.
    """
    def __init__(self, data_tensor):
        self.data = data_tensor

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModule(pl.LightningDataModule):
    """
    DataModule for CelebA dataset. Adjusts transforms and splits into train/val.
    """
    def __init__(self, batch_size, val_batch_size, data_dir="./img_align_celeba/img_align_celeba"):
        super().__init__()
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.data_dir = data_dir

        self.train_transform = A.Compose([
            A.Resize(32, 32),
            A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            AP.ToTensorV2(),
        ])

        self.val_transform = A.Compose([
            A.Resize(32, 32),
            A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            AP.ToTensorV2(),
        ])

    def setup(self, stage=None):
        image_files = os.listdir(self.data_dir)
        images = []

        for img_file in image_files:
            image_path = os.path.join(self.data_dir, img_file)
            image = cv2.imread(image_path)
            if image is None:
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = self.train_transform(image=image)["image"]
            images.append(image)

        data_tensor = torch.stack(images)
        data_tensor = data_tensor.to(device=torch.device('cuda'), dtype=torch.bfloat16)

        self.dataset = CelebADataset(data_tensor)
        split = int(0.9 * len(self.dataset))
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(
            self.dataset, [split, len(self.dataset) - split]
        )

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=0,
            pin_memory=False
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val, 
            batch_size=self.val_batch_size, 
            shuffle=False, 
            num_workers=0,
            pin_memory=False
        )


# -----------------------
# 3. (Optional) Latent code for storing precomputed latents
# -----------------------
@staticmethod
def scale_latents_to_minus_one_one(x):
    """Scale raw latents -> [-1, 1]."""
    x_scaled = x.div(2 * 3).add(0.5).clamp(0, 1)  # to [0, 1]
    return x_scaled.mul(2).sub(1)  # to [-1, 1]

@staticmethod
def unscale_latents_from_minus_one_one(x):
    """Scale [-1, 1] latents -> raw latents."""
    x_zero_one = x.add(1).div(2)      # to [0, 1]
    return x_zero_one.sub(0.5).mul(2 * 3)

class LatentDataset(Dataset):
    """
    Stores latent codes from *.pt files in a specified directory.
    """
    def __init__(self, latent_dir):
        self.latent_dir = latent_dir
        self.latent_files = [f for f in os.listdir(latent_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self, idx):
        latent_path = os.path.join(self.latent_dir, self.latent_files[idx])
        latent = torch.load(latent_path)
        return latent

class LatentDataModule(pl.LightningDataModule):
    """
    DataModule for loading latent *.pt files.
    """
    def __init__(self, batch_size, val_batch_size, latent_dir):
        super().__init__()
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.latent_dir = latent_dir

    def setup(self, stage=None):
        self.dataset_train = LatentDataset(latent_dir=self.latent_dir)
        self.dataset_val = LatentDataset(latent_dir=self.latent_dir)

    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.val_batch_size, shuffle=False, num_workers=8)


# -----------------------
# 4. Helper functions for KL-based training
# -----------------------
def compute_mean_std(features, epsilon=1e-10):
    mu = features.mean(dim=0)
    var = features.var(dim=0, unbiased=False) + epsilon
    return mu, var

def symmetric_kl_divergence(real_features, fake_features):
    """
    Computes symmetric KL divergence between real and fake features.
    """
    mu_real, std_real = compute_mean_std(real_features)
    mu_fake, std_fake = compute_mean_std(fake_features)
    real_dist = Independent(Normal(mu_real, std_real), 1)
    fake_dist = Independent(Normal(mu_fake, std_fake), 1)

    kl_real_fake = kl_divergence(real_dist, fake_dist)
    kl_fake_real = kl_divergence(fake_dist, real_dist)
    return 0.5 * (torch.log1p(kl_real_fake) + torch.log1p(kl_fake_real))


# -----------------------
# 5. MinibatchDiscrimination module
# -----------------------
class MinibatchDiscrimination(nn.Module):
    """
    A layer to help reduce mode collapse by comparing samples in a batch
    against each other.
    """
    def __init__(self, in_features, out_features, kernel_dims, mean=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.kernel_dims = kernel_dims
        self.mean = mean
        self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims))
        nn.init.normal_(self.T, 0, 1)

    def forward(self, x):
        # x shape: NxA
        # T shape: AxBxC
        matrices = x.mm(self.T.view(self.in_features, -1))
        matrices = matrices.view(-1, self.out_features, self.kernel_dims)
        M = matrices.unsqueeze(0)   # 1xNxBxC
        M_T = M.permute(1, 0, 2, 3) # Nx1xBxC

        norm = torch.abs(M - M_T).sum(3)   # NxNxB
        expnorm = torch.exp(-norm)
        o_b = (expnorm.sum(0) - 1)         # NxB
        if self.mean:
            o_b /= x.size(0) - 1
        x = torch.cat([x, o_b], 1)
        return x


# -----------------------
# 6. Generator and Discriminator with DIM as a hyperparameter
# -----------------------
class Generator(nn.Module):
    """
    KL-GAN Generator with trainable distribution parameters (mu, log_sigma).
    """
    def __init__(self, latent_dim=128, dim=128):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.dim = dim

        # Trainable distribution parameters
        self.mu = nn.Parameter(torch.zeros(latent_dim))
        self.log_sigma = nn.Parameter(torch.zeros(latent_dim))

        # Main architecture
        self.preprocess = nn.Sequential(
            nn.Linear(latent_dim, 4 * 4 * dim),
            nn.Mish(),
            nn.LayerNorm(4 * 4 * dim)
        )
        self.block1 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(4 * dim, 2 * dim, 3, 1, 1),
            nn.Mish(),
            nn.ConvTranspose2d(2 * dim, 2 * dim, 5),
            nn.Mish(),
            nn.InstanceNorm2d(2 * dim, affine=True),
            nn.Conv2d(2 * dim, 2 * dim, 3, 1, 1),
            nn.Mish()
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(2 * dim, 2 * dim, 4, 2, 1),
            nn.Mish(),
            nn.InstanceNorm2d(2 * dim, affine=True),
            nn.Conv2d(2 * dim, 2 * dim, 3, 1, 1),
            nn.Mish()
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(2 * dim, dim, 4, 2, 1),
            nn.Mish(),
            nn.InstanceNorm2d(dim, affine=True),
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.Mish()
        )
        self.deconv_out = nn.Conv2d(dim, 3, 3, 1, 1)
        self.tanh = nn.Tanh()

    def forward(self, eps):
        sigma = torch.exp(self.log_sigma)
        z = self.mu + sigma * eps
        
        out = self.preprocess(z)
        out = out.view(-1, 4 * self.dim, 2, 2)
        out = self.block1(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.deconv_out(out)
        out = self.tanh(out)
        return out

from torch.nn.utils import spectral_norm as SpectralNorm
class Discriminator(nn.Module):
    def __init__(self, type_model, minibatch_shader=False, dim=128, use_minibatch=True):
        super(Discriminator, self).__init__()
        self.type_model = type_model
        self.minibatch_shader = minibatch_shader
        self.dim = dim
        self.use_minibatch = use_minibatch

        self.main = nn.Sequential(
            nn.Conv2d(3, dim, 5, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2),
            
            nn.Conv2d(dim, 2 * dim, 5, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2),
            
            nn.Conv2d(2 * dim, 4 * dim, 5, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2),
            
            nn.Conv2d(4 * dim, 8 * dim, 2, stride=2),
            nn.LeakyReLU(negative_slope=0.2),
            
            nn.Conv2d(8 * dim, 16 * dim, 2, stride=2),
            nn.LeakyReLU(negative_slope=0.2),
        )
        
        self.output = nn.Linear(16 * dim, 8)
        if use_minibatch:
            self.minibatch = MinibatchDiscrimination(8, 8, 1)
            self.final = nn.Linear(16, 1)
        else:
            self.final = nn.Linear(8, 1)

    def forward(self, x):
        if self.type_model == "KL-GAN":
            return self.forward_kl(x)
        else:
            return self.forward_average(x)

    def forward_average(self, x):
        out = self.main(x)
        out = out.flatten(1)
        out = self.output(out)
        if self.use_minibatch:
            out = self.minibatch(out)
        out = self.final(out)
        return out

    def forward_kl(self, x):
        """
        For KL-GAN, we chunk real/fake images from x along the batch dimension.
        """
        out = self.main(x)
        out = out.flatten(1)
        out = self.output(out)
        
        if self.use_minibatch:
            if self.minibatch_shader:
                real_features, fake_features = self.minibatch(out).chunk(2, dim=0)
                return symmetric_kl_divergence(real_features, fake_features)
            else:
                real_features, fake_features = out.chunk(2, dim=0)
                return symmetric_kl_divergence(
                    self.minibatch(real_features),
                    self.minibatch(fake_features)
                )
        else:
            real_features, fake_features = out.chunk(2, dim=0)
            return symmetric_kl_divergence(real_features, fake_features)



# -----------------------
# 8. Multi-scale generator and discriminator
# -----------------------
class StableFlowGenerator(nn.Module):
    """
    Multi-scale generator with "to_rgb" at each level.
    Generation starts at 4x4, then 8x8, 16x16, etc. up to the specified resolution.
    """
    def __init__(
        self,
        latent_dim=128,
        resolution=32,
        dim=128
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.resolution = resolution
        
        # Calculate number of levels needed (e.g., if resolution=32, levels will be 4->8->16->32)
        self.num_levels = int(math.log2(self.resolution)) - 2  # for 32 -> (log2(32)=5) -> 5 - 2 = 3 levels

        # Distribution parameters (as in KL-GAN Generator)
        self.mu = nn.Parameter(torch.zeros(latent_dim))
        self.log_sigma = nn.Parameter(torch.zeros(latent_dim))

        # Initial block to project latent to 4x4 spatial tensor
        initial_channels = dim * 4
        self.initial = nn.Sequential(
            nn.Linear(latent_dim, 4 * 4 * initial_channels),
            nn.Mish()
        )
        self.initial_norm = nn.LayerNorm([initial_channels, 4, 4])

        # Create blocks for growing resolution
        self.blocks = nn.ModuleList()
        self.to_rgb = nn.ModuleList()
        
        in_channels = initial_channels
        for level in range(self.num_levels):
            out_channels = max(dim, in_channels // 2)
            block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.Mish(),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.Mish(),
                nn.InstanceNorm2d(out_channels, affine=True),
            )
            self.blocks.append(block)
            self.to_rgb.append(
                nn.Conv2d(out_channels, 3, kernel_size=1)  # to_rgb for each level
            )
            in_channels = out_channels

        self.tanh = nn.Tanh()

    def forward(self, eps):
        sigma = torch.exp(self.log_sigma)
        z = self.mu + sigma * eps  # sample from parameterized distribution

        # Initial 4x4 tensor
        out = self.initial(z)
        N = out.size(0)
        # Transform to (N, in_channels, 4, 4)
        initial_channels = self.blocks[0][1].in_channels if len(self.blocks) > 0 else self.initial_norm.normalized_shape[0]
        out = out.view(N, initial_channels, 4, 4)
        out = self.initial_norm(out)

        images = []
        current = out
        for block, to_rgb in zip(self.blocks, self.to_rgb):
            current = block(current)
            rgb = to_rgb(current)
            rgb = self.tanh(rgb)
            images.append(rgb)

        # Return the *entire* list: [4x4, 8x8, 16x16, ..., resolution]
        return images


# -----------------------
# Multi-scale discriminator
# -----------------------
class StableFlowDiscriminator(nn.Module):
    """
    Multi-scale discriminator that takes a list of images
    of different sizes [4x4, 8x8, ..., resolution x resolution].
    """
    def __init__(
        self,
        resolution=32,
        dim=128,
        type_model="KL-GAN",
        minibatch_shader=False,
        use_minibatch=True,
    ):
        super().__init__()
        self.resolution = resolution
        self.type_model = type_model
        self.minibatch_shader = minibatch_shader
        self.use_minibatch = use_minibatch

        self.num_levels = int(math.log2(self.resolution)) - 2

        # from_rgb blocks + discriminator blocks
        self.from_rgb = nn.ModuleList()
        self.blocks = nn.ModuleList()

        # Based on the generator, at the beginning (4x4) we had dim*4 channels
        in_channels = dim * 4
        prev_channels = 0  # for tracking channels from previous level

        for level in range(self.num_levels):
            out_channels = max(dim, in_channels // 2)

            # from_rgb remains unchanged
            frgb = nn.Conv2d(3, out_channels, kernel_size=1)
            self.from_rgb.append(frgb)

            # Discriminator block now takes out_channels + prev_channels
            block = nn.Sequential(
                nn.Conv2d(out_channels + prev_channels, out_channels, 3, padding=1),
                nn.LeakyReLU(0.2),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.LeakyReLU(0.2),
                nn.AvgPool2d(2)
            )
            self.blocks.append(block)

            prev_channels = out_channels  # save for next level
            in_channels = out_channels

        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(in_channels * 4 * 4, 8)  # equivalent to output
        if self.use_minibatch:
            self.minibatch = MinibatchDiscrimination(8, 8, 1)
            self.final = nn.Linear(16, 1)
        else:
            self.final = nn.Linear(8, 1)

    def forward(self, multi_res_images):
        """
        multi_res_images is a list of [4x4, 8x8, ..., resolution].
        But the order is usually from smaller to larger.
        In the code below we want to go from larger resolution to smaller -
        so we reverse the list.
        
        If KL-GAN, then in each tensor batch = 2N (concatenated real/fake).
        """
        if self.type_model == "KL-GAN":
            return self.forward_kl(multi_res_images)
        else:
            return self.forward_average(multi_res_images)

    def forward_average(self, x):
        features = self.forward_multiscale(x)  
        if self.use_minibatch:
            features = self.minibatch(features)
        out = self.final(features)
        return out

    def forward_kl(self, x):
        """
        x[i] is [2N, 3, H, W].
        After forward_multiscale we get [2N, 8] (after fc).
        Then chunk into real/fake and compute symmetric_kl_divergence.
        """
        features = self.forward_multiscale(x)
        if self.use_minibatch:
            if self.minibatch_shader:
                real_f, fake_f = self.minibatch(features).chunk(2, dim=0)
                return symmetric_kl_divergence(real_f, fake_f)
            else:
                real_chunk, fake_chunk = features.chunk(2, dim=0)
                return symmetric_kl_divergence(
                    self.minibatch(real_chunk),
                    self.minibatch(fake_chunk)
                )
        else:
            real_chunk, fake_chunk = features.chunk(2, dim=0)
            return symmetric_kl_divergence(real_chunk, fake_chunk)

    def forward_multiscale(self, multi_res_images):
        # Reverse the list to go from larger resolution to smaller
        multi_res_images = multi_res_images[::-1]

        x = None
        for img, frgb, block in zip(multi_res_images, self.from_rgb, self.blocks):
            # from_rgb
            feat = frgb(img)
            if x is None:
                x = feat
            else:
                # concatenate features
                x = torch.cat([x, feat], dim=1)
            # convolutional block + avgpool
            x = block(x)
        x = self.final_conv(x)  # more convolutions at 4x4
        x = x.view(x.size(0), -1)
        x = self.fc(x)          # -> [batch, 8]
        return x                # next minibatch + final
        
# -----------------------
# 7. PyTorch Lightning Module for training
# -----------------------
class GAN_Training(pl.LightningModule):
    """
    LightningModule for training KL-GAN (and other variants).
    Added hyperparameter use_multiscale to enable multi-scale mode.
    """
    def __init__(
        self,
        learning_rate: float = 0.00002,
        batch_size: int = 256,
        seed_value: int = 1,
        type_model: str = "KL-GAN",
        latent_dim: int = 128,
        dim: int = 128,
        use_minibatch: bool = True,
        # New hyperparameter
        use_multiscale: bool = False,
        resolution: int = 32,   # to pass to multi-scale G/D
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        
        if seed_value is not None:
            pl.seed_everything(seed_value)
            torch.manual_seed(seed_value)
            torch.cuda.manual_seed(seed_value)
            random.seed(seed_value)

        # If multi-scale mode is enabled
        if self.hparams.use_multiscale:
            # Multi-scale G/D
            self.generator = StableFlowGenerator(
                latent_dim=latent_dim, 
                resolution=resolution,
                dim=dim
            )
            self.discriminator = StableFlowDiscriminator(
                resolution=resolution,
                dim=dim,
                type_model=type_model,
                use_minibatch=use_minibatch
            )
        else:
            # Standard implementations
            self.generator = Generator(latent_dim=latent_dim, dim=dim)
            self.discriminator = Discriminator(
                type_model=type_model, 
                dim=dim,
                use_minibatch=use_minibatch
            )
        

    def compute_gradient_penalty(self, real_data, fake_data):
        """
        Gradient penalty for WGAN-GP.
        """
        alpha = torch.rand(real_data.size(0), 1, 1, 1, device=self.device)
        alpha = alpha.expand(real_data.size())

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        interpolates = torch.autograd.Variable(interpolates, requires_grad=True)

        disc_interpolates = self.discriminator(interpolates)
        grad_outputs = torch.ones(disc_interpolates.size(), device=self.device)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def r1_penalty(self, real_data):
        """
        R1 penalty for real images (StyleGAN approach).
        """
        real_data.requires_grad = True
        real_pred = self.discriminator(real_data)
        grad_real = torch.autograd.grad(
            outputs=real_pred.sum(),
            inputs=real_data,
            create_graph=True
        )[0]
        r1_reg = torch.mean(grad_real.pow(2).sum(dim=[i for i in range(1, grad_real.ndim)]))
        return r1_reg

    def on_validation_epoch_end(self):
        gc.collect()
        # Save checkpoint every 50 epochs
        if self.current_epoch % 50 == 0:
            self.trainer.save_checkpoint(filepath=f"./checkpoint_{self.hparams.type_model}_seed{self.hparams.seed_value}.ckpt")

    def diversity_loss(self, fake_images):
        """
        A simple diversity measure: average pairwise L2 distances among samples.
        Minimizing the negative encourages more variety.
        """
        batch_size = fake_images.size(0)
        fake_images_flat = fake_images.view(batch_size, -1)
        distances = torch.pdist(fake_images_flat, p=2)
        diversity = distances.mean()
        return -diversity
    def training_step(self, batch, batch_idx):
        """
        Manual training loop step that alternates between G and D updates 
        based on the chosen adversarial objective.
        """
        true = batch  # [N, 3, H, W]
        noise = torch.randn((true.shape[0], self.hparams.latent_dim), device=self.device)

        optimizer_dis = self.optimizers()[1]
        optimizer_gen = self.optimizers()[0]
        optimizer_dis.zero_grad()
        optimizer_gen.zero_grad()

        # If multiscale mode, prepare real_list, fake_list and follow the logic
        if self.hparams.use_multiscale:
            real_list = tensor_to_multiscale(true, max_resolution=self.hparams.resolution, min_resolution=8)
            fake_list = self.generator(noise)  # [8x8, ..., resolution]
            if batch_idx % 8 == 0:
                self.log_data(real_list[-1], fake_list[-1])
            # Choose loss function
            if self.hparams.type_model == "KL-GAN":
                # Combine real+fake
                combined_list = combine_real_fake_for_kl(real_list, fake_list)
                kl = self.discriminator(combined_list)  # forward_kl
                self.log('Fake_dist/Train', kl, prog_bar=True, on_epoch=True)

                # Generator: backward with -KL
                (-kl).backward()
                for name, param in self.generator.named_parameters():
                    if param.grad is not None:
                        param.grad.data = -param.grad.data
                optimizer_gen.step()
                optimizer_dis.step()
                return fake_list[-1].detach()  # return the largest level
                
            elif self.hparams.type_model == "LS-GAN":
                return self.ls_gan_step_multiscale(optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx)
            elif self.hparams.type_model == "WGAN-GP":
                return self.wgan_gp_step_multiscale(optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx)
            elif self.hparams.type_model == "Hinge-GAN":
                return self.hinge_gan_step_multiscale(optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx)
            elif self.hparams.type_model == "R1-GAN":
                return self.r1_regularized_hinge_step_multiscale(optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx)

        else:
            # If NOT multiscale mode, use old logic:
            if self.hparams.type_model == "KL-GAN":
                fake = self.kl_gan_step(optimizer_gen, optimizer_dis, noise, true, batch_idx)
            elif self.hparams.type_model == "LS-GAN":
                fake = self.ls_gan_step(optimizer_gen, optimizer_dis, noise, true, batch_idx)
            elif self.hparams.type_model == "WGAN-GP":
                fake = self.wgan_gp_step(optimizer_gen, optimizer_dis, noise, true, batch_idx)
            elif self.hparams.type_model == "Hinge-GAN":
                fake = self.hinge_gan_step(optimizer_gen, optimizer_dis, noise, true, batch_idx)
            elif self.hparams.type_model == "R1-GAN":
                fake = self.r1_regularized_hinge_step(optimizer_gen, optimizer_dis, noise, true, batch_idx)

            return fake

    def log_data(self, true, fake):
        with torch.no_grad():
            mu_real, cov_real, mu_fake, cov_fake = get_features_RegNet(true, fake, eps=10)
            kl_div = calculate_kl_divergence(mu_real, cov_real, mu_fake, cov_fake)
            self.log('KL_divergence/Train', kl_div, prog_bar=True, on_epoch=True)

            dv = self.diversity_loss(fake)
            self.log('Diversity/Train', -dv, prog_bar=True, on_epoch=True)

    # Different adversarial objectives (original)
    def ls_gan_step(self, optimizer_gen, optimizer_dis, noise, true, batch_idx):
        # Generator step
        with optimizer_gen.toggle_model():
            fake = self.generator(noise)
            if batch_idx % 8 == 0:
                self.log_data(true, fake)
            g_loss = self.discriminator(fake)
            self.log('g_loss_ls_gan/Train', g_loss.mean(), prog_bar=True, on_epoch=True)
            torch.mean(g_loss ** 2).backward()
            optimizer_gen.step()

        # Discriminator step
        fake = fake.detach()
        with optimizer_dis.toggle_model():
            true_loss = self.discriminator(true)
            fake_loss = self.discriminator(fake)
            self.log('true_loss_ls_gan/Train', true_loss.mean(), prog_bar=True, on_epoch=True)
            self.log('fake_loss_ls_gan/Train', fake_loss.mean(), prog_bar=True, on_epoch=True)

            true_loss = torch.mean((true_loss) ** 2)
            fake_loss = torch.mean((fake_loss - 1) ** 2)
            loss = (true_loss + fake_loss)
            loss.backward()
            optimizer_dis.step()

        return fake

    def kl_gan_step(self, optimizer_gen, optimizer_dis, noise, true, batch_idx):
        fake = self.generator(noise)
        if batch_idx % 8 == 0:
            self.log_data(true, fake)
        kl = self.discriminator(torch.cat([true, fake], dim=0))
        self.log('Fake_dist/Train', kl, prog_bar=True, on_epoch=True)

        (-kl).backward()
        for name, param in self.generator.named_parameters():
            if param.grad is not None:
                param.grad.data = -param.grad.data

        optimizer_gen.step()
        optimizer_dis.step()
        return fake.detach()

    def wgan_gp_step(self, optimizer_gen, optimizer_dis, noise, true, batch_idx):
        with optimizer_gen.toggle_model():
            fake = self.generator(noise)
            if batch_idx % 8 == 0:
                self.log_data(true, fake)
            gen_loss = -self.discriminator(fake).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake = fake.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(true)
            fake_pred = self.discriminator(fake)
            gp = self.compute_gradient_penalty(true, fake)
            dis_loss = (fake_pred.mean() - real_pred.mean()) + 10.0 * gp
            dis_loss.backward()
            optimizer_dis.step()

        self.log('g_loss_wgan_gp/Train', gen_loss, prog_bar=True, on_epoch=True)
        self.log('d_loss_wgan_gp/Train', dis_loss, prog_bar=True, on_epoch=True)
        return fake

    def hinge_gan_step(self, optimizer_gen, optimizer_dis, noise, true, batch_idx):
        with optimizer_gen.toggle_model():
            fake = self.generator(noise)
            if batch_idx % 8 == 0:
                self.log_data(true, fake)
            gen_loss = -self.discriminator(fake).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake = fake.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(true)
            fake_pred = self.discriminator(fake)
            d_loss_real = torch.mean(F.relu(1 - real_pred))
            d_loss_fake = torch.mean(F.relu(1 + fake_pred))
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_dis.step()

        self.log('g_loss_hinge/Train', gen_loss, prog_bar=True, on_epoch=True)
        self.log('d_loss_hinge/Train', d_loss, prog_bar=True, on_epoch=True)
        return fake

    def r1_regularized_hinge_step(self, optimizer_gen, optimizer_dis, noise, true, batch_idx):
        with optimizer_gen.toggle_model():
            fake = self.generator(noise)
            if batch_idx % 8 == 0:
                self.log_data(true, fake)
            gen_loss = -self.discriminator(fake).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake = fake.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(true)
            fake_pred = self.discriminator(fake)
            d_loss_real = torch.mean(F.relu(1 - real_pred))
            d_loss_fake = torch.mean(F.relu(1 + fake_pred))
            d_loss = d_loss_real + d_loss_fake
            r1 = self.r1_penalty(true) * 10.0
            total_loss = d_loss + r1
            total_loss.backward()
            optimizer_dis.step()

        self.log('g_loss_r1_hinge/Train', gen_loss, prog_bar=True, on_epoch=True)
        self.log('d_loss_r1_hinge/Train', d_loss, prog_bar=True, on_epoch=True)
        self.log('r1_penalty/Train', r1, prog_bar=True, on_epoch=True)
        return fake

    # New multiscale methods (following similar logic at the "last" level)
    def ls_gan_step_multiscale(self, optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx):
        with optimizer_gen.toggle_model():
            fake_high = fake_list[-1]
            true_high = real_list[-1]
            if batch_idx % 8 == 0:
                self.log_data(true_high, fake_high)
            g_loss = self.discriminator(fake_list)
            torch.mean(g_loss ** 2).backward()
            optimizer_gen.step()

        fake_high = fake_high.detach()
        with optimizer_dis.toggle_model():
            true_loss = self.discriminator(real_list)
            fake_loss = self.discriminator(fake_list)
            true_loss = torch.mean((true_loss) ** 2)
            fake_loss = torch.mean((fake_loss - 1) ** 2)
            loss = (true_loss + fake_loss)
            loss.backward()
            optimizer_dis.step()

        return fake_high

    def wgan_gp_step_multiscale(self, optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx):
        with optimizer_gen.toggle_model():
            fake_high = fake_list[-1]
            true_high = real_list[-1]
            if batch_idx % 8 == 0:
                self.log_data(true_high, fake_high)
            gen_loss = -self.discriminator(fake_list).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake_high = fake_high.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(real_list)
            fake_pred = self.discriminator(fake_list)
            # GP at the last level:
            gp = self.compute_gradient_penalty(true_high, fake_high)
            dis_loss = (fake_pred.mean() - real_pred.mean()) + 10.0 * gp
            dis_loss.backward()
            optimizer_dis.step()

        self.log('g_loss_wgan_gp/Train', gen_loss, prog_bar=True, on_epoch=True)
        self.log('d_loss_wgan_gp/Train', dis_loss, prog_bar=True, on_epoch=True)
        return fake_high

    def hinge_gan_step_multiscale(self, optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx):
        with optimizer_gen.toggle_model():
            fake_high = fake_list[-1]
            true_high = real_list[-1]
            if batch_idx % 8 == 0:
                self.log_data(true_high, fake_high)
            gen_loss = -self.discriminator(fake_list).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake_high = fake_high.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(real_list)
            fake_pred = self.discriminator(fake_list)
            d_loss_real = torch.mean(F.relu(1 - real_pred))
            d_loss_fake = torch.mean(F.relu(1 + fake_pred))
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_dis.step()

        self.log('g_loss_hinge/Train', gen_loss, prog_bar=True, on_epoch=True)
        self.log('d_loss_hinge/Train', d_loss, prog_bar=True, on_epoch=True)
        return fake_high

    def r1_regularized_hinge_step_multiscale(self, optimizer_gen, optimizer_dis, noise, real_list, fake_list, batch_idx):
        with optimizer_gen.toggle_model():
            fake_high = fake_list[-1]
            true_high = real_list[-1]
            if batch_idx % 8 == 0:
                self.log_data(true_high, fake_high)
            gen_loss = -self.discriminator(fake_list).mean()
            gen_loss.backward()
            optimizer_gen.step()

        fake_high = fake_high.detach()
        with optimizer_dis.toggle_model():
            real_pred = self.discriminator(real_list)
            fake_pred = self.discriminator(fake_list)
            d_loss_real = torch.mean(F.relu(1 - real_pred))
            d_loss_fake = torch.mean(F.relu(1 + fake_pred))
            d_loss = d_loss_real + d_loss_fake
            r1 = self.r1_penalty(true_high) * 10.0
            total_loss = d_loss + r1
            total_loss.backward()
            optimizer_dis.step()

        return fake_high

    def validation_step(self, batch, batch_idx):
        if batch_idx == 0:  # Initialize accumulators with first batch
            self.val_features_real = []
            self.val_features_fake = []
            
            # Log real and fake images as before
            real = batch
            resize_transform = transforms.Resize((299, 299), antialias=True)
            real_resized = resize_transform(real)
            
            if self.current_epoch == 0:
                self.logger.experiment.log({
                    "Real": [wandb.Image(
                        real_resized[0].permute(1,2,0).detach().float().cpu().numpy(),
                        caption=" "
                    )]
                })
                self.noise_fixed = torch.randn((10000, self.hparams.latent_dim), device=self.device)

        # Collect features for FID in batches
        if batch_idx * batch.shape[0] < 10000:
            real = batch
            resize_transform = transforms.Resize((299, 299), antialias=True)
            real_resized = resize_transform(real)
            
            # Generate fakes for current batch
            current_noise = self.noise_fixed[batch_idx * batch.shape[0]:(batch_idx + 1) * batch.shape[0]]
            if self.hparams.use_multiscale:
                fake_list = self.generator(current_noise)
                fake = fake_list[-1]  # take the last level
            else:
                fake = self.generator(current_noise)
            fake_resized = resize_transform(fake)

            # Get features through inception
            with torch.no_grad():
                real_features = inception_model(real_resized)
                fake_features = inception_model(fake_resized)
                
                self.val_features_real.append(real_features.cpu())
                self.val_features_fake.append(fake_features.cpu())

        # Calculate FID in the last batch
        if (batch_idx + 1) * batch.shape[0] >= 10000 and self.val_features_real != None:
            # Collect all features
            all_real_features = torch.cat(self.val_features_real, dim=0)[:10000]
            all_fake_features = torch.cat(self.val_features_fake, dim=0)[:10000]
            
            # Calculate statistics
            mu_real = all_real_features.mean(0)
            mu_fake = all_fake_features.mean(0)
            
            cov_real = torch.cov(all_real_features.permute(1,0))
            cov_fake = torch.cov(all_fake_features.permute(1,0))
            
            # Add eps for numerical stability
            eps = 1e-8
            cov_real += eps * torch.eye(cov_real.size(0)).to(cov_real.device)
            cov_fake += eps * torch.eye(cov_fake.size(0)).to(cov_fake.device)
            
            # Calculate FID
            fid_value = calculate_fid(mu_real, cov_real, mu_fake, cov_fake)
            self.log('FID/Validation', fid_value, prog_bar=True, on_epoch=True)
            
            # Clear memory
            del self.val_features_real
            del self.val_features_fake
            torch.cuda.empty_cache()
            self.val_features_real = None
            self.val_features_fake = None
            
            # Log grid as before
            images = []
            for i in range(10):
                for _ in range(5):
                    noise = torch.randn(1, self.hparams.latent_dim, device=self.device)
                    if self.hparams.use_multiscale:
                        generated_list = self.generator(noise)
                        image = generated_list[-1]
                    else:
                        image = self.generator(noise)
                    images.append(image)

            images_resized = [F.interpolate(img, size=(64, 64))[0] for img in images]
            images_grid = torchvision.utils.make_grid(images_resized, nrow=5)
            self.logger.experiment.log({
                "Validation_panel": [wandb.Image(
                    images_grid.permute(1,2,0).detach().float().cpu().numpy(),
                    caption="All"
                )]
            })

    def configure_optimizers(self):
        lr = self.hparams.learning_rate
        optimizer_gen = torch.optim.AdamW(
            self.generator.parameters(),
            lr=lr
        )
        optimizer_disc = torch.optim.AdamW(
            self.discriminator.parameters(),
            lr=lr * 0.5 if self.hparams.use_minibatch and self.hparams.type_model != "R1-GAN" and self.hparams.type_model != "LS-GAN" else lr
        )
        scheduler_gen = torch.optim.lr_scheduler.LinearLR(
            optimizer_gen,
            start_factor=1.0,
            end_factor=0.01,
            total_iters=300000
        )
        scheduler_disc = torch.optim.lr_scheduler.LinearLR(
            optimizer_disc,
            start_factor=1.0,
            end_factor=0.01,
            total_iters=300000
        )
        return [
            {
                "optimizer": optimizer_gen,
                "lr_scheduler": {
                    "scheduler": scheduler_gen,
                    "interval": "step"
                }
            },
            {
                "optimizer": optimizer_disc,
                "lr_scheduler": {
                    "scheduler": scheduler_disc,
                    "interval": "step"
                }
            }
        ]

torch.set_float32_matmul_precision('medium')


# Additional functions for multiscale (needed to transform real into a list, etc.)
import torch.nn.functional as F

def tensor_to_multiscale(real, max_resolution=32, min_resolution=8):
    """
    Transform tensor [N, 3, max_resolution, max_resolution] into a list
    [4x4, 8x8, 16x16, 32x32].
    """
    images = []
    current_res = min_resolution
    while current_res <= max_resolution:
        scaled = F.interpolate(
            real,
            size=(current_res, current_res),
            mode='bilinear',
            align_corners=False
        )
        images.append(scaled)
        current_res *= 2
    return images

def combine_real_fake_for_kl(real_list, fake_list):
    """
    At each level, concatenate along batch dimension: [N + N, 3, H, W].
    """
    combined = []
    for r, f in zip(real_list, fake_list):
        combined.append(torch.cat([r, f], dim=0))
    return combined


# Running multiple seeds for each method
if __name__ == "__main__":
    wandb.login(key="") #key="..."

    # Common config
    Project_name = "KL-GAN CelebA Experiment"
    methods = ["KL-GAN", "R1-GAN", "LS-GAN", "WGAN-GP", "Hinge-GAN"]
    seeds = [1, 2, 3, 4, 5]

    for method in methods:
        for seed in seeds:
            for use_minibatch in [True, False]:
                for use_multiscale in [False]:
                    minibatch_suffix = "with_minibatch" if use_minibatch else "no_minibatch"
                    scale_suffix = "multiscale" if use_multiscale else "single_scale"
                    run_name = f"{method}_{minibatch_suffix}_{scale_suffix}"
                    wandb_logger = WandbLogger(
                        name=run_name,
                        project=Project_name,
                        save_dir="./wandb_logs",
                        version=None,
                        reinit=True
                    )

                    CustomProgressBar = RichProgressBar(
                        refresh_rate=20,
                        theme=RichProgressBarTheme(
                            description="green_yellow",
                            progress_bar="green1",
                            progress_bar_finished="green1",
                            progress_bar_pulse="#6206E0",
                            batch_progress="green_yellow",
                            time="grey82",
                            processing_speed="grey82",
                            metrics="grey82",
                        )
                    )

                    trainer = pl.Trainer(
                        accelerator="gpu",
                        devices="auto",
                        precision="bf16-mixed",
                        log_every_n_steps=40,
                        callbacks=[CustomProgressBar],
                        logger=[wandb_logger],
                        max_epochs=300,
                        limit_train_batches=1.0,
                        check_val_every_n_epoch=50
                    )

                    model = GAN_Training(
                        learning_rate=0.00008,
                        batch_size=1024,
                        seed_value=seed,
                        type_model=method,
                        latent_dim=128,
                        dim=128,
                        use_minibatch=use_minibatch,
                        use_multiscale=use_multiscale,
                    )

                    datamodule = DataModule(
                        batch_size=model.hparams.batch_size,
                        val_batch_size=1024,
                        data_dir="./img_align_celeba/img_align_celeba"
                    )

                    trainer.fit(model, datamodule)
                    wandb_logger.experiment.finish()
