In [None]:
import os
"""
Core dependencies and libraries for the ExplanationGAN implementation.

This module imports required packages for:
- File operations (os, shutil)
- Numerical and data processing (numpy, pandas)
- Deep learning with PyTorch (torch, nn, optim)
- Data loading and batching (Dataset, DataLoader)
- Image processing and transformations (transforms, PIL)
- Pre-trained models (models)
- Visualization (matplotlib)
- Data preprocessing (LabelEncoder)

Dependencies:
    os: Operating system interfaces
    numpy: Numerical computing
    pandas: Data manipulation and analysis
    torch: PyTorch deep learning framework
    PIL: Python Imaging Library
    sklearn: Machine learning utilities
    matplotlib: Plotting library
    shutil: High-level file operations
    torchvision: Computer vision utilities
"""
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import shutil

In [None]:

class HAM10000Dataset(Dataset):
    """A PyTorch Dataset class for the HAM10000 skin lesion dataset.
    This class handles loading and preprocessing of the HAM10000 dataset, which contains
    dermatoscopic images of various skin lesions across 7 diagnostic categories.
    Args:
        csv_file (str): Path to the CSV file containing image metadata and labels.
        img_dirs (list): List of directory paths containing the image files.
        transform (callable, optional): Optional transform to be applied on a sample.
            Defaults to None.
        device (str, optional): Device to store the tensors on ('cuda' or 'cpu').
            Defaults to 'cuda'.
    Attributes:
        data (pandas.DataFrame): The loaded CSV data containing image metadata.
        img_dirs (list): List of directories containing image files.
        transform (callable): Transform to be applied to images.
        device (str): Device for tensor storage.
        label_encoder (LabelEncoder): Scikit-learn label encoder for categorical labels.
    Returns:
        tuple: A tuple containing:
            - image (Tensor): The processed image
            - label (int): The encoded label
    Raises:
        FileNotFoundError: If an image file cannot be found in any of the provided directories.
    """
    def __init__(self, csv_file, img_dirs, transform=None, device='cuda'):
        self.data = pd.read_csv(csv_file)
        self.img_dirs = img_dirs
        self.transform = transform
        self.device = device
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.data['encoded_label'] = self.label_encoder.fit_transform(self.data['dx'])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['image_id'] + '.jpg'
        for img_dir in self.img_dirs:
            img_path = os.path.join(img_dir, img_name)
            if os.path.exists(img_path):
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                label = self.data.iloc[idx]['encoded_label']
                return image, label
        raise FileNotFoundError(f"Image {img_name} not found in directories {self.img_dirs}")


In [None]:

class EnhancedSLEBlock(nn.Module):
    """Enhanced Selective Local Enhancement (SLE) Block for feature modulation.
    This module implements an enhanced version of the SLE block that separately processes
    content and style information for more effective feature modulation. It combines
    content-based attention with style-based modulation for improved feature enhancement.
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    Attributes:
        global_pool: Global average pooling layer for content pathway.
        content_fc1: First convolutional layer for content processing.
        content_fc2: Second convolutional layer for content processing.
        style_modulation: Sequential module for style feature processing.
        gamma: Learnable parameter for style modulation scaling.
        beta: Learnable parameter for style modulation bias.
    Forward Args:
        x (torch.Tensor): Input tensor for content pathway.
        skip_x (torch.Tensor): Skip connection tensor for style pathway.
    Returns:
        torch.Tensor: Modulated feature map combining content and style information.
    """
    def __init__(self, in_channels, out_channels):
        super(EnhancedSLEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Content branch - adjusted channel dimensions
        self.content_fc1 = nn.Conv2d(in_channels, out_channels, 1)
        self.content_fc2 = nn.Conv2d(out_channels, out_channels, 1)
        
        # Style branch
        self.style_modulation = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(True)
        )
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x, skip_x):
        # Content pathway
        content = self.global_pool(x)
        content = F.relu(self.content_fc1(content))
        content = self.content_fc2(content)
        content = torch.sigmoid(content)
        
        # Style pathway
        style = self.style_modulation(skip_x)
        
        # Combine content and style
        output = skip_x * content  # Content modulation
        output = output + self.gamma * style + self.beta  # Style modulation
        return output


In [None]:

class EnhancedFASTGANGenerator(nn.Module):
    """Enhanced FASTGAN Generator module.
    This module implements an enhanced version of the FASTGAN generator architecture with Skip-Layer
    Excitation (SLE) blocks. It generates images through a series of transposed convolution operations
    with batch normalization and ReLU activations.
    Args:
        latent_dim (int, optional): Dimension of the input latent vector. Defaults to 256.
        ngf (int, optional): Number of generator filters in the first layer. Defaults to 32.
        output_size (int, optional): Size of the output image. Defaults to 64.
    Architecture:
        - Initial block: Transposed convolution from latent space
        - 4 upsampling layers with decreasing number of channels
        - 2 Skip-Layer Excitation blocks between layers
        - Final Tanh activation
    Returns:
        torch.Tensor: Generated image of shape (batch_size, 3, output_size, output_size)
    Note:
        The architecture follows a progressive growing pattern where the number of
        channels is halved at each layer while the spatial dimensions are doubled.
    """
    def __init__(self, latent_dim=256, ngf=32, output_size=64):
        super(EnhancedFASTGANGenerator, self).__init__()
        self.output_size = output_size
        
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True)
        )
        
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True)
        )
        
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )
        
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1),
            nn.Tanh()
        )
        
        self.sle1 = EnhancedSLEBlock(ngf * 8, ngf * 4)  # Changed input and output channels
        self.sle2 = EnhancedSLEBlock(ngf * 4, ngf * 2)  # Changed input and output channels

    def forward(self, z):
        x0 = self.initial(z)
        x1 = self.layer1(x0)
        x1_sle = self.sle1(x0, x1)
        x2 = self.layer2(x1_sle)
        x2_sle = self.sle2(x1_sle, x2)
        x3 = self.layer3(x2_sle)
        x4 = self.layer4(x3)
        return x4


In [None]:

class EnhancedFASTGANDiscriminator(nn.Module):
    """Enhanced FASTGAN Discriminator Network.
    This discriminator network is based on the FASTGAN architecture with additional
    self-supervision capabilities through an integrated decoder. It consists of three
    main components: a shared feature extractor, a discriminator head for real/fake
    classification, and a decoder for self-supervised learning.
    Args:
        ndf (int, optional): Number of discriminator filters in the first conv layer.
            Defaults to 64.
        input_size (int, optional): Size of the input images. Defaults to 64.
    Attributes:
        features: Shared convolutional feature extractor.
        discriminator: Classification head for real/fake prediction.
        decoder: Decoder network for self-supervised reconstruction.
    Returns:
        tuple: A tuple containing:
            - validity (torch.Tensor): Prediction score for real/fake classification.
            - reconstruction (torch.Tensor): Reconstructed image from the features.
    Example:
        >>> discriminator = EnhancedFASTGANDiscriminator(ndf=64, input_size=64)
        >>> validity, reconstruction = discriminator(images)
    """
    def __init__(self, ndf=64, input_size=64):
        super(EnhancedFASTGANDiscriminator, self).__init__()
        self.input_size = input_size
        
        # Shared feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2)
        )
        
        # Discriminator head
        self.discriminator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf * 8, 1, 1),
            nn.Flatten(),
            nn.Sigmoid()
        )
        
        # Decoder for self-supervision
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(ndf * 8, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 4, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 2, ndf, 4, 2, 1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        features = self.features(x)
        validity = self.discriminator(features)
        reconstruction = self.decoder(features)
        return validity, reconstruction


In [None]:

class SyntheticImageClassifier:
    """A class that classifies synthetic images using two different CNN architectures.
    This classifier uses both EfficientNetV2 and ShuffleNetV2 pre-trained models to classify images.
    Images are classified by both models and predictions are compared to ensure agreement.
    Attributes:
        device (str): The device to run the models on ('cuda' or 'cpu')
        efficientnet (torch.nn.Module): EfficientNetV2 model instance
        shufflenet (torch.nn.Module): ShuffleNetV2 model instance
        transform (torchvision.transforms): Transformation pipeline for input images
    Args:
        num_classes (int): Number of output classes for classification
        device (str, optional): Device to run the models on. Defaults to 'cuda'
    Methods:
        classify_synthetic_images(synthetic_images): Classifies synthetic images using both models
            and returns a mask indicating where both models agree on the classification
    """
    def __init__(self, num_classes, device='cuda'):
        self.device = device
        
        # EfficientNetV2
        self.efficientnet = models.efficientnet_v2_s(pretrained=True)
        self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, num_classes)
        self.efficientnet = self.efficientnet.to(device)
        
        # ShuffleNetV2
        self.shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
        self.shufflenet.fc = nn.Linear(self.shufflenet.fc.in_features, num_classes)
        self.shufflenet = self.shufflenet.to(device)
        
        # Transformation for input images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def classify_synthetic_images(self, synthetic_images):
        resized_images = F.interpolate(synthetic_images, size=(224, 224), mode='bilinear', align_corners=False)
        normalized_images = (resized_images - resized_images.min()) / (resized_images.max() - resized_images.min())
        normalized_images = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(normalized_images)
        
        with torch.no_grad():
            efficientnet_preds = self.efficientnet(normalized_images)
            shufflenet_preds = self.shufflenet(normalized_images)
        
        efficientnet_classes = torch.argmax(efficientnet_preds, dim=1)
        shufflenet_classes = torch.argmax(shufflenet_preds, dim=1)
        
        agreed_classification_mask = (efficientnet_classes == shufflenet_classes)
        
        return agreed_classification_mask


In [None]:
"""Performs a single training step for both Generator and Discriminator in a GAN with enhanced training features.
                    This function implements one training iteration including:
                    - Training the discriminator with real and fake images using hinge loss
                    - Self-supervision through reconstruction loss
                    - Training the generator with hinge loss
                    Args:
                        real_imgs (torch.Tensor): Batch of real images
                        generator (nn.Module): The generator model
                        discriminator (nn.Module): The discriminator model
                        g_optimizer (torch.optim.Optimizer): Optimizer for the generator
                        d_optimizer (torch.optim.Optimizer): Optimizer for the discriminator
                        device (torch.device): Device to run the computation on
                        lambda_rec (float, optional): Weight for reconstruction loss. Defaults to 10.0
                    Returns:
                        tuple:                   
                            - dict: Dictionary containing various loss values:
                                - 'd_loss': Total discriminator loss
                                - 'd_loss_adv': Adversarial component of discriminator loss
                                - 'd_loss_rec': Reconstruction component of discriminator loss
                                - 'g_loss': Generator loss
                            - torch.Tensor: Batch of generated fake images
                    """
def enhanced_train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, 
                       device, lambda_rec=10.0):
    batch_size = real_imgs.size(0)
    
    # Train Discriminator
    d_optimizer.zero_grad()
    
    # Real images
    real_validity, real_reconstruction = discriminator(real_imgs)
    
    # Generate fake images
    z = torch.randn(batch_size, 256, 1, 1, device=device)
    fake_imgs = generator(z)
    fake_validity, _ = discriminator(fake_imgs.detach())
    
    # Hinge loss
    d_loss_real = torch.mean(F.relu(1.0 - real_validity))
    d_loss_fake = torch.mean(F.relu(1.0 + fake_validity))
    d_loss_adv = d_loss_real + d_loss_fake
    
    # Reconstruction loss for self-supervision
    d_loss_rec = F.mse_loss(real_reconstruction, real_imgs)
    
    # Total discriminator loss
    d_loss = d_loss_adv + lambda_rec * d_loss_rec
    
    d_loss.backward()
    d_optimizer.step()
    
    # Train Generator
    g_optimizer.zero_grad()
    
    fake_validity, _ = discriminator(fake_imgs)
    g_loss = -torch.mean(fake_validity)  # Hinge loss for generator
    
    g_loss.backward()
    g_optimizer.step()
    
    return {
        'd_loss': d_loss.item(),
        'd_loss_adv': d_loss_adv.item(),
        'd_loss_rec': d_loss_rec.item(),
        'g_loss': g_loss.item()
    }, fake_imgs


In [None]:

def plot_data_distribution_comparison(original_csv, synthetic_images_dir):
    """
    Plots and compares the distribution of samples between original and synthetic datasets.
    This function creates a bar plot comparing the number of samples per class in the
    original HAM10000 dataset versus the synthetically generated images. The comparison
    is visualized using a side-by-side bar chart and saved as a PNG file.
    Parameters
    ----------
    original_csv : str
        Path to the CSV file containing metadata of the original HAM10000 dataset.
        Must include a 'dx' column with class labels.
    synthetic_images_dir : str
        Path to the directory containing synthetic images organized in subdirectories
        by class name.
    Returns
    -------
    None
        Saves the generated plot as 'dataset_distribution_comparison.png' in the
        current directory.
    Notes
    -----
    - The function expects synthetic images to be stored in subdirectories named
      after their respective classes
    - Supported image formats are .png and .jpg
    - Plot includes numerical annotations above each bar showing exact counts
    """
    metadata = pd.read_csv(original_csv)
    original_class_counts = metadata['dx'].value_counts()
    synthetic_class_counts = {}
    label_encoder = LabelEncoder()
    label_encoder.fit(metadata['dx'])
    
    for class_name in label_encoder.classes_:
        class_dir = os.path.join(synthetic_images_dir, class_name)
        if os.path.exists(class_dir):
            synthetic_class_counts[class_name] = len([f for f in os.listdir(class_dir) 
                                                    if f.endswith(('.png', '.jpg'))])
        else:
            synthetic_class_counts[class_name] = 0
    
    synthetic_class_counts = pd.Series(synthetic_class_counts)
    
    plt.figure(figsize=(15, 6))
    x = np.arange(len(original_class_counts))
    width = 0.4
    
    plt.bar(x - width/2, original_class_counts.values, width, label='Original Dataset', color='blue', alpha=0.7)
    plt.bar(x + width/2, synthetic_class_counts.values, width, label='Synthetic Images', color='orange', alpha=0.7)
    
    plt.title('Comparison of Original HAM10000 Dataset and Synthetic Images', fontsize=16)
    plt.xlabel('Skin Lesion Type', fontsize=14)
    plt.ylabel('Number of Samples', fontsize=14)
    plt.xticks(x, original_class_counts.index, rotation=90, ha='right')
    plt.legend()
    
    for i, (orig, synth) in enumerate(zip(original_class_counts.values, synthetic_class_counts.values)):
        plt.text(i - width/2, orig + 50, str(int(orig)), ha='center', va='bottom', fontsize=8)
        plt.text(i + width/2, synth + 50, str(int(synth)), ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('dataset_distribution_comparison.png')
    plt.close()


In [None]:

def copy_original_images_by_class(csv_file, img_dirs, output_base_dir='synthetic_images'):
    """
    Copies and organizes image files into class-specific directories based on metadata from a CSV file.
    This function reads image metadata from a CSV file and copies images from source directories
    to a structured output directory, organizing them by their class (diagnosis type).
    Args:
        csv_file (str): Path to the CSV file containing image metadata with 'dx' (diagnosis)
            and 'image_id' columns.
        img_dirs (list): List of directory paths containing the source images.
        output_base_dir (str, optional): Base directory where images will be copied and organized.
            Defaults to 'synthetic_images'.
    Each class gets its own subdirectory under output_base_dir, and images are copied
    maintaining their original filenames. The function ensures each image is copied only once,
    even if it appears in multiple source directories.
    Prints:
        - Confirmation message when copying is complete
        - Total number of unique images copied
    Note:
        - Creates directories if they don't exist
        - Expects images to be in .jpg format
        - Uses image_id from CSV with '.jpg' extension to find source files
    """
    metadata = pd.read_csv(csv_file)
    os.makedirs(output_base_dir, exist_ok=True)
    copied_images = set()
    
    for class_name in metadata['dx'].unique():
        class_output_dir = os.path.join(output_base_dir, class_name)
        os.makedirs(class_output_dir, exist_ok=True)
        
        class_metadata = metadata[metadata['dx'] == class_name]
        
        for _, row in class_metadata.iterrows():
            img_filename = row['image_id'] + '.jpg'
            
            for img_dir in img_dirs:
                img_path = os.path.join(img_dir, img_filename)
                
                if os.path.exists(img_path):
                    dest_path = os.path.join(class_output_dir, img_filename)
                    
                    if img_path not in copied_images:
                        shutil.copy2(img_path, dest_path)
                        copied_images.add(img_path)
                    break
    
    print(f"Original images copied to {output_base_dir}")
    print(f"Total unique images copied: {len(copied_images)}")


In [None]:
"""
                        Trains an enhanced FastGAN model with both generator and discriminator networks.
                        Parameters:
                            generator (nn.Module): Generator neural network model
                            discriminator (nn.Module): Discriminator neural network model 
                            dataloader (DataLoader): DataLoader containing the training data
                            num_epochs (int): Number of training epochs
                            device (str): Device to run training on ('cuda' or 'cpu'), defaults to 'cuda'
                            lambda_rec (float): Weight for reconstruction loss, defaults to 10.0
                            save_interval (int): Number of batches between saving sample images, defaults to 100
                        Returns:
                            tuple: Trained (generator, discriminator) models
                        The function performs the following:
                        - Sets up Adam optimizers for both networks
                        - Creates a directory for saving training progress images
                        - For each epoch:
                            - Processes batches from the dataloader
                            - Performs training step using enhanced_train_step()
                            - Prints loss metrics every 100 batches
                            - Saves sample generated images at specified intervals
                        """
def train_enhanced_fastgan(generator, discriminator, dataloader, num_epochs, device='cuda', lambda_rec=10.0, save_interval=100):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    os.makedirs('training_progress', exist_ok=True)
    
    for epoch in range(num_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            
            losses, fake_imgs = enhanced_train_step(
                real_imgs, generator, discriminator,
                g_optimizer, d_optimizer, device, lambda_rec
            )
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Batch [{i}], '
                      f'D_loss: {losses["d_loss"]:.4f}, '
                      f'D_adv: {losses["d_loss_adv"]:.4f}, '
                      f'D_rec: {losses["d_loss_rec"]:.4f}, '
                      f'G_loss: {losses["g_loss"]:.4f}')
                
                # Save sample images
                if i % save_interval == 0:
                    save_image(fake_imgs[:16] * 0.5 + 0.5,
                             f'training_progress/epoch_{epoch}_batch_{i}.png',
                             nrow=4, normalize=False)
    
    return generator, discriminator


In [None]:
"""
                            Generates and saves synthetic images using a trained generator and classifier.
                            This function generates synthetic images for each class using a trained generator,
                            filters them using a classifier, and saves the valid images to specified directories.
                            Args:
                                generator (torch.nn.Module): Trained generator model
                                classifier (torch.nn.Module): Trained classifier model used to filter generated images
                                num_classes (int): Number of classes to generate images for
                                num_images_per_class (int): Number of images to generate per class
                                device (str, optional): Device to run the models on. Defaults to 'cuda'
                                batch_size (int, optional): Batch size for image generation. Defaults to 64
                                output_dir (str, optional): Base directory to save generated images. Defaults to 'synthetic_images'
                            Returns:
                                None
                            Notes:
                                - Creates directories for each class under output_dir
                                - Generated images are saved as PNG files
                                - Images are normalized from [-1,1] to [0,1] range before saving
                                - Progress is printed during generation
                            """
def generate_synthetic_images(generator, classifier, num_classes, num_images_per_class,device='cuda', batch_size=64, output_dir='synthetic_images'):
    os.makedirs(output_dir, exist_ok=True)
    generator.eval()
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            class_dir = os.path.join(output_dir, f'class_{class_idx}')
            os.makedirs(class_dir, exist_ok=True)
            
            num_generated = 0
            while num_generated < num_images_per_class:
                # Generate images
                z = torch.randn(batch_size, 256, 1, 1, device=device)
                fake_imgs = generator(z)
                
                # Filter images using classifier
                valid_mask = classifier.classify_synthetic_images(fake_imgs)
                valid_images = fake_imgs[valid_mask]
                
                # Save valid images
                for idx, img in enumerate(valid_images):
                    if num_generated >= num_images_per_class:
                        break
                    save_image(img * 0.5 + 0.5,
                             os.path.join(class_dir, f'synthetic_{num_generated}.png'))
                    num_generated += 1
                
                print(f'Class {class_idx}: Generated {num_generated}/{num_images_per_class} images')


In [None]:

def main():
    """
    Main function for training and evaluating an Enhanced FASTGAN model on the HAM10000 skin cancer dataset.
    This function performs the following key operations:
    1. Sets up the device (CPU/GPU) and random seeds for reproducibility
    2. Loads and preprocesses the HAM10000 dataset
    3. Initializes the Generator, Discriminator, and Classifier models
    4. Trains the GAN model
    5. Generates synthetic images for each class
    6. Plots distribution comparison between real and synthetic data
    The function uses an Enhanced FASTGAN architecture specifically adapted for
    generating synthetic skin cancer images across multiple classes.
    Dependencies:
        - torch
        - numpy
        - torchvision
        - Custom modules (HAM10000Dataset, EnhancedFASTGANGenerator, 
          EnhancedFASTGANDiscriminator, SyntheticImageClassifier)
    Returns:
        None
    Example:
        >>> main()
        Using device: cuda
        Number of classes: 7
        Starting training...
        Training and generation complete!
    """
    # Set device and random seeds
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Dataset parameters
    csv_file = '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv'
    img_dirs = ['/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1', '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_2']
    
    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Initialize dataset and dataloader
    dataset = HAM10000Dataset(csv_file, img_dirs, transform=transform, device=device)
    num_classes = len(dataset.label_encoder.classes_)
    print(f"Number of classes: {num_classes}")
    print("Class labels:", dataset.label_encoder.classes_)
    
    batch_size = 64
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Initialize models
    generator = EnhancedFASTGANGenerator(latent_dim=256, output_size=64).to(device)
    discriminator = EnhancedFASTGANDiscriminator(input_size=64).to(device)
    classifier = SyntheticImageClassifier(num_classes=num_classes, device=device)
    
    # Training parameters
    num_epochs = 100
    print("Starting training...")
    
    # Train the model
    generator, discriminator = train_enhanced_fastgan(
        generator, discriminator, dataloader, 
        num_epochs=num_epochs, device=device
    )
    
    # Generate synthetic images for each class
    print("Generating synthetic images...")
    generate_synthetic_images(
        generator, classifier, 
        num_classes=num_classes,
        num_images_per_class=1000,  # Adjust as needed
        device=device,
        output_dir='synthetic_images'
    )
    
    # Plot distribution comparison
    plot_data_distribution_comparison(csv_file, 'synthetic_images')
    
    print("Training and generation complete!")

if __name__ == "__main__":
    main()

1. Architecture Components:
- Generator: Uses single convolution layer per resolution with restricted channels
- Skip-Layer Excitation (SLE) module for gradient flow enhancement
- Self-supervised discriminator with small decoders
- Two CNN classifiers (EfficientNetV2 and ShuffleNetV2) for synthetic image filtering

2. Key Features:
- SLE module: 
  - Implements channel-wise multiplications
  - Creates skip-connections between distant resolutions
  - Handles content/style attribute disentanglement

3. Training Process:
- Uses hinge version of adversarial loss
- Iterative training between discriminator and generator
- Discriminator regularization through auto-encoding
- Two-stage filtering:
  1. Generate synthetic images from original medical data
  2. Filter generations through dual CNN validation

4. Data Flow:
```
Original Images → FASTGAN Generator → Synthetic Images → Dual CNN Filtering → Filtered Dataset + Original Images → Few-shot Classifier
```


Title: Enhanced FASTGAN Architecture for Synthetic Medical Image Generation 

Abstract:
This implementation presents an enhanced FASTGAN (Fast and Adaptive Synthetic Training GAN) architecture specifically designed for generating synthetic medical images from the HAM10000 skin lesion dataset. The system incorporates a novel dual-classifier validation mechanism and self-supervised learning components to ensure high-quality, class-consistent synthetic image generation.

1. Architecture Overview

The system consists of three primary components:

a) Generator Architecture
- Implements an enhanced FASTGAN generator with a latent dimension of 256
- Utilizes transposed convolutions for upsampling
- Incorporates Enhanced Skip-Layer Excitation (SLE) blocks for improved feature propagation
- Employs batch normalization and ReLU activation functions

b) Discriminator Architecture
- Features a shared feature extractor backbone
- Implements dual-head architecture:
  * Classification head for real/fake discrimination
  * Decoder head for self-supervised reconstruction
- Utilizes hinge loss for adversarial training

c) Synthetic Image Classifier
- Employs ensemble approach with EfficientNetV2 and ShuffleNetV2
- Provides validation through classification agreement
- Performs image normalization and size standardization

2. Dataset Implementation

The HAM10000Dataset class implements a custom PyTorch Dataset with the following features:
- Supports multiple image directory sources
- Performs label encoding for categorical diagnoses
- Implements on-the-fly image transformation
- Handles missing image scenarios with appropriate error handling

3. Training Methodology

The training process incorporates several key components:

a) Enhanced Training Step
```python
d_loss = d_loss_adv + lambda_rec * d_loss_rec
```
- Combines adversarial and reconstruction losses for the discriminator
- Implements hinge loss for improved stability
- Uses self-supervised reconstruction for feature learning

b) Generator Training
- Utilizes adversarial hinge loss for generator optimization
- Implements adaptive learning rates via Adam optimizer
- Includes periodic progress visualization

4. Synthetic Image Generation Process

The generation pipeline includes:
- Batch-wise image generation
- Multi-classifier validation filtering
- Class-specific output organization
- Distribution matching with original dataset

5. Quality Control Mechanisms

Several quality control measures are implemented:
- Dual-model classification verification
- Image normalization and standardization
- Distribution comparison visualization
- Automated storage and organization of synthetic images

6. Key Innovations

The implementation includes several novel components:
- Enhanced SLE blocks for improved feature propagation
- Dual-classifier validation mechanism
- Self-supervised reconstruction in discriminator
- Distribution-aware generation process

7. Experimental Setup

The main function orchestrates the following workflow:
- Device and seed initialization
- Dataset preparation and loading
- Model initialization and training
- Synthetic image generation and validation
- Distribution analysis and visualization

8. Performance Monitoring

The system includes comprehensive monitoring:
- Loss tracking for generator and discriminator
- Image quality assessment through periodic sampling
- Class distribution visualization
- Generation success rate tracking

Conclusion:
This implementation provides a robust framework for generating synthetic medical images with built-in validation mechanisms. The architecture's emphasis on quality control and class consistency makes it particularly suitable for medical image synthesis applications.

Future Work:
Potential improvements could include:
- Implementation of additional validation metrics
- Integration of domain-specific medical image quality assessments
- Extension to support multi-modal medical imaging data
- Enhancement of the classification validation mechanism

It integrates FASTGAN to generate additional medical images from available training data. Employing two CNNs, the approach identifes synthetic images prone to misclassifcation, removing them from the dataset. This systematic elimination aims to improve the accuracy and reliability of the classifcation task, showcasing the potential of FSL techniques in medical image classifcation.
This study employs FASTGAN [4] for generating synthetic images in the training dataset. FASTGAN employs a single convolution layer per resolution with restricted channels, 
yielding a smaller and quicker-to-train model. Moreover, it integrates the Skip-Layer Excitation (SLE) module, enhancing gradient fow by modifying skip-connections for efcient 
gradient signal propagation across resolutions. SLE utilizes channel-wise multiplications 
and extends skip-connections between distant resolutions to improve gradient fow without notable computational overhead. While resembling the Squeeze-and-Excitation module, SLE operates between distant feature-maps, aiding gradient fow and channel-wise 
feature recalibration essential for disentangling content and style attributes in generated 
images. FASTGAN also incorporates a self-supervised discriminator trained with small 
decoders, enhancing image feature extraction via auto-encoding. This approach improves 
the discriminator’s capacity to extract comprehensive representations from inputs, thereby 
enhancing model robustness and synthesis quality. This approach maintains a pure GAN 
framework, using auto-encoding solely for discriminator regularization. Additionally, the 
method employs the hinge version of the adversarial loss for iterative training of the discriminator and generator. Overall, these methodological advancements contribute to more 
efcient and efective GAN training, which has implications for various image synthesis 
tasks.
Figure 3 illustrates the integration of the GAN with the few-shot classifer network. This 
paper initially undertakes the generation of synthetic medical images by using original 
images. Considering that some synthetic medical images may not be completely realistic 
or accurately categorized, the possibility of inaccuracies necessitates manual examination 
by domain experts for correction. To circumvent the need for expert intervention, the study 
employs two CNNs, namely EfcientNetV2 and ShufeNetV2, trained on the original 
medical images for the accurate classifcation of synthetic images. Consequently, only the synthetic images correctly classifed by these CNNs are utilized. These mutually classifed 
images, along with the original medical images, form the training set for the Few-Shot 
Learning (FSL) classifer. This approach ensures the selection of high-quality synthetic 
images for inclusion in the training dataset, thereby enhancing the performance and reliability of the FSL classifer. By leveraging the capabilities of pretrained CNNs, the study 
establishes a mechanism to automate the classifcation process. It mitigates the need for 
manual intervention and facilitates the seamless integration of synthetic images into the 
training pipeline.