<a href="https://colab.research.google.com/github/AdithyaSean/synthetic-medical-scan-generator/blob/main/Medical_Image_Synthesis_Ensemble_Model_Strategies_MoE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here is a comprehensive guide to developing a medical image synthesis system using ensemble models, addressing your specific questions with insights from recent research.

### Ensemble Model Architectures

Ensemble models are a powerful approach for your multi-modal synthesis task, as they can improve performance, robustness, and generalizability by combining the strengths of multiple models. Here's a breakdown of suitable ensemble strategies:

*   **Stacking:** This involves training a meta-model to combine the outputs of several base generative models. For instance, you could use a convolutional neural network (CNN) to learn how to best combine the outputs of a Generative Adversarial Network (GAN), a Variational Autoencoder (VAE), and a diffusion model to produce a final, high-fidelity image. One study successfully used a stacking ensemble of a VGG-19 network and a Siamese neural network for medical image fusion.
*   **Mixture of Experts (MoE):** In an MoE architecture, you would train multiple "expert" generative models, each specializing in a particular modality (e.g., one expert for MRI, one for CT, and one for X-ray). A "gating network" would then learn to determine which expert (or combination of experts) to use for a given input condition. This is highly suitable for your conditional synthesis task.
*   **Bagging and Boosting:** While less common for generative tasks, these techniques could be adapted. For instance, in a bagging-like approach, you could train multiple generative models on different subsets of your training data and average their outputs.

**Promising Combinations of Generative Models:**

Recent research has shown great promise in using state-of-the-art generative models for medical imaging:

*   **Diffusion Models:** These models have emerged as the leading approach for generating high-quality, diverse, and anatomically coherent medical images, often outperforming GANs in terms of training stability and image fidelity.
*   **Generative Adversarial Networks (GANs):** GANs, particularly conditional GANs (cGANs), are a powerful tool for image-to-image translation and synthesis. They can be conditioned on various inputs, including image modality, to generate specific types of scans.
*   **Vision Transformers (ViT):** For capturing long-range spatial relationships in medical images, ViTs can be combined with CNNs in an ensemble to enhance performance.

A recommended approach would be to create a stacking ensemble that combines the outputs of a conditional diffusion model and a conditional GAN. This would leverage the strengths of both architectures – the high-fidelity output of the diffusion model and the sharp details often produced by GANs.

### Conditional Generation Techniques

Robustly conditioning the generative process on the input scan type is crucial. Here are some best practices:

*   **Concatenation-Based Conditioning:** A simple and effective method is to encode the modality information as a one-hot vector (e.g., for MRI, for CT, for X-ray) and concatenate it with the input noise vector or the image representation at each step of the generative process. This technique has been successfully used in conditional diffusion models for 3D medical image synthesis.
*   **Semantic Mask Conditioning:** For more fine-grained control, you can condition the model on a semantic segmentation mask. This allows you to not only specify the modality but also the desired anatomical structures and pathologies. For example, you could provide a mask of a brain with a tumor and have the model generate a realistic MRI of that brain.
*   **Cross-Attention Mechanisms:** In Transformer-based architectures like ViTs and some diffusion models, you can use cross-attention to condition the generation process on the encoded modality information. This allows the model to learn more complex relationships between the condition and the generated image.

### Type Identification/Classification

For identifying the input scan type, you have two main options:

1.  **Separate Pre-processing Step:** This involves training a dedicated, high-performance image classification model to first identify the modality of the input scan. This model can be an ensemble of CNNs, ViTs, or a combination of both to achieve high accuracy. This approach is modular and allows you to use a highly specialized classification model.
2.  **Integrated Pipeline:** You can integrate the classification task into the generative pipeline. For instance, the encoder of your generative model could be trained to produce a latent representation that is not only used for generation but also fed into a small classification head to predict the modality. This can be more efficient but may require careful balancing of the generation and classification losses.

For robustness, especially with noisy or varied datasets, an ensemble-based classifier is recommended. One approach is to use a two-stage selective ensemble of CNN branches, which has been shown to mitigate overfitting and the vanishing gradient problem.

### Dataset Considerations and Augmentation

A high-quality, diverse, and representative dataset is the foundation of your system.

**Dataset Curation:**

*   **Diversity:** Your dataset should include images from a wide range of patients, pathologies, and acquisition settings to ensure your model generalizes well.
*   **Data Quality:** The images should be of high quality, with accurate and consistent annotations.
*   **Data Balance:** Ensure a balanced representation of each modality to prevent the model from being biased towards the most frequent class.

**Advanced Data Augmentation:**

Beyond traditional augmentation techniques like rotation, flipping, and scaling, consider these advanced methods:

*   **Generative Data Augmentation:** Use a pre-trained generative model (e.g., a GAN or diffusion model) to synthesize additional training data. This is particularly useful for augmenting underrepresented classes.
*   **Learned Transformations:** Instead of random transformations, learn the transformations from the data itself. This can involve learning spatial deformation fields and intensity changes to create more realistic augmentations.
*   **Semantic Data Augmentation (SDA):** This involves manipulating the latent space representations of images to create semantic variations while preserving the label.

### Evaluation Metrics

To assess the clinical utility and realism of your synthetic scans, you need to go beyond standard image quality metrics.

*   **Downstream Task Evaluation:** The most crucial evaluation is to assess the utility of your synthetic data for a downstream clinical task. For example, train a segmentation or classification model on your synthetic data and evaluate its performance on a real test set using metrics like the Dice Similarity Coefficient (DSC).
*   **Visual Turing Test:** Have clinical experts try to distinguish between your synthetic images and real ones. This provides a qualitative assessment of realism.
*   **Domain-Specific Metrics:** Instead of using FID with a feature extractor trained on natural images, use a feature extractor pre-trained on a large medical image dataset. The Fréchet MedicalNet Distance (FMD) is a good example of this.
*   **Standard Metrics:** Continue to use standard metrics like Peak Signal-to-Noise Ratio (PSNR), Structural Similarity Index (SSIM), and Fréchet Inception Distance (FID) as a baseline, but be aware of their limitations in the medical context.

### Challenges and Mitigation Strategies

*   **Mode Collapse and Memorization:**
    *   **Challenge:** The model may generate a limited variety of images or simply copy training examples.
    *   **Mitigation:** Diffusion models are generally less prone to mode collapse than GANs. To avoid memorization, use large and diverse datasets, and employ regularization techniques.
*   **Anatomical Inconsistency and Artifacts:**
    *   **Challenge:** The model may generate anatomically incorrect images or introduce artifacts.
    *   **Mitigation:** Conditioning on semantic masks can help enforce anatomical consistency. Using multi-scale discriminators in GANs can help the model learn local details and reduce artifacts.
*   **Data Scarcity:**
    *   **Challenge:** Limited access to large, annotated medical datasets.
    *   **Mitigation:** Use data augmentation techniques, especially generative data augmentation, to expand your dataset. Federated learning can also be used to train on data from multiple institutions without compromising privacy.

### Computational Resources

Training complex ensemble models for high-resolution medical image generation is computationally intensive.

*   **GPU:** High-end GPUs with large VRAM (e.g., NVIDIA A100 or H100) are essential for training these models in a reasonable timeframe.
*   **Memory:** The memory requirements will depend on the model complexity, image resolution (especially for 3D images), and batch size. Expect to need a significant amount of RAM and VRAM.
*   **Training Time:** Training can take anywhere from several days to weeks, depending on the size of your dataset and the complexity of your model. One study noted that a DCGAN model took 4 to 7 hours to train, while more complex models like CycleGAN took 7 to 10 hours on the same datasets.
*   **Federated Learning:** This can distribute the computational load across multiple institutions, but it requires a robust infrastructure for communication and model aggregation.

This is a great question that dives into the practical application of advanced deep learning. While creating a full-fledged, production-ready ensemble model for medical image synthesis is a significant undertaking, I can provide you with a detailed Python notebook structure that serves as a starting point. This notebook will focus on a core component: training a Conditional Deep Convolutional Generative Adversarial Network (cDCGAN) to generate medical images of a specific type. You can build upon this foundation to create a more complex ensemble system.

This guide is inspired by several excellent resources and tutorials on GANs and medical imaging.[1][2][3] It will walk you through the process step-by-step, explaining each part of the code. For more advanced applications, including 3D medical image synthesis and readily available tools, you might find the MONAI framework by Project MONAI to be very useful.[4]

Python Notebook: Conditional GAN for Medical Image Synthesis

This notebook demonstrates how to train a Conditional GAN (cGAN) to generate 2D medical images (e.g., X-rays) based on a class condition. We'll use the popular PyTorch library. You can adapt this notebook to use different medical imaging datasets and modalities.

# 1. Setup and Imports

First, let's install the necessary libraries and import them. This notebook is designed to be run in a Google Colab environment to take advantage of their free GPU resources.[5]

In [1]:
# Install necessary libraries
!pip install torch torchvision torcheval
!pip install matplotlib

Collecting torcheval
  Downloading torcheval-0.0.7-py3-none-any.whl.metadata (8.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

In [24]:
# Import libraries
import os
import subprocess
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from google.colab import drive, files

# 2. Dataset and DataLoader

For this example, we'll use a public medical imaging dataset. A good choice is the Chest X-Ray Images (Pneumonia) dataset from Kaggle. You will need to download it and place it in a known directory.

This notebook will assume you have a directory structure like this:

```
/kaggle_data/
    xray_dataset/
        chest_dataset/
            chest_xray/
                train/
                    NORMAL/
                        ... (normal x-ray images)
                    PNEUMONIA/
                        ... (pneumonia x-ray images)
```

We will create a custom dataset class to load the images and their corresponding labels (0 for NORMAL, 1 for PNEUMONIA).

In [25]:
# ==============================================================================
#  HELPER FUNCTION TO RUN SHELL COMMANDS
# ==============================================================================
def run_command(command):
    """A helper function to run shell commands and print their output."""
    try:
        print(f"Executing: {command}")
        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        if result.stdout:
            print("Output:\n", result.stdout)
        if result.stderr:
            print("Error output:\n", result.stderr) # Kaggle API often prints to stderr
        print("Command executed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: {command}")
        print(f"Return code: {e.returncode}")
        if e.stdout:
            print(f"Output: \n{e.stdout}")
        if e.stderr:
            print(f"Error output: \n{e.stderr}")

In [26]:
# ==============================================================================
#  GOOGLE DRIVE MOUNT
# ==============================================================================
try:
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")
    print("Proceeding without Google Drive, checkpoints will only be saved locally.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully.


In [28]:
# ==============================================================================
#  DATASET DOWNLOAD METHODS (MedMNIST & Kaggle)
# ==============================================================================
# --- METHOD 1: MedMNIST v2 ---
def download_medmnist_datasets():
    print("="*60)
    print("METHOD 1: DOWNLOADING DATASETS WITH MedMNIST")
    print("="*60)
    run_command("pip install --quiet medmnist")
    try:
        from medmnist import ChestMNIST, OrganMNISTAxial, BrainMNIST
        print("\n--- Downloading ChestMNIST (X-Ray) ---")
        ChestMNIST(split='train', download=True, root='./medmnist_data/')
        ChestMNIST(split='test', download=True, root='./medmnist_data/')
        print("ChestMNIST downloaded successfully to ./medmnist_data/")

        print("\n--- Downloading OrganMNISTAxial (CT) ---")
        OrganMNISTAxial(split='train', download=True, root='./medmnist_data/')
        OrganMNISTAxial(split='test', download=True, root='./medmnist_data/')
        print("OrganMNISTAxial downloaded successfully to ./medmnist_data/")

        print("\n--- Downloading BrainMNIST (MRI) ---")
        BrainMNIST(split='train', download=True, root='./medmnist_data/')
        BrainMNIST(split='test', download=True, root='./medmnist_data/')
        print("BrainMNIST downloaded successfully to ./medmnist_data/")
        print("\nAll selected MedMNIST datasets downloaded.")
    except ImportError as e:
        print(f"Could not import from medmnist. Please ensure it is installed correctly. Error: {e}")
    except Exception as e:
        print(f"An error occurred during MedMNIST download: {e}")

# --- METHOD 2: KAGGLE ---
KAGGLE_DATA_DIR = "./kaggle_data" # Base directory for Kaggle datasets

def download_kaggle_dataset(dataset_slug, unzip_name):
    print(f"\n--- Downloading {dataset_slug} from Kaggle ---")
    run_command(f"kaggle datasets download -d {dataset_slug} -p {KAGGLE_DATA_DIR} --force") # --force to overwrite if exists
    zip_file_name = dataset_slug.split('/')[-1] + ".zip"
    zip_path = os.path.join(KAGGLE_DATA_DIR, zip_file_name)
    unzip_target_path = os.path.join(KAGGLE_DATA_DIR, unzip_name)
    run_command(f"mkdir -p {unzip_target_path}")
    run_command(f"unzip -q -o {zip_path} -d {unzip_target_path}") # -o to overwrite
    run_command(f"rm {zip_path}")
    print(f"{dataset_slug} dataset ready at: {unzip_target_path}")
    return unzip_target_path

def setup_kaggle_api():
    print("\n" + "="*60)
    print("METHOD 2: DOWNLOADING DATASETS FROM KAGGLE")
    print("="*60)
    run_command("pip install --quiet kaggle")
    try:
        print("\nPlease upload your 'kaggle.json' file (if not already configured):")
        # Check if kaggle.json already exists to avoid re-uploading if kernel restarts
        if not os.path.exists("~/.kaggle/kaggle.json") and not os.path.exists("kaggle.json"):
             files.upload() # This will prompt for upload
        else:
            print("'kaggle.json' or ~/.kaggle/kaggle.json seems to exist. Skipping upload.")

        if os.path.exists("kaggle.json"): # If uploaded to current dir
            run_command("mkdir -p ~/.kaggle")
            run_command("cp kaggle.json ~/.kaggle/")
            run_command("chmod 600 ~/.kaggle/kaggle.json")
        elif not os.path.exists("~/.kaggle/kaggle.json"):
            print("Kaggle API token not found. Please ensure kaggle.json is uploaded or placed in ~/.kaggle/")
            return False
        print("\nKaggle API configured successfully (or was already configured).")
        return True
    except Exception as e:
        print(f"\nAn error occurred during Kaggle setup: {e}")
        print("If running locally, please place your 'kaggle.json' file in '~/.kaggle/' manually.")
        return False

In [29]:
# ==============================================================================
#  DATASET CONFIGURATION
# ==============================================================================
# Define configurations for different datasets
# 'image_data_path_relative': Path relative to the dataset's base_path where class folders or images are.
# 'num_classes': For cGAN, number of distinct classes. For unconditional GAN, can be 1.
# 'img_channels': Number of image channels (1 for grayscale, 3 for RGB).
# 'img_size': Target image size after resizing.

DATASET_CONFIGS = {
    "covid_ct": {
        "needs_kaggle_api": True,
        "kaggle_slug": "maedemaftouni/large-covid19-ct-slice-dataset",
        "base_path_segment": "covid_ct_dataset", # Folder name under KAGGLE_DATA_DIR
        # This dataset often unzips with a nested structure.
        # The actual image classes (e.g., '1NonCOVID', '2COVID') are typically here:
        "image_data_path_relative": ["curated_data/curated_data", "Large-COVID-19-CT-slice-dataset/curated_data/curated_data"],
        "num_classes": 2, # Example: COVID vs Non-COVID
        "img_channels": 1, # Will be transformed to grayscale
        "img_size": 64
    },
    "chest_xray": {
        "needs_kaggle_api": True,
        "kaggle_slug": "paultimothymooney/chest-xray-pneumonia",
        "base_path_segment": "xray_dataset",
        # Assumes training images are in 'chest_xray/train' relative to base_path_segment
        "image_data_path_relative": ["chest_xray/train", "chest-xray-pneumonia/train"],
        "num_classes": 2, # PNEUMONIA vs NORMAL
        "img_channels": 1, # Will be transformed to grayscale
        "img_size": 64
    },
    # Add more dataset configs here if needed
    # "brats_mri": { ... }
}

# --- SELECT THE DATASET TO USE ---
# Change this to "chest_xray", or other keys you add to DATASET_CONFIGS
SELECTED_DATASET_NAME = "covid_ct"
# SELECTED_DATASET_NAME = "chest_xray"

if SELECTED_DATASET_NAME not in DATASET_CONFIGS:
    raise ValueError(f"Dataset '{SELECTED_DATASET_NAME}' is not configured in DATASET_CONFIGS.")

ACTIVE_DATASET_CONFIG = DATASET_CONFIGS[SELECTED_DATASET_NAME]

In [30]:
# ==============================================================================
#  DOWNLOAD SELECTED DATASET (if applicable)
# ==============================================================================
# This block will attempt to download the selected dataset if it's from Kaggle
# and if the Kaggle API is configured.

dataset_base_path = None
actual_image_data_root = None

if ACTIVE_DATASET_CONFIG.get("needs_kaggle_api", False):
    if setup_kaggle_api():
        dataset_base_path = download_kaggle_dataset(
            ACTIVE_DATASET_CONFIG["kaggle_slug"],
            ACTIVE_DATASET_CONFIG["base_path_segment"]
        )

        # Determine the actual image data root by checking potential relative paths
        found_path = False
        for rel_path in ACTIVE_DATASET_CONFIG["image_data_path_relative"]:
            potential_path = os.path.join(dataset_base_path, rel_path)
            if os.path.exists(potential_path) and os.path.isdir(potential_path):
                actual_image_data_root = potential_path
                print(f"Image data found at: {actual_image_data_root}")
                found_path = True
                break
        if not found_path:
            print(f"ERROR: Could not find image data directory for {SELECTED_DATASET_NAME} at expected relative paths within {dataset_base_path}.")
            print(f"Please check the 'image_data_path_relative' in DATASET_CONFIGS for {SELECTED_DATASET_NAME} and the unzipped Kaggle dataset structure.")
            # Optionally, list the contents of dataset_base_path to help debug
            print(f"Contents of {dataset_base_path}:")
            run_command(f"ls -lR {dataset_base_path}")
            # Set actual_image_data_root to None to prevent dataloader creation if path is not found
            actual_image_data_root = None
    else:
        print("Kaggle API setup failed. Cannot download Kaggle dataset.")
# Add MedMNIST download trigger if a MedMNIST dataset is selected (not implemented in this example config)
# elif ACTIVE_DATASET_CONFIG.get("source") == "medmnist":
#     download_medmnist_datasets()
#     actual_image_data_root = os.path.join("./medmnist_data", ACTIVE_DATASET_CONFIG["medmnist_name"], "train") # Adjust as needed


METHOD 2: DOWNLOADING DATASETS FROM KAGGLE
Executing: pip install --quiet kaggle
Command executed successfully!

Please upload your 'kaggle.json' file (if not already configured):
'kaggle.json' or ~/.kaggle/kaggle.json seems to exist. Skipping upload.
Executing: mkdir -p ~/.kaggle
Command executed successfully!
Executing: cp kaggle.json ~/.kaggle/
Command executed successfully!
Executing: chmod 600 ~/.kaggle/kaggle.json
Command executed successfully!

Kaggle API configured successfully (or was already configured).

--- Downloading maedemaftouni/large-covid19-ct-slice-dataset from Kaggle ---
Executing: kaggle datasets download -d maedemaftouni/large-covid19-ct-slice-dataset -p ./kaggle_data --force
Output:
 Dataset URL: https://www.kaggle.com/datasets/maedemaftouni/large-covid19-ct-slice-dataset
License(s): other
Downloading large-covid19-ct-slice-dataset.zip to ./kaggle_data


Error output:
 
  0%|          | 0.00/2.06G [00:00<?, ?B/s]
  5%|▍         | 96.0M/2.06G [00:00<00:02, 996MB/

In [32]:
# ==============================================================================
#  CUSTOM DATASET CLASSES
# ==============================================================================
# Generic Dataset Class for image folders (can be used for ChestXRay, COVID-CT if structured with class subfolders)
class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None, expected_channels=3):
        """
        Args:
            root_dir (string): Directory with all the image class folders (e.g., 'train' folder containing 'NORMAL' and 'PNEUMONIA').
            transform (callable, optional): Optional transform to be applied on a sample.
            expected_channels (int): 1 for grayscale, 3 for RGB. For .convert()
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = []
        self.labels = []
        self.class_names = []

        if not os.path.exists(root_dir):
            print(f"Error: Root directory for dataset not found at {root_dir}")
            return

        # Assuming structure: root_dir/class_name/image.png
        sorted_classes = sorted(os.listdir(root_dir)) # Sort for consistent label assignment
        for i, class_name in enumerate(sorted_classes):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_names.append(class_name)
                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')):
                        self.image_files.append(os.path.join(class_path, img_name))
                        self.labels.append(i) # Assign label based on folder index

        if not self.image_files:
            print(f"Warning: No image files found in {root_dir} or its subdirectories.")
        else:
            print(f"Found {len(self.image_files)} images in {len(self.class_names)} classes from {root_dir}.")
            print(f"Classes found: {self.class_names}")


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.image_files[idx]
        try:
            # Ensure consistent number of channels based on what the transform expects
            # For this GAN, we usually convert to RGB first, then Grayscale in transform if needed.
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image and label or raise error
            dummy_shape = (ACTIVE_DATASET_CONFIG['img_size'], ACTIVE_DATASET_CONFIG['img_size'])
            if ACTIVE_DATASET_CONFIG['img_channels'] == 1:
                return torch.zeros((1, *dummy_shape)), torch.tensor(0, dtype=torch.long)
            else:
                return torch.zeros((3, *dummy_shape)), torch.tensor(0, dtype=torch.long)

        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

# Note: BraTS2020Dataset is more complex due to NIfTI files and 3D data.
# The placeholder is kept here, but a full implementation is beyond this scope.
# class BraTS2020Dataset(Dataset): ...

In [34]:
# ==============================================================================
#  DATA TRANSFORMATION AND DATALOADER CREATION
# ==============================================================================
# These should be available from your ACTIVE_DATASET_CONFIG
IMG_SIZE = ACTIVE_DATASET_CONFIG['img_size']
IMG_CHANNELS = ACTIVE_DATASET_CONFIG['img_channels']

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels=IMG_CHANNELS), # Ensure correct channels (1 for grayscale, 3 for RGB)
    transforms.ToTensor(), # Converts PIL image or numpy.ndarray to tensor and scales to [0, 1]
    transforms.Normalize((0.5,) * IMG_CHANNELS, (0.5,) * IMG_CHANNELS) # Normalizes to [-1, 1] for each channel
])

# --- Attempt to create DataLoader for the selected dataset ---
main_dataloader = None # Initialize to None

if actual_image_data_root and os.path.exists(actual_image_data_root) and os.path.isdir(actual_image_data_root):
    # The ImageFolderDataset class itself prints messages about found images/classes or errors during its initialization.
    print(f"Attempting to load dataset from: {actual_image_data_root}")
    selected_dataset = ImageFolderDataset(root_dir=actual_image_data_root, transform=transform)

    if len(selected_dataset) > 0: # Proceed only if the dataset found some images
        # --- Dynamically update num_classes based on dataset scan ---
        actual_num_classes_found = len(selected_dataset.class_names)

        if actual_num_classes_found > 0:
            # Case 1: Classes (subdirectories with images) were found in the dataset directory
            if actual_num_classes_found != ACTIVE_DATASET_CONFIG['num_classes']:
                print(f"INFO: Dynamically updating num_classes for '{SELECTED_DATASET_NAME}'.")
                print(f"  Config had: {ACTIVE_DATASET_CONFIG['num_classes']} classes.")
                print(f"  Dataset scan found: {actual_num_classes_found} classes ({selected_dataset.class_names}).")
                ACTIVE_DATASET_CONFIG['num_classes'] = actual_num_classes_found
                print(f"  ACTIVE_DATASET_CONFIG['num_classes'] is now: {ACTIVE_DATASET_CONFIG['num_classes']}.")
            else:
                # Number of classes found matches the config, no update needed, but good to confirm.
                print(f"INFO: Confirmed num_classes for '{SELECTED_DATASET_NAME}'.")
                print(f"  Config and dataset scan both indicate: {actual_num_classes_found} classes ({selected_dataset.class_names}).")

        elif actual_num_classes_found == 0:
            # Case 2: No class subdirectories with images were found by ImageFolderDataset
            # This means root_dir itself might contain images, or it's truly empty of class structures.
            # ImageFolderDataset as written expects class subfolders. If it finds 0 class_names,
            # it implies no such subfolders were processed.
            print(f"WARNING: No class subdirectories with images were found by ImageFolderDataset in '{actual_image_data_root}'.")
            if ACTIVE_DATASET_CONFIG['num_classes'] > 0: # If config expected classes
                print(f"  Config for '{SELECTED_DATASET_NAME}' expected {ACTIVE_DATASET_CONFIG['num_classes']} classes.")
                print(f"  This discrepancy might lead to issues if models require conditional input based on num_classes > 0.")
                # Decision point: What to do?
                # Option A: Abort if classes are strictly necessary:
                #   raise ValueError(f"Dataset at {actual_image_data_root} has no class subfolders, but config expects {ACTIVE_DATASET_CONFIG['num_classes']}.")
                # Option B: Fallback to treating it as unconditional (num_classes = 1 or effectively ignored by model)
                #   ACTIVE_DATASET_CONFIG['num_classes'] = 1 # Or handle this in model appropriately
                #   print(f"  Setting num_classes to 1 as a fallback for unconditional processing.")
                # For now, we'll just warn and proceed with the config's num_classes.
                # The models should be robust or this case should be handled by specific dataset logic if needed.
                print(f"  Proceeding with config's num_classes = {ACTIVE_DATASET_CONFIG['num_classes']}. Verify model compatibility.")

            else: # actual_num_classes_found == 0 and ACTIVE_DATASET_CONFIG['num_classes'] == 0 (or was already 0)
                print(f"  Config for '{SELECTED_DATASET_NAME}' also has num_classes = 0. Assuming unconditional GAN setup or this is intended.")
        # --- End of dynamic num_classes update ---

        # Instantiate DataLoader
        # Ensure BATCH_SIZE and NUM_WORKERS are defined (e.g., in your Hyperparameters section)
        # Using typical values here as placeholders if they aren't globally defined for this snippet.
        current_batch_size = 64 # Replace with your BATCH_SIZE variable if defined elsewhere
        current_num_workers = 2  # Replace with your NUM_WORKERS variable if defined elsewhere

        main_dataloader = DataLoader(
            selected_dataset,
            batch_size=current_batch_size,
            shuffle=True, # Shuffle data for training
            num_workers=current_num_workers,
            pin_memory=True # Can speed up data transfer to GPU if available
        )

        # Report the final state after attempting DataLoader creation
        print(f"--- DataLoader Creation Summary for '{SELECTED_DATASET_NAME}' ---")
        print(f"  Dataset path: {actual_image_data_root}")
        print(f"  Total images found by Dataset class: {len(selected_dataset)}")
        print(f"  Number of classes for model (from ACTIVE_DATASET_CONFIG): {ACTIVE_DATASET_CONFIG['num_classes']}")
        if selected_dataset.class_names:
             print(f"  Classes detected in dataset structure: {selected_dataset.class_names}")
        print(f"  DataLoader created successfully.")

        # Optional: Display an example batch (uncomment to use for verification)
        # print(f"\n--- Verifying DataLoader: Example Batch from '{SELECTED_DATASET_NAME}' ---")
        # try:
        #     example_images, example_labels = next(iter(main_dataloader))
        #     print(f"  Batch details - Images shape: {example_images.shape}, Labels shape: {example_labels.shape}")
        #     print(f"  Example labels in batch: {example_labels[:5].tolist()}...") # Show first 5 labels
        #     if IMG_CHANNELS == 1 and example_images.nelement() > 0:
        #         plt.figure(figsize=(3,3))
        #         plt.imshow(example_images[0].squeeze().cpu().numpy(), cmap='gray')
        #         plt.title(f"Label: {example_labels[0].item()}")
        #         plt.axis('off')
        #         plt.show()
        #     elif example_images.nelement() > 0: # For RGB
        #          plt.figure(figsize=(3,3))
        #          # Unnormalize for display: (data * 0.5) + 0.5
        #          img_to_show = (example_images[0].permute(1,2,0).cpu().numpy() * 0.5) + 0.5
        #          plt.imshow(img_to_show.clip(0,1)) # Clip to ensure valid range for imshow
        #          plt.title(f"Label: {example_labels[0].item()}")
        #          plt.axis('off')
        #          plt.show()
        # except StopIteration:
        #     print("  ERROR: Could not retrieve an example batch (DataLoader might be empty unexpectedly).")
        # except Exception as e:
        #     print(f"  ERROR displaying example batch: {e}")

    else: # This 'else' corresponds to 'if len(selected_dataset) > 0'
        print(f"ERROR: Dataset for '{SELECTED_DATASET_NAME}' was initialized but found 0 images in '{actual_image_data_root}'.")
        print(f"  DataLoader cannot be created with an empty dataset.")
        if selected_dataset.class_names: # If class folders existed but were empty of valid images
            print(f"  Note: Class folders were detected by ImageFolderDataset: {selected_dataset.class_names}, but they contained no loadable image files.")
        elif not os.listdir(actual_image_data_root):
             print(f"  Note: The directory '{actual_image_data_root}' appears to be empty.")
        else:
            print(f"  Note: Please check the contents of '{actual_image_data_root}' for valid image files and expected class subfolder structure.")

else: # This 'else' corresponds to 'if actual_image_data_root and os.path.exists(actual_image_data_root) ...'
    if not actual_image_data_root:
        print(f"ERROR: Image data root path for '{SELECTED_DATASET_NAME}' was not determined (variable is None). DataLoader not created.")
    elif not os.path.exists(actual_image_data_root):
        print(f"ERROR: Image data root path for '{SELECTED_DATASET_NAME}' ('{actual_image_data_root}') does not exist. DataLoader not created.")
    elif not os.path.isdir(actual_image_data_root):
        print(f"ERROR: Image data root path for '{SELECTED_DATASET_NAME}' ('{actual_image_data_root}') is not a directory. DataLoader not created.")

# The rest of your script (Model Definitions, Hyperparameters & Setup, etc.)
# will proceed. If main_dataloader is None, the training loop should have a check.
# The num_classes used for model initialization should be taken from ACTIVE_DATASET_CONFIG['num_classes']
# E.g., in your "Hyperparameters & Setup" section:
#   num_classes_for_model = ACTIVE_DATASET_CONFIG['num_classes']
#   generator = Generator(latent_dim, num_classes_for_model, img_shape_tuple)
#   discriminator = BCEDiscriminator(num_classes_for_model, img_shape_tuple)

Attempting to load dataset from: ./kaggle_data/covid_ct_dataset/curated_data/curated_data
Found 17104 images in 3 classes from ./kaggle_data/covid_ct_dataset/curated_data/curated_data.
Classes found: ['1NonCOVID', '2COVID', '3CAP']
INFO: Dynamically updating num_classes for 'covid_ct'.
  Config had: 2 classes.
  Dataset scan found: 3 classes (['1NonCOVID', '2COVID', '3CAP']).
  ACTIVE_DATASET_CONFIG['num_classes'] is now: 3.
--- DataLoader Creation Summary for 'covid_ct' ---
  Dataset path: ./kaggle_data/covid_ct_dataset/curated_data/curated_data
  Total images found by Dataset class: 17104
  Number of classes for model (from ACTIVE_DATASET_CONFIG): 3
  Classes detected in dataset structure: ['1NonCOVID', '2COVID', '3CAP']
  DataLoader created successfully.


# 3. Generator Network

The Generator takes a random noise vector and a class label as input and generates an image.

In [35]:
# ==============================================================================
#  MODEL DEFINITIONS (Generator)
# ==============================================================================
# --- Generator ---
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape_tuple): # img_shape_tuple (C, H, W)
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.img_shape_tuple = img_shape_tuple
        self.init_size = img_shape_tuple[1] // 4 # Initial size for transposed conv.

        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128 * (self.img_shape_tuple[1]//4)**2 ), # Adjusted for upsampling path
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Upsampling path (example using ConvTranspose2d)
        # This is a common DCGAN-style generator architecture. The original was MLP-based.
        # If sticking to MLP, the Linear layers need to output np.prod(img_shape_tuple).
        # The original code had a simple MLP which might struggle with image generation.
        # Let's keep the original MLP structure for now as per the user's code,
        # but note that CNN-based generators are generally better for images.

        self.mlp_model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256), # BatchNorm after Linear, before activation
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(np.prod(self.img_shape_tuple))),
            nn.Tanh() # To scale output to [-1, 1]
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img_flat = self.mlp_model(gen_input)
        img = img_flat.view(img_flat.size(0), *self.img_shape_tuple)
        return img

# 4. Discriminator Network

The Discriminator takes an image and a class label as input and determines if the image is real or fake.

In [36]:
# ==============================================================================
#  MODEL DEFINITIONS (Discriminator)
# ==============================================================================
# --- Discriminator (for BCEWithLogitsLoss) ---
class BCEDiscriminator(nn.Module):
    def __init__(self, num_classes, img_shape_tuple): # img_shape_tuple (C, H, W)
        super(BCEDiscriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.img_shape_tuple = img_shape_tuple

        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(self.img_shape_tuple)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1) # Output layer (no sigmoid, BCEWithLogitsLoss handles it)
        )

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        d_in = torch.cat((img_flat, self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

# 5. Training Loop

This is the core of the GAN training process. The Generator and Discriminator are trained in an adversarial manner.

In [37]:
# ==============================================================================
#  HYPERPARAMETERS & SETUP
# ==============================================================================
latent_dim = 100
num_classes = ACTIVE_DATASET_CONFIG['num_classes']
# Image shape tuple: (channels, height, width)
img_shape_tuple = (ACTIVE_DATASET_CONFIG['img_channels'], ACTIVE_DATASET_CONFIG['img_size'], ACTIVE_DATASET_CONFIG['img_size'])
epochs = 100
lr = 0.0002
b1 = 0.5  # Adam optimizer beta1
b2 = 0.999 # Adam optimizer beta2
batch_size = 64 # Already used in DataLoader, ensure consistency or pass from here

# --- Directories for saving ---
LOCAL_CHECKPOINT_DIR_NAME = f"gan_checkpoints_{SELECTED_DATASET_NAME}"
GDRIVE_CHECKPOINT_DIR_NAME = f"gan_checkpoints_{SELECTED_DATASET_NAME}" # Same name, but on Drive
LOCAL_IMAGE_DIR = f"gan_images_{SELECTED_DATASET_NAME}"

os.makedirs(LOCAL_CHECKPOINT_DIR_NAME, exist_ok=True)
os.makedirs(LOCAL_IMAGE_DIR, exist_ok=True)

# Google Drive path (ensure Drive is mounted)
gdrive_base_path = '/content/drive/My Drive/'
gdrive_save_path = None
if os.path.exists(gdrive_base_path): # Check if drive was mounted
    gdrive_save_path = os.path.join(gdrive_base_path, GDRIVE_CHECKPOINT_DIR_NAME)
    os.makedirs(gdrive_save_path, exist_ok=True)
    print(f"Checkpoints will be saved to Google Drive at: {gdrive_save_path}")
else:
    print("Google Drive not mounted. Checkpoints will only be saved locally.")


# --- Initialize Models, Loss, and Optimizers ---
generator = Generator(latent_dim, num_classes, img_shape_tuple)
discriminator = BCEDiscriminator(num_classes, img_shape_tuple)
adversarial_loss = torch.nn.BCEWithLogitsLoss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)
print(f"Using device: {device}")

Checkpoints will be saved to Google Drive at: /content/drive/My Drive/gan_checkpoints_covid_ct
Using device: cpu


In [39]:
# ==============================================================================
#  CHECKPOINT LOADING CONFIGURATION & FUNCTIONALITY
# ==============================================================================
LOAD_FROM_CHECKPOINT = True  # SET TO True TO ENABLE LOADING
START_EPOCH = 0 # Will be updated if checkpoint is loaded

def find_latest_checkpoint_epoch(checkpoint_dir, model_prefix):
    """Finds the latest epoch number for a given model prefix in a checkpoint directory."""
    latest_epoch = -1
    if not os.path.exists(checkpoint_dir):
        print(f"Checkpoint directory not found: {checkpoint_dir}")
        return latest_epoch # Directory doesn't exist

    print(f"Scanning for checkpoints in: {checkpoint_dir} with prefix: {model_prefix}")
    found_files = []
    for f in os.listdir(checkpoint_dir):
        if f.startswith(model_prefix) and f.endswith(".pth"):
            # Corrected regular expression string:
            match = re.search(r"_epoch_(\d+)\.pth$", f)
            if match:
                epoch_num = int(match.group(1))
                found_files.append({'file': f, 'epoch': epoch_num})
                if epoch_num > latest_epoch:
                    latest_epoch = epoch_num
    if not found_files:
        print("No matching checkpoint files found.")
    else:
        print(f"Found {len(found_files)} potential checkpoint files. Latest epoch determined: {latest_epoch}")
    return latest_epoch

if LOAD_FROM_CHECKPOINT:
    if gdrive_save_path and os.path.exists(gdrive_save_path):
        print(f"Attempting to load latest checkpoints from Google Drive: {gdrive_save_path}")
        checkpoint_source_dir = gdrive_save_path
    elif os.path.exists(LOCAL_CHECKPOINT_DIR_NAME): # Fallback to local if gdrive not available but local exists
        print(f"Google Drive path not available or does not exist. Attempting to load from local directory: {LOCAL_CHECKPOINT_DIR_NAME}")
        checkpoint_source_dir = LOCAL_CHECKPOINT_DIR_NAME
    else:
        print("Checkpoint loading enabled, but neither Google Drive path nor local checkpoint directory is available or specified.")
        checkpoint_source_dir = None

    if checkpoint_source_dir:
        latest_g_epoch = find_latest_checkpoint_epoch(checkpoint_source_dir, "generator_epoch_")
        latest_d_epoch = find_latest_checkpoint_epoch(checkpoint_source_dir, "discriminator_epoch_")

        # Determine the epoch to load from. Prefer a matched pair.
        # If epochs are different, decide on a strategy (e.g., use the minimum of the two, or generator's, or error out)
        # For this implementation, we'll try to load from the latest generator epoch and find a matching discriminator.
        # If no matching discriminator, we'll use the discriminator's own latest epoch.

        load_epoch_g = latest_g_epoch
        load_epoch_d = latest_d_epoch

        if load_epoch_g != -1: # Generator checkpoint(s) found
            gen_checkpoint_path = os.path.join(checkpoint_source_dir, f"generator_epoch_{load_epoch_g}.pth")
            try:
                print(f"Loading Generator from: {gen_checkpoint_path}")
                generator.load_state_dict(torch.load(gen_checkpoint_path, map_location=device))
                print(f"Successfully loaded Generator from epoch {load_epoch_g}.")

                # Try to load discriminator from the same epoch as generator for consistency
                disc_checkpoint_path_matched = os.path.join(checkpoint_source_dir, f"discriminator_epoch_{load_epoch_g}.pth")
                if os.path.exists(disc_checkpoint_path_matched):
                    print(f"Loading matching Discriminator from: {disc_checkpoint_path_matched}")
                    discriminator.load_state_dict(torch.load(disc_checkpoint_path_matched, map_location=device))
                    print(f"Successfully loaded Discriminator from epoch {load_epoch_g}.")
                    START_EPOCH = load_epoch_g + 1
                elif load_epoch_d != -1: # No matched discriminator, but discriminator checkpoints exist
                    disc_checkpoint_path_latest_d = os.path.join(checkpoint_source_dir, f"discriminator_epoch_{load_epoch_d}.pth")
                    print(f"No Discriminator checkpoint found for generator's epoch {load_epoch_g}. Trying latest D epoch: {load_epoch_d}")
                    print(f"Loading Discriminator from: {disc_checkpoint_path_latest_d}")
                    discriminator.load_state_dict(torch.load(disc_checkpoint_path_latest_d, map_location=device))
                    print(f"Successfully loaded Discriminator from its latest epoch {load_epoch_d}.")
                    # Determine start epoch carefully if G and D epochs are different
                    # Simplest is to start from max(load_epoch_g, load_epoch_d) + 1 or min(...) + 1
                    # Or just use the generator's epoch as the primary reference for resuming
                    START_EPOCH = load_epoch_g + 1
                    if load_epoch_g != load_epoch_d:
                        print(f"Warning: Generator loaded from epoch {load_epoch_g}, Discriminator from epoch {load_epoch_d}. Resuming from epoch {START_EPOCH}.")
                else: # No discriminator checkpoints found at all
                    print(f"No Discriminator checkpoints found in {checkpoint_source_dir}. Discriminator will be initialized from scratch.")
                    START_EPOCH = load_epoch_g + 1 # Still resume based on generator

                print(f"Resuming training from epoch {START_EPOCH}.")

            except Exception as e:
                print(f"Error loading Generator checkpoint: {e}. Training will start from epoch 0.")
                START_EPOCH = 0

        elif load_epoch_d != -1: # Only discriminator checkpoints found, no generator
            print("Found Discriminator checkpoints but no Generator checkpoints.")
            print("Cannot resume training without a Generator checkpoint. Training will start from epoch 0.")
            START_EPOCH = 0
        else:
            print(f"No suitable Generator or Discriminator checkpoints found in {checkpoint_source_dir}.")
            START_EPOCH = 0 # Ensure it's 0 if nothing is loaded

    else: # checkpoint_source_dir was None
        print("No checkpoint directory specified or found. Training will start from epoch 0.")
        START_EPOCH = 0
else:
    print("LOAD_FROM_CHECKPOINT is False. Training will start from epoch 0.")
    START_EPOCH = 0

print(f"Final START_EPOCH for training: {START_EPOCH}")

Attempting to load latest checkpoints from Google Drive: /content/drive/My Drive/gan_checkpoints_covid_ct
Scanning for checkpoints in: /content/drive/My Drive/gan_checkpoints_covid_ct with prefix: generator_epoch_
No matching checkpoint files found.
Scanning for checkpoints in: /content/drive/My Drive/gan_checkpoints_covid_ct with prefix: discriminator_epoch_
No matching checkpoint files found.
No suitable Generator or Discriminator checkpoints found in /content/drive/My Drive/gan_checkpoints_covid_ct.
Final START_EPOCH for training: 0


In [None]:
# ==============================================================================
#  TRAINING LOOP
# ==============================================================================
print("Starting Training Loop...")
save_interval = 10 # Save models and images every N epochs

# Check if the dataloader was created successfully
if main_dataloader is None:
    print("FATAL: Dataloader not found or empty. Aborting training.")
    # Exit or handle as appropriate, script will end if this is None due to earlier logic
else:
    for epoch in range(START_EPOCH, epochs):
        for i, (imgs, labels) in enumerate(main_dataloader):
            real_imgs = imgs.to(device)
            labels = labels.to(device)

            # Adversarial ground truths
            valid = torch.ones(imgs.size(0), 1, requires_grad=False).to(device)
            fake = torch.zeros(imgs.size(0), 1, requires_grad=False).to(device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            z = torch.randn(imgs.size(0), latent_dim).to(device)
            gen_labels = torch.randint(0, num_classes, (imgs.size(0),), device=device, dtype=torch.long)
            gen_imgs = generator(z, gen_labels)

            validity_real = discriminator(real_imgs, labels)
            d_real_loss = adversarial_loss(validity_real, valid)

            validity_fake = discriminator(gen_imgs.detach(), gen_labels)
            d_fake_loss = adversarial_loss(validity_fake, fake)

            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Generate a fresh batch of fake images for the generator's turn
            z_g = torch.randn(imgs.size(0), latent_dim).to(device)
            gen_labels_g = torch.randint(0, num_classes, (imgs.size(0),), device=device, dtype=torch.long)
            gen_imgs_g = generator(z_g, gen_labels_g)

            validity = discriminator(gen_imgs_g, gen_labels_g)
            g_loss = adversarial_loss(validity, valid)

            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0: # Print progress
                print(
                    f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(main_dataloader)}] "
                    f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                )

        # --- Save generated images and model checkpoints periodically ---
        if epoch % save_interval == 0 or epoch == epochs -1 : # Save on interval or last epoch
            print(f"--- Saving models and images for epoch {epoch} ---")

            # Generate and save sample images
            # Create a fixed noise and label set for consistent image generation across epochs
            fixed_z_samples = num_classes * 5
            fixed_z = torch.randn(fixed_z_samples, latent_dim).to(device)
            fixed_sample_labels = torch.LongTensor(np.array([num for _ in range(5) for num in range(num_classes)])).to(device)

            if len(fixed_sample_labels) > fixed_z_samples : # Should not happen with current logic
                fixed_sample_labels = fixed_sample_labels[:fixed_z_samples]

            gen_imgs_sample = generator(fixed_z, fixed_sample_labels).detach().cpu()

            fig, axs = plt.subplots(num_classes, 5, figsize=(10, num_classes * 2 if num_classes > 1 else 3))
            fig.suptitle(f"Generated Images at Epoch {epoch} ({SELECTED_DATASET_NAME})", fontsize=16)
            for r_idx in range(num_classes):
                for c_idx in range(5):
                    img_index = r_idx * 5 + c_idx
                    if img_index < len(gen_imgs_sample):
                        # Transpose from (C, H, W) to (H, W, C) for display
                        # Handle single channel (grayscale) or multi-channel
                        img_display = gen_imgs_sample[img_index].permute(1, 2, 0) * 0.5 + 0.5 # Un-normalize
                        current_ax = axs[r_idx, c_idx] if num_classes > 1 else axs[c_idx]

                        if img_display.shape[2] == 1: # Grayscale
                            current_ax.imshow(img_display.squeeze(), cmap='gray')
                        else: # RGB
                            current_ax.imshow(img_display)
                        current_ax.set_title(f"Class {r_idx}")
                        current_ax.axis('off')

            image_save_path = os.path.join(LOCAL_IMAGE_DIR, f"epoch_{epoch}.png")
            plt.savefig(image_save_path)
            plt.close(fig) # Close the plot
            print(f"Sample images saved to {image_save_path}")

            # Save model checkpoints (locally and to Google Drive if available)
            local_g_path = os.path.join(LOCAL_CHECKPOINT_DIR_NAME, f"generator_epoch_{epoch}.pth")
            local_d_path = os.path.join(LOCAL_CHECKPOINT_DIR_NAME, f"discriminator_epoch_{epoch}.pth")
            torch.save(generator.state_dict(), local_g_path)
            torch.save(discriminator.state_dict(), local_d_path)
            print(f"Successfully saved Generator locally to: {local_g_path}")
            print(f"Successfully saved Discriminator locally to: {local_d_path}")

            if gdrive_save_path:
                gdrive_g_path = os.path.join(gdrive_save_path, f"generator_epoch_{epoch}.pth")
                gdrive_d_path = os.path.join(gdrive_save_path, f"discriminator_epoch_{epoch}.pth")
                try:
                    torch.save(generator.state_dict(), gdrive_g_path)
                    torch.save(discriminator.state_dict(), gdrive_d_path)
                    print(f"Successfully saved Generator to Google Drive: {gdrive_g_path}")
                    print(f"Successfully saved Discriminator to Google Drive: {gdrive_d_path}")
                except Exception as e:
                    print(f"Error saving checkpoints to Google Drive: {e}")

    print("Training finished.")

    # Final save (redundant if last epoch is caught by save_interval, but good as a fallback)
    # This was in the original code, but typically the loop's `epoch == epochs -1` handles the final save.
    # If keeping, ensure `epoch` variable is correctly reflecting the last completed epoch.
    # For now, relying on the save within the loop for the last epoch.

Starting Training Loop...
[Epoch 0/100] [Batch 0/268] [D loss: 0.6987] [G loss: 0.7231]
[Epoch 0/100] [Batch 100/268] [D loss: 0.6050] [G loss: 0.8805]
[Epoch 0/100] [Batch 200/268] [D loss: 0.5998] [G loss: 1.0380]
--- Saving models and images for epoch 0 ---
Sample images saved to gan_images_covid_ct/epoch_0.png
Successfully saved Generator locally to: gan_checkpoints_covid_ct/generator_epoch_0.pth
Successfully saved Discriminator locally to: gan_checkpoints_covid_ct/discriminator_epoch_0.pth
Successfully saved Generator to Google Drive: /content/drive/My Drive/gan_checkpoints_covid_ct/generator_epoch_0.pth
Successfully saved Discriminator to Google Drive: /content/drive/My Drive/gan_checkpoints_covid_ct/discriminator_epoch_0.pth
[Epoch 1/100] [Batch 0/268] [D loss: 0.7358] [G loss: 1.3538]
[Epoch 1/100] [Batch 100/268] [D loss: 0.5973] [G loss: 0.9650]
[Epoch 1/100] [Batch 200/268] [D loss: 0.6987] [G loss: 1.3842]
[Epoch 2/100] [Batch 0/268] [D loss: 0.5804] [G loss: 1.0161]
[Epoch

## Repeat above steps for all datasets to create each expert


# 6. Sample Usage: Generating Images

Once the model is trained, you can use the generator to create new, synthetic medical images.

In [None]:
# ==============================================================================
#  LOAD PRE-TRAINED GENERATOR AND GENERATE IMAGES
# ==============================================================================

# --- Configuration (Ensure these match the training configuration of the loaded model) ---
# These would ideally be loaded from the same configuration used for training,
# or ensure ACTIVE_DATASET_CONFIG is still in scope from previous cells.

# If ACTIVE_DATASET_CONFIG is not in scope, you might need to redefine it or parts of it:
# Example:
# ACTIVE_DATASET_CONFIG = {
#     "num_classes": 3,  # Replace with the actual number of classes for your saved model
#     "img_channels": 1, # Replace with actual image channels
#     "img_size": 64,    # Replace with actual image size
#     # Add any other relevant keys if your Generator class uses them differently
# }
# latent_dim = 100 # Should be the same as during training

# Ensure model class definition is available (e.g., from a previous cell or redefine here)
# If not defined, you'd paste the Generator class definition here:
# class Generator(nn.Module):
#     def __init__(self, latent_dim, num_classes, img_shape_tuple):
#         super(Generator, self).__init__()
#         # ... (rest of the Generator class definition) ...
#     def forward(self, noise, labels):
#         # ... (rest of the forward pass) ...

# --- Helper function to find the latest checkpoint (if needed) ---
# This is the same function used during training checkpoint loading.
def find_latest_epoch_in_dir(checkpoint_dir, model_prefix="generator_epoch_"):
    latest_epoch = -1
    if not os.path.exists(checkpoint_dir):
        print(f"Checkpoint directory not found: {checkpoint_dir}")
        return latest_epoch
    for f in os.listdir(checkpoint_dir):
        if f.startswith(model_prefix) and f.endswith(".pth"):
            match = re.search(r"_epoch_(\d+)\.pth$", f) # Ensure r is before quote
            if match:
                epoch_num = int(match.group(1))
                if epoch_num > latest_epoch:
                    latest_epoch = epoch_num
    return latest_epoch

# --- Specify Checkpoint Path and Epoch to Load ---
# Option 1: Load a specific epoch
# LOAD_EPOCH = 90 # Example: if you want to load epoch 90
# CHECKPOINT_BASE_DIR = gdrive_save_path # Or LOCAL_CHECKPOINT_DIR_NAME
# generator_checkpoint_filename = f"generator_epoch_{LOAD_EPOCH}.pth"
# generator_checkpoint_path = os.path.join(CHECKPOINT_BASE_DIR, generator_checkpoint_filename)

# Option 2: Load the latest available generator epoch
print("Attempting to load the latest generator checkpoint...")
# Determine which directory to search: Google Drive first, then local.
checkpoint_load_dir = None
if 'gdrive_save_path' in globals() and gdrive_save_path and os.path.exists(gdrive_save_path):
    print(f"Searching in Google Drive path: {gdrive_save_path}")
    checkpoint_load_dir = gdrive_save_path
elif 'LOCAL_CHECKPOINT_DIR_NAME' in globals() and os.path.exists(LOCAL_CHECKPOINT_DIR_NAME):
    print(f"Google Drive path not found or specified. Searching in local path: {LOCAL_CHECKPOINT_DIR_NAME}")
    checkpoint_load_dir = LOCAL_CHECKPOINT_DIR_NAME
else:
    print("ERROR: Neither Google Drive checkpoint path nor local checkpoint directory is defined or found.")
    checkpoint_load_dir = None

loaded_generator = None
if checkpoint_load_dir:
    latest_epoch_to_load = find_latest_epoch_in_dir(checkpoint_load_dir, "generator_epoch_")

    if latest_epoch_to_load != -1:
        generator_checkpoint_filename = f"generator_epoch_{latest_epoch_to_load}.pth"
        generator_checkpoint_path = os.path.join(checkpoint_load_dir, generator_checkpoint_filename)
        print(f"Found latest generator checkpoint: {generator_checkpoint_path}")

        # --- Initialize Model ---
        # Ensure these parameters match the saved model's training configuration.
        # These should be accessible from previous cells (e.g., ACTIVE_DATASET_CONFIG)
        current_num_classes = ACTIVE_DATASET_CONFIG['num_classes']
        current_img_channels = ACTIVE_DATASET_CONFIG['img_channels']
        current_img_size = ACTIVE_DATASET_CONFIG['img_size']
        current_img_shape_tuple = (current_img_channels, current_img_size, current_img_size)
        # latent_dim should also be available from training setup

        # Instantiate the generator
        loaded_generator = Generator(latent_dim, current_num_classes, current_img_shape_tuple)

        # --- Load State Dictionary ---
        try:
            # Ensure `device` is defined (e.g., torch.device("cuda" if torch.cuda.is_available() else "cpu"))
            loaded_generator.load_state_dict(torch.load(generator_checkpoint_path, map_location=device))
            loaded_generator.to(device) # Move model to the device
            loaded_generator.eval() # Set the model to evaluation mode (important for layers like BatchNorm, Dropout)
            print(f"Generator successfully loaded from {generator_checkpoint_path} and set to evaluation mode.")
        except FileNotFoundError:
            print(f"ERROR: Checkpoint file not found at {generator_checkpoint_path}")
            loaded_generator = None
        except Exception as e:
            print(f"ERROR: Failed to load generator checkpoint: {e}")
            loaded_generator = None
    else:
        print(f"No generator checkpoints found in {checkpoint_load_dir}.")
else:
    print("No checkpoint directory available to load from.")


# --- Generate and Display Images (if generator was loaded) ---
if loaded_generator:
    num_images_to_generate_per_class = 5
    total_images_to_generate = current_num_classes * num_images_to_generate_per_class

    # Prepare noise and labels
    # Use a fixed seed for reproducible generation if desired for testing
    # torch.manual_seed(42)

    # Generate random noise
    z_generate = torch.randn(total_images_to_generate, latent_dim).to(device)

    # Generate labels: (0,0,0,0,0, 1,1,1,1,1, 2,2,2,2,2, ... for each class)
    gen_labels_list = []
    for i in range(current_num_classes):
        gen_labels_list.extend([i] * num_images_to_generate_per_class)

    generated_labels = torch.LongTensor(gen_labels_list).to(device)

    # Generate images
    with torch.no_grad(): # No need to track gradients during generation
        fake_images = loaded_generator(z_generate, generated_labels).detach().cpu()

    print(f"\nGenerated {fake_images.shape[0]} images.")

    # --- Display Generated Images ---
    fig_gen, axs_gen = plt.subplots(current_num_classes,
                                    num_images_to_generate_per_class,
                                    figsize=(num_images_to_generate_per_class * 2, current_num_classes * 2.2))
    fig_gen.suptitle(f"Images Generated by Loaded Model (Epoch {latest_epoch_to_load if 'latest_epoch_to_load' in locals() else 'Unknown'})", fontsize=16)

    for r_idx in range(current_num_classes):
        for c_idx in range(num_images_to_generate_per_class):
            img_index = r_idx * num_images_to_generate_per_class + c_idx
            if img_index < len(fake_images):
                # Un-normalize: images were normalized to [-1, 1], so (data * 0.5) + 0.5 brings them to [0, 1]
                img_display = fake_images[img_index].permute(1, 2, 0) * 0.5 + 0.5

                # Determine current axis, handling single row/column case for subplots
                if current_num_classes == 1 and num_images_to_generate_per_class == 1:
                    current_ax = axs_gen
                elif current_num_classes == 1:
                    current_ax = axs_gen[c_idx]
                elif num_images_to_generate_per_class == 1:
                    current_ax = axs_gen[r_idx]
                else:
                    current_ax = axs_gen[r_idx, c_idx]

                if current_img_channels == 1: # Grayscale
                    current_ax.imshow(img_display.squeeze(), cmap='gray')
                else: # RGB
                    current_ax.imshow(img_display.clip(0,1)) # Clip to ensure valid range for imshow

                current_ax.set_title(f"Class {r_idx}")
                current_ax.axis('off')
            else: # Should not happen if total_images_to_generate is correct
                if current_num_classes == 1 and num_images_to_generate_per_class == 1:
                    axs_gen.axis('off')
                elif current_num_classes == 1:
                    axs_gen[c_idx].axis('off')
                elif num_images_to_generate_per_class == 1:
                    axs_gen[r_idx].axis('off')
                else:
                    axs_gen[r_idx, c_idx].axis('off')


    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
    plt.show()
else:
    print("\nGenerator not loaded. Cannot generate images.")

# Next Steps and Building an Ensemble

This notebook provides a foundational cGAN. To build the full system you envision, you would:

Expand to More Modalities: Adapt the dataset and model to handle MRI, CT, and X-ray images, increasing the num_classes parameter accordingly.

Build an Ensemble:

Mixture of Experts: You could train separate cGANs, one for each modality, and then use a classification model to decide which generator to use.

Stacking: Train multiple different generative models (e.g., this cGAN, a VAE, and a diffusion model) and then use another neural network to combine their outputs.

Improve the Classifier: For identifying the input scan type, you would train a dedicated, high-performance classification model on a labeled dataset of medical images.

Incorporate Advanced Models: For higher-fidelity results, consider exploring more advanced generative models like diffusion models, which are the current state-of-the-art for image synthesis.

This notebook should give you a solid starting point for your project. Good luck

Sources
```
pyimagesearch.com
kaggle.com
ovhcloud.com
github.com
inria.fr
```

# Here is a guide on how you can stack models to create a Mixture of Experts (MoE) model, complete with a Python code example using PyTorch.

## Conceptual Overview of a Mixture of Experts (MoE) Model

A Mixture of Experts model is a powerful ensemble technique where you have:

Multiple "Expert" Models: Each expert is a neural network (or any other model) that is trained to become proficient in a specific subset of the problem space. For instance, in your medical imaging case, one expert might specialize in generating MRIs, another in CT scans, and a third in X-rays.

A "Gating Network": This is a separate model that acts as a traffic controller. It takes the same input as the experts and decides how much to trust each expert for that specific input. It outputs a set of "weights" or "probabilities" that sum to 1.

Final Output: The final output of the MoE model is a weighted sum of the outputs from all the expert models, with the weights determined by the gating network.

This architecture allows the model to learn that different experts are better at different tasks, leading to better performance and more efficient use of parameters.

Python Notebook: Stacking Models to Create an MoE

Here’s a practical, step-by-step guide to building an MoE model. We'll create a simple example where different experts learn to model different parts of a synthetic dataset. You can then adapt this architecture to your more complex generative models.

## 1. Setup and Imports

First, let's set up the environment and import the necessary libraries.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# for creating a synthetic dataset
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

# 2. Defining the "Expert" Models

Let's define a simple neural network that will serve as our expert. In a real-world application, these experts could be your complex generative models (like a GAN generator or a diffusion model). For this example, we'll use a simple feed-forward network.

In [None]:
class Expert(nn.Module):
    """A simple expert model."""
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 3. Defining the "Gating Network"

The gating network is responsible for deciding which expert to rely on for a given input. It will output a probability distribution over the experts. We use a Softmax activation to ensure the weights sum to 1.

In [None]:
class GatingNetwork(nn.Module):
    """A gating network that decides which expert to use."""
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_experts),
            nn.Softmax(dim=1)  # Output a probability distribution over experts
        )

    def forward(self, x):
        return self.net(x)

# 4. Stacking the Experts and Gating Network into an MoE Model

Now, we'll combine the experts and the gating network into a single MoE model. This class will manage the forward pass, routing the input to the experts and combining their outputs based on the gating network's decisions.

In [None]:
class MoEModel(nn.Module):
    """The main Mixture of Experts model."""
    def __init__(self, input_dim, output_dim, num_experts):
        super(MoEModel, self).__init__()
        self.num_experts = num_experts

        # Create the expert models
        self.experts = nn.ModuleList([Expert(input_dim, output_dim) for _ in range(num_experts)])

        # Create the gating network
        self.gating = GatingNetwork(input_dim, num_experts)

    def forward(self, x):
        # Get the weights from the gating network
        gating_weights = self.gating(x)  # Shape: (batch_size, num_experts)

        # Get the outputs from each expert
        expert_outputs = [expert(x) for expert in self.experts]
        expert_outputs_tensor = torch.stack(expert_outputs, dim=2) # Shape: (batch_size, output_dim, num_experts)

        # Weight the expert outputs
        # We need to reshape the gating weights to multiply them with the expert outputs
        gating_weights = gating_weights.unsqueeze(1) # Shape: (batch_size, 1, num_experts)

        # The final output is the weighted sum of the expert outputs
        weighted_expert_outputs = expert_outputs_tensor * gating_weights
        final_output = torch.sum(weighted_expert_outputs, dim=2)

        return final_output, gating_weights.squeeze(1)

# 5. Training and Usage

Let's test our MoE model on a synthetic dataset. We'll create a dataset with 3 distinct clusters, which is a good scenario for an MoE with 3 experts.

# --- 1. Create a Synthetic Dataset ---

In [None]:
# We'll create a dataset with 3 centers, ideal for 3 experts
X, y = make_blobs(n_samples=5000, centers=3, n_features=2, random_state=42, cluster_std=1.5)

# Convert to PyTorch tensors
X_tensor = torch.FloatTensor(X)
y_tensor = torch.FloatTensor(y).view(-1, 1) # Reshape for regression-like loss

# Split data
X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

# Create DataLoaders
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# --- 2. Initialize and Train the Model ---

In [None]:
# Hyperparameters
input_dim = 2
output_dim = 1
num_experts = 3
epochs = 50
learning_rate = 0.001

# Initialize the MoE model
moe_model = MoEModel(input_dim, output_dim, num_experts)
criterion = nn.MSELoss()
optimizer = optim.Adam(moe_model.parameters(), lr=learning_rate)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
moe_model.to(device)

# --- Training Loop ---
for epoch in range(epochs):
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        # Forward pass
        optimizer.zero_grad()
        output, _ = moe_model(X_batch)
        loss = criterion(output, y_batch)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# --- 3. Inspect the Gating Network's Decisions ---

In [None]:
moe_model.eval()
with torch.no_grad():
    # Take a few samples from the test set
    sample_data = X_test[:15].to(device)
    true_labels = y_test[:15]

    # Get predictions and gating weights
    predictions, gating_weights = moe_model(sample_data)

    # For each sample, find out which expert was most influential
    most_influential_expert = torch.argmax(gating_weights, dim=1)

    print("\n--- Gating Network Decisions ---")
    for i in range(len(sample_data)):
        print(f"Sample {i+1} (True Cluster: {int(true_labels[i].item())}): "
              f"Predicted Value: {predictions[i].item():.2f}, "
              f"Most Influential Expert: {most_influential_expert[i].item()}")
        print(f"   Gating Weights: {[f'{w:.2f}' for w in gating_weights[i].cpu().numpy()]}")

# --- 4. Visualize Gating Decisions (Optional) ---

In [None]:
[link text](https://)plt.figure(figsize=(12, 5))
plt.suptitle("Gating Network Specialization", fontsize=16)

# Get gating weights for the entire test set
_, gating_weights_all = moe_model(X_test.to(device))
most_influential_expert_all = torch.argmax(gating_weights_all, dim=1).cpu().numpy()

# Plot for each expert
for i in range(num_experts):
    plt.subplot(1, num_experts, i + 1)
    plt.title(f'Expert {i} Specialization')
    plt.scatter(X_test[:, 0], X_test[:, 1], c=(most_influential_expert_all == i), cmap='viridis', alpha=0.5)
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

How to Adapt This to Your Medical Image Synthesis Task

Replace the Experts: Your Expert class would be replaced by your generative models (e.g., a Generator from a GAN, or a DiffusionModel). The forward method of your expert would take a conditional input (like a noise vector and a class label) and output a generated image.

Adapt the Gating Network: The GatingNetwork would take as input the condition you want to generate (e.g., a one-hot encoded vector representing "MRI", "CT", or "X-ray"). It would then output the weights for combining the outputs of your expert generators.

Combine the Outputs: The most complex part is combining the outputs of generative models. Instead of a simple weighted sum of pixel values, you might want to perform this weighting in the latent space, or use a "meta-learner" (as in a stacking ensemble) to learn the best way to combine the generated images into a final, high-fidelity output.

This MoE structure provides a very flexible and powerful way to build more sophisticated and specialized AI systems.