In [None]:
import os
"""
Initial imports for a GAN (Generative Adversarial Network) implementation.

This module imports necessary libraries and frameworks for implementing a GAN:

Core Libraries:
    - os: Operating system interface
    - numpy: Numerical computing
    - pandas: Data manipulation and analysis

PyTorch Components:
    - torch: Main PyTorch library
    - torch.nn: Neural network modules
    - torch.nn.functional: Neural network functions
    - torch.optim: Optimization algorithms
    - torch.utils.data: Data loading utilities

Computer Vision Tools:
    - torchvision.transforms: Image transformation utilities
    - torchvision.models: Pre-trained models
    - torchvision.utils: Image handling utilities
    - PIL: Python Imaging Library

Data Processing & Visualization:
    - sklearn.preprocessing: Data preprocessing tools
    - matplotlib.pyplot: Plotting library

The module sets up the foundation for building and training a GAN model
with image processing capabilities.
"""
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import shutil

In [None]:
class HAM10000Dataset(Dataset):
    """HAM10000Dataset class for loading and preprocessing the HAM10000 skin lesion dataset.
    This class inherits from torch.utils.data.Dataset and provides functionality to load
    and prepare the HAM10000 dataset for training deep learning models.
    Args:
        csv_file (str): Path to the CSV file containing image metadata and labels.
        img_dirs (list): List of directories containing the image files.
        transform (callable, optional): Optional transform to be applied on an image.
            Defaults to None.
        device (str, optional): Device to load the data to ('cuda' or 'cpu').
            Defaults to 'cuda'.
    Attributes:
        data (DataFrame): Pandas DataFrame containing the dataset metadata.
        img_dirs (list): List of image directory paths.
        transform (callable): Transform to be applied to images.
        device (str): Device for data loading.
        label_encoder (LabelEncoder): Scikit-learn label encoder for class labels.
    Returns:
        tuple: A tuple containing:
            - image (Tensor): The processed image
            - label (int): The encoded label
    Raises:
        FileNotFoundError: If an image file is not found in any of the provided directories.
    """
    def __init__(self, csv_file, img_dirs, transform=None, device='cuda'):
        self.data = pd.read_csv(csv_file)
        self.img_dirs = img_dirs
        self.transform = transform
        self.device = device
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.data['encoded_label'] = self.label_encoder.fit_transform(self.data['dx'])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['image_id'] + '.jpg'
        for img_dir in self.img_dirs:
            img_path = os.path.join(img_dir, img_name)
            if os.path.exists(img_path):
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                label = self.data.iloc[idx]['encoded_label']
                return image, label
        raise FileNotFoundError(f"Image {img_name} not found in directories {self.img_dirs}")

In [None]:
def plot_data_distribution_comparison(original_csv, synthetic_images_dir):
    """
    Compare the data distribution of the original HAM10000 dataset 
    with the distribution after adding synthetic images
    
    Args:
        original_csv (str): Path to the original metadata CSV file
        synthetic_images_dir (str): Path to the directory containing synthetic images

    
    Compare the data distribution between the original HAM10000 dataset and synthetic images.
    This function creates a side-by-side bar plot comparing the class distribution in the original
    HAM10000 dataset with the distribution of synthetically generated images. It also prints
    detailed statistics about the distribution and augmentation ratios.
        original_csv (str): Path to the CSV file containing original HAM10000 metadata.
        synthetic_images_dir (str): Path to the directory containing synthetic images organized
            in subdirectories by class.
    Returns:
        None. Displays a plot and prints distribution statistics to console.
    The function:
    - Creates a bar plot comparing original vs synthetic image counts per class
    - Adds count labels on top of each bar
    - Prints detailed distribution percentages for both original and synthetic data
    - Calculates and displays augmentation ratios per class
    Example:
        >>> plot_data_distribution_comparison('metadata.csv', 'synthetic_images/')
    Compare the data distribution of the original HAM10000 dataset 
    with the distribution after adding synthetic images
    
    Args:
        original_csv (str): Path to the original metadata CSV file
        synthetic_images_dir (str): Path to the directory containing synthetic images
    
    """
    # Read the original metadata
    metadata = pd.read_csv(original_csv)
    
    # Count original class distribution
    original_class_counts = metadata['dx'].value_counts()
    
    # Prepare synthetic image class counts
    synthetic_class_counts = {}
    label_encoder = LabelEncoder()
    label_encoder.fit(metadata['dx'])
    
    # Count synthetic images per class
    for class_name in label_encoder.classes_:
        class_dir = os.path.join(synthetic_images_dir, class_name)
        if os.path.exists(class_dir):
            synthetic_class_counts[class_name] = len([f for f in os.listdir(class_dir) 
                                                      if f.endswith(('.png', '.jpg'))])
        else:
            synthetic_class_counts[class_name] = 0
    
    # Convert to Series for consistent plotting
    synthetic_class_counts = pd.Series(synthetic_class_counts)
    
    # Prepare the plot
    plt.figure(figsize=(15, 6))
    
    # Create a side-by-side bar plot
    x = np.arange(len(original_class_counts))
    width = 0.4
    
    plt.bar(x - width/2, original_class_counts.values, width, label='Original Dataset', color='blue', alpha=0.7)
    plt.bar(x + width/2, synthetic_class_counts.values, width, label='Synthetic Images', color='orange', alpha=0.7)
    
    plt.title('Comparison of Original HAM10000 Dataset and Synthetic Images', fontsize=16)
    plt.xlabel('Skin Lesion Type', fontsize=14)
    plt.ylabel('Number of Samples', fontsize=14)
    plt.xticks(x, original_class_counts.index, rotation=90, ha='right')
    plt.legend()
    
    # Add count labels on top of each bar
    for i, (orig, synth) in enumerate(zip(original_class_counts.values, synthetic_class_counts.values)):
        plt.text(i - width/2, orig + 50, str(int(orig)), ha='center', va='bottom', fontsize=8)
        plt.text(i + width/2, synth + 50, str(int(synth)), ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    #plt.savefig('dataset_distribution_comparison.png', dpi=300, bbox_inches='tight')
    plt.show('dataset_distribution_comparison.png')
    plt.close()

    # Print detailed comparison
    print("\nOriginal Dataset Distribution:")
    total_original = len(metadata)
    for cls, count in original_class_counts.items():
        percentage = (count / total_original) * 100
        print(f"{cls}: {count} samples ({percentage:.2f}%)")
    
    print("\nSynthetic Images Distribution:")
    total_synthetic = synthetic_class_counts.sum()
    for cls, count in synthetic_class_counts.items():
        percentage = (count / total_synthetic) * 100 if total_synthetic > 0 else 0
        print(f"{cls}: {count} synthetic images ({percentage:.2f}%)")
    
    # Calculate and print augmentation ratio
    print("\nAugmentation Ratio:")
    for cls in original_class_counts.index:
        orig_count = original_class_counts.get(cls, 0)
        synth_count = synthetic_class_counts.get(cls, 0)
        augmentation_ratio = synth_count / orig_count if orig_count > 0 else 0
        print(f"{cls}: {augmentation_ratio:.2f}x")

In [None]:
class SLEBlock(nn.Module):
    """Spatial Light-weight Enhancement Block (SLE).

    This block implements a lightweight channel attention mechanism that enhances relevant
    features by focusing on important channels. It uses global average pooling followed by
    a bottleneck architecture with two fully connected layers.

    Args:
        in_channels (int): Number of input channels.

    Attributes:
        global_pool (nn.AdaptiveAvgPool2d): Global average pooling layer.
        fc1 (nn.Conv2d): First fully connected layer implemented as 1x1 convolution.
        fc2 (nn.Conv2d): Second fully connected layer implemented as 1x1 convolution.
        sigmoid (nn.Sigmoid): Sigmoid activation function.

    Input:
        x (torch.Tensor): Input feature map for generating attention weights.
        y (torch.Tensor): Feature map to be enhanced.

    Returns:
        torch.Tensor: Enhanced feature map (y * attention_weights).
    """
    def __init__(self, in_channels):
        super(SLEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // 2, 1)
        self.fc2 = nn.Conv2d(in_channels // 2, in_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        x = self.global_pool(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sigmoid(x)
        return y * x
        

In [None]:
class FASTGANGenerator(nn.Module):
    """FASTGAN Generator implementation.

    This class implements the Generator part of FASTGAN architecture, which transforms
    a latent vector into an image using transposed convolutions and skip-layer
    excitation blocks.

    Args:
        latent_dim (int, optional): Dimension of the input noise vector. Defaults to 256.
        ngf (int, optional): Number of generator filters in the first conv layer. Defaults to 64.
        output_size (int, optional): Size of the output image. Defaults to 64.

    Attributes:
        output_size (int): Size of the output image.
        initial (nn.Sequential): Initial upsampling block.
        layer1-4 (nn.Sequential): Upsampling layers with transposed convolutions.
        sle1-2 (SLEBlock): Skip-Layer Excitation blocks for feature refinement.

    Returns:
        torch.Tensor: Generated image of shape (batch_size, 3, output_size, output_size)
    """
    def __init__(self, latent_dim=256, ngf=64, output_size=64):
        super(FASTGANGenerator, self).__init__()
        self.output_size = output_size
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf * 16, 4, 1, 0),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True)
        )
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True)
        )
        self.sle1 = SLEBlock(ngf * 8)
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        self.sle2 = SLEBlock(ngf * 4)
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True)
        )
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.initial(z)
        x = self.layer1(x)
        x = self.sle1(x, x)
        x = self.layer2(x)
        x = self.sle2(x, x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

class FASTGANDiscriminator(nn.Module):
    def __init__(self, ndf=64, input_size=64):
        super(FASTGANDiscriminator, self).__init__()
        self.input_size = input_size
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf * 8, 1, 1),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
class SelfSupervisedDiscriminator(nn.Module):
    """A self-supervised discriminator neural network module for GANs.
    This discriminator combines traditional GAN discrimination with self-supervised learning
    by incorporating an additional reconstruction task. It consists of a main discriminator
    network that processes images through multiple convolutional layers, and a decoder
    network that attempts to reconstruct the input from the learned features.
    Args:
        ndf (int, optional): Number of discriminator filters in first conv layer. Default: 64
    Architecture:
        Main discriminator:
        - Series of Conv2d layers that downsample from 16x16 to 1x1
        - Each conv layer followed by BatchNorm and LeakyReLU
        - Final Sigmoid activation for binary classification
        Decoder (self-supervised):
        - Series of ConvTranspose2d layers that upsample from 1x1 to 16x16
        - Each layer followed by BatchNorm and ReLU
        - Final Tanh activation for image reconstruction
    Returns:
        tuple: Contains:
            - validity (Tensor): Discrimination score between 0 and 1
            - reconstruction (Tensor): Reconstructed version of input image
    """
    def __init__(self, ndf=64):
        super(SelfSupervisedDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),          # 16x16 -> 8x8
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),   # 8x8 -> 4x4
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),  # 4x4 -> 2x2
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),  # 2x2 -> 1x1
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 8, ndf * 16, 1, 1, 0), # 1x1 -> 1x1
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 16, 1, 1, 1, 0),      # 1x1 -> 1x1
            nn.Sigmoid()
        )
        
        # Small decoders for self-supervised learning
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(ndf * 16, ndf * 8, 4, 1, 0),   # 1x1 -> 2x2
            nn.BatchNorm2d(ndf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 8, ndf * 4, 4, 2, 1),    # 2x2 -> 4x4
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 4, ndf * 2, 4, 2, 1),    # 4x4 -> 8x8
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 2, 3, 4, 2, 1),          # 8x8 -> 16x16
            nn.Tanh()
        )

    def forward(self, x):
        features = self.main[:-1](x)            # Extract features before the final layer
        validity = self.main[-1](features)      # Compute validity score
        reconstruction = self.decoder1(features) # Reconstruct the image
        
        # Debugging Statements
        print(f"Input Image Size: {x.size()}")
        print(f"Reconstructed Image Size: {reconstruction.size()}")
        
        return validity, reconstruction

In [None]:
def train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, 
               device, autoencoder_loss):
    """Train one step of GAN.
            This function performs one training step for both the generator and discriminator
            in a Generative Adversarial Network (GAN).
            Args:
                real_imgs (torch.Tensor): Batch of real images from the dataset
                generator (nn.Module): The generator model
                discriminator (nn.Module): The discriminator model
                g_optimizer (torch.optim.Optimizer): Optimizer for the generator
                d_optimizer (torch.optim.Optimizer): Optimizer for the discriminator
                device (torch.device): Device to run the computations on ('cuda' or 'cpu')
                autoencoder_loss: Currently unused parameter for potential autoencoder loss
            Returns:
                tuple:
                    - float: Discriminator loss for the current batch
                    - float: Generator loss for the current batch
                    - torch.Tensor: Generated fake images
            Note:
                The function first trains the discriminator to better distinguish between real
                and fake images, then trains the generator to produce more realistic images
                that can fool the discriminator.
            """
    batch_size = real_imgs.size(0)
    
    # Train Discriminator
    d_optimizer.zero_grad()
    
    real_validity = discriminator(real_imgs)
    
    z = torch.randn(batch_size, 256, 1, 1, device=device)
    fake_imgs = generator(z)
    fake_validity = discriminator(fake_imgs.detach())
    
    d_loss = (F.binary_cross_entropy(real_validity, torch.ones_like(real_validity)) +
              F.binary_cross_entropy(fake_validity, torch.zeros_like(fake_validity)))
    
    d_loss.backward()
    d_optimizer.step()
    
    # Train Generator
    g_optimizer.zero_grad()
    
    fake_validity = discriminator(fake_imgs)
    g_loss = F.binary_cross_entropy(fake_validity, torch.ones_like(fake_validity))
    
    g_loss.backward()
    g_optimizer.step()
    
    return d_loss.item(), g_loss.item(), fake_imgs

In [None]:
def train_fastgan(generator, discriminator, dataloader, num_epochs, device='cuda'):
    """
    Trains a FastGAN model using a generator and discriminator.
    This function implements the training loop for a FastGAN architecture, performing
    alternating updates to the generator and discriminator networks using Adam optimization.
    Args:
        generator: The generator neural network model
        discriminator: The discriminator neural network model
        dataloader: PyTorch DataLoader containing the training data
        num_epochs (int): Number of epochs to train for
        device (str, optional): Device to run the training on. Defaults to 'cuda'.
    Returns:
        None. The function updates the models in-place and saves generated images periodically.
    Example:
        >>> train_fastgan(generator, discriminator, train_loader, num_epochs=100)
    """
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    autoencoder_loss = nn.MSELoss()
    
    for epoch in range(num_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            
            d_loss, g_loss, fake_imgs = train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, device, autoencoder_loss)
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], '
                      f'D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')
                
                # Optional: Save some generated images
                if i % 500 == 0:
                    save_image(fake_imgs[:16] * 0.5 + 0.5, 
                               f'generated_images_epoch_{epoch}_batch_{i}.png', 
                               normalize=False)

In [None]:
class ProgressiveGrowingManager:
    """Progressive Growing Manager for GANs.
    This class manages the progressive growing of GANs by controlling image size transitions
    and blending between different resolutions using an alpha parameter.
    Args:
        start_size (int): Initial size of generated images. Defaults to 16.
        target_size (int): Final target size for generated images. Defaults to 64.
        n_steps (int): Number of progressive growing steps. Defaults to 3.
    Attributes:
        current_size (int): Current size of generated images.
        target_size (int): Target size to reach through progressive growing.
        n_steps (int): Total number of growing steps.
        alpha (float): Blending factor between resolutions (0.0 to 1.0).
    Methods:
        step(): Advances the progressive growing process by updating alpha and size.
        get_size(): Returns the current image size.
    Example:
        pg_manager = ProgressiveGrowingManager(start_size=16, target_size=128, n_steps=4)
        current_size = pg_manager.get_size()  # Returns 16
        pg_manager.step()  # Updates alpha value
    """
    def __init__(self, start_size=16, target_size=64, n_steps=3):
        self.current_size = start_size
        self.target_size = target_size
        self.n_steps = n_steps
        self.alpha = 0.0
        
    def step(self):
        self.alpha = min(1.0, self.alpha + 0.1)
        if self.alpha >= 1.0 and self.current_size < self.target_size:
            self.current_size = min(self.current_size * 2, self.target_size)
            self.alpha = 0.0
            
    def get_size(self):
        return self.current_size

In [None]:
class SyntheticImageClassifier:
    """
    A class that combines EfficientNetV2 and ShuffleNetV2 models for synthetic image classification.
    This classifier uses two pre-trained models (EfficientNetV2 and ShuffleNetV2) to classify synthetic images.
    The models are modified to output the specified number of classes and can be loaded with custom pretrained weights.
    Images are classified only when both models agree on the classification.
    Attributes:
        device (str): Device to run the models on ('cuda' or 'cpu')
        efficientnet (torch.nn.Module): EfficientNetV2 model instance
        shufflenet (torch.nn.Module): ShuffleNetV2 model instance
        transform (torchvision.transforms.Compose): Image transformation pipeline for preprocessing
        num_classes (int): Number of output classes for classification
        device (str, optional): Device to run models on. Defaults to 'cuda'.
    Example:
        >>> classifier = SyntheticImageClassifier(num_classes=10)
        >>> classifier.load_pretrained_weights('efficientnet.pth', 'shufflenet.pth')
        >>> mask = classifier.classify_synthetic_images(synthetic_images)
    """
    def __init__(self, num_classes, device='cuda'):
        self.device = device
        
        # EfficientNetV2
        self.efficientnet = models.efficientnet_v2_s(pretrained=True)
        self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, num_classes)
        self.efficientnet = self.efficientnet.to(device)
        
        # ShuffleNetV2
        self.shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
        self.shufflenet.fc = nn.Linear(self.shufflenet.fc.in_features, num_classes)
        self.shufflenet = self.shufflenet.to(device)
        
        # Transformation for input images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def load_pretrained_weights(self, efficientnet_path, shufflenet_path):
        """
        Load pretrained weights for both models
        
        Args:
            efficientnet_path (str): Path to EfficientNetV2 weights
            shufflenet_path (str): Path to ShuffleNetV2 weights
        """
        self.efficientnet.load_state_dict(torch.load(efficientnet_path))
        self.shufflenet.load_state_dict(torch.load(shufflenet_path))
        
        # Set models to evaluation mode
        self.efficientnet.eval()
        self.shufflenet.eval()
    
    def classify_synthetic_images(self, synthetic_images):
        """
        Classify synthetic images using both models
        
        Args:
            synthetic_images (torch.Tensor): Tensor of synthetic images
        
        Returns:
            torch.Tensor: Mask of correctly classified images
        """
        # Resize and normalize synthetic images for classification
        resized_images = F.interpolate(synthetic_images, size=(224, 224), mode='bilinear', align_corners=False)
        normalized_images = (resized_images - resized_images.min()) / (resized_images.max() - resized_images.min())
        normalized_images = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(normalized_images)
        
        # Get predictions from both models
        with torch.no_grad():
            efficientnet_preds = self.efficientnet(normalized_images)
            shufflenet_preds = self.shufflenet(normalized_images)
        
        # Get class predictions
        efficientnet_classes = torch.argmax(efficientnet_preds, dim=1)
        shufflenet_classes = torch.argmax(shufflenet_preds, dim=1)
        
        # Create mask where both models agree
        agreed_classification_mask = (efficientnet_classes == shufflenet_classes)
        
        return agreed_classification_mask

In [None]:
def copy_original_images_by_class(csv_file, img_dirs, output_base_dir='synthetic_images'):
    """
    Copy original images to synthetic images folder, organized by class.
    This function reads a metadata CSV file containing image information and copies the original
    images to a new directory structure organized by class. Each image is copied only once to
    avoid duplicates.
        csv_file (str): Path to the metadata CSV file containing image information with 'dx' 
            (diagnosis/class) and 'image_id' columns.
        img_dirs (list): List of directory paths where original images are stored. The function
            will search these directories in order until it finds each image.
        output_base_dir (str, optional): Base directory where the class-organized images will
            be copied to. Defaults to 'synthetic_images'.
    Returns:
        None
    Side Effects:
        - Creates output_base_dir if it doesn't exist
        - Creates subdirectories for each unique class in the metadata
        - Copies images to their respective class directories
        - Prints summary of copied images and their destination
    Example:
        >>> csv_file = 'metadata.csv'
        >>> img_dirs = ['images/folder1', 'images/folder2']
        >>> copy_original_images_by_class(csv_file, img_dirs, 'output_directory')
    Returns:
        None
    Side Effects:
        - Creates output_base_dir if it doesn't exist
        - Creates subdirectories for each unique class in the metadata
        - Copies images to their respective class directories
        - Prints summary of copied images and their destination
    Example:
        >>> csv_file = 'metadata.csv'
        >>> img_dirs = ['images/folder1', 'images/folder2']
        >>> copy_original_images_by_class(csv_file, img_dirs, 'output_directory')
    Copy original images to synthetic images folder, organized by class.
    This function reads a metadata CSV file containing image information and copies the original
    images to a new directory structure organized by class. Each image is copied only once to
    avoid duplicates.
        csv_file (str): Path to the metadata CSV file containing image information with 'dx' 
            (diagnosis/class) and 'image_id' columns.
        img_dirs (list): List of directory paths where original images are stored. The function
            will search these directories in order until it finds each image.
        output_base_dir (str, optional): Base directory where the class-organized images will
            be copied to. Defaults to 'synthetic_images'.
    
    Copy original images to synthetic images folder, organized by class
    
    Args:
        csv_file (str): Path to the metadata CSV file
        img_dirs (list): List of directories containing original images
        output_base_dir (str): Base directory for synthetic images
    """
    # Read the metadata
    metadata = pd.read_csv(csv_file)
    
    # Ensure the output base directory exists
    os.makedirs(output_base_dir, exist_ok=True)
    
    # Track copied images to avoid duplicates
    copied_images = set()
    
    # Iterate through unique classes
    for class_name in metadata['dx'].unique():
        # Create class-specific directory
        class_output_dir = os.path.join(output_base_dir, class_name)
        os.makedirs(class_output_dir, exist_ok=True)
        
        # Filter metadata for current class
        class_metadata = metadata[metadata['dx'] == class_name]
        
        # Copy images for this class
        for _, row in class_metadata.iterrows():
            img_filename = row['image_id'] + '.jpg'
            
            # Search for the image in provided directories
            for img_dir in img_dirs:
                img_path = os.path.join(img_dir, img_filename)
                
                if os.path.exists(img_path):
                    # Destination path
                    dest_path = os.path.join(class_output_dir, img_filename)
                    
                    # Copy only if not already copied
                    if img_path not in copied_images:
                        shutil.copy2(img_path, dest_path)
                        copied_images.add(img_path)
                    break
    
    print(f"Original images copied to {output_base_dir}")
    print(f"Total unique images copied: {len(copied_images)}")


In [None]:
def main():
    """
    Main function for training and generating synthetic medical images using a FASTGAN architecture.
    This function performs the following steps:
    1. Sets up the device (CPU/GPU) and seeds for reproducibility
    2. Initializes data transformations and loads the HAM10000 skin cancer dataset
    3. Creates and trains a FASTGAN model (generator and discriminator)
    4. Generates synthetic images for underrepresented classes
    5. Filters synthetic images using a classifier
    6. Saves the generated images and plots data distribution comparisons
    Returns:
        None
    Dependencies:
        - torch: PyTorch library for deep learning
        - torchvision: PyTorch computer vision library
        - numpy: Numerical computing library
        - os: Operating system interface
    Note:
        - The function is designed to work with the HAM10000 skin cancer dataset
        - Synthetic images are generated for all classes except 'nv' and 'vasc'
        - The function requires a GPU for optimal performance
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    torch.manual_seed(42)
    np.random.seed(42)
    
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    csv_file = '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv'
    img_dirs = ['/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1', '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_2']
    
    # Call the function to plot the data distribution
    
    copy_original_images_by_class(csv_file, img_dirs)
    dataset = HAM10000Dataset(csv_file, img_dirs, transform=transform, device=device)
    
    num_classes = len(dataset.label_encoder.classes_)
    print("Unique Classes:", dataset.label_encoder.classes_)
    print("Number of Classes:", num_classes)
    
    batch_size = 64
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    latent_dim = 256
    generator = FASTGANGenerator(latent_dim, output_size=64).to(device)
    discriminator = FASTGANDiscriminator(input_size=64).to(device)
    
    num_epochs = 150
    
    train_fastgan(generator, discriminator, data_loader, num_epochs, device)
    
    os.makedirs('synthetic_images', exist_ok=True)
    
    # Initialize the Synthetic Image Classifier
    classifier = SyntheticImageClassifier(num_classes=num_classes, device=device)
    
    # Note: In a real scenario, you would load pretrained weights
    # classifier.load_pretrained_weights('path/to/efficientnet_weights.pth', 'path/to/shufflenet_weights.pth')
    
    synthetic_images_by_class = {}
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            # Get the class name
            class_name = dataset.label_encoder.inverse_transform([class_idx])[0]
            
            # Skip generating 1000 images for 'nv' and 'vasc' classes
            if class_name in ['nv', 'vasc']:
                continue
            
            # Number of images to generate
            num_images_to_generate = 1000
            
            # Calculate number of batches needed
            num_batches = (num_images_to_generate + batch_size - 1) // batch_size
            
            valid_synthetic_images_list = []
            
            for _ in range(num_batches):
                # Generate a batch of images
                z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
                synthetic_images = generator(z)
                
                # Classify synthetic images
                valid_image_mask = classifier.classify_synthetic_images(synthetic_images)
                
                # Filter synthetic images based on classification
                valid_synthetic_images = synthetic_images[valid_image_mask]
                
                valid_synthetic_images_list.append(valid_synthetic_images)
                
                # Break if we have enough images
                if len(torch.cat(valid_synthetic_images_list)) >= num_images_to_generate:
                    break
            
            # Concatenate and trim to exact number of images
            valid_synthetic_images = torch.cat(valid_synthetic_images_list)[:num_images_to_generate]
            
            synthetic_images_by_class[class_idx] = valid_synthetic_images.cpu()
            
            class_dir = os.path.join('synthetic_images', class_name)
            os.makedirs(class_dir, exist_ok=True)
            
            #for i, img in enumerate(valid_synthetic_images):
            #    save_path = os.path.join(class_dir, f'synthetic_image_{i}.png')
            #    save_image((img * 0.5 + 0.5), save_path)
    
    print("Synthetic image generation, classification, and filtering complete!")
    plot_data_distribution_comparison(csv_file,'/kaggle/working/synthetic_images')
    
if __name__ == "__main__":
    main()

Let me explain Step 2 of this code, which focuses on the FastGAN architecture and training components. I'll break down the key classes and functions:

1. **SLEBlock (Skip-Layer Excitation Block)**
```python
class SLEBlock(nn.Module):
    def __init__(self, in_channels):
        super(SLEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // 2, 1)
        self.fc2 = nn.Conv2d(in_channels // 2, in_channels, 1)
        self.sigmoid = nn.Sigmoid()
```
This block implements attention mechanism in the generator:
- Uses global average pooling to capture channel-wise statistics
- Has two 1x1 convolutions that act as fully connected layers
- Applies sigmoid activation to generate attention weights
- The output modulates the skip connection features

2. **FASTGANGenerator**
```python
class FASTGANGenerator(nn.Module):
    def __init__(self, latent_dim=256, ngf=64, output_size=64)
```
The generator has a progressive architecture:
- Takes 256-dimensional noise vector as input
- Uses transposed convolutions to progressively upscale the image
- Incorporates SLE blocks after certain layers for better feature refinement
- The output size is 64x64x3 (RGB image)
- Uses BatchNorm and ReLU activations throughout
- Final Tanh activation to normalize output to [-1, 1]

3. **FASTGANDiscriminator**
```python
class FASTGANDiscriminator(nn.Module):
    def __init__(self, ndf=64, input_size=64)
```
The discriminator:
- Takes 64x64x3 images as input
- Uses regular convolutions to progressively downsample
- Includes BatchNorm and LeakyReLU activations
- Ends with adaptive average pooling and final convolution
- Outputs a single value through sigmoid for real/fake classification

4. **Training Functions**

`train_step()` handles a single training iteration:
```python
def train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, device, autoencoder_loss)
```
- First trains discriminator:
  - Gets predictions for real images
  - Generates fake images
  - Calculates discriminator loss using binary cross entropy
  - Updates discriminator weights

- Then trains generator:
  - Generates fake images
  - Gets discriminator predictions
  - Calculates generator loss
  - Updates generator weights

`train_fastgan()` manages the overall training process:
```python
def train_fastgan(generator, discriminator, dataloader, num_epochs, device='cuda')
```
- Sets up Adam optimizers for both networks
- Runs training for specified number of epochs
- Prints progress every 100 batches
- Saves sample generated images every 500 batches

5. **Image Classification Verification**

```python
class SyntheticImageClassifier:
    def __init__(self, num_classes, device='cuda')
```
This class uses two pre-trained models (EfficientNetV2 and ShuffleNetV2) to verify the quality of generated images:
- Both models are modified for the specific number of classes
- Images are only kept if both models agree on the classification
- Helps filter out poor quality or ambiguous generated images

The whole system follows a progressive training approach where:
1. The generator creates images from random noise
2. The discriminator learns to distinguish real from fake
3. Generated images are verified by classification models
4. Only high-quality, confidently classified images are kept

This architecture is specifically designed for fast training while maintaining good image quality, making it suitable for medical image synthesis tasks.

Here's the core GAN algorithm architecture and configuration from the provided information:

1. Architecture Components:
- Generator: Uses single convolution layer per resolution with restricted channels
- Skip-Layer Excitation (SLE) module for gradient flow enhancement
- Self-supervised discriminator with small decoders
- Two CNN classifiers (EfficientNetV2 and ShuffleNetV2) for synthetic image filtering

2. Key Features:
- SLE module: 
  - Implements channel-wise multiplications
  - Creates skip-connections between distant resolutions
  - Handles content/style attribute disentanglement

3. Training Process:
- Uses hinge version of adversarial loss
- Iterative training between discriminator and generator
- Discriminator regularization through auto-encoding
- Two-stage filtering:
  1. Generate synthetic images from original medical data
  2. Filter generations through dual CNN validation

4. Data Flow:
```
Original Images → FASTGAN Generator → Synthetic Images → Dual CNN Filtering → Filtered Dataset + Original Images → Few-shot Classifier
```