In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wan = user_secrets.get_secret("wandb_api")
wandb.login(key=wan)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import nibabel as nib
import numpy as np
# from torchvision import transforms # Not explicitly needed for this setup
import wandb
import time
import os
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch.nn.functional as F # For interpolation
import traceback # For detailed error printing

# --- Configuration ---
BATCH_SIZE = 16           # Adjust based on GPU memory
# Set high, training will likely stop due to time limit before reaching this.
NUM_EPOCHS = 2500
LATENT_DIM = 100          # Dimension of the noise vector
TARGET_SIZE = (64, 64, 64)# Target size for resizing images (MUST match GAN architecture)
SAVE_INTERVAL_HOURS = 8 # Save every 8 hours
DATASET_PATH = '/kaggle/input/aligned-train/processed_images' # <<< YOUR DATASET PATH HERE
WANDB_PROJECT_NAME = "Simple-GAN-MRI-3D" # Project name for W&B
OUTPUT_CHANNELS = 1       # Usually 1 for grayscale MRI/CT
LEARNING_RATE = 0.0002
BETA1 = 0.5
SAVE_DIR = 'saved_models_gan' # Directory to save models

# Define the Generator for 3D volumes (Matches TARGET_SIZE 64x64x64)
class Generator(nn.Module):
    def __init__(self, latent_dim, output_channels):
        super(Generator, self).__init__()
        ngf = 64 # Number of generator features in the last conv layer
        self.main = nn.Sequential(
            # Input is Z (latent_dim x 1 x 1 x 1), going into a convolution
            nn.ConvTranspose3d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm3d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4 x 4
            nn.ConvTranspose3d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8 x 8
            nn.ConvTranspose3d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16 x 16
            nn.ConvTranspose3d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32 x 32
            nn.ConvTranspose3d( ngf, output_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (output_channels) x 64 x 64 x 64
        )

    def forward(self, input):
        # Ensure input is 5D: (N, C, D, H, W) where D,H,W are 1 for the latent vector
        if input.dim() == 2: # If input is just (N, latent_dim)
            input = input.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        elif input.dim() != 5:
             raise ValueError(f"Generator expected 2D or 5D input, got {input.dim()}D")
        return self.main(input)

# Define the Discriminator for 3D volumes (Matches TARGET_SIZE 64x64x64)
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        ndf = 64 # Number of discriminator features in the first conv layer
        self.main = nn.Sequential(
            # input is (input_channels) x 64 x 64 x 64
            nn.Conv3d(input_channels, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32 x 32
            nn.Conv3d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16 x 16
            nn.Conv3d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8 x 8
            nn.Conv3d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm3d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4 x 4
            nn.Conv3d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output size: N x 1 x 1 x 1
        )

    def forward(self, input):
        # Input should be 5D: (N, C, D, H, W)
        if input.dim() != 5:
             raise ValueError(f"Discriminator expected 5D input (N, C, D, H, W), but got {input.dim()}D input with shape {input.shape}")
        return self.main(input)


# Custom Dataset for .nii files
class NiftiDataset(Dataset):
    def __init__(self, root_dir, target_size=(64, 64, 64)): # Add target size for resizing
        self.root_dir = root_dir
        self.target_size = target_size
        self.image_domain = 'mr' # Choose 'mr' or 'ct' for simple GAN
        print(f"NiftiDataset: Loading '{self.image_domain}' images from {root_dir} and resizing to {target_size}")

        self.file_paths = []
        if not os.path.isdir(root_dir):
            raise FileNotFoundError(f"Dataset directory not found: {root_dir}")

        for pair_folder in os.listdir(root_dir):
            pair_dir_path = os.path.join(root_dir, pair_folder)
            if os.path.isdir(pair_dir_path):
                 # Try finding the pair ID robustly
                 parts = pair_folder.split('_')
                 pair_id = parts[-1] if parts[-1].isdigit() else None
                 if pair_id is None and len(parts) > 1 and parts[-2].isdigit(): # Handle cases like 'pair_id_extra'
                     pair_id = parts[-2]
                 elif pair_id is None: # Fallback if no number found at end
                     pair_id = pair_folder # Use folder name as ID if parsing fails

                 if pair_id:
                    # Construct path based on chosen domain
                    img_path = None # Initialize path
                    if self.image_domain == 'mr':
                        img_fname = f'mr_image_{pair_id}.nii'
                        # Handle potential variations like .nii.gz
                        potential_path = os.path.join(pair_dir_path, img_fname)
                        if os.path.exists(potential_path):
                            img_path = potential_path
                        # No need to check for .gz if user confirmed only .nii

                    elif self.image_domain == 'ct':
                        img_fname = f'registered_ct_image_{pair_id}.nii'
                        potential_path = os.path.join(pair_dir_path, img_fname)
                        if os.path.exists(potential_path):
                            img_path = potential_path
                        # No need to check for .gz if user confirmed only .nii
                    else:
                        raise ValueError("image_domain must be 'mr' or 'ct'")

                    if img_path: # Check if a valid path was found
                        self.file_paths.append(img_path)
                    #else:
                    #    print(f"Warning: Expected file not found for ID {pair_id} in {pair_dir_path}")
                 else:
                    print(f"Warning: Could not determine valid pair ID for folder: {pair_folder}")


        if not self.file_paths:
             print(f"Warning: No '{self.image_domain}' image files found in the expected format within {root_dir}")

        print(f"NiftiDataset: Found {len(self.file_paths)} image files.")


    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]

        try:
            img_nii = nib.load(img_path, mmap=False) # Disable memory mapping
            img_data = img_nii.get_fdata(dtype=np.float32, caching='unchanged')
        except Exception as e:
            print(f"Warning: Error loading NIfTI file {img_path}: {e}")
            return None # Handle loading errors

        # --- Data Preprocessing ---
        # 1. Convert to Tensor
        img_tensor = torch.from_numpy(img_data) # Already float32

        # 2. Add Channel Dimension (Ensure it's C, D, H, W)
        if img_tensor.dim() == 3: # D, H, W -> C, D, H, W
            img_tensor = img_tensor.unsqueeze(0)
        elif img_tensor.dim() == 4:
             if img_tensor.shape[0] != 1 and img_tensor.shape[-1] == 1:
                 img_tensor = img_tensor.permute(3, 0, 1, 2)
             elif img_tensor.shape[0] != 1:
                 print(f"Warning: 4D tensor has >1 channel or channel not first/last for {img_path}, shape {img_tensor.shape}. Using first channel only.")
                 img_tensor = img_tensor[0:1, ...]
        if img_tensor.dim() != 4 or img_tensor.shape[0] != 1:
             print(f"Warning: Unexpected dimensions {img_tensor.dim()} or channels {img_tensor.shape[0]} for {img_path}. Skipping.")
             return None

        # 3. Resize to target_size using interpolation
        img_tensor_batched = img_tensor.unsqueeze(0)
        try:
            img_tensor_resized = F.interpolate(img_tensor_batched.float(), size=self.target_size, mode='trilinear', align_corners=False)
        except Exception as e:
             print(f"Warning: Error interpolating image {img_path} (shape: {img_tensor_batched.shape}) to {self.target_size}: {e}")
             return None
        img_tensor = img_tensor_resized.squeeze(0) # Remove batch dim

        # 4. Normalize to [-1, 1] for Tanh activation
        min_val = torch.min(img_tensor)
        max_val = torch.max(img_tensor)
        denominator = max_val - min_val
        if denominator > 1e-6:
             img_tensor = 2.0 * (img_tensor - min_val) / denominator - 1.0
        else:
             img_tensor = torch.zeros_like(img_tensor) - 1.0 # Normalize constant to -1

        return img_tensor

# Custom collate function to handle None values from dataset
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    try:
        first_shape = batch[0].shape
        for i, item in enumerate(batch):
            if item.shape != first_shape:
                print(f"Error in collate_fn: Tensor shape mismatch. Item {i} shape {item.shape} != first shape {first_shape}. Skipping batch.")
                return None
        return torch.stack(batch, dim=0)
    except Exception as e:
        print(f"Error during torch.stack in collate_fn: {e}")
        for i, item in enumerate(batch):
            if isinstance(item, torch.Tensor): print(f" Stack error - Batch item {i} shape: {item.shape}")
            else: print(f" Stack error - Batch item {i} type: {type(item)}")
        return None


# Function to calculate metrics (ensure shapes match)
def calculate_metrics(real, fake):
    try:
        if not isinstance(real, torch.Tensor) or not isinstance(fake, torch.Tensor):
             print("Warning: Non-tensor input to calculate_metrics.")
             return 0.0, 0.0
        real_np = real.detach().cpu().numpy()
        fake_np = fake.detach().cpu().numpy()
    except AttributeError as e:
        print(f"Warning: Could not convert tensors to numpy for metrics calculation: {e}")
        return 0.0, 0.0

    if real_np.ndim == 5: real_np = real_np.squeeze(0)
    if fake_np.ndim == 5: fake_np = fake_np.squeeze(0)
    if real_np.ndim == 4 and real_np.shape[0] == 1: real_np = real_np.squeeze(0)
    if fake_np.ndim == 4 and fake_np.shape[0] == 1: fake_np = fake_np.squeeze(0)

    if real_np.ndim != 3 or fake_np.ndim != 3:
        print(f"Warning: Cannot calculate metrics. Unexpected shapes after squeeze - Real: {real_np.shape}, Fake: {fake_np.shape}")
        return 0.0, 0.0

    try:
        real_np = real_np.astype(np.float64)
        data_range = real_np.max() - real_np.min()

        if data_range < 1e-9:
             if np.allclose(real_np, fake_np.astype(np.float64), atol=1e-6): return 1.0, 100.0
             else: return 0.0, 0.0

        fake_np_denorm = (fake_np.astype(np.float64) + 1.0) / 2.0 * data_range + real_np.min()

        min_dim = min(real_np.shape)
        win_size = min(7, min_dim)
        if win_size % 2 == 0: win_size -= 1

        if win_size < 3: ssim_value = 0.0
        else: ssim_value = ssim(real_np, fake_np_denorm, data_range=data_range, channel_axis=None, win_size=win_size, gaussian_weights=True, use_sample_covariance=False)

        psnr_value = psnr(real_np, fake_np_denorm, data_range=data_range)
        if np.isinf(psnr_value): psnr_value = 100.0

    except Exception as e:
        print(f"Error calculating metrics: {e}")
        print(f"Metric calculation error details - Real shape: {real_np.shape}, Fake shape: {fake_np.shape}, Denorm fake shape: {fake_np_denorm.shape if 'fake_np_denorm' in locals() else 'N/A'}, Data Range: {data_range if 'data_range' in locals() else 'N/A'}")
        return 0.0, 0.0

    return float(ssim_value), float(psnr_value)

# Function to save the model
def save_model(generator, discriminator, optimizer_G, optimizer_D, epoch, elapsed_time, latent_dim, target_size, save_dir=SAVE_DIR):
    """Saves the generator, discriminator, optimizers, and training state."""
    if not os.path.exists(save_dir):
        try:
            os.makedirs(save_dir)
            print(f"Created save directory: {save_dir}")
        except OSError as e:
            print(f"Error creating save directory {save_dir}: {e}")
            return

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    model_filename = f'gan_model_epoch{epoch}_time{elapsed_time/3600:.2f}h_{timestamp}.pth'
    model_path = os.path.join(save_dir, model_filename)

    try:
        # Save state dicts directly (no need to move models if loading on same device type)
        generator_state = generator.state_dict()
        discriminator_state = discriminator.state_dict()

        save_data = {
            'epoch': epoch,
            'generator_state_dict': generator_state,
            'discriminator_state_dict': discriminator_state,
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'elapsed_time_seconds': elapsed_time,
            'latent_dim': latent_dim,
            'target_size': target_size,
            'wandb_project': WANDB_PROJECT_NAME,
            'dataset_path': DATASET_PATH
        }
        torch.save(save_data, model_path)
        print(f"Model saved successfully to {model_path}")

    except Exception as e:
        print(f"Error saving model to {model_path}: {e}")


# Training loop
def train_gan(generator, discriminator, dataloader, num_epochs, device, latent_dim, target_size, save_interval_hours=8):
    global global_start_time # Make start_time accessible

    # Initialize W&B
    wandb_active = False
    run = None
    try:
        run = wandb.init(project=WANDB_PROJECT_NAME, config={
            "learning_rate": LEARNING_RATE, "beta1": BETA1, "batch_size": dataloader.batch_size if dataloader.batch_size else BATCH_SIZE,
            "num_epochs": num_epochs, "latent_dim": latent_dim, "target_image_size": target_size,
            "dataset_path": DATASET_PATH, "architecture": "SimpleDCGAN3D", "device": str(device)
        })
        print(f"WandB run initialized: {run.name} - {run.url}")
        wandb_active = True
    except Exception as e:
        print(f"Error initializing WandB: {e}. Training will continue without logging.")

    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

    # Fixed noise for visualization
    vis_batch_size = min(8, dataloader.batch_size if dataloader.batch_size else BATCH_SIZE)
    if vis_batch_size <= 0: vis_batch_size = 1
    fixed_noise = torch.randn(vis_batch_size, latent_dim, 1, 1, 1, device=device)

    global_start_time = time.time() # Record start time
    print(f"Starting training on {device}...")
    print(f"Target image size: {target_size}")
    print(f"Number of epochs: {num_epochs}")
    print(f"Batch size: {dataloader.batch_size if dataloader.batch_size else BATCH_SIZE}")
    print(f"Save interval: {save_interval_hours} hours ({save_interval_hours * 3600} seconds)")
    # last_save_time = global_start_time # Not needed for single save trigger

    # --- Training Loop ---
    current_epoch = 0 # Track epoch number (1-based)
    training_interrupted_by_time = False # Flag to indicate if time limit caused break
    try:
        for epoch in range(num_epochs):
            current_epoch = epoch + 1
            epoch_start_time = time.time()
            g_loss_epoch_total = 0.0
            d_loss_epoch_total = 0.0
            num_batches_processed = 0
            last_real_images_for_metrics = None # Store last valid real batch for metrics

            generator.train()
            discriminator.train()

            print(f"\n--- Starting Epoch {current_epoch}/{num_epochs} ---")

            # Check time before starting epoch batches
            elapsed_time = time.time() - global_start_time
            if elapsed_time >= save_interval_hours * 3600:
                 print(f"\n--- Save interval ({save_interval_hours} hours) reached BEFORE starting Epoch {current_epoch}. Saving model now. ---")
                 save_model(generator, discriminator, optimizer_G, optimizer_D, current_epoch -1 , elapsed_time, latent_dim, target_size)
                 print("--- Exiting training loop after saving ---")
                 training_interrupted_by_time = True
                 break # Exit the main training loop

            # --- Batch Loop ---
            for i, data in enumerate(dataloader, 0):
                if data is None: continue # Skip empty batches from collate_fn

                real_images = data.to(device)
                current_batch_size = real_images.size(0)
                if current_batch_size == 0: continue

                # Store the first valid batch for end-of-epoch metrics if needed
                if last_real_images_for_metrics is None:
                     last_real_images_for_metrics = real_images.detach() # Detach if only used for metrics

                if real_images.dim() != 5:
                     print(f"Warning: Skipping batch {i}. Expected 5D tensor, got {real_images.dim()}D with shape {real_images.shape}")
                     continue

                # --- Train Discriminator ---
                discriminator.zero_grad()
                # Real
                label_real = torch.full((current_batch_size,), 1.0, dtype=torch.float, device=device)
                output_real = discriminator(real_images).view(-1)
                errD_real = nn.BCELoss()(output_real, label_real)
                errD_real.backward()
                D_x = output_real.mean().item()
                # Fake
                noise = torch.randn(current_batch_size, latent_dim, 1, 1, 1, device=device)
                fake_images = generator(noise)
                label_fake = torch.full((current_batch_size,), 0.0, dtype=torch.float, device=device)
                output_fake = discriminator(fake_images.detach()).view(-1)
                errD_fake = nn.BCELoss()(output_fake, label_fake)
                errD_fake.backward()
                D_G_z1 = output_fake.mean().item()
                # Update D
                errD = errD_real + errD_fake
                optimizer_D.step()

                # --- Train Generator ---
                generator.zero_grad()
                label_real_for_G = torch.full((current_batch_size,), 1.0, dtype=torch.float, device=device)
                output_fake_for_G = discriminator(fake_images).view(-1)
                errG = nn.BCELoss()(output_fake_for_G, label_real_for_G)
                errG.backward()
                D_G_z2 = output_fake_for_G.mean().item()
                # Update G
                optimizer_G.step()

                # --- Logging & Tracking ---
                g_loss_epoch_total += errG.item()
                d_loss_epoch_total += errD.item()
                num_batches_processed += 1

                # Print progress periodically
                if i % 50 == 0 or i == len(dataloader) - 1:
                    print(f'  [{current_epoch}/{num_epochs}][{i}/{len(dataloader)-1}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} -> {D_G_z2:.4f}')
            # --- End of Batch Loop ---

            # --- End of Epoch Processing ---
            if num_batches_processed == 0:
                 print(f"Epoch [{current_epoch}/{num_epochs}] skipped - no valid batches processed.")
                 continue # Skip logging/saving for this epoch

            avg_g_loss = g_loss_epoch_total / num_batches_processed
            avg_d_loss = d_loss_epoch_total / num_batches_processed
            epoch_time = time.time() - epoch_start_time

            print(f"Epoch [{current_epoch}/{num_epochs}] completed in {epoch_time:.2f}s. Avg Loss_D: {avg_d_loss:.4f}, Avg Loss_G: {avg_g_loss:.4f}")

            # --- W&B Logging ---
            if wandb_active:
                log_dict = {
                    "Epoch": current_epoch, "Average Generator Loss": avg_g_loss, "Average Discriminator Loss": avg_d_loss,
                    "Epoch Time (s)": epoch_time, "D(x) (last batch avg)": D_x if 'D_x' in locals() else 0.0,
                    "D(G(z)) (last batch avg)": D_G_z2 if 'D_G_z2' in locals() else 0.0
                }

                # Log Images & Metrics periodically
                if current_epoch % 5 == 0 or epoch == num_epochs - 1:
                     generator.eval()
                     with torch.no_grad(): fake_fixed = generator(fixed_noise).detach()
                     generator.train()

                     # Log images
                     try:
                         if fake_fixed.shape[0] > 0:
                            middle_slice_idx = fake_fixed.shape[2] // 2
                            wandb_images = [wandb.Image(img[0, middle_slice_idx, :, :].cpu().numpy(), caption=f"Epoch {current_epoch}") for img in fake_fixed]
                            log_dict["Generated Images (Fixed Noise, Middle Slice)"] = wandb_images
                     except Exception as img_log_e: print(f"Warning: Error creating WandB images for epoch {current_epoch}: {img_log_e}")

                     # Calculate/log metrics
                     # Use the first image from the *first* valid batch stored, and first fixed fake image
                     if last_real_images_for_metrics is not None and last_real_images_for_metrics.shape[0] > 0 and fake_fixed.shape[0] > 0:
                         try:
                             ssim_val, psnr_val = calculate_metrics(last_real_images_for_metrics[0], fake_fixed[0])
                             log_dict["SSIM (Sample)"] = ssim_val; log_dict["PSNR (Sample)"] = psnr_val
                             print(f"Epoch {current_epoch} Sample Metrics - SSIM: {ssim_val:.4f}, PSNR: {psnr_val:.4f} dB")
                         except Exception as metric_e: print(f"Error calculating metrics for epoch {current_epoch}: {metric_e}"); log_dict["SSIM (Sample)"] = 0.0; log_dict["PSNR (Sample)"] = 0.0
                     else: print(f"Skipping metrics calculation for epoch {current_epoch}, no valid real/fake images available.")

                # Log dictionary to W&B
                try: wandb.log(log_dict, step=current_epoch)
                except Exception as log_e: print(f"  Error during wandb.log for Epoch {current_epoch}: {log_e}")

            # --- Check Save Interval (End of Epoch) ---
            # This check is slightly redundant if the check at the start of the epoch works,
            # but acts as a failsafe in case an epoch takes exactly long enough.
            elapsed_time = time.time() - global_start_time
            # total_time_hours = elapsed_time / 3600 # Calculate if needed for print
            if elapsed_time >= save_interval_hours * 3600:
                print(f"\n--- Save interval ({save_interval_hours} hours) reached END of Epoch {current_epoch}. Saving model. ---")
                save_model(generator, discriminator, optimizer_G, optimizer_D, current_epoch, elapsed_time, latent_dim, target_size)
                print("--- Exiting training loop after saving ---")
                training_interrupted_by_time = True
                break # Exit the main training loop

        # --- End of Epoch Loop (Natural Completion or Break) ---
        # No need for the erroneous break here anymore

    # --- End of Training Function Try Block ---
    except KeyboardInterrupt: # Handle manual interruption (Ctrl+C)
        print("\n--- Training interrupted manually (KeyboardInterrupt) ---")
        print("Attempting to save final model state...")
        current_elapsed_time = time.time() - global_start_time
        if 'optimizer_G' in locals() and 'optimizer_D' in locals():
             save_model(generator, discriminator, optimizer_G, optimizer_D, current_epoch, current_elapsed_time, latent_dim, target_size)
        else: print("Cannot save model state - optimizers not fully initialized.")
        training_interrupted_by_time = True # Set flag to indicate not a normal finish

    except Exception as train_e:
        print(f"\n--- An error occurred during training at Epoch {current_epoch}: {train_e} ---")
        traceback.print_exc() # Print detailed traceback
        print("Attempting to save model state due to error...")
        current_elapsed_time = time.time() - global_start_time
        if 'optimizer_G' in locals() and 'optimizer_D' in locals():
            save_model(generator, discriminator, optimizer_G, optimizer_D, current_epoch, current_elapsed_time, latent_dim, target_size)
        else: print("Cannot save model state - optimizers not fully initialized.")
        training_interrupted_by_time = True # Set flag to indicate not a normal finish
    finally:
        # This block executes whether the loop finished normally, broke, or had an exception
        print("\n--- Training Function Ended ---")
        # Save final model if training finished all epochs *without* being interrupted by time or error
        if not training_interrupted_by_time and current_epoch == num_epochs:
             print("Saving final model after completing all epochs...")
             final_elapsed_time = time.time() - global_start_time
             if 'optimizer_G' in locals() and 'optimizer_D' in locals():
                save_model(generator, discriminator, optimizer_G, optimizer_D, current_epoch, final_elapsed_time, latent_dim, target_size)
             else: print("Cannot save final model - optimizers not available.")

        if wandb_active and run:
            try:
                wandb.finish()
                print("WandB run finished.")
            except Exception as e:
                print(f"Error finishing WandB run: {e}")

# --- Main Execution ---
if __name__ == "__main__":
    # --- Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    global_start_time = 0.0 # Initialize; will be set in train_gan

    # --- Models ---
    print("Initializing models...")
    try:
        generator = Generator(LATENT_DIM, output_channels=OUTPUT_CHANNELS).to(device)
        discriminator = Discriminator(input_channels=OUTPUT_CHANNELS).to(device)
        print(f"Generator parameters: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}")
        print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}")
    except Exception as e:
        print(f"Error initializing models: {e}")
        exit(1)

    # --- Data ---
    print(f"\nLoading dataset from: {DATASET_PATH}, Target size: {TARGET_SIZE}")
    try:
        dataset = NiftiDataset(root_dir=DATASET_PATH, target_size=TARGET_SIZE)
        if len(dataset) == 0:
             raise ValueError("Dataset initialization resulted in zero valid files.")

        dataloader = DataLoader(dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True if device.type == 'cuda' else False,
                                collate_fn=collate_fn,
                                drop_last=True)

        print(f"Dataset loaded. Number of samples: {len(dataset)}. Number of batches: {len(dataloader)}")
        if len(dataloader) == 0:
             raise ValueError("DataLoader is empty after collation. Check dataset/processing steps.")

    except Exception as e:
        print(f"Error loading dataset or creating dataloader: {e}")
        traceback.print_exc()
        exit(1)

    # --- Train ---
    print("\nStarting training process...")
    train_gan(generator=generator,
              discriminator=discriminator,
              dataloader=dataloader,
              num_epochs=NUM_EPOCHS,
              device=device,
              latent_dim=LATENT_DIM,
              target_size=TARGET_SIZE,
              save_interval_hours=SAVE_INTERVAL_HOURS)

    print("\nMain script execution finished.")


In [None]:
# Modified test evaluation code
print("\nLoading test dataset...")
test_dataset = NiftiDataset(
    root_dir='/kaggle/input/mrict-test/processed_images',
    target_size=(64, 64, 64),
    modality='mr',
    paired_images=True
)

# Set batch size to match dataset size
test_batch_size = 8  # Divides evenly into 29 (8*3 + 5*1 = 29)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

print(f"Test samples: {len(test_dataset)}, Batches: {len(test_loader)}")

# Evaluation loop
generator.eval()
total_ssim = 0
total_psnr = 0
total_samples = 0  # Track actual number of processed samples

with torch.no_grad():
    for i, (real_mr, real_ct) in enumerate(test_loader):
        current_batch_size = real_mr.size(0)  # Actual samples in this batch
        real_mr = real_mr.to(device)
        fake_ct = generator(real_mr)
        
        # Convert tensors to numpy arrays
        fake_ct_np = fake_ct.detach().cpu().numpy()
        real_ct_np = real_ct.detach().cpu().numpy()
        
        # Calculate metrics for current batch
        batch_ssim = ssim(fake_ct_np, real_ct_np) * current_batch_size
        batch_psnr = psnr(fake_ct_np, real_ct_np) * current_batch_size
        
        total_ssim += batch_ssim
        total_psnr += batch_psnr
        total_samples += current_batch_size

        # Save all test images with original filenames
        for j in range(current_batch_size):
            idx = i * test_batch_size + j
            orig_filename = test_dataset.filenames[idx]
            save_filename = f"test_result_{orig_filename.split('.')[0]}.png"
            save_image(fake_ct[j], save_filename)

# Calculate weighted averages
avg_ssim = total_ssim / total_samples
avg_psnr = total_psnr / total_samples
print(f"\nTest Results (n={total_samples}) - SSIM: {avg_ssim:.4f}, PSNR: {avg_psnr:.2f} dB")