### Training the GAN

In [None]:
import os
import random
import logging
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from torch.nn.utils import spectral_norm

# --- Configuration ---

# Configure logging to display information messages during training
# Format: Timestamp Level: Message
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')


# --- Reproducibility ---

def set_seed(seed):
    """
    Sets the random seed for Python's random module, NumPy, and PyTorch
    to ensure reproducibility of results across runs.
    Args:
        seed (int): The seed value to use.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # If CUDA (GPU support) is available, set the seed for all GPUs as well
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    logging.info(f"Set random seed to {seed}")


# Set the seed for the entire script
set_seed(42)


# --- Dataset Definition ---

class NPYDataset(Dataset):
    """
    Custom PyTorch Dataset for loading 3D data stored in NPY files.
    Assumes the data contains discrete values {0, 1}.
    The dataset handles loading, transposing, and providing individual samples.
    """

    def __init__(self, data_path: str):
        """
        Initializes the dataset.
        Args:
            data_path (str): The path to the .npy file containing the data.
                             Expected original shape: [Depth, Height, Width, N]
                             Expected values: {0, 1}
        Raises:
            FileNotFoundError: If the data file does not exist.
            Exception: For other errors during data loading or processing.
        """
        try:
            # Load the data from the specified .npy file
            # Assumes the raw data is stored as [Depth, Height, Width, N]
            # For this specific dataset 'twocat.npy', the shape is (16, 128, 128, 146)
            logging.info(f"Loading data from {data_path}...")
            data = np.load(data_path)
            logging.info(f"Original data shape: {data.shape}") # Expected: (16, 128, 128, 146)
            # Check data values (optional but good for verification)
            unique_values = np.unique(data)
            logging.info(f"Unique values in data: {unique_values}")
            if not np.all(np.isin(unique_values, [0, 1])):
                 logging.warning(f"Data contains values other than {{0, 1}}: {unique_values}. Ensure this is expected.")

        except FileNotFoundError:
            logging.error(f"Data file not found at {data_path}")
            raise FileNotFoundError(f"Data file not found at {data_path}")
        except Exception as e:
            logging.error(f"Error loading data: {e}")
            raise Exception(f"Error loading data: {e}")

        # Transpose data to the PyTorch standard format [N, Depth, Height, Width]
        # (3, 0, 1, 2) maps the original axes (D, H, W, N) to (N, D, H, W)
        self.data = np.transpose(data, (3, 0, 1, 2))
        logging.info(f"Transposed data shape (N, D, H, W): {self.data.shape}") # Expected: (146, 16, 128, 128)

        # IMPORTANT: Data values remain {0, 1}. They will be converted to float
        # tensors in __getitem__ and potentially scaled/shifted later if needed
        # by the model or loss function (though WGAN-GP doesn't strictly require [-1, 1]).
        # The Generator uses Sigmoid, outputting [0, 1], which matches the {0, 1} target well.

    def __len__(self):
        """Returns the number of samples (N) in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves a single sample from the dataset at the given index.
        Args:
            idx (int): The index of the sample to retrieve.
        Returns:
            torch.Tensor: A tensor representing the sample, with shape [Depth, Height, Width].
                          The tensor will have dtype float32.
        """
        # Get the data slice for the given index and convert it to a PyTorch tensor.
        # Convert to float because neural networks typically work with floating-point numbers.
        sample = torch.from_numpy(self.data[idx]).float()
        # Shape returned: [16, 128, 128]
        return sample


# --- Generator Network ---

class Generator(nn.Module):
    """
    Generator network (G) for the GAN.
    Takes a random latent vector as input and generates a 3D data sample
    with the target shape [1, Depth, Height, Width], specifically [1, 16, 128, 128].
    Uses ConvTranspose3d layers for upsampling.
    Outputs values in the range [0, 1] using Sigmoid activation, suitable for {0, 1} data.
    """

    def __init__(self, latent_dim=100):
        """
        Initializes the Generator layers.
        Args:
            latent_dim (int): The dimensionality of the input random noise vector.
        """
        super().__init__()
        self.latent_dim = latent_dim # Store latent dimension size

        # Define the sequential network structure
        self.net = nn.Sequential(
            # 1. Fully Connected Layer: Project latent vector to a larger size suitable for reshaping
            # Input: [batch_size, latent_dim] (e.g., [B, 100])
            # Output: [batch_size, 512 * 2 * 8 * 8] (e.g., [B, 65536])
            nn.Linear(latent_dim, 512 * 2 * 8 * 8),
            nn.LeakyReLU(0.2, inplace=True), # Apply activation

            # 2. Unflatten/Reshape: Convert the 1D vector into a 4D tensor (Channels, Depth, Height, Width)
            # Input: [batch_size, 512 * 2 * 8 * 8]
            # Output: [batch_size, 512, 2, 8, 8]
            nn.Unflatten(1, (512, 2, 8, 8)),

            # 3. First Upsampling Block (ConvTranspose3d, BatchNorm, LeakyReLU)
            # Doubles Depth, Height, and Width dimensions (approximately, due to kernel/stride/padding)
            # Input: [B, 512, 2, 8, 8]
            # Output: [B, 256, 4, 16, 16] (D: (2-1)*2+4-2*1 = 4, H/W: (8-1)*2+4-2*1=14+4-2=16)
            nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(256), # Stabilize training
            nn.LeakyReLU(0.2, inplace=True),

            # 4. Second Upsampling Block
            # Input: [B, 256, 4, 16, 16]
            # Output: [B, 128, 8, 32, 32] (D: (4-1)*2+4-2*1 = 8, H/W: (16-1)*2+4-2*1=30+4-2=32)
            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 5. Third Upsampling Block
            # Input: [B, 128, 8, 32, 32]
            # Output: [B, 64, 16, 64, 64] (D: (8-1)*2+4-2*1 = 16, H/W: (32-1)*2+4-2*1=62+4-2=64)
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # 6. Final Upsampling Layer (Output Layer)
            # Custom kernel/stride/padding to reach the exact target dimensions [1, 16, 128, 128]
            # Stride (1, 2, 2) only upsamples H and W.
            # Input: [B, 64, 16, 64, 64]
            # Output: [B, 1, 16, 128, 128]
            # D: (16-1)*1 + 1*(3-1) + 1 - 2*1 = 15 + 2 + 1 - 2 = 16
            # H/W: (64-1)*2 + 1*(4-1) + 1 - 2*1 = 63*2 + 3 + 1 - 2 = 126 + 3 + 1 - 2 = 128
            nn.ConvTranspose3d(64, 1, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=1),

            # 7. Sigmoid Activation: Squash output values to the range [0, 1]
            # This is suitable for generating data that should mimic the input {0, 1} values.
            # The output can then be rounded to get discrete {0, 1} values.
            nn.Sigmoid()
        )

    def forward(self, z):
        """
        Performs the forward pass of the Generator.
        Args:
            z (torch.Tensor): Input latent noise tensor, shape [batch_size, latent_dim].
        Returns:
            torch.Tensor: Generated 3D sample, shape [batch_size, 1, 16, 128, 128].
        """
        return self.net(z)


# --- Discriminator Network ---

class Discriminator(nn.Module):
    """
    Discriminator network (D) or Critic for the GAN.
    Takes a 3D data sample (real or generated) as input and outputs a single scalar value (critic score).
    Uses Conv3d layers for downsampling.
    Applies Spectral Normalization to convolutional and linear layers to stabilize training (helps enforce Lipschitz constraint).
    Does NOT use a final activation function (like Sigmoid), typical for WGAN critics.
    """

    def __init__(self):
        """Initializes the Discriminator layers."""
        super().__init__()

        # Define the sequential network structure
        self.net = nn.Sequential(
            # 1. First Convolutional Block (SpectralNorm Conv3d, LeakyReLU)
            # Input: [B, 1, 16, 128, 128] (Channel, Depth, Height, Width)
            # Output: [B, 64, 8, 64, 64] (Halves D, H, W)
            # spectral_norm helps stabilize training by controlling the Lipschitz constant
            spectral_norm(nn.Conv3d(1, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            # 2. Second Convolutional Block
            # Input: [B, 64, 8, 64, 64]
            # Output: [B, 128, 4, 32, 32] (Halves D, H, W)
            spectral_norm(nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            # 3. Third Convolutional Block
            # Input: [B, 128, 4, 32, 32]
            # Output: [B, 256, 2, 16, 16] (Halves D, H, W)
            spectral_norm(nn.Conv3d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            # 4. Fourth Convolutional Block
            # Input: [B, 256, 2, 16, 16]
            # Output: [B, 512, 1, 8, 8] (Halves D, H, W)
            spectral_norm(nn.Conv3d(256, 512, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            # 5. Flatten: Convert the 4D feature map into a 1D vector
            # Input: [B, 512, 1, 8, 8]
            # Output: [B, 512 * 1 * 8 * 8] = [B, 32768]
            nn.Flatten(),

            # 6. Final Linear Layer: Output a single scalar value (the critic score)
            # Input: [B, 32768]
            # Output: [B, 1]
            # No activation function (linear output) as required by WGAN.
            spectral_norm(nn.Linear(512 * 1 * 8 * 8, 1))
        )

    def forward(self, x):
        """
        Performs the forward pass of the Discriminator.
        Args:
            x (torch.Tensor): Input 3D sample, shape [batch_size, 1, 16, 128, 128].
        Returns:
            torch.Tensor: Critic score, shape [batch_size, 1].
        """
        return self.net(x)


# --- WGAN-GP Trainer ---

class WGANTrainer:
    """
    Trainer class for Wasserstein GAN with Gradient Penalty (WGAN-GP).
    Manages the training process, including optimizer steps, loss calculations,
    and gradient penalty computation.
    """

    def __init__(self, generator, discriminator, device, lambda_gp=10, n_critic=5, latent_dim=100):
        """
        Initializes the trainer.
        Args:
            generator (nn.Module): The Generator network instance.
            discriminator (nn.Module): The Discriminator network instance.
            device (torch.device): The device to run training on (CPU or CUDA).
            lambda_gp (int): The weight coefficient for the gradient penalty term in the loss.
            n_critic (int): The number of times to train the Discriminator for each Generator update.
            latent_dim (int): The dimensionality of the latent space.
        """
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.device = device
        self.lambda_gp = lambda_gp # Gradient penalty coefficient
        self.n_critic = n_critic   # Number of discriminator updates per generator update
        self.latent_dim = latent_dim # Dimension of noise vector

        # Optimizers: Adam is commonly used for GANs. Betas (0.5, 0.9) are often recommended.
        self.g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
        self.d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))
        logging.info(f"WGANTrainer initialized with lambda_gp={lambda_gp}, n_critic={n_critic}")

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """
        Calculates the gradient penalty loss for WGAN-GP.
        The penalty encourages the norm of the critic's gradient with respect to its input
        to be close to 1 for points interpolated between real and fake samples.
        This helps enforce the 1-Lipschitz constraint required by Wasserstein distance.

        Args:
            real_samples (torch.Tensor): Batch of real data samples.
            fake_samples (torch.Tensor): Batch of generated (fake) samples.

        Returns:
            torch.Tensor: The computed gradient penalty (scalar tensor).
        """
        batch_size = real_samples.size(0)
        # Generate random weights (epsilon) for interpolation, shape [batch_size, 1, 1, 1, 1]
        # The extra dimensions match the sample dimensions [N, C, D, H, W] for broadcasting.
        epsilon = torch.rand(batch_size, 1, 1, 1, 1, device=self.device)
        epsilon = epsilon.expand_as(real_samples) # Expand to match sample shape

        # Create interpolated samples between real and fake data
        interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)

        # Get the critic scores for these interpolated samples
        d_interpolates = self.discriminator(interpolates)

        # Create a tensor of ones as the target for gradient computation
        # We want the gradient of the critic's output w.r.t. its input.
        fake_output_targets = torch.ones(batch_size, 1, device=self.device, requires_grad=False)

        # Calculate gradients of d_interpolates with respect to the interpolates themselves
        gradients = torch.autograd.grad(
            outputs=d_interpolates,         # Scalar output for which we need gradients
            inputs=interpolates,            # Input tensor w.r.t. which gradients are computed
            grad_outputs=fake_output_targets, # Gradient of the loss w.r.t. outputs (d_interpolates), here just 1s
            create_graph=True,              # Create graph for potential higher-order derivatives (needed for GP)
            retain_graph=True,              # Retain graph as gradients are needed for both D and G updates potentially
            only_inputs=True                # Only compute gradients w.r.t. specified inputs (interpolates)
        )[0] # Get the first element (gradients w.r.t. interpolates)

        # Reshape gradients from [B, C, D, H, W] to [B, -1] to easily compute the norm per sample
        gradients = gradients.view(batch_size, -1)

        # Calculate the gradient penalty: (||\nabla D(interpolates)||_2 - 1)^2
        # The L2 norm is calculated per sample (dim=1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() # Average over the batch
        return gradient_penalty

    def train_step(self, real_samples):
        """
        Performs a single training step, updating both the Discriminator and Generator.

        Args:
            real_samples (torch.Tensor): A batch of real data samples from the dataset,
                                         already on the correct device and with channel dim added.
                                         Shape: [batch_size, 1, Depth, Height, Width]

        Returns:
            tuple[float, float]: A tuple containing the Discriminator loss and Generator loss for this step.
        """
        batch_size = real_samples.size(0)

        # --- Train Discriminator (Critic) ---
        # Train the critic n_critic times for each generator update
        d_loss_total = 0.0
        for _ in range(self.n_critic):
            self.d_optimizer.zero_grad() # Clear previous gradients

            # 1. Generate fake samples using the *current* generator
            # Create random noise vector z
            z = torch.randn(batch_size, self.latent_dim, device=self.device)
            # Generate fake samples (output shape: [B, 1, D, H, W])
            # Use .detach() to prevent gradients from flowing back to the generator during discriminator training
            fake_samples = self.generator(z).detach()

            # 2. Compute critic scores for real and fake samples
            real_validity = self.discriminator(real_samples) # Score for real samples
            fake_validity = self.discriminator(fake_samples) # Score for fake samples

            # 3. Compute Gradient Penalty
            gradient_penalty = self.compute_gradient_penalty(real_samples.data, fake_samples.data)

            # 4. Compute Discriminator Loss (WGAN-GP loss)
            # Loss = E[D(fake)] - E[D(real)] + lambda * GradientPenalty
            # We want to maximize (D(real) - D(fake)), which is equivalent to minimizing (D(fake) - D(real))
            d_loss = fake_validity.mean() - real_validity.mean() + self.lambda_gp * gradient_penalty
            d_loss_total += d_loss.item() # Accumulate loss value for logging

            # 5. Backward pass and optimizer step
            d_loss.backward()           # Compute gradients
            self.d_optimizer.step()     # Update discriminator weights

        # Average D loss over n_critic steps for reporting
        d_loss_avg = d_loss_total / self.n_critic

        # --- Train Generator ---
        self.g_optimizer.zero_grad() # Clear previous generator gradients

        # 1. Generate a new batch of fake samples (NO detach this time)
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_samples = self.generator(z)

        # 2. Compute critic scores for the new fake samples
        fake_validity = self.discriminator(fake_samples)

        # 3. Compute Generator Loss
        # Loss = -E[D(fake)]
        # The generator wants to maximize the discriminator's score for fake samples (fool the discriminator),
        # which is equivalent to minimizing the negative score.
        g_loss = -fake_validity.mean()

        # 4. Backward pass and optimizer step
        g_loss.backward()         # Compute gradients for the generator
        self.g_optimizer.step()   # Update generator weights

        # Return the scalar loss values for logging/monitoring
        return d_loss_avg, g_loss.item()

    def generate_samples(self, num_samples: int) -> torch.Tensor:
        """
        Generates samples using the trained generator and maps them to discrete {0, 1} values.

        Args:
            num_samples (int): The number of samples to generate.

        Returns:
            torch.Tensor: A tensor containing the generated discrete samples,
                          shape [num_samples, 1, Depth, Height, Width], values {0., 1.}.
        """
        self.generator.eval() # Set generator to evaluation mode (disables dropout/batchnorm updates)
        samples = None # Initialize samples variable
        with torch.no_grad(): # Disable gradient calculation for inference
            # Create random noise vectors
            z = torch.randn(num_samples, self.latent_dim, device=self.device)
            # Generate samples (output range is [0, 1] due to Sigmoid)
            samples = self.generator(z)

            # Map continuous [0, 1] output to discrete {0, 1} values
            # Since the generator outputs values between 0 and 1 (due to Sigmoid),
            # we can simply round the output to the nearest integer (0 or 1).
            samples = samples.round() # Rounds to 0. or 1.

            # Optional: Clamp values just in case of any numerical instability, though Sigmoid should guarantee [0, 1]
            # samples = samples.clamp(0, 1) # Ensure values are strictly within [0, 1] before rounding (if not using round directly)


        self.generator.train() # Set generator back to training mode
        # Return samples on the device they were generated on (likely GPU)
        # Shape: [num_samples, 1, 16, 128, 128]
        return samples


# --- Visualization ---

class VisualizationManager:
    """
    Manages saving visualization of generated samples during training.
    Creates a timestamped directory for each run.
    """

    def __init__(self, save_dir: str = "results"):
        """
        Initializes the manager and creates the save directory.
        Args:
            save_dir (str): The base directory where results will be saved.
                            A timestamped subdirectory will be created inside this.
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Generate timestamp string
        self.save_dir = Path(save_dir) / timestamp # Create full path for the run
        self.save_dir.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist
        logging.info(f"Saving results to {self.save_dir}")

    def save_samples(self, samples: torch.Tensor, epoch: int, batch_idx: int):
        """
        Saves a grid of slices from the generated 3D samples as a PNG image.

        Args:
            samples (torch.Tensor): Generated samples tensor (usually on GPU), values should be {0., 1.}.
                                    Shape: [num_samples, 1, Depth, Height, Width]
            epoch (int): Current epoch number.
            batch_idx (int): Current batch index within the epoch.
        """
        # Move samples to CPU and convert to NumPy array for plotting
        samples = samples.cpu().numpy()
        # Remove the channel dimension (C=1) for easier indexing
        # Shape becomes: [num_samples, Depth, Height, Width]
        samples = np.squeeze(samples, axis=1)

        num_samples_to_plot = min(4, samples.shape[0]) # Plot at most 4 samples
        num_slices_per_sample = 4 # Show 4 slices spaced evenly through the depth

        # Create subplots: num_samples rows, num_slices columns
        fig, axes = plt.subplots(num_samples_to_plot, num_slices_per_sample,
                                 figsize=(num_slices_per_sample * 3, num_samples_to_plot * 3)) # Adjust figsize as needed

        # Handle the case where only one sample is plotted (axes might not be a 2D array)
        if num_samples_to_plot == 1:
            axes = axes[np.newaxis, :]

        # Determine indices for the slices to show (evenly spaced)
        depth_indices = np.linspace(0, samples.shape[1] - 1, num_slices_per_sample, dtype=int)

        # Iterate through the samples and slices to plot
        for i in range(num_samples_to_plot):
            for j, depth_idx in enumerate(depth_indices):
                # Get the specific 2D slice [Height, Width]
                slice_2d = samples[i, depth_idx]
                # Display the slice using imshow
                # cmap='viridis' is a common colormap, 'gray' might also be suitable for binary data
                # vmin=0, vmax=1 ensures the color scale is fixed for {0, 1} data
                im = axes[i, j].imshow(slice_2d, cmap='viridis', vmin=0, vmax=1)
                axes[i, j].set_title(f"Sample {i+1}, Depth {depth_idx}")
                axes[i, j].axis('off') # Hide axes ticks and labels

        # Add a single colorbar for the entire figure
        fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7) # Adjust shrink as needed
        fig.suptitle(f'Generated Samples - Epoch {epoch}, Batch {batch_idx}') # Overall title
        fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

        # Save the figure to the run's directory
        save_path = self.save_dir / f'samples_e{epoch}_b{batch_idx}.png'
        plt.savefig(save_path)
        plt.close(fig) # Close the figure to free memory
        # logging.debug(f"Saved sample visualization to {save_path}") # Use debug level for less verbose logging


# --- Checkpoint Loading ---

def load_checkpoint(checkpoint_path, generator, discriminator, trainer):
    """
    Loads model weights, optimizer states, and epoch number from a checkpoint file.

    Args:
        checkpoint_path (str or Path): Path to the checkpoint file (.pt).
        generator (nn.Module): The Generator model instance.
        discriminator (nn.Module): The Discriminator model instance.
        trainer (WGANTrainer): The WGANTrainer instance containing optimizers.

    Returns:
        int: The epoch number to start training from (last saved epoch + 1).
             Returns 0 if the checkpoint is not found or fails to load.
    """
    start_epoch = 0
    checkpoint_path = Path(checkpoint_path) # Ensure it's a Path object
    if checkpoint_path.is_file():
        try:
            logging.info(f"Loading checkpoint from {checkpoint_path}...")
            # Load checkpoint onto the device specified in the trainer
            checkpoint = torch.load(checkpoint_path, map_location=trainer.device)

            # Load state dictionaries
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            trainer.g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
            trainer.d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])

            # Load the epoch number (add 1 to start from the next epoch)
            # Use .get with a default value for backward compatibility if 'epoch' key is missing
            start_epoch = checkpoint.get('epoch', -1) + 1
            logging.info(f"Successfully loaded checkpoint. Resuming from epoch {start_epoch}")

        except FileNotFoundError: # This condition is checked by is_file(), but added for completeness
            logging.warning(f"Checkpoint file not found at {checkpoint_path}. Starting training from scratch.")
            start_epoch = 0
        except KeyError as e:
            logging.error(f"Checkpoint file is missing key: {e}. Could not load checkpoint properly. Starting from scratch.")
            start_epoch = 0
        except Exception as e:
            logging.error(f"Error loading checkpoint: {e}. Starting training from scratch.")
            start_epoch = 0
    else:
        logging.warning(f"Checkpoint path '{checkpoint_path}' not found. Starting training from scratch.")
        start_epoch = 0

    return start_epoch


# --- Main Training Function ---

def train(data_path: str, num_epochs: int, batch_size: int = 96, checkpoint_path=None, save_interval: int = 50, log_interval: int = 100):
    """
    Main function to set up and run the WGAN-GP training loop.

    Args:
        data_path (str): Path to the NPY data file.
        num_epochs (int): Total number of epochs to train for.
        batch_size (int): Number of samples per batch.
        checkpoint_path (str, optional): Path to a checkpoint file to resume training from. Defaults to None.
        save_interval (int): Save a checkpoint every `save_interval` epochs.
        log_interval (int): Log training progress and save sample visualization every `log_interval` batches.
    """
    # 1. Setup Device (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    # 2. Load Dataset and Create DataLoader
    try:
        dataset = NPYDataset(data_path)
    except Exception as e:
        logging.error(f"Failed to initialize dataset: {e}")
        return # Exit if dataset cannot be loaded

    # Determine number of workers based on OS (common practice)
    num_workers = 4 if os.name == 'posix' else 0 # Use workers on Linux/Mac, 0 on Windows usually safer
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,       # Shuffle data each epoch for better training
        num_workers=num_workers, # Use multiple processes to load data if available
        pin_memory=True     # Speeds up data transfer to GPU if using CUDA
    )
    logging.info(f"DataLoader created with batch size {batch_size}, num workers {num_workers}")

    # 3. Initialize Models, Trainer, and Visualization Manager
    generator = Generator()
    discriminator = Discriminator()
    trainer = WGANTrainer(generator, discriminator, device) # n_critic=5 by default
    vis_manager = VisualizationManager() # Saves results in './results/<timestamp>/'

    # 4. Load Checkpoint if specified
    start_epoch = 0
    if checkpoint_path:
        start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, trainer)

    # 5. Training Loop
    logging.info(f"Starting training from epoch {start_epoch} for {num_epochs} epochs...")
    d_losses = [] # List to store discriminator losses per step
    g_losses = [] # List to store generator losses per step
    best_g_loss = float('inf') # Track the best generator loss observed so far for saving the 'best' model

    for epoch in range(start_epoch, start_epoch + num_epochs):
        # Set models to training mode at the start of each epoch
        generator.train()
        discriminator.train()

        for batch_idx, real_samples in enumerate(dataloader):
            # Prepare real samples:
            # - Add channel dimension: [B, D, H, W] -> [B, 1, D, H, W]
            # - Ensure float type (already done in dataset, but good practice)
            # - Move to the training device
            real_samples = real_samples.unsqueeze(1).float().to(device)

            # Perform one training step (updates D and G)
            d_loss, g_loss = trainer.train_step(real_samples)

            # Record losses
            d_losses.append(d_loss)
            g_losses.append(g_loss)

            # Logging and Visualization (periodically)
            if batch_idx % log_interval == 0:
                logging.info(
                    f"Epoch [{epoch}/{start_epoch + num_epochs - 1}] "
                    f"Batch [{batch_idx}/{len(dataloader)}] "
                    f"D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}"
                )
                # Generate and save sample visualizations
                # Generate a small number of samples (e.g., 4)
                samples_to_visualize = trainer.generate_samples(4)
                vis_manager.save_samples(samples_to_visualize, epoch, batch_idx)

                # Save the 'best' model based on Generator loss (heuristic)
                # Lower G loss might indicate better generation quality, but should be verified visually.
                if g_loss < best_g_loss:
                    best_g_loss = g_loss
                    best_model_path = vis_manager.save_dir / 'best_model.pt'
                    torch.save({
                        'epoch': epoch,
                        'generator_state_dict': trainer.generator.state_dict(),
                        'discriminator_state_dict': trainer.discriminator.state_dict(),
                        'g_optimizer_state_dict': trainer.g_optimizer.state_dict(),
                        'd_optimizer_state_dict': trainer.d_optimizer.state_dict(),
                        'g_loss': g_loss, # Store the loss value for reference
                        'd_loss': d_loss,
                    }, best_model_path)
                    logging.info(f"Saved new best model (G Loss: {g_loss:.4f}) at Epoch {epoch}, Batch {batch_idx} to {best_model_path}")

        # Save Checkpoint (periodically at the end of an epoch)
        if (epoch + 1) % save_interval == 0 or epoch == start_epoch + num_epochs - 1: # Save every save_interval epochs or on the last epoch
            checkpoint_save_path = vis_manager.save_dir / f'checkpoint_epoch_{epoch}.pt'
            torch.save({
                'epoch': epoch,
                'generator_state_dict': trainer.generator.state_dict(),
                'discriminator_state_dict': trainer.discriminator.state_dict(),
                'g_optimizer_state_dict': trainer.g_optimizer.state_dict(),
                'd_optimizer_state_dict': trainer.d_optimizer.state_dict(),
            }, checkpoint_save_path)
            logging.info(f"Saved checkpoint at {checkpoint_save_path}")

    # 6. Post-Training: Plot and save loss curves
    logging.info("Training finished. Plotting loss curves...")
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss (Avg per G step)')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Training Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('WGAN-GP Training Losses')
    loss_curve_path = vis_manager.save_dir / 'loss_curve.png'
    plt.savefig(loss_curve_path)
    plt.close()
    logging.info(f"Loss curve saved to {loss_curve_path}")
    logging.info(f"All results saved in {vis_manager.save_dir}")


# --- Script Execution ---

if __name__ == "__main__":
    # --- Parameters ---
    DATA_FILE = "twocat.npy"       # Name of your input data file (must be in the same directory or provide full path)
    NUM_EPOCHS = 2000             # Total number of epochs for training
    BATCH_SIZE = 64              # Batch size (adjust based on GPU memory) - 96 might be too large for some GPUs
    # CHECKPOINT_TO_LOAD = None     # Set to None to train from scratch
    CHECKPOINT_TO_LOAD = "results/20231115_103000/checkpoint_epoch_1999.pt" # Example: Path to a checkpoint file to resume from
                                                                             # Replace with your actual checkpoint path if needed
    SAVE_EVERY_N_EPOCHS = 100     # How often to save a checkpoint
    LOG_EVERY_N_BATCHES = 100     # How often to log loss and save sample images

    # --- Start Training ---
    # Check if data file exists before starting
    if not Path(DATA_FILE).is_file():
        logging.error(f"Data file '{DATA_FILE}' not found. Please ensure it is in the correct directory.")
    else:
        train(
            data_path=DATA_FILE,
            num_epochs=NUM_EPOCHS,
            batch_size=BATCH_SIZE,
            checkpoint_path=CHECKPOINT_TO_LOAD,
            save_interval=SAVE_EVERY_N_EPOCHS,
            log_interval=LOG_EVERY_N_BATCHES
        )

    # --- Example: How to load the best model and generate samples later ---
    # best_model_path = "results/<your_timestamp_folder>/best_model.pt" # Replace with actual path
    # if Path(best_model_path).is_file():
    #     logging.info(f"Loading best model from {best_model_path} for generation...")
    #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #     generator = Generator().to(device)
    #     checkpoint = torch.load(best_model_path, map_location=device)
    #     generator.load_state_dict(checkpoint['generator_state_dict'])
    #
    #     # Need a WGANTrainer instance just to use the generate_samples method conveniently
    #     # Or create a standalone generation function
    #     dummy_discriminator = Discriminator().to(device) # Not actually used for generation
    #     trainer_for_generation = WGANTrainer(generator, dummy_discriminator, device)
    #
    #     num_generated = 8
    #     generated_samples = trainer_for_generation.generate_samples(num_generated)
    #     logging.info(f"Generated {num_generated} samples with shape: {generated_samples.shape}")
    #     # You can now save or further process 'generated_samples'
    #     # Example: Save as numpy array
    #     # np.save("generated_output.npy", generated_samples.cpu().numpy())
    # else:
    #      logging.info("Best model checkpoint not found, skipping generation example.")