# Data Generation

In [None]:
## Blob growth simulation

import numpy as np
import random
from scipy.ndimage import gaussian_filter, binary_dilation, uniform_filter, binary_erosion
import matplotlib.pyplot as plt # ADDED for visualization
from tqdm import tqdm # ADDED for progress tracking
from pathlib import Path

# --- Set seed ---
# SEED = 2024 # Used for test data
SEED = 2025 # Used for training data
np.random.seed(SEED)
random.seed(SEED)
print(f"Random seeds set to {SEED} for reproducibility.")

# --- FAF Configuration ---
TARGET_SIZE = 256
NUM_PEAKS = 50 
GROWTH_RATE_MIN = 1 # Not used anymore, kept for reference
GROWTH_RATE_MAX = 10 # Not used anymore, kept for reference
NUM_SAMPLES = 1 # Target dataset size (Set to 1 for quick test)

# --- Anisotropic Growth Configuration ---
ANISO_GROWTH_STEPS = 5              # Number of smaller, randomized steps per frame
ANISO_DIR_STRENGTH = 0.6            # Bias factor for directional growth (0.0 to 1.0, 1.0 is max anisotropy)
ANISO_JAGGEDNESS_FREQ = 0.75        # Probability of applying the jagged erosion/dilation cycle (was 0.75)

# --- Vein Configuration ---
VEIN_VALUE_FLOAT = 0.01      # DECREASED for darker/more visible veins
BACKGROUND_VALUE_FLOAT = 1.0
VEIN_SMOOTHING_SIGMA = 6.0   # INCREASED further for thicker veins
VEIN_WALK_LENGTH = 75        # INCREASED for longer veins
VEIN_STEP_LENGTH = 5.0      

# --- HELPER FUNCTIONS ---

def generate_vein_mask(size=TARGET_SIZE):
    """Generates a static vein structure."""
    vein_scratchpad = np.zeros((size, size), dtype=np.float32)
    num_start_locations = 12 # INCREASED for more common veins
    zones = [
        (1, 127, 1, 127), (128, 255, 1, 127), (1, 127, 128, 255),
        (128, 255, 128, 255), (64, 192, 64, 192)
    ]
    for i in range(num_start_locations):
        x_min, x_max, y_min, y_max = zones[i % len(zones)]
        x, y = random.uniform(x_min, x_max), random.uniform(y_min, y_max)
        angle = random.uniform(0, 2 * np.pi)
        for _ in range(VEIN_WALK_LENGTH):
            angle += random.uniform(-np.pi / 8, np.pi / 8) 
            x_new = x + VEIN_STEP_LENGTH * np.sin(angle)
            y_new = y - VEIN_STEP_LENGTH * np.cos(angle)
            ix, iy = int(np.clip(x, 0, size - 1)), int(np.clip(y, 0, size - 1))
            vein_scratchpad[iy, ix] = 1.0
            x, y = x_new, y_new
            if not (0 <= x < size and 0 <= y < size):
                break

    vein_filtered = gaussian_filter(vein_scratchpad, sigma=VEIN_SMOOTHING_SIGMA)
    if vein_filtered.max() > 0:
        vein_filtered /= vein_filtered.max()
    vein_mask = BACKGROUND_VALUE_FLOAT - (vein_filtered * (BACKGROUND_VALUE_FLOAT - VEIN_VALUE_FLOAT))
    return vein_mask

def generate_initial_solid_blob(size=TARGET_SIZE):
    """Generates the initial field and starting mask with high size variation."""
    base_field = np.zeros((size, size), dtype=np.float32)
    for _ in range(NUM_PEAKS):
        center_x = np.random.randint(size * 0.15, size * 0.85)
        center_y = np.random.randint(size * 0.15, size * 0.85)
        peak_height = np.random.uniform(0.3, 1.2) 
        peak_sigma = np.random.uniform(3, 25) 
        peak_mask = np.zeros((size, size)); peak_mask[center_y, center_x] = peak_height
        peak_mound = gaussian_filter(peak_mask, sigma=peak_sigma)
        base_field += peak_mound

    if base_field.max() > 0:
        base_field /= base_field.max()
        
    # MODIFIED: Increased max percentile from 30 to 50 for a higher chance of smaller lesions.
    random_threshold = np.percentile(base_field[base_field > 0.01], np.random.uniform(5, 90))  
    # MODIFIED: Increased max clip to allow higher thresholds resulting in smaller masks.
    threshold = np.clip(random_threshold, 0.03, 0.45) 

    mask = (base_field > threshold).astype(np.float32)
    return base_field, mask

def generate_growing_masks(initial_mask, num_frames=4):
    """Generates a sequence of masks with anisotropic, non-uniform growth and jagged edges."""
    masks = [initial_mask]
    current_mask = initial_mask.astype(bool)

    # Directional kernels (Anisotropic Growth)
    k_north = np.array([[0, 1, 0], [0, 1, 0], [0, 0, 0]])
    k_south = np.array([[0, 0, 0], [0, 1, 0], [0, 1, 0]])
    k_east = np.array([[0, 0, 0], [0, 1, 1], [0, 0, 0]])
    k_west = np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]])
    k_jagged = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
    
    directional_kernels = [k_north, k_south, k_east, k_west]

    # Pre-select a *main* growth direction for this lesion (anisotropy)
    main_direction_idx = random.randint(0, 3)

    for _ in range(1, num_frames):
        # 1. Non-uniform, Anisotropic Growth
        for step in range(ANISO_GROWTH_STEPS):
            # Biased random kernel selection
            if random.random() < ANISO_DIR_STRENGTH:
                # Use the pre-selected main direction for stronger growth
                kernel = directional_kernels[main_direction_idx]
            else:
                # Use a random direction for non-uniformity
                kernel = random.choice(directional_kernels)
                
            # Apply a small dilation step
            current_mask = binary_dilation(current_mask, structure=kernel, iterations=1)

        # 2. Jaggedness Cycle 
        if random.random() < ANISO_JAGGEDNESS_FREQ:
            # Erosion to remove small isolated points and introduce jaggedness
            next_mask_eroded = binary_erosion(current_mask, structure=k_jagged, iterations=1)
            # Dilation to fill in some gaps and keep the growth
            current_mask = binary_dilation(next_mask_eroded, structure=k_jagged, iterations=1)

        masks.append(current_mask.astype(np.float32))

    return masks

def generate_edge_artifact(size=TARGET_SIZE, intensity=0.3):
    """
    Generates a dark, sharp, lesion-like blob near one edge.
    (UPDATED: Larger and Closer to the Edge)
    """
    artifact_field = np.zeros((size, size), dtype=np.float32)
    
    side = np.random.randint(4)
    
    # CRITICAL CHANGE 1: Define a smaller peripheral zone (10% border)
    border_zone = int(size * 0.10) 
    
    # 1. Determine the center of the artifact in the border zone
    # Logic remains similar, but uses the smaller border_zone
    if side == 0: # Top
        center_y = np.random.randint(0, border_zone)
        center_x = np.random.randint(border_zone, size - border_zone)
    elif side == 1: # Bottom
        center_y = np.random.randint(size - border_zone, size)
        center_x = np.random.randint(border_zone, size - border_zone)
    elif side == 2: # Left
        center_y = np.random.randint(border_zone, size - border_zone)
        center_x = np.random.randint(0, border_zone)
    else: # Right
        center_y = np.random.randint(border_zone, size - border_zone)
        center_x = np.random.randint(size - border_zone, size)

    # 2. Generate a single smooth Gaussian blob
    peak_mask = np.zeros((size, size))
    peak_mask[center_y, center_x] = intensity 
    
    # CRITICAL CHANGE 2: INCREASED MAX SIGMA for a larger artifact
    peak_sigma = np.random.uniform(8, 18) 
    artifact_mound = gaussian_filter(peak_mask, sigma=peak_sigma)
    
    if artifact_mound.max() > 0:
        artifact_mound /= artifact_mound.max()
        
    return artifact_mound * intensity

def generate_static_faf_background(size=TARGET_SIZE, base_field=None):
    """Pre-calculates the static, textured background layer for speed and includes the optional edge artifact."""
    if base_field is None:
        base_field = np.ones((size, size))    
        
    background_texture = np.zeros((size, size))
    for _ in range(5):    
        rand_sigma = np.random.uniform(20, 100)
        rand_intensity = np.random.uniform(0.05, 0.15)
        background_texture += gaussian_filter(np.random.rand(size, size) * rand_intensity, sigma=rand_sigma)
    background_texture -= background_texture.mean()      
    background_texture *= 8.0    
    
    MIN_TISSUE_FAF = 0.65
    MAX_TISSUE_FAF = 0.75
    FAF_base_smooth = MIN_TISSUE_FAF + (1 - base_field) * (MAX_TISSUE_FAF - MIN_TISSUE_FAF)      
    FAF_tissue_area = FAF_base_smooth + background_texture
    
    # --- NEW INTEGRATION: Optional Edge Artifact (50% Chance) ---
    if np.random.rand() < 0.5:
        # Generate a dark artifact with random darkness
        # INCREASED MAX DARKNESS from 0.35 to 0.50
        darkness = np.random.uniform(0.15, 0.50)
        edge_artifact = generate_edge_artifact(size=size, intensity=darkness)
        
        # Subtract the dark artifact from the background
        FAF_tissue_area -= edge_artifact
    # -------------------------------------------------------------
    
    # Ensure the background remains positive after subtraction
    FAF_tissue_area = np.clip(FAF_tissue_area, 0.0, 1.0)
    
    return FAF_tissue_area.astype(np.float32)

def add_fine_speckle_noise(image, noise_factor=0.02):
    """Adds very fine, granular speckle noise to the image."""
    noise = (np.random.rand(*image.shape) - 0.5) * noise_factor
    return np.clip(image + noise, 0.0, 1.0)

def generate_synthetic_faf(static_faf_background, mask, size=TARGET_SIZE, lesion_contrast_factor=0.7):
    """Generates FAF image for a single frame with dynamic, size-dependent smoothing."""
    FAF_final = static_faf_background.copy()
    
    # ----------------------------------------------------------------------
    # *** CRITICAL FIX: DYNAMICALLY SCALE SMOOTHING SIGMA ***
    mask_area = mask.sum()
    if mask_area > 0:
        # Scale sigma based on area, clamping to ensure it's not too small or too large
        # We use cbrt(area) for a gentler scaling than sqrt(area)
        # 1. Calculate base sigma (e.g., cbrt(area) / 2)
        base_sigma = np.cbrt(mask_area) * 0.5 
        
        # 2. Clamp the sigma: Min 8.0 (for jaggedness), Max 15.0 (to prevent over-smoothing large lesions)
        dynamic_sigma = np.clip(base_sigma, 8.0, 15.0) 
    else:
        dynamic_sigma = 8.0 # Fallback for an empty mask
    # ----------------------------------------------------------------------

    # Internal Lesion Splotches (omitted for brevity)
    lesion_splotches = np.zeros_like(mask)
    if mask.sum() > 0:
        num_splotches = np.random.randint(3, 6)
        for _ in range(num_splotches):
             center_x = np.random.randint(0, size); center_y = np.random.randint(0, size)
             splotch_sigma = np.random.uniform(8, 15); splotch_val = np.random.uniform(-0.2, 0.2)
             splotch_mask = np.zeros((size, size)); splotch_mask[center_y, center_x] = splotch_val 
             lesion_splotches += gaussian_filter(splotch_mask, sigma=splotch_sigma)
        lesion_splotches = np.clip(lesion_splotches, -0.1, 0.15)   
        lesion_splotches *= mask   

    # Final Assembly  
    DARKNESS_FACTOR = 0.55       
    # Apply the DYNAMICALLY CALCULATED sigma
    SMOOTHED_MASK = np.clip(gaussian_filter(mask, sigma=dynamic_sigma), 0.0, 1.0) 
    FAF_final -= SMOOTHED_MASK * DARKNESS_FACTOR
    
    # Noise, Texture, and Contrast Stretch (omitted for brevity)
    local_intensity_map = gaussian_filter(np.random.randn(size, size), sigma=50) * 0.1
    FAF_final += local_intensity_map
    TEXTURE_AMPLITUDE = 0.08
    structured_noise_low = uniform_filter(np.random.rand(size, size), size=int(size/64)) * TEXTURE_AMPLITUDE * 0.5
    structured_noise_high = gaussian_filter(np.random.randn(size, size), sigma=0.7) * TEXTURE_AMPLITUDE * 0.5
    FAF_final += structured_noise_low + structured_noise_high
    
    background_region = FAF_final[mask == 0]    
    if background_region.size > 0:
        bg_min = np.percentile(background_region, 10); bg_max = np.percentile(background_region, 90)
        if bg_max > bg_min:
            stretched_background = (background_region - bg_min) / (bg_max - bg_min)
            stretched_background = stretched_background * 0.2 + 0.55
            FAF_final[mask == 0] = stretched_background
    
    CONTRAST_POWER_FACTOR = 1.0       
    FAF_final = np.power(FAF_final, CONTRAST_POWER_FACTOR)
    
    FAF_normalized = np.clip(FAF_final, 0.0, 1.0)      

    return FAF_normalized.astype(np.float32)

def apply_soft_vignette(image, strength=0.08, sigma_ratio=0.3):
    """Applies a smooth, soft vignette effect to fade the edges."""
    size = image.shape[0]
    center_y, center_x = size // 2, size // 2
    y, x = np.ogrid[-center_y:size-center_y, -center_x:size-center_x]
    
    max_dist = np.sqrt((size/2)**2 + (size/2)**2)
    distance_map = np.sqrt(x*x + y*y) / max_dist
    
    sigma = size * sigma_ratio
    weight = np.exp(-(distance_map**2) / (2 * (sigma/size)**2))
    
    weight = (weight - weight.min()) / (weight.max() - weight.min()) * strength + (1.0 - strength/2)
    
    return np.clip(image * weight, 0.0, 1.0)

def generate_single_simulation():
    """
    Runs one simulation and returns the stacked data for 4 frames: (3, 4, 256, 256).
    """
    # 1. Initialization
    base_field, initial_mask = generate_initial_solid_blob(TARGET_SIZE)

    # 2. Pre-calculate Static Components
    static_faf_background = generate_static_faf_background(TARGET_SIZE, base_field)
    vein_mask = generate_vein_mask(TARGET_SIZE)

    # 3. Generate Masks and Residuals
    all_masks = generate_growing_masks(initial_mask, num_frames=4)

    residual_masks = [np.zeros_like(initial_mask)] 
    for i in range(1, 4):
        # Calculate the difference and ensure it's a binary mask of new growth
        residual = (all_masks[i] - all_masks[i-1]) > 0
        residual_masks.append(residual.astype(np.float32))

    # 4. Generate FAF Images
    final_images = []
    for i, mask in enumerate(all_masks):
        # Calls the function with the lesion edge fix (sigma=10.0)
        faf_image = generate_synthetic_faf(static_faf_background, mask)
        
        # Overlay Veins
        final_overlay = np.minimum(faf_image, vein_mask)

        # Add Fine Speckle Noise
        final_artifact_image = add_fine_speckle_noise(final_overlay, noise_factor=0.02)

        # Apply Vignette for global edge smoothness (to prevent artifacts at the image border)
        final_vignetted_image = apply_soft_vignette(final_artifact_image) # <-- ADDED

        final_images.append(final_vignetted_image) # <-- APPENDS VIGNETTED IMAGE
    
    # 5. Stack into (3, 4, 256, 256)
    stacked_frames = np.stack([
        np.array(final_images),     # Channel 0: FAF Images
        np.array(all_masks),        # Channel 1: Lesion Masks
        np.array(residual_masks)    # Channel 2: Residual Masks
    ], axis=0) # Stacks along the channel dimension (axis 0)
    
    return stacked_frames

# --- MAIN DATASET GENERATION FUNCTION ---

def generate_synthetic_faf_dataset(num_samples=NUM_SAMPLES):
    """
    Generates a dataset of synthetic FAF growth simulations.

    Args:
        num_samples (int): The number of independent simulations to generate.

    Returns:
        np.ndarray: A dataset array of shape (num_samples, 3, 4, 256, 256).
                    Channels: [FAF_Image, Lesion_Mask, Residual_Mask].
    """
    dataset_list = []
    
    # USING TQDM FOR PROGRESS TRACKING
    print(f"Generating {num_samples} synthetic FAF simulations with ANISOTROPIC growth...")

    for i in tqdm(range(num_samples), desc="Building FAF Dataset"):
        try:
            sample_data = generate_single_simulation()
            dataset_list.append(sample_data)
        except Exception as e:
            # Handle rare cases where random generation might cause an error (e.g., empty mask percentile)
            # Use tqdm.write to avoid interfering with the progress bar
            tqdm.write(f"Warning: Simulation {i} failed with error: {e}. Retrying.")
            num_samples += 1
            continue

    # Stack all simulations along the first dimension (axis 0) to get the final shape
    final_dataset = np.stack(dataset_list, axis=0)
    
    tqdm.write(f"\nGeneration complete. Final dataset shape: {final_dataset.shape}")
    
    return final_dataset

In [None]:
# Example usage: Generate the dataset and print the final shape
synthetic_dataset = generate_synthetic_faf_dataset(num_samples=NUM_SAMPLES)

# --- Visualization of an Example ---
if synthetic_dataset.shape[0] > 0:
    sample_index = 0
    example_sample = synthetic_dataset[sample_index] # Shape (3, 4, 256, 256)

    print(f"\nDisplaying Example Sample {sample_index} from the dataset (FAF, Mask, Residual over 4 frames):")

    fig, axes = plt.subplots(3, 4, figsize=(16, 12)) 
    fig.suptitle(f'Example Synthetic FAF Sample {sample_index} | ANISOTROPIC Growth', fontsize=16)

    titles = ["FAF Image (Channel 0)", "Lesion Mask (Channel 1)", "Residual Mask (Channel 2)"]

    for channel in range(3):
        for frame in range(4):
            img = example_sample[channel, frame]
            vmin = 0.0
            vmax = 1.0

            # For masks, show the new growth in a distinct color for better visualization
            if channel == 2:
                # Residual mask - show new growth in bright color
                axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
            elif channel == 1:
                 # Lesion mask - show in a binary color
                axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
            else:
                # FAF image
                axes[channel, frame].imshow(img, cmap='gray', vmin=vmin, vmax=vmax)

            axes[channel, frame].set_title(f'{titles[channel]} - Frame {frame+1}')    
            axes[channel, frame].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
else:
    print("Dataset generation completed, but no samples were available to display.")

## Generate Test Set

In [None]:
# # --- Execution and Saving Logic ---

# # 1. Generate the dataset
# NUM_SAMPLES = 100

# synthetic_dataset = generate_synthetic_faf_dataset(num_samples=NUM_SAMPLES)

# # 2. Define the save path
# save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Synthetic Data')
# file_path = save_dir / 'synthetic_faf_dataset.npy'

# # 3. Create the directory if it doesn't exist
# save_dir.mkdir(parents=True, exist_ok=True)
# print(f"Saving dataset to: {file_path}")

# # 4. Save the dataset using NumPy
# np.save(file_path, synthetic_dataset)
# print("Dataset successfully saved!")

# # --- Visualization of an Example (Optional) ---
# if synthetic_dataset.shape[0] > 0:
#     sample_index = 0
#     example_sample = synthetic_dataset[sample_index] # Shape (3, 4, 256, 256)

#     print(f"\nDisplaying Example Sample {sample_index} from the dataset (FAF, Mask, Residual over 4 frames):")

#     fig, axes = plt.subplots(3, 4, figsize=(16, 12))
#     fig.suptitle(f'Example Synthetic FAF Sample {sample_index} | ANISOTROPIC Growth + Enhanced Noise', fontsize=16)

#     titles = ["FAF Image (Channel 0)", "Lesion Mask (Channel 1)", "Residual Mask (Channel 2)"]

#     for channel in range(3):
#         for frame in range(4):
#             img = example_sample[channel, frame]
#             vmin = 0.0
#             vmax = 1.0

#             if channel == 2:
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
#             elif channel == 1:
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
#             else:
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=vmin, vmax=vmax)

#             axes[channel, frame].set_title(f'{titles[channel]} - Frame {frame+1}')
#             axes[channel, frame].axis('off')

#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#     plt.show()
# else:
#     print("Dataset generation completed, but no samples were available to display.")

In [None]:
# # --- Code to Display Multiple Samples ---

# # Define how many samples you want to visualize
# NUM_SAMPLES_TO_DISPLAY = NUM_SAMPLES

# for sample_index in range(min(NUM_SAMPLES_TO_DISPLAY, synthetic_dataset.shape[0])):
#     example_sample = synthetic_dataset[sample_index] # Shape (3, 4, 256, 256)

#     print(f"\nDisplaying Example Sample {sample_index} from the dataset (FAF, Mask, Residual over 4 frames):")

#     fig, axes = plt.subplots(3, 4, figsize=(16, 12))
#     fig.suptitle(f'Example Synthetic FAF Sample {sample_index} | ANISOTROPIC Growth + Enhanced Noise', fontsize=16)

#     titles = ["FAF Image (Channel 0)", "Lesion Mask (Channel 1)", "Residual Mask (Channel 2)"]

#     for channel in range(3):
#         for frame in range(4):
#             img = example_sample[channel, frame]
#             vmin = 0.0
#             vmax = 1.0

#             if channel == 2:
#                 # Residual mask
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
#             elif channel == 1:
#                 # Lesion mask
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
#             else:
#                 # FAF image
#                 axes[channel, frame].imshow(img, cmap='gray', vmin=vmin, vmax=vmax)

#             axes[channel, frame].set_title(f'{titles[channel]} - Frame {frame+1}')
#             axes[channel, frame].axis('off')

#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#     plt.show()

# if synthetic_dataset.shape[0] == 0:
#     print("Dataset generation completed, but no samples were available to display.")

# print(f"Finished displaying {min(NUM_SAMPLES_TO_DISPLAY, synthetic_dataset.shape[0])} samples.")

# Pretraining

In [None]:
## Setup

%cd /Users/Pracioppo/Desktop/GA Forecasting

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import functools
from torch.nn import init
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau 

import numpy as np
import argparse
import random
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from pathlib import Path
from datetime import datetime
import scipy.io as sio
from PIL import Image
import cv2
from tabulate import tabulate
from collections import defaultdict

# Assuming these utilities are available as imported
from torch.utils.data import Dataset, DataLoader, random_split, Subset 

# ---------------------------------------

from preprocessing_utils import f_rescale_dataset, f_Residuals, f_reshape_training_data, f_rotate_and_zoom, f_random_crop, f_rotate_and_zoom_all, f_crop_all, f_flip_all, f_augment_dataset2

from data_utils import DataWrapper, visualize_sample, compare_split_masks

from visualization import f_display_autoencoder, plot_log_loss, f_display_frames

from models import init_weights, count_parameters
from models import rotate_half, RotaryPositionalEmbedding, RoPEMultiheadAttention, RoPETransformerEncoderLayer, ResidualBlock, ChannelReducer, Unet_Enc, Unet_Dec, U_Net_AE

from augmentation_utils import f_augment_spatial_and_intensity
from training_utils import dsc, dice_loss, GDLoss
from training_utils import freeze_batch_norm, f_single_epoch_AE, f_single_epoch_spatiotemporal, calculate_total_loss, f_single_epoch_spatiotemporal_accumulated
from training_utils import save_model_weights, load_model
from eval_utils import f_eval_pred_dice_test_set, f_eval_pred_dice_train_set, plot_train_test_dice_history, soft_dice_score, f_get_individual_dice, f_plot_individual_dice

from models import DynNet, CausalConvAggregator, UPredNet, FusionBlockBottleneck, ChannelFusionBlock, LocalSpatioTemporalMixer, SpatioTemporalGatedMixer, AxialTemporalSWAInterleavedLayer, InterleavedAxialTemporalSWAIntegrator, SlidingWindowAttention, SWAU_Net, SWAU_CFB_Ablation, SWAU_DynNet_Ablation
from models import ConvLSTMCell, ConvLSTMCore, ConvLSTMBaseline, ConvLSTM_Simple
from models import InterleavedAxialTemporalRKAIntegrator, RKAFeatureAggregator, RKAU_Net
from models import AxialMultiheadAttention, StandardAxialInterleavedLayer, StandardAxialIntegrator, AxialU_Net
from models import RKA_MultiheadAttention_Fast, AxialTemporalRKAInterleavedLayer_Fast, InterleavedAxialTemporalRKAIntegrator_Fast, RKAFeatureAggregator_Fast, RKAU_Net_Fast
from models import CNN_Unet_Enc, CNN_Unet_Dec, CNN_U_Net_AE, CNN_DynNet, SWAU_Net_CNN
    
# ---------------------------------------

ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')
tensorboard_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')
                            
start_epoch = 0

resume_ckpt = None

summary_writer = SummaryWriter(tensorboard_save_dir.absolute().as_posix())


# --- SETUP ---
# Define a simple placeholder for command-line arguments and configuration
parser = argparse.ArgumentParser('AE Model Args')
args = parser.parse_args(args=[])

# Defining essential arguments (set to match your intended lightweight AE setup)
args.N = 4                       # CRITICAL FIX: Batch size is 4 for memory safety
args.nhead = 4                   # CRITICAL FIX: Reduced heads from 8 to 4
args.d_attn1 = 192               # FFN dimension for L3 (112 channels)
args.d_attn2 = 384               # FFN dimension for L4 (224 channels)
args.img_channels = 3            # Three grayscale images (FAF, masks, growth masks)
args.img_sz = 256                # Image size 256x256
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BASE_CHANNELS = 24               # Reduced from 28 to 24 for extra parameter savings

# Training loop arguments
args.num_epochs = 1             # UPDATED: Set to 1 epoch as requested
args.show_example_epochs = 5
args.batch_size = args.N        # Batch size for iteration is args.N
args.num_t_steps = 4            # Time steps (used only for data simulation/flattening)

# Initialize paths (for saving checkpoints)
resume_AE_ckpt = Path('./ae_checkpoints')

print(torch.__version__)

torch.cuda.is_available()

args.device = torch.device('cuda:0')
print(f"Using {args.device} device")

## Load the Data

In [None]:
DATA_DIR = Path('/Users/Pracioppo/Desktop//GA Forecasting//Synthetic Data')
print("Loading synthetic test data...")

# PATH TO THE SAVED SYNTHETIC TEST DATA
SYNTHETIC_TEST_PATH = DATA_DIR / 'synthetic_faf_dataset.npy'

# Load the single NumPy array
synthetic_data_np = np.load(SYNTHETIC_TEST_PATH) 

# Convert to PyTorch tensor and ensure float32
synthetic_data_tensor = torch.from_numpy(synthetic_data_np.astype(np.float32))

# Shape of loaded data: (N_test_samples, 3, 4, 256, 256)
# N_test_samples = 140 (as per previous K-Fold logic for fold 0)
N_TEST_SAMPLES = synthetic_data_tensor.shape[0]
C = synthetic_data_tensor.shape[1] # 3: [FAF, Mask, Residual]
T = synthetic_data_tensor.shape[2] # 4: Frames

# Split the single tensor into FAFs, Masks, and Residuals
# We assume the saved order is: Channel 0: FAFs, Channel 1: Masks, Channel 2: Residuals
FAFs_test = synthetic_data_tensor[:, 0:1, :, :, :] # [N, 1, T, H, W]
masks_test = synthetic_data_tensor[:, 1:2, :, :, :] # [N, 1, T, H, W]
residuals_test = synthetic_data_tensor[:, 2:3, :, :, :] # [N, 1, T, H, W]

# Global normalization (if not already done during save)
# Assuming FAFs and masks were normalized [0, 1] before saving, 
# but a re-check doesn't hurt.
masks_test /= torch.max(masks_test) if torch.max(masks_test) > 0 else 1.0
FAFs_test /= torch.max(FAFs_test) if torch.max(FAFs_test) > 0 else 1.0
# Residuals are already relative, no max norm needed.


# --- 3. Direct Test Dataset Creation ---

# Instantiate the DataWrapper with the loaded test features
test_dataset = DataWrapper(FAFs_test, masks_test, residuals_test)
print(f"Test Dataset size (samples): {len(test_dataset)}")
print(f"Test FAFs tensor size: {FAFs_test.size()}")

# --- K-Fold Split Result (Removed K-Fold Logic) ---

print("--- Data Loading Result (Synthetic Test Only) ---")
print(f"Test Dataset Samples: {len(test_dataset)} (Expected 140)")

# Placeholder for train_dataset (since we don't load it yet)
train_dataset = None 


# --- 4. DATA ASSEMBLY FOR MANUAL ITERATION (Train Data Removed) ---
# Skip loading and assembling training data (full_clean_data_tensor_cpu)
# since we are only dealing with the test set for now.


# --- DataLoader Setup (Only for the Test/Validation Set) ---

test_loader = DataLoader(
    test_dataset,
    batch_size=args.N, # Use args.N (batch_size for GPU)
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\nTest DataLoader setup complete with batch size: {args.N}")


# --- UPDATED VISUALIZATION EXAMPLE (Side-by-Side) ---

# Example function call
# Note: Since 'train_dataset' is None, we need to adjust the visualization call
visualize_sample(train_dataset=None, test_dataset=test_dataset, sample_idx=np.random.choice(100), dataset_name='test')

2,000 videos times 4 frames/video = 8,000 unique images

## Pretrain Baseline Conv LSTM Model

In [None]:
# # --- Configuration Update (Reconfirmed) ---
# BASE_CHANNELS = 16 # 16 channels
# args.img_channels = 3 # 3 channels (FAF, Mask, Residual)

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the ConvLSTMBaseline model...")

# # Instantiate the baseline model (assumes Unet_Enc, Unet_Dec, ConvLSTMCore are defined)
# model_baseline = ConvLSTMBaseline(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(model_baseline.E1)
# p_lstm_params = count_parameters(model_baseline.P_LSTM) # ConvLSTM Core
# d1_params = count_parameters(model_baseline.D1)
# total_params = e1_params + p_lstm_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor", f"{e1_params:,}"],
#     ["ConvLSTMCore (P_LSTM)", "**Recurrent Temporal Core**", f"**{p_lstm_params:,}**"],
#     ["Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**ConvLSTMBaseline Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### ConvLSTMBaseline Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- INITIALIZATION AND HYPERPARAMETER SETUP ---

# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# optimizer_baseline = torch.optim.Adam(model_baseline.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_baseline = ReduceLROnPlateau(
#     optimizer_baseline, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# # LLR_WEIGHT and BOTTLENECK_L2_WEIGHT are correctly used below
# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(model_baseline)

# # --- BASELINE HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median)
# baseline_train_t1, baseline_train_t2, baseline_train_t3 = [], [], []
# baseline_test_t1, baseline_test_t2, baseline_test_t3 = [], [], []
# # Residual SDs
# baseline_train_sd_t1, baseline_train_sd_t2, baseline_train_sd_t3 = [], [], []
# baseline_test_sd_t1, baseline_test_sd_t2, baseline_test_sd_t3 = [], [], []

# # Mask Scores (Mean/Median)
# baseline_train_mask_t1, baseline_train_mask_t2, baseline_train_mask_t3 = [], [], []
# baseline_test_mask_t1, baseline_test_mask_t2, baseline_test_mask_t3 = [], [], []
# # Mask SDs
# baseline_train_mask_sd_t1, baseline_train_mask_sd_t2, baseline_train_mask_sd_t3 = [], [], []
# baseline_test_mask_sd_t1, baseline_test_mask_sd_t2, baseline_test_mask_sd_t3 = [], [], []


# print(f"\n Starting ConvLSTM Baseline Training for {args.num_epochs} epoch(s)...")

# # print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
# # current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
# # current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
# # print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
# current_train_data = None

# # --- TRAINING LOOP (100 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # NOTE: Replace 'generate_synthetic_faf_dataset' with your actual function.
#         # It must return a PyTorch tensor of shape (N_samples, 3, 4, 256, 256).
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device) # Ensure it's on device
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     # The training function uses the newly generated 'current_train_data'
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, model_baseline, optimizer_baseline, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, model_baseline, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, model_baseline, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation ---
#     # Residual Scores
#     baseline_train_t1.append(res_train_scores[0]); baseline_train_t2.append(res_train_scores[1]); baseline_train_t3.append(res_train_scores[2])
#     baseline_test_t1.append(res_test_scores[0]); baseline_test_t2.append(res_test_scores[1]); baseline_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     baseline_train_sd_t1.append(res_train_sds[0]); baseline_train_sd_t2.append(res_train_sds[1]); baseline_train_sd_t3.append(res_train_sds[2])
#     baseline_test_sd_t1.append(res_test_sds[0]); baseline_test_sd_t2.append(res_test_sds[1]); baseline_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     baseline_train_mask_t1.append(msk_train_scores[0]); baseline_train_mask_t2.append(msk_train_scores[1]); baseline_train_mask_t3.append(msk_train_scores[2])
#     baseline_test_mask_t1.append(msk_test_scores[0]); baseline_test_mask_t2.append(msk_test_scores[1]); baseline_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     baseline_train_mask_sd_t1.append(msk_train_sds[0]); baseline_train_mask_sd_t2.append(msk_train_sds[1]); baseline_train_mask_sd_t3.append(msk_train_sds[2])
#     baseline_test_mask_sd_t1.append(msk_test_sds[0]); baseline_test_mask_sd_t2.append(msk_test_sds[1]); baseline_test_mask_sd_t3.append(msk_test_sds[2]) # Corrected logic using test_sds

#     # --- 4. Scheduler & Logging ---
#     scheduler_baseline.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_baseline.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, model_baseline, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         baseline_train_t1, baseline_train_t2, baseline_train_t3,
#         baseline_test_t1, baseline_test_t2, baseline_test_t3,
#         baseline_train_sd_t1, baseline_train_sd_t2, baseline_train_sd_t3,
#         baseline_test_sd_t1, baseline_test_sd_t2, baseline_test_sd_t3,
#         plot_title='ConvLSTM Baseline Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         baseline_train_mask_t1, baseline_train_mask_t2, baseline_train_mask_t3,
#         baseline_test_mask_t1, baseline_test_mask_t2, baseline_test_mask_t3,
#         baseline_train_mask_sd_t1, baseline_train_mask_sd_t2, baseline_train_mask_sd_t3,
#         baseline_test_mask_sd_t1, baseline_test_mask_sd_t2, baseline_test_mask_sd_t3,
#         plot_title='ConvLSTM Baseline Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# # --- Setup for Plotting ---

# model = model_baseline

# ## Plot with soft DICE (Residual and Mask)
# soft_dice = True
# metric_type_str = "Soft Dice"

# (res_scores_test, msk_scores_test), _ = f_get_individual_dice(
#     test_dataset, model, args, is_train_set=False, soft_dice=soft_dice
# )

# # 2. Plot Residuals (Soft) - (Plotting logic remains correct)
# f_plot_individual_dice(res_scores_test, res_scores_test, metric_type_str, channel_name='Residual Mask')

# # 3. Plot Masks (Soft)
# f_plot_individual_dice(msk_scores_test, msk_scores_test, metric_type_str, channel_name='Full Mask')

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=model_baseline, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "ConvLSTM_baseline_pretrain"
    
# )

# # del model_baseline
# # torch.cuda.empty_cache()
# # gc.collect()

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "ConvLSTM_baseline_synthetic_pretrain_epoch50_20251104_163556.pth" 
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=model_baseline, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

## Pretrain Simple ConvLSTM

In [None]:
# --- Configuration Update (Reconfirmed) ---
BASE_CHANNELS = 16 # 16 channels
args.img_channels = 3 # 3 channels (FAF, Mask, Residual)

# Function to count trainable parameters (Provided in setup)
def count_parameters(model):
    """Counts the total number of trainable parameters in a PyTorch model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- Instantiation and Parameter Calculation ---

print("\nInstantiating the ConvLSTM_Simple model...")

# Instantiate the simple CNN-based model
# NOTE: This model uses CNN_Unet_Enc and CNN_Unet_Dec
model_simple = ConvLSTM_Simple(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# Calculate parameters for each main component
e1_params = count_parameters(model_simple.E1)
p_lstm_params = count_parameters(model_simple.P_LSTM) # ConvLSTM Core
d1_params = count_parameters(model_simple.D1)
total_params = e1_params + p_lstm_params + d1_params

# --- Create Table Data ---
param_data = [
    ["CNN_Unet_Enc (E1)", "Feature Extractor (Ablated CNN)", f"{e1_params:,}"],
    ["ConvLSTMCore (P_LSTM)", "**Recurrent Temporal Core**", f"**{p_lstm_params:,}**"],
    ["CNN_Unet_Dec (D1)", "Frame Reconstructor (Ablated CNN)", f"{d1_params:,}"],
    ["", "", ""], # Separator
    ["**TOTAL**", "**ConvLSTM_Simple Model**", f"**{total_params:,}**"],
]

# --- Print Table ---
print("\n### ConvLSTM_Simple Parameter Summary\n")
print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# --- INITIALIZATION AND HYPERPARAMETER SETUP ---

# --- DYNAMIC DATA GENERATION CONFIG ---
# Define the size of the synthetic training set to generate each time
NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
DATA_GENERATION_EPOCH_CYCLE = 5 
# ------------------------------------------

# HYPERPARAMETERS
args.num_epochs = 50
ACCUMULATION_STEPS = 8 
soft_dice = True # Use Soft Dice for stability
lr = 1E-3 # Initial LR

# Update: model_baseline replaced with model_simple
optimizer_simple = torch.optim.Adam(model_simple.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# Update: scheduler_baseline replaced with scheduler_simple
scheduler_simple = ReduceLROnPlateau(
    optimizer_simple, 
    mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
)

# Loss functions (Ensure these are correctly instantiated elsewhere)
loss_fn_bce = nn.BCELoss(reduction='mean')
loss_fn_l1 = nn.L1Loss(reduction='mean') 
loss_fn_l2 = nn.MSELoss(reduction='mean')
loss_fn_dice = dice_loss # This relies on your custom dice_loss function
loss_fn_gdl = GDLoss(alpha=1, beta=1)

# LLR_WEIGHT and BOTTLENECK_L2_WEIGHT are correctly used below
BOTTLENECK_L2_WEIGHT = 1e-6 

# Freeze Batch Norm layers (essential for small batches)
# Update: model_baseline replaced with model_simple
freeze_batch_norm(model_simple)

# --- BASELINE HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# Loss/Iteration Tracking
all_iteration_losses = [] 
epoch_iteration_counts = []

# Update: prefix changed from 'baseline' to 'simple'
# Residual Scores (Mean/Median)
simple_train_t1, simple_train_t2, simple_train_t3 = [], [], []
simple_test_t1, simple_test_t2, simple_test_t3 = [], [], []
# Residual SDs
simple_train_sd_t1, simple_train_sd_t2, simple_train_sd_t3 = [], [], []
simple_test_sd_t1, simple_test_sd_t2, simple_test_sd_t3 = [], [], []

# Mask Scores (Mean/Median)
simple_train_mask_t1, simple_train_mask_t2, simple_train_mask_t3 = [], [], []
simple_test_mask_t1, simple_test_mask_t2, simple_test_mask_t3 = [], [], []
# Mask SDs
simple_train_mask_sd_t1, simple_train_mask_sd_t2, simple_train_mask_sd_t3 = [], [], []
simple_test_mask_sd_t1, simple_test_mask_sd_t2, simple_test_mask_sd_t3 = [], [], []


# Update: Log message changed
print(f"\n Starting ConvLSTM Simple Training for {args.num_epochs} epoch(s)...")

current_train_data = None

# --- TRAINING LOOP (args.num_epochs) ---
for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
    # --- DYNAMIC DATA GENERATION ---
    if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
        print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
        # NOTE: Replace 'generate_synthetic_faf_dataset' with your actual function.
        # It must return a PyTorch tensor of shape (N_samples, 3, 4, 256, 256).
        current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
        current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device) # Ensure it's on device
        print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
    # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
    # Update: model_baseline and optimizer_baseline replaced with model_simple and optimizer_simple
    epoch_losses = f_single_epoch_spatiotemporal_accumulated(
        current_train_data, model_simple, optimizer_simple, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
        accumulation_steps=ACCUMULATION_STEPS, 
        lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
        lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
    )

    all_iteration_losses.extend(epoch_losses.tolist())
    epoch_iteration_counts.append(len(epoch_losses))
    mean_epoch_loss = np.mean(epoch_losses)
    
    # --- 2. Evaluation Step (Median/SD) ---
    # Update: model_baseline replaced with model_simple
    (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
        f_eval_pred_dice_test_set(test_loader, model_simple, args, soft_dice=soft_dice, use_median=True)
    (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
        f_eval_pred_dice_train_set(current_train_data, model_simple, args, args.batch_size, soft_dice=soft_dice, use_median=True)

    # --- 3. Accumulation ---
    # Update: Variable names changed from 'baseline' to 'simple'
    # Residual Scores
    simple_train_t1.append(res_train_scores[0]); simple_train_t2.append(res_train_scores[1]); simple_train_t3.append(res_train_scores[2])
    simple_test_t1.append(res_test_scores[0]); simple_test_t2.append(res_test_scores[1]); simple_test_t3.append(res_test_scores[2])
    # Residual SDs
    simple_train_sd_t1.append(res_train_sds[0]); simple_train_sd_t2.append(res_train_sds[1]); simple_train_sd_t3.append(res_train_sds[2])
    simple_test_sd_t1.append(res_test_sds[0]); simple_test_sd_t2.append(res_test_sds[1]); simple_test_sd_t3.append(res_test_sds[2])
    
    # Mask Scores
    simple_train_mask_t1.append(msk_train_scores[0]); simple_train_mask_t2.append(msk_train_scores[1]); simple_train_mask_t3.append(msk_train_scores[2])
    simple_test_mask_t1.append(msk_test_scores[0]); simple_test_mask_t2.append(msk_test_scores[1]); simple_test_mask_t3.append(msk_test_scores[2])
    # Mask SDs
    simple_train_mask_sd_t1.append(msk_train_sds[0]); simple_train_mask_sd_t2.append(msk_train_sds[1]); simple_train_mask_sd_t3.append(msk_train_sds[2])
    simple_test_mask_sd_t1.append(msk_test_sds[0]); simple_test_mask_sd_t2.append(msk_test_sds[1]); simple_test_mask_sd_t3.append(msk_test_sds[2])

    # --- 4. Scheduler & Logging ---
    # Update: scheduler_baseline replaced with scheduler_simple
    scheduler_simple.step(mean_epoch_loss)

    print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_simple.param_groups[0]['lr']:.2e}) ---")
    print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
    print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
    # --- Per-Epoch Visualizations ---
    print("\n--- Generating Per-Epoch Visualizations ---")
    
    # A. Plot Loss History
    plot_log_loss(all_iteration_losses, epoch_iteration_counts)

    # B. Plot Sample Prediction
    # Update: model_baseline replaced with model_simple
    f_display_frames(current_train_data, model_simple, args, sample_idx=20, T_total=4)
    
    # C. Plot Residual History
    # Update: Variable names and plot title changed
    plot_train_test_dice_history(
        simple_train_t1, simple_train_t2, simple_train_t3,
        simple_test_t1, simple_test_t2, simple_test_t3,
        simple_train_sd_t1, simple_train_sd_t2, simple_train_sd_t3,
        simple_test_sd_t1, simple_test_sd_t2, simple_test_sd_t3,
        plot_title='ConvLSTM Simple Residual Dice History (Median ± SD)'
    )

    # D. Plot Mask History
    # Update: Variable names and plot title changed
    plot_train_test_dice_history(
        simple_train_mask_t1, simple_train_mask_t2, simple_train_mask_t3,
        simple_test_mask_t1, simple_test_mask_t2, simple_test_mask_t3,
        simple_train_mask_sd_t1, simple_train_mask_sd_t2, simple_train_mask_sd_t3,
        simple_test_mask_sd_t1, simple_test_mask_sd_t2, simple_test_mask_sd_t3,
        plot_title='ConvLSTM Simple Full Mask Dice History (Median ± SD)'
    )

# --- Final Message ---
print("\n--- Training Complete ---")

In [None]:
ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
FINAL_EPOCH = args.num_epochs
saved_path = save_model_weights(
    model=model_simple,
    final_epoch=FINAL_EPOCH,
    save_dir=ckpt_save_dir,
    model_name = "ConvLSTM_simple_pretrain"
    
)

# del model_simple
# torch.cuda.empty_cache()
# gc.collect()

## Pretrain SWAU Net

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the full SWAU_Net model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net, which owns E1, CFB_enc, CFB_dec, SWA, P, and D1.
# swau_model = SWAU_Net(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_model.E1)

# # CORRECTED: Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(swau_model.CFB_enc) 
# cfb_dec_params = count_parameters(swau_model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# swa_params = count_parameters(swau_model.SWA) 
# p_params = count_parameters(swau_model.P)
# d1_params = count_parameters(swau_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"], # NEW LINE: Aggregate CFB
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator**", f"**{swa_params:,}**"], 
#     ["DynNet (P)", "Temporal Feature Predictor (M_t → Evolved_t)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full SWAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL, OPTIMIZER, SCHEDULER RENAMING ---
# # Assuming the instantiated model is now referred to as 'swau_model'
# # (The old 'model_baseline' reference will be replaced by 'swau_model' in function calls)
# optimizer_swau = torch.optim.Adam(swau_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_swau = ReduceLROnPlateau(
#     optimizer_swau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(swau_model) # Use swau_model

# # --- HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - RENAMED to swau_...
# swau_train_t1, swau_train_t2, swau_train_t3 = [], [], []
# swau_test_t1, swau_test_t2, swau_test_t3 = [], [], []
# # Residual SDs
# swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3 = [], [], []
# swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3 = [], [], []
# swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3 = [], [], []
# # Mask SDs
# swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3 = [], [], []
# swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting SWAU Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_model, optimizer_swau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, swau_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, swau_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     swau_train_t1.append(res_train_scores[0]); swau_train_t2.append(res_train_scores[1]); swau_train_t3.append(res_train_scores[2])
#     swau_test_t1.append(res_test_scores[0]); swau_test_t2.append(res_test_scores[1]); swau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     swau_train_sd_t1.append(res_train_sds[0]); swau_train_sd_t2.append(res_train_sds[1]); swau_train_sd_t3.append(res_train_sds[2])
#     swau_test_sd_t1.append(res_test_sds[0]); swau_test_sd_t2.append(res_test_sds[1]); swau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     swau_train_mask_t1.append(msk_train_scores[0]); swau_train_mask_t2.append(msk_train_scores[1]); swau_train_mask_t3.append(msk_train_scores[2])
#     swau_test_mask_t1.append(msk_test_scores[0]); swau_test_mask_t2.append(msk_test_scores[1]); swau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     swau_train_mask_sd_t1.append(msk_train_sds[0]); swau_train_mask_sd_t2.append(msk_train_sds[1]); swau_train_mask_sd_t3.append(msk_train_sds[2])
#     swau_test_mask_sd_t1.append(msk_test_sds[0]); swau_test_mask_sd_t2.append(msk_test_sds[1]); swau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     scheduler_swau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_swau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, swau_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         swau_train_t1, swau_train_t2, swau_train_t3,
#         swau_test_t1, swau_test_t2, swau_test_t3,
#         swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3,
#         swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3,
#         plot_title='SWAU Model Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3,
#         swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3,
#         swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3,
#         swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3,
#         plot_title='SWAU Model Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=swau_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "SWAU_synthetic_pretrain"
# )

# # # del swau_modelckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # # Load the model
# # MODEL_FILENAME = "SWAU_synthetic_pretrain_epoch50_20251104_185416.pth" 
# # MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# # loaded_model, loaded_epoch = load_model(
# #     model=swau_model, 
# #     model_path=MODEL_PATH, 
# #     device=args.device
# # )
# # # torch.cuda.empty_cache()
# # # gc.collect()

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "SWAU_synthetic_pretrain_epoch50_20251104_185416.pth" 
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=swau_model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

## Pretrain CFB Ablation

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# args.num_attn_layers = 2

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the full SWAU_Net model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net, which owns E1, CFB_enc, CFB_dec, SWA, P, and D1.
# swau_model = SWAU_CFB_Ablation(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_model.E1)

# swa_params = count_parameters(swau_model.SWA) 
# p_params = count_parameters(swau_model.P)
# d1_params = count_parameters(swau_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator**", f"**{swa_params:,}**"], 
#     ["DynNet (P)", "Temporal Feature Predictor (M_t → Evolved_t)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full SWAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL, OPTIMIZER, SCHEDULER RENAMING ---
# # Assuming the instantiated model is now referred to as 'swau_model'
# # (The old 'model_baseline' reference will be replaced by 'swau_model' in function calls)
# optimizer_swau = torch.optim.Adam(swau_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_swau = ReduceLROnPlateau(
#     optimizer_swau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(swau_model) # Use swau_model

# # --- HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - RENAMED to swau_...
# swau_train_t1, swau_train_t2, swau_train_t3 = [], [], []
# swau_test_t1, swau_test_t2, swau_test_t3 = [], [], []
# # Residual SDs
# swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3 = [], [], []
# swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3 = [], [], []
# swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3 = [], [], []
# # Mask SDs
# swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3 = [], [], []
# swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting SWAU Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_model, optimizer_swau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, swau_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, swau_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     swau_train_t1.append(res_train_scores[0]); swau_train_t2.append(res_train_scores[1]); swau_train_t3.append(res_train_scores[2])
#     swau_test_t1.append(res_test_scores[0]); swau_test_t2.append(res_test_scores[1]); swau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     swau_train_sd_t1.append(res_train_sds[0]); swau_train_sd_t2.append(res_train_sds[1]); swau_train_sd_t3.append(res_train_sds[2])
#     swau_test_sd_t1.append(res_test_sds[0]); swau_test_sd_t2.append(res_test_sds[1]); swau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     swau_train_mask_t1.append(msk_train_scores[0]); swau_train_mask_t2.append(msk_train_scores[1]); swau_train_mask_t3.append(msk_train_scores[2])
#     swau_test_mask_t1.append(msk_test_scores[0]); swau_test_mask_t2.append(msk_test_scores[1]); swau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     swau_train_mask_sd_t1.append(msk_train_sds[0]); swau_train_mask_sd_t2.append(msk_train_sds[1]); swau_train_mask_sd_t3.append(msk_train_sds[2])
#     swau_test_mask_sd_t1.append(msk_test_sds[0]); swau_test_mask_sd_t2.append(msk_test_sds[1]); swau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     scheduler_swau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_swau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, swau_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         swau_train_t1, swau_train_t2, swau_train_t3,
#         swau_test_t1, swau_test_t2, swau_test_t3,
#         swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3,
#         swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3,
#         plot_title='SWAU Model Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3,
#         swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3,
#         swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3,
#         swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3,
#         plot_title='SWAU Model Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=swau_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "CFB_Ablation_synthetic_pretrain"
# )

# # del swau_modelckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "SWAU_synthetic_pretrain_epoch50_20251104_185416.pth" 
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=swau_model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )
# # torch.cuda.empty_cache()
# # gc.collect()

## Pretrain DynNet Ablation

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Base channel width
# args.d_attn1 = 128 # Feed-forward dim for L3
# args.d_attn2 = 256 # Feed-forward dim for L4/L5

# args.num_attn_layers = 2 # Number of SWA layers

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the SWAU_DynNet_Ablation model with the updated configuration...")

# # Instantiate the ablation model and move it to the device
# # NOTE: The CFB components are now included inside E1 and D1 or explicitly defined.
# # Using the SWAU_Net_NoDynNet_Decoupled class.
# swau_dynnet_ablation_model = SWAU_DynNet_Ablation(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_dynnet_ablation_model.E1)
# cfb_enc_params = count_parameters(swau_dynnet_ablation_model.CFB_enc)
# cfb_dec_params = count_parameters(swau_dynnet_ablation_model.CFB_dec)
# swa_params = count_parameters(swau_dynnet_ablation_model.SWA)

# # DynNet is removed, so its parameter count is 0.
# p_params = 0 

# d1_params = count_parameters(swau_dynnet_ablation_model.D1)

# # Ensure all components are summed up for the total count (including CFB blocks)
# total_params = e1_params + cfb_enc_params + cfb_dec_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["CFB_enc", "Pre-Aggregation Channel Refinement", f"{cfb_enc_params:,}"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Estimator**", f"**{swa_params:,}**"],
#     ["DynNet (P)", "Temporal Evolution Module", f"**{p_params:,} (Removed)**"], # P_params = 0
#     ["CFB_dec", "Post-Aggregation Channel Refinement", f"{cfb_dec_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**SWAU_DynNet_Ablation Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_DynNet_Ablation Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL INSTANTIATION REFERENCE ---
# # NOTE: The model swau_dynnet_ablation_model is already defined in the preceding block.
# # We will treat this instance as the working 'swau_model' for the loop, 
# # although we must use the original instance variable name (swau_dynnet_ablation_model) 
# # for the optimizer and scheduler definitions to ensure continuity.

# # Renaming the existing instance for clarity within the loop context:
# swau_model = swau_dynnet_ablation_model

# # --- MODEL, OPTIMIZER, SCHEDULER ---
# optimizer_swau = torch.optim.Adam(swau_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# # NOTE: Assume ReduceLROnPlateau is defined elsewhere
# scheduler_swau = ReduceLROnPlateau(
#     optimizer_swau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# # NOTE: Assume dice_loss and GDLoss are defined elsewhere.
# loss_fn_dice = dice_loss 
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# # NOTE: Assume freeze_batch_norm is defined elsewhere.
# freeze_batch_norm(swau_model) 

# # --- HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - RENAMED to swau_...
# swau_train_t1, swau_train_t2, swau_train_t3 = [], [], []
# swau_test_t1, swau_test_t2, swau_test_t3 = [], [], []
# # Residual SDs
# swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3 = [], [], []
# swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3 = [], [], []
# swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3 = [], [], []
# # Mask SDs
# swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3 = [], [], []
# swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting SWAU Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# # NOTE: Assume tqdm, generate_synthetic_faf_dataset, f_single_epoch_spatiotemporal_accumulated, 
# # f_eval_pred_dice_test_set, f_eval_pred_dice_train_set, plot_log_loss, and f_display_frames are defined elsewhere.
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # NOTE: Assume generate_synthetic_faf_dataset is available
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_model, optimizer_swau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     # NOTE: Assume f_eval_pred_dice_* functions and test_loader are available
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, swau_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, swau_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     swau_train_t1.append(res_train_scores[0]); swau_train_t2.append(res_train_scores[1]); swau_train_t3.append(res_train_scores[2])
#     swau_test_t1.append(res_test_scores[0]); swau_test_t2.append(res_test_scores[1]); swau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     swau_train_sd_t1.append(res_train_sds[0]); swau_train_sd_t2.append(res_train_sds[1]); swau_train_sd_t3.append(res_train_sds[2])
#     swau_test_sd_t1.append(res_test_sds[0]); swau_test_sd_t2.append(res_test_sds[1]); swau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     swau_train_mask_t1.append(msk_train_scores[0]); swau_train_mask_t2.append(msk_train_scores[1]); swau_train_mask_t3.append(msk_train_scores[2])
#     swau_test_mask_t1.append(msk_test_scores[0]); swau_test_mask_t2.append(msk_test_scores[1]); swau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     swau_train_mask_sd_t1.append(msk_train_sds[0]); swau_train_mask_sd_t2.append(msk_train_sds[1]); swau_train_mask_sd_t3.append(msk_train_sds[2])
#     swau_test_mask_sd_t1.append(msk_test_sds[0]); swau_test_mask_sd_t2.append(msk_test_sds[1]); swau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     # NOTE: Assume scheduler_swau and plotting functions are available
#     scheduler_swau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_swau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, swau_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         swau_train_t1, swau_train_t2, swau_train_t3,
#         swau_test_t1, swau_test_t2, swau_test_t3,
#         swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3,
#         swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3,
#         plot_title='SWAU Model Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3,
#         swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3,
#         swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3,
#         swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3,
#         plot_title='SWAU Model Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=swau_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "DynNet_Ablation_synthetic_pretrain"
# )

## Pretrain RKAU Net

In [None]:
# ## Pretrain RKAU-Net

# # --- Configuration Update ---
# BASE_CHANNELS = 16 # Retaining C=16
# args.d_attn1 = 128 # Retaining d_attn1=128
# args.d_attn2 = 256 # Retaining d_attn2=256
# args.img_channels = 3
# args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# args.num_attn_layers = 2

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the FULL RKAU_Net model (E1 -> CFB -> RKA -> DynNet -> CFB -> D1)...")

# # Instantiate the full model and move it to the device
# rkau_model = RKAU_Net(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- Calculate parameters for each main component ---

# e1_params = count_parameters(rkau_model.E1)
# d1_params = count_parameters(rkau_model.D1)

# # RKA is the Feature Aggregation module
# rka_agg_params = count_parameters(rkau_model.RKA_Aggregator) 

# # P is the DynNet State Evolutionary Predictor
# p_params = count_parameters(rkau_model.P) 

# # CFBs (Enc and Dec)
# cfb_enc_params = count_parameters(rkau_model.CFB_enc)
# cfb_dec_params = count_parameters(rkau_model.CFB_dec)
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_enc_params + rka_agg_params + p_params + cfb_dec_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Spatial)", f"{e1_params:,}"],
#     ["CFB_enc", "Pre-Dynamics Mixer (Refinement)", f"{cfb_enc_params:,}"],
#     ["RKA Aggregator", "Temporal Aggregation Core (RKA)", f"**{rka_agg_params:,}**"],
#     ["DynNet (P)", "State Evolutionary Predictor", f"**{p_params:,}**"],
#     ["CFB_dec", "Post-Dynamics Mixer (Refinement)", f"{cfb_dec_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Spatial)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full RKAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### RKAU_Net Component Parameter Summary (Full Architecture)\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))


In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL, OPTIMIZER, SCHEDULER RENAMING ---
# # NOTE: Assuming 'rkau_model' is the instantiated RKAU_Net model object
# optimizer_rkau = torch.optim.Adam(rkau_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_rkau = ReduceLROnPlateau(
#     optimizer_rkau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss 
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(rkau_model) # Use rkau_model

# # --- HISTORY INITIALIZATION (RENAMED to rkau_...) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median)
# rkau_train_t1, rkau_train_t2, rkau_train_t3 = [], [], []
# rkau_test_t1, rkau_test_t2, rkau_test_t3 = [], [], []
# # Residual SDs
# rkau_train_sd_t1, rkau_train_sd_t2, rkau_train_sd_t3 = [], [], []
# rkau_test_sd_t1, rkau_test_sd_t2, rkau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# rkau_train_mask_t1, rkau_train_mask_t2, rkau_train_mask_t3 = [], [], []
# rkau_test_mask_t1, rkau_test_mask_t2, rkau_test_mask_t3 = [], [], []
# # Mask SDs
# rkau_train_mask_sd_t1, rkau_train_mask_sd_t2, rkau_train_mask_sd_t3 = [], [], []
# rkau_test_mask_sd_t1, rkau_test_mask_sd_t2, rkau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting RKAU Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # NOTE: Assuming generate_synthetic_faf_dataset function is available
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES) 
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using RKAU_Net's accumulated loss function) ---
#     # NOTE: Function name f_single_epoch_spatiotemporal_accumulated retained, but called with rkau_model/optimizer_rkau
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, rkau_model, optimizer_rkau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, rkau_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, rkau_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     rkau_train_t1.append(res_train_scores[0]); rkau_train_t2.append(res_train_scores[1]); rkau_train_t3.append(res_train_scores[2])
#     rkau_test_t1.append(res_test_scores[0]); rkau_test_t2.append(res_test_scores[1]); rkau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     rkau_train_sd_t1.append(res_train_sds[0]); rkau_train_sd_t2.append(res_train_sds[1]); rkau_train_sd_t3.append(res_train_sds[2])
#     rkau_test_sd_t1.append(res_test_sds[0]); rkau_test_sd_t2.append(res_test_sds[1]); rkau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     rkau_train_mask_t1.append(msk_train_scores[0]); rkau_train_mask_t2.append(msk_train_scores[1]); rkau_train_mask_t3.append(msk_train_scores[2])
#     rkau_test_mask_t1.append(msk_test_scores[0]); rkau_test_mask_t2.append(msk_test_scores[1]); rkau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     rkau_train_mask_sd_t1.append(msk_train_sds[0]); rkau_train_mask_sd_t2.append(msk_train_sds[1]); rkau_train_mask_sd_t3.append(msk_train_sds[2])
#     rkau_test_mask_sd_t1.append(msk_test_sds[0]); rkau_test_mask_sd_t2.append(msk_test_sds[1]); rkau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     scheduler_rkau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_rkau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, rkau_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         rkau_train_t1, rkau_train_t2, rkau_train_t3,
#         rkau_test_t1, rkau_test_t2, rkau_test_t3,
#         rkau_train_sd_t1, rkau_train_sd_t2, rkau_train_sd_t3,
#         rkau_test_sd_t1, rkau_test_sd_t2, rkau_test_sd_t3,
#         plot_title='RKAU Model Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         rkau_train_mask_t1, rkau_train_mask_t2, rkau_train_mask_t3,
#         rkau_test_mask_t1, rkau_test_mask_t2, rkau_test_mask_t3,
#         rkau_train_mask_sd_t1, rkau_train_mask_sd_t2, rkau_train_mask_sd_t3,
#         rkau_test_mask_sd_t1, rkau_test_mask_sd_t2, rkau_test_mask_sd_t3,
#         plot_title='RKAU Model Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=rkau_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "RKAU_synthetic_pretrain"
# )

# # ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # # Load the model
# # MODEL_FILENAME = "RKAU_synthetic_pretrain_epoch50_20251104_185416.pth" 
# # MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# # loaded_model, loaded_epoch = load_model(
# #     model=swau_model, 
# #     model_path=MODEL_PATH, 
# #     device=args.device
# # )

## Pretrain RKAU Fast Net

In [None]:
# ## Pretrain RKAU Fast Net

# # --- Configuration Update ---
# BASE_CHANNELS = 16 # Retaining C=16
# args.d_attn1 = 128 # Retaining d_attn1=128
# args.d_attn2 = 256 # Retaining d_attn2=256
# args.img_channels = 3
# args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# args.num_attn_layers = 2

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the FULL RKAU_Net model (E1 -> CFB -> RKA -> DynNet -> CFB -> D1)...")

# # Instantiate the full model and move it to the device
# rkau_fast_model = RKAU_Net_Fast(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- Calculate parameters for each main component ---

# e1_params = count_parameters(rkau_fast_model.E1)
# d1_params = count_parameters(rkau_fast_model.D1)

# # RKA is the Feature Aggregation module
# rka_agg_params = count_parameters(rkau_fast_model.RKA_Aggregator) 

# # P is the DynNet State Evolutionary Predictor
# p_params = count_parameters(rkau_fast_model.P) 

# # CFBs (Enc and Dec)
# cfb_enc_params = count_parameters(rkau_fast_model.CFB_enc)
# cfb_dec_params = count_parameters(rkau_fast_model.CFB_dec)
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_enc_params + rka_agg_params + p_params + cfb_dec_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Spatial)", f"{e1_params:,}"],
#     ["CFB_enc", "Pre-Dynamics Mixer (Refinement)", f"{cfb_enc_params:,}"],
#     ["RKA Aggregator", "Temporal Aggregation Core (RKA)", f"**{rka_agg_params:,}**"],
#     ["DynNet (P)", "State Evolutionary Predictor", f"**{p_params:,}**"],
#     ["CFB_dec", "Post-Dynamics Mixer (Refinement)", f"{cfb_dec_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Spatial)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full RKAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### RKAU_Net Component Parameter Summary (Full Architecture)\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL, OPTIMIZER, SCHEDULER RENAMING ---
# optimizer_rkau = torch.optim.Adam(rkau_fast_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_rkau = ReduceLROnPlateau(
#     optimizer_rkau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss 
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(rkau_fast_model)

# # --- HISTORY INITIALIZATION (RENAMED to rkau_...) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median)
# rkau_train_t1, rkau_train_t2, rkau_train_t3 = [], [], []
# rkau_test_t1, rkau_test_t2, rkau_test_t3 = [], [], []
# # Residual SDs
# rkau_train_sd_t1, rkau_train_sd_t2, rkau_train_sd_t3 = [], [], []
# rkau_test_sd_t1, rkau_test_sd_t2, rkau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# rkau_train_mask_t1, rkau_train_mask_t2, rkau_train_mask_t3 = [], [], []
# rkau_test_mask_t1, rkau_test_mask_t2, rkau_test_mask_t3 = [], [], []
# # Mask SDs
# rkau_train_mask_sd_t1, rkau_train_mask_sd_t2, rkau_train_mask_sd_t3 = [], [], []
# rkau_test_mask_sd_t1, rkau_test_mask_sd_t2, rkau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting RKAU Fast Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # NOTE: Assuming generate_synthetic_faf_dataset function is available
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES) 
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using RKAU_Net's accumulated loss function) ---
#     # NOTE: Function name f_single_epoch_spatiotemporal_accumulated retained, but called with rkau_fast_model/optimizer_rkau
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, rkau_fast_model, optimizer_rkau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, rkau_fast_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, rkau_fast_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     rkau_train_t1.append(res_train_scores[0]); rkau_train_t2.append(res_train_scores[1]); rkau_train_t3.append(res_train_scores[2])
#     rkau_test_t1.append(res_test_scores[0]); rkau_test_t2.append(res_test_scores[1]); rkau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     rkau_train_sd_t1.append(res_train_sds[0]); rkau_train_sd_t2.append(res_train_sds[1]); rkau_train_sd_t3.append(res_train_sds[2])
#     rkau_test_sd_t1.append(res_test_sds[0]); rkau_test_sd_t2.append(res_test_sds[1]); rkau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     rkau_train_mask_t1.append(msk_train_scores[0]); rkau_train_mask_t2.append(msk_train_scores[1]); rkau_train_mask_t3.append(msk_train_scores[2])
#     rkau_test_mask_t1.append(msk_test_scores[0]); rkau_test_mask_t2.append(msk_test_scores[1]); rkau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     rkau_train_mask_sd_t1.append(msk_train_sds[0]); rkau_train_mask_sd_t2.append(msk_train_sds[1]); rkau_train_mask_sd_t3.append(msk_train_sds[2])
#     rkau_test_mask_sd_t1.append(msk_test_sds[0]); rkau_test_mask_sd_t2.append(msk_test_sds[1]); rkau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     scheduler_rkau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_rkau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, rkau_fast_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         rkau_train_t1, rkau_train_t2, rkau_train_t3,
#         rkau_test_t1, rkau_test_t2, rkau_test_t3,
#         rkau_train_sd_t1, rkau_train_sd_t2, rkau_train_sd_t3,
#         rkau_test_sd_t1, rkau_test_sd_t2, rkau_test_sd_t3,
#         plot_title='RKAU Model Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         rkau_train_mask_t1, rkau_train_mask_t2, rkau_train_mask_t3,
#         rkau_test_mask_t1, rkau_test_mask_t2, rkau_test_mask_t3,
#         rkau_train_mask_sd_t1, rkau_train_mask_sd_t2, rkau_train_mask_sd_t3,
#         rkau_test_mask_sd_t1, rkau_test_mask_sd_t2, rkau_test_mask_sd_t3,
#         plot_title='RKAU Model Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

## Axial U-Net Baseline

In [None]:
# # --- Configuration Update ---
# BASE_CHANNELS = 16 # Retaining C=16
# args.d_attn1 = 128 # Retaining d_attn1=128
# args.d_attn2 = 256 # Retaining d_attn2=256
# args.img_channels = 3
# args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# args.num_attn_layers = 2 # N=2 stacked layers for Axial Integrators

# print("\nInstantiating the AXIALU_NET Baseline (N=2 Layers)...")

# # Instantiate the full model
# axialu_model = AxialU_Net(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- Calculate parameters for each main component ---

# e1_params = count_parameters(axialu_model.E1)
# d1_params = count_parameters(axialu_model.D1)

# # Axial_Aggregator is the feature aggregation module using Standard Axial Attention
# axial_agg_params = count_parameters(axialu_model.Axial_Aggregator) 

# # P is the DynNet State Evolutionary Predictor
# p_params = count_parameters(axialu_model.P) 

# # CFBs (Enc and Dec)
# cfb_enc_params = count_parameters(axialu_model.CFB_enc)
# cfb_dec_params = count_parameters(axialu_model.CFB_dec)
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_enc_params + axial_agg_params + p_params + cfb_dec_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Spatial)", f"{e1_params:,}"],
#     ["CFB_enc", "Pre-Dynamics Mixer (Refinement)", f"{cfb_enc_params:,}"],
#     ["Axial Aggregator", "Temporal Aggregation Core (Standard Axial)", f"**{axial_agg_params:,}**"],
#     ["DynNet (P)", "State Evolutionary Predictor", f"**{p_params:,}**"],
#     ["CFB_dec", "Post-Dynamics Mixer (Refinement)", f"{cfb_dec_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Spatial)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full AxialU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### AxialU_Net Baseline Parameter Summary (N=2 Stacked Layers)\n")
# # Using the tabulate function as requested in the context
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Samples per dynamic training set
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS (Matched to RKAU-Net for comparison)
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR
# BOTTLENECK_L2_WEIGHT = 1e-6 

# # --- MODEL, OPTIMIZER, SCHEDULER (Renamed for AxialU-Net) ---
# optimizer_axialu = torch.optim.Adam(axialu_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_axialu = ReduceLROnPlateau(
#     optimizer_axialu, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Assuming they are correctly defined/imported)
# # These are only used internally within the utility functions.
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss 
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss(reduction='mean')
# loss_fn_l2 = nn.MSELoss(reduction='mean')

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(axialu_model) 

# # --- HISTORY INITIALIZATION (Renamed to axialu_...) ---

# # Loss/Iteration Tracking
# axialu_all_iteration_losses = [] 
# axialu_epoch_iteration_counts = []

# # Residual Scores (Mean/Median)
# axialu_train_t1, axialu_train_t2, axialu_train_t3 = [], [], []
# axialu_test_t1, axialu_test_t2, axialu_test_t3 = [], [], []
# # Residual SDs
# axialu_train_sd_t1, axialu_train_sd_t2, axialu_train_sd_t3 = [], [], []
# axialu_test_sd_t1, axialu_test_sd_t2, axialu_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# axialu_train_mask_t1, axialu_train_mask_t2, axialu_train_mask_t3 = [], [], []
# axialu_test_mask_t1, axialu_test_mask_t2, axialu_test_mask_t3 = [], [], []
# # Mask SDs
# axialu_train_mask_sd_t1, axialu_train_mask_sd_t2, axialu_train_mask_sd_t3 = [], [], []
# axialu_test_mask_sd_t1, axialu_test_mask_sd_t2, axialu_test_mask_sd_t3 = [], [], []


# print(f"\n Starting AxialU_Net Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- 1. DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # NOTE: Assuming generate_synthetic_faf_dataset function is available and returns a NumPy array
#         # This function generates the synthetic data on the CPU
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES) 
#         # Move the entire synthetic dataset to the device (e.g., CUDA)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 2. Training Step (Using Accumulated Loss Function) ---
#     # f_single_epoch_spatiotemporal_accumulated handles batching, augmentation, forward pass, and gradient steps
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, axialu_model, optimizer_axialu, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     axialu_all_iteration_losses.extend(epoch_losses.tolist())
#     axialu_epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 3. Evaluation Step (Median/SD) ---
#     # Evaluate the current state of the model on the test set
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, axialu_model, args, soft_dice=soft_dice, use_median=True)
    
#     # Evaluate the current state of the model on the fresh synthetic training data
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, axialu_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 4. Accumulation (History Tracking) ---
#     # Residual Scores
#     axialu_train_t1.append(res_train_scores[0]); axialu_train_t2.append(res_train_scores[1]); axialu_train_t3.append(res_train_scores[2])
#     axialu_test_t1.append(res_test_scores[0]); axialu_test_t2.append(res_test_scores[1]); axialu_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     axialu_train_sd_t1.append(res_train_sds[0]); axialu_train_sd_t2.append(res_train_sds[1]); axialu_train_sd_t3.append(res_train_sds[2])
#     axialu_test_sd_t1.append(res_test_sds[0]); axialu_test_sd_t2.append(res_test_sds[1]); axialu_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     axialu_train_mask_t1.append(msk_train_scores[0]); axialu_train_mask_t2.append(msk_train_scores[1]); axialu_train_mask_t3.append(msk_train_scores[2])
#     axialu_test_mask_t1.append(msk_test_scores[0]); axialu_test_mask_t2.append(msk_test_scores[1]); axialu_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     axialu_train_mask_sd_t1.append(msk_train_sds[0]); axialu_train_mask_sd_t2.append(msk_train_sds[1]); axialu_train_mask_sd_t3.append(msk_train_sds[2])
#     axialu_test_mask_sd_t1.append(msk_test_sds[0]); axialu_test_mask_sd_t2.append(msk_test_sds[1]); axialu_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 5. Scheduler & Logging ---
#     scheduler_axialu.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_axialu.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(axialu_all_iteration_losses, axialu_epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     # NOTE: current_train_data is the full dataset on device, f_display_frames expects the full dataset tensor
#     f_display_frames(current_train_data, axialu_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         axialu_train_t1, axialu_train_t2, axialu_train_t3,
#         axialu_test_t1, axialu_test_t2, axialu_test_t3,
#         axialu_train_sd_t1, axialu_train_sd_t2, axialu_train_sd_t3,
#         axialu_test_sd_t1, axialu_test_sd_t2, axialu_test_sd_t3,
#         plot_title='AxialU_Net Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         axialu_train_mask_t1, axialu_train_mask_t2, axialu_train_mask_t3,
#         axialu_test_mask_t1, axialu_test_mask_t2, axialu_test_mask_t3,
#         axialu_train_mask_sd_t1, axialu_train_mask_sd_t2, axialu_train_mask_sd_t3,
#         axialu_test_mask_sd_t1, axialu_test_mask_sd_t2, axialu_test_mask_sd_t3,
#         plot_title='AxialU_Net Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- AxialU_Net Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=axialu_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "AxialU_Net_pretrain"
# )

## Ablate Spatial Attention

In [None]:
# ## Ablate Spatial Attention


# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the **SWAU_Net-CNN Ablation** model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net_CNN, which uses CNN_Unet_Enc/Dec and CNN_DynNet.
# swau_model = SWAU_Net_CNN(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_model.E1)

# # CORRECTED: Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(swau_model.CFB_enc) 
# cfb_dec_params = count_parameters(swau_model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# swa_params = count_parameters(swau_model.SWA) 
# p_params = count_parameters(swau_model.P)
# d1_params = count_parameters(swau_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["CNN_Unet_Enc (E1)", "Feature Extractor (No Spatial Attention)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator (Temporal Axial Only)**", f"**{swa_params:,}**"], 
#     ["CNN_DynNet (P)", "Temporal Feature Predictor (No Attention)", f"{p_params:,}"],
#     ["CNN_Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**SWAU_Net-CNN Ablation**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net-CNN Ablation Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL INSTANTIATION (MODIFIED TO SWAU_Net_CNN) ---
# # NOTE: BASE_CHANNELS, args.img_channels, args.device, etc., assumed defined globally.
# # Instantiate SWAU_Net_CNN to ablate all spatial/DynNet attention
# swau_model = SWAU_Net_CNN(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- OPTIMIZER, SCHEDULER RENAMING ---
# # (The old 'model_baseline' reference will be replaced by 'swau_model' in function calls)
# optimizer_swau = torch.optim.Adam(swau_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_swau = ReduceLROnPlateau(
#     optimizer_swau, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(swau_model) # Use swau_model

# # --- HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - RENAMED to swau_...
# swau_train_t1, swau_train_t2, swau_train_t3 = [], [], []
# swau_test_t1, swau_test_t2, swau_test_t3 = [], [], []
# # Residual SDs
# swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3 = [], [], []
# swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3 = [], [], []
# swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3 = [], [], []
# # Mask SDs
# swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3 = [], [], []
# swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3 = [], [], []


# print(f"\n Starting **SWAU_Net-CNN** Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_model, optimizer_swau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, swau_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, swau_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED for consistency) ---
#     # Residual Scores
#     swau_train_t1.append(res_train_scores[0]); swau_train_t2.append(res_train_scores[1]); swau_train_t3.append(res_train_scores[2])
#     swau_test_t1.append(res_test_scores[0]); swau_test_t2.append(res_test_scores[1]); swau_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     swau_train_sd_t1.append(res_train_sds[0]); swau_train_sd_t2.append(res_train_sds[1]); swau_train_sd_t3.append(res_train_sds[2])
#     swau_test_sd_t1.append(res_test_sds[0]); swau_test_sd_t2.append(res_test_sds[1]); swau_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     swau_train_mask_t1.append(msk_train_scores[0]); swau_train_mask_t2.append(msk_train_scores[1]); swau_train_mask_t3.append(msk_train_scores[2])
#     swau_test_mask_t1.append(msk_test_scores[0]); swau_test_mask_t2.append(msk_test_scores[1]); swau_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     swau_train_mask_sd_t1.append(msk_train_sds[0]); swau_train_mask_sd_t2.append(msk_train_sds[1]); swau_train_mask_sd_t3.append(msk_train_sds[2])
#     swau_test_mask_sd_t1.append(msk_test_sds[0]); swau_test_mask_sd_t2.append(msk_test_sds[1]); swau_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     scheduler_swau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_swau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, swau_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         swau_train_t1, swau_train_t2, swau_train_t3,
#         swau_test_t1, swau_test_t2, swau_test_t3,
#         swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3,
#         swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3,
#         plot_title='SWAU_Net-CNN Residual Dice History (Median ± SD)' # MODIFIED TITLE
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3,
#         swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3,
#         swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3,
#         swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3,
#         plot_title='SWAU_Net-CNN Full Mask Dice History (Median ± SD)' # MODIFIED TITLE
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=swau_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "SWAU_CNN_pretrain"
    
# )

# # del model_baseline
# # torch.cuda.empty_cache()
# # gc.collect()

## Ablate Spatiotemporal Attention

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the **UPredNet (SWA Ablation)** model with the updated configuration...")

# # Instantiate the UPredNet model (UPredNet uses E1, CFB_enc, P=DynNet, CFB_dec, D1)
# # We assume UPredNet is accessible, and base_channels defaults to 16 if not set.
# upred_model = UPredNet(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(upred_model.E1)

# # Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(upred_model.CFB_enc) 
# cfb_dec_params = count_parameters(upred_model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # The UPredNet model does NOT have an 'SWA' module. This parameter should be 0.
# swa_params = 0 

# # P is the DynNet
# p_params = count_parameters(upred_model.P)
# d1_params = count_parameters(upred_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (No Spatial Attention)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"],
#     ["**SWA Module**", "**Ablated**", f"**{swa_params:,}**"], # SWA is 0
#     ["DynNet (P)", "Temporal Feature Predictor (Evolution)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**UPredNet (SWA Ablation)**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### UPredNet (SWA Ablation) Component Parameter Summary\n")
# # Assuming the 'tabulate' library and 'BASE_CHANNELS' constant are available in the execution environment
# try:
#     from tabulate import tabulate
#     print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))
# except ImportError:
#     print("Tabulate library not available. Printing raw data.")
#     for row in param_data:
#         print(row)

In [None]:
# # --- DYNAMIC DATA GENERATION CONFIG ---
# # Define the size of the synthetic training set to generate each time
# NUM_TRAIN_SAMPLES = 100 # Example size, adjust as needed
# DATA_GENERATION_EPOCH_CYCLE = 5 
# # ------------------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 50
# ACCUMULATION_STEPS = 8 
# soft_dice = True # Use Soft Dice for stability
# lr = 1E-3 # Initial LR

# # --- MODEL INSTANTIATION (MODIFIED TO UPredNet - SWA Ablation) ---
# # NOTE: BASE_CHANNELS, args.img_channels, args.device, etc., assumed defined globally.

# # Instantiate UPredNet to ablate all explicit Spatio-Temporal Attention (SWA/Axial)
# uprednet_model = UPredNet(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- OPTIMIZER, SCHEDULER RENAMING (Correcting 'swau' to 'uprednet' references) ---
# # NOTE: The provided code had 'uprednet.parameters()' which is incorrect. Using the instantiated model name.
# optimizer_uprednet = torch.optim.Adam(uprednet_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_uprednet = ReduceLROnPlateau(
#     optimizer_uprednet, 
#     mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(uprednet_model) # Use uprednet_model

# # --- HISTORY INITIALIZATION (REVISING ALL 'swau' REFERENCES TO 'uprednet') ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - Corrected Naming
# uprednet_train_t1, uprednet_train_t2, uprednet_train_t3 = [], [], []
# uprednet_test_t1, uprednet_test_t2, uprednet_test_t3 = [], [], []
# # Residual SDs
# uprednet_train_sd_t1, uprednet_train_sd_t2, uprednet_train_sd_t3 = [], [], []
# uprednet_test_sd_t1, uprednet_test_sd_t2, uprednet_test_sd_t3 = [], [], []
# # Mask Scores (Mean/Median)
# uprednet_train_mask_t1, uprednet_train_mask_t2, uprednet_train_mask_t3 = [], [], []
# uprednet_test_mask_t1, uprednet_test_mask_t2, uprednet_test_mask_t3 = [], [], []
# # Mask SDs
# uprednet_train_mask_sd_t1, uprednet_train_mask_sd_t2, uprednet_train_mask_sd_t3 = [], [], []
# uprednet_test_mask_sd_t1, uprednet_test_mask_sd_t2, uprednet_test_mask_sd_t3 = [], [], []


# print(f"\n Starting **UPredNet (SWA Ablation)** Model Training for {args.num_epochs} epoch(s)...")

# current_train_data = None

# # --- TRAINING LOOP (50 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- DYNAMIC DATA GENERATION ---
#     if epoch % DATA_GENERATION_EPOCH_CYCLE == 0:
#         print(f"\n[Epoch {epoch+1}] Regenerating fresh synthetic training data...")
#         # Assuming generate_synthetic_faf_dataset and current_train_data_np are defined elsewhere
#         current_train_data_np = generate_synthetic_faf_dataset(num_samples=NUM_TRAIN_SAMPLES)
#         current_train_data = torch.from_numpy(current_train_data_np.astype(np.float32)).to(args.device)
#         print(f"Data generation complete. New training set size: {current_train_data.shape[0]} samples.")
    
#     # --- 1. Training Step (Using UPredNet's accumulated loss function) ---
#     # NOTE: Model and optimizer references are corrected from 'swau' to 'uprednet'
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, uprednet_model, optimizer_uprednet, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=1e-2, lambda_faf=0.5, lambda_mask=2.0, lambda_residual=5.0, 
#         lambda_recon=0.5, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     # NOTE: Model reference is corrected
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, uprednet_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, uprednet_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (REVISED to use 'uprednet' history variables) ---
#     # Residual Scores
#     uprednet_train_t1.append(res_train_scores[0]); uprednet_train_t2.append(res_train_scores[1]); uprednet_train_t3.append(res_train_scores[2])
#     uprednet_test_t1.append(res_test_scores[0]); uprednet_test_t2.append(res_test_scores[1]); uprednet_test_t3.append(res_test_scores[2])
#     # Residual SDs
#     uprednet_train_sd_t1.append(res_train_sds[0]); uprednet_train_sd_t2.append(res_train_sds[1]); uprednet_train_sd_t3.append(res_train_sds[2])
#     uprednet_test_sd_t1.append(res_test_sds[0]); uprednet_test_sd_t2.append(res_test_sds[1]); uprednet_test_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     uprednet_train_mask_t1.append(msk_train_scores[0]); uprednet_train_mask_t2.append(msk_train_scores[1]); uprednet_train_mask_t3.append(msk_train_scores[2])
#     uprednet_test_mask_t1.append(msk_test_scores[0]); uprednet_test_mask_t2.append(msk_test_scores[1]); uprednet_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     uprednet_train_mask_sd_t1.append(msk_train_sds[0]); uprednet_train_mask_sd_t2.append(msk_train_sds[1]); uprednet_train_mask_sd_t3.append(msk_train_sds[2])
#     uprednet_test_mask_sd_t1.append(msk_test_sds[0]); uprednet_test_mask_sd_t2.append(msk_test_sds[1]); uprednet_test_mask_sd_t3.append(msk_test_sds[2])

#     # --- 4. Scheduler & Logging ---
#     # NOTE: Scheduler reference is corrected
#     scheduler_uprednet.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_uprednet.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     # NOTE: Model reference is corrected
#     f_display_frames(current_train_data, uprednet_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     # NOTE: History variables and plot title are corrected
#     plot_train_test_dice_history(
#         uprednet_train_t1, uprednet_train_t2, uprednet_train_t3,
#         uprednet_test_t1, uprednet_test_t2, uprednet_test_t3,
#         uprednet_train_sd_t1, uprednet_train_sd_t2, uprednet_train_sd_t3,
#         uprednet_test_sd_t1, uprednet_test_sd_t2, uprednet_test_sd_t3,
#         plot_title='UPredNet (SWA Ablation) Residual Dice History (Median ± SD)' # MODIFIED TITLE
#     )

#     # D. Plot Mask History
#     # NOTE: History variables and plot title are corrected
#     plot_train_test_dice_history(
#         uprednet_train_mask_t1, uprednet_train_mask_t2, uprednet_train_mask_t3,
#         uprednet_test_mask_t1, uprednet_test_mask_t2, uprednet_test_mask_t3,
#         uprednet_train_mask_sd_t1, uprednet_train_mask_sd_t2, uprednet_train_mask_sd_t3,
#         uprednet_test_mask_sd_t1, uprednet_test_mask_sd_t2, uprednet_test_mask_sd_t3,
#         plot_title='UPredNet (Spatiotemporal Ablation) Full Mask Dice History (Median ± SD)' # MODIFIED TITLE
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# FINAL_EPOCH = args.num_epochs
# saved_path = save_model_weights(
#     model=uprednet_model, 
#     final_epoch=FINAL_EPOCH, 
#     save_dir=ckpt_save_dir,
#     model_name = "UPredNet_pretrain"
    
# )

# # del model_baseline
# # torch.cuda.empty_cache()
# # gc.collect()