## Setup

In [None]:
## Setup

%cd /Users/Pracioppo/Desktop/GA Forecasting

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import functools
from torch.nn import init
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau 

import numpy as np
import argparse
import random
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from pathlib import Path
from datetime import datetime
import scipy.io as sio
from PIL import Image
import cv2
from tabulate import tabulate
from collections import defaultdict
import pickle

from torch.utils.data import Dataset, DataLoader, random_split, Subset 

# ---------------------------------------

from preprocessing_utils import f_rescale_dataset, f_Residuals, f_reshape_training_data, f_rotate_and_zoom, f_random_crop, f_rotate_and_zoom_all, f_crop_all, f_flip_all, f_augment_dataset2

from data_utils import DataWrapper, visualize_sample, compare_split_masks

from visualization import f_display_autoencoder, plot_log_loss, f_display_frames

from models import init_weights, count_parameters
from models import rotate_half, RotaryPositionalEmbedding, RoPEMultiheadAttention, RoPETransformerEncoderLayer, ResidualBlock, ChannelReducer, Unet_Enc, Unet_Dec, U_Net_AE

from augmentation_utils import f_augment_spatial_and_intensity
from training_utils import dsc, dice_loss, GDLoss
from training_utils import freeze_batch_norm, f_single_epoch_AE, f_single_epoch_spatiotemporal, calculate_total_loss, f_single_epoch_spatiotemporal_accumulated
from training_utils import save_model_weights, load_model
from training_utils import save_final_experiment_data
from eval_utils import f_eval_pred_dice_test_set, f_eval_pred_dice_train_set, plot_train_test_dice_history, soft_dice_score, f_get_individual_dice, f_plot_individual_dice
from eval_utils import analyze_kfold_results, corrected_paired_ttest_nadeau, paired_ttest_student, compare_models_nadeau_bengio, compare_models

# from models import TemporalDeltaBlock
from models import DynNet, CausalConvAggregator, UPredNet
from models import FusionBlockBottleneck, ChannelFusionBlock, LocalSpatioTemporalMixer, SpatioTemporalGatedMixer, AxialTemporalSWAInterleavedLayer, InterleavedAxialTemporalSWAIntegrator, FullGlobalSWAIntegrator, SlidingWindowAttention, SWAU_Net, SWAU_CFB_Ablation, SWAU_DynNet_Ablation
from models import ConvLSTMCell, ConvLSTMCore, ConvLSTMBaseline, ConvLSTM_Simple
from models import create_causal_mask, create_block_causal_mask, RKA_MultiheadAttention, AxialTemporalRKAInterleavedLayer, plot_attention_matrix
from models import AxialMultiheadAttention, StandardAxialInterleavedLayer, StandardAxialIntegrator, AxialU_Net
from models import CNN_Unet_Enc, CNN_Unet_Dec, CNN_U_Net_AE, CNN_DynNet, SWAU_Net_CNN

# ---------------------------------------

ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')
tensorboard_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')

start_epoch = 0

resume_ckpt = None

summary_writer = SummaryWriter(tensorboard_save_dir.absolute().as_posix())


# --- SETUP ---
# Define a simple placeholder for command-line arguments and configuration
parser = argparse.ArgumentParser('AE Model Args')
args = parser.parse_args(args=[])

# Defining essential arguments (set to match your intended lightweight AE setup)
args.N = 4                       # CRITICAL FIX: Batch size is 4 for memory safety
args.nhead = 4                   # CRITICAL FIX: Reduced heads from 8 to 4
args.d_attn1 = 192               # FFN dimension for L3 (112 channels)
args.d_attn2 = 384               # FFN dimension for L4 (224 channels)
args.img_channels = 3            # Three grayscale images (FAF, masks, growth masks)
args.img_sz = 256                # Image size 256x256
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BASE_CHANNELS = 24               # Reduced from 28 to 24 for extra parameter savings

# Training loop arguments
args.num_epochs = 1             # UPDATED: Set to 1 epoch as requested
args.show_example_epochs = 5
args.batch_size = args.N        # Batch size for iteration is args.N
args.num_t_steps = 4            # Time steps (used only for data simulation/flattening)

print(torch.__version__)

torch.cuda.is_available()

args.device = torch.device('cuda:0')
print(f"Using {args.device} device")

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(2025)

## Load the Data

In [None]:
# --- Data Loading and Global Preprocessing ---

# WARNING: Replace this with your actual data directory
DATA_DIR = Path('/Users/Pracioppo/Desktop/GA Forecasting/GA_Zubens_data') 
print("Loading data...")

FAFS_PATH = DATA_DIR / 'fafs_reg3.mat'
MASKS_PATH = DATA_DIR / 'masks_reg3.mat'

FAFs = torch.from_numpy(sio.loadmat(FAFS_PATH)["fafs_reg3"].astype(np.float32)).reshape(660,1,4,256,256)
masks = torch.from_numpy(sio.loadmat(MASKS_PATH)["masks_reg3"].astype(np.float32)).reshape(660,1,4,256,256)

# Global normalization
masks /= torch.max(masks) if torch.max(masks) > 0 else 1.0
FAFs /= torch.max(FAFs) if torch.max(FAFs) > 0 else 1.0

# Residual calculation and normalization (applied globally to all 660 samples)
all_residuals = f_Residuals(masks)

# --- 3. K-Fold Split using Subsets ---

# Instantiate the Dataset with all pre-processed feature tensors
all_data = DataWrapper(FAFs, masks, all_residuals)
print(f"\nTotal Dataset size (samples): {len(all_data)}")

In [None]:
## WITH DATA AUGMENTATIONS

# K-FOLD GROUP SETUP (To avoid data leakage from augmented samples)
SAMPLES_PER_GROUP = 10
N_SAMPLES = len(all_data) # N_SAMPLES = 660
N_GROUPS = N_SAMPLES // SAMPLES_PER_GROUP # N_GROUPS = 66
K_FOLDS = 5 

# Calculate fold sizes based on groups (66 groups in 5 folds: 4 folds of 13 groups, 1 fold of 14 groups)
group_fold_size_base = N_GROUPS // K_FOLDS # 13
group_fold_remainder = N_GROUPS % K_FOLDS # 1

k = 4 # The current fold index 

# Calculate group indices for the current fold (k=0 gets the larger remainder group)
G_all_indices = np.arange(N_GROUPS)

# Determine the size of the first 'remainder' folds
# For k=0, this will be 13 + 1 = 14 groups
current_group_fold_size = group_fold_size_base + (1 if k < group_fold_remainder else 0) 

G_start_idx = k * group_fold_size_base + min(k, group_fold_remainder) # 0
G_end_idx = G_start_idx + current_group_fold_size # 14

G_test_indices = G_all_indices[G_start_idx:G_end_idx]
G_train_indices = np.concatenate([G_all_indices[:G_start_idx], G_all_indices[G_end_idx:]])

# Map Group indices back to Sample indices (0-659)
def map_group_to_sample_indices(group_indices, samples_per_group):
    # For each group index i, generate indices [i*10, i*10 + 1, ..., i*10 + 9]
    # Using np.concatenate for efficiency over Python loops for large arrays
    sample_indices = np.concatenate([
        np.arange(g * samples_per_group, (g + 1) * samples_per_group)
        for g in group_indices
    ])
    return sample_indices

train_indices = map_group_to_sample_indices(G_train_indices, SAMPLES_PER_GROUP)
test_indices = map_group_to_sample_indices(G_test_indices, SAMPLES_PER_GROUP)

sz = len(test_indices) # Should be 140 for k=0

# Create Subsets for training and testing
train_dataset = Subset(all_data, train_indices)
test_dataset = Subset(all_data, test_indices)

print("--- K-Fold Split Result (Group-Aware) ---")
print(f"Total Groups: {N_GROUPS}, Samples/Group: {SAMPLES_PER_GROUP}")
print(f"Fold {k} Groups: {len(G_test_indices)} (Fold Size: {sz})")
print(f"Train Dataset Samples: {len(train_dataset)} (Expected 520)")
print(f"Test Dataset Samples: {len(test_dataset)} (Expected 140)")

In [None]:
# # REMOVE ALL DATA AUGMENTATIONS

# # --- Original Data Parameters (DO NOT CHANGE) ---
# SAMPLES_PER_GROUP = 10
# N_SAMPLES = 660 # Total dataset size
# N_GROUPS = 66   # Total number of unique patients/original images
# K_FOLDS = 5

# # Instantiate the Dataset with all pre-processed feature tensors (660 samples)
# # all_data = DataWrapper(FAFs, masks, all_residuals) # Keep your original DataWrapper setup

# # --- Identify and Map to Original 66 Samples ---
# # The original 66 images are the first sample in each group of 10.
# # Indices: 0, 10, 20, ..., 650
# original_sample_indices = np.arange(0, N_SAMPLES, SAMPLES_PER_GROUP)
# # len(original_sample_indices) is 66

# # --- K-Fold Group Setup (Preserving the Original Logic) ---
# k = 4 # The current fold index (Example: k=2)

# # Calculate fold sizes based on groups (66 groups in 5 folds: 4 folds of 13 groups, 1 fold of 14 groups)
# group_fold_size_base = N_GROUPS // K_FOLDS # 13
# group_fold_remainder = N_GROUPS % K_FOLDS # 1

# # Calculate group indices for the current fold (k=0 gets the larger remainder group)
# G_all_indices = np.arange(N_GROUPS) # [0, 1, ..., 65]

# # Determine the size of the current fold (13 or 14 groups)
# current_group_fold_size = group_fold_size_base + (1 if k < group_fold_remainder else 0)

# # Calculate start and end index for the test group indices
# G_start_idx = k * group_fold_size_base + min(k, group_fold_remainder)
# G_end_idx = G_start_idx + current_group_fold_size

# G_test_indices = G_all_indices[G_start_idx:G_end_idx]
# G_train_indices = np.concatenate([G_all_indices[:G_start_idx], G_all_indices[G_end_idx:]])

# # --- Map Group Indices to Sample Indices (THE MODIFICATION) ---

# def map_group_to_repeated_original_sample_indices(group_indices, samples_per_group):
#     """
#     Maps group indices (0-65) to sample indices (0-659) such that:
#     1. It identifies the ORIGINAL sample index (g * SAMPLES_PER_GROUP).
#     2. It creates 'samples_per_group' copies of that SINGLE index.
    
#     This ensures that each group (10 samples) in the final dataset
#     is a repetition of the single original image corresponding to that group.
#     """
    
#     # Identify the ORIGINAL sample index for each group
#     original_indices = group_indices * samples_per_group
    
#     # Repeat each original index 'samples_per_group' times
#     # e.g., [0, 10] -> [0,0,0,0,0,0,0,0,0,0, 10,10,10,10,10,10,10,10,10,10]
#     repeated_indices = np.repeat(original_indices, samples_per_group)
    
#     return repeated_indices

# # Create the new training and testing indices
# # These indices *now refer to the original samples, repeated 10 times*
# train_indices_orig_repeated = map_group_to_repeated_original_sample_indices(G_train_indices, SAMPLES_PER_GROUP)
# test_indices_orig_repeated = map_group_to_repeated_original_sample_indices(G_test_indices, SAMPLES_PER_GROUP)

# sz_test = len(test_indices_orig_repeated) # Will still be 140 for k=2
# sz_train = len(train_indices_orig_repeated) # Will still be 520

# # Create Subsets for training and testing using the new indices
# # NOTE: The Subset needs to be applied to the original 'all_data' which holds the 660 features
# train_dataset = Subset(all_data, train_indices_orig_repeated)
# test_dataset = Subset(all_data, test_indices_orig_repeated)

# # --- K-Fold Split Result (Original-Sample-Aware) ---
# print("\n--- K-Fold Split Result (Original-Sample-Aware) ---")
# print(f"Current Fold k: {k}")
# print(f"Test Group Indices: {G_test_indices}")
# print(f"Test Samples: {sz_test}")
# print(f"Train Dataset Samples: {sz_train} (Expected {N_SAMPLES - sz_test})") # Corrected Expected value
# print(f"Test Dataset Samples: {sz_test} (Expected {current_group_fold_size * SAMPLES_PER_GROUP})") # Corrected Expected value

In [None]:
# --- DATA ASSEMBLY FOR MANUAL ITERATION ---

# We use a DataLoader only to efficiently stack the training Subset items into one tensor.
N_TRAIN_SAMPLES = len(train_dataset)

temp_loader = DataLoader(
    train_dataset,
    batch_size=N_TRAIN_SAMPLES,
    shuffle=False, # Must be False for sequential index fetching
    num_workers=0 
)

# Fetch the entire training dataset into a single CPU tensor
# Shape: [N_TRAIN_SAMPLES, C, T, H, W]
for batch in temp_loader:
    full_clean_data_tensor_cpu = batch 
    break

print(f"\nASSEMBLED TENSOR: full_clean_data_tensor_cpu size: {full_clean_data_tensor_cpu.size()}")

# --- DataLoader Setup (Only for the Test/Validation Set) ---

test_loader = DataLoader(
    test_dataset,
    batch_size=args.N, # Use args.N (batch_size for GPU)
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# --- UPDATED VISUALIZATION EXAMPLE (Side-by-Side) ---

# Example function calls
visualize_sample(train_dataset, test_dataset, sample_idx=20, dataset_name='test')
# visualize_sample(sample_idx=500, dataset_name='train')

In [None]:
# for idx in range(52):
#     visualize_sample(sample_idx=idx*10, dataset_name='train')
    
# for idx in range(14):
#     visualize_sample(sample_idx=idx*10, dataset_name='test')

# Run the full visual split comparison
compare_split_masks(train_dataset, test_dataset, channel_type='mask')

In [None]:
compare_split_masks(train_dataset, test_dataset, channel_type='residual', time_step=1)

## Sliding Window Attention U-Net (SWAU-Net)

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# args.num_attn_layers = 2

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the full SWAU_Net model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net, which owns E1, CFB_enc, CFB_dec, SWA, P, and D1.
# model = SWAU_Net(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(model.E1)

# # CORRECTED: Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(model.CFB_enc) 
# cfb_dec_params = count_parameters(model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# swa_params = count_parameters(model.SWA) 
# p_params = count_parameters(model.P)
# d1_params = count_parameters(model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"], # NEW LINE: Aggregate CFB
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator**", f"**{swa_params:,}**"], 
#     ["DynNet (P)", "Temporal Feature Predictor (M_t → Evolved_t)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full SWAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "SWAU_synthetic_pretrain_epoch50_20251119_193158.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION FOR LOGGING ---
# # Set to higher epochs to clearly observe overfitting
# args.num_epochs = 50

# # # --- OVERFITTING TEST ACTIVATED ---
# # # Create an overfit dataset by replicating the first sample (index 0) 100 times
# # overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 100, dim=0)
# # current_train_data = overfit_data 
# # # ----------------------------------

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # History lists
# all_iteration_losses = [] 
# epoch_iteration_counts = [] # Stores number of iterations in each epoch

# # --- RESIDUAL HISTORY LISTS (Scores and SDs) ---
# train_dice_t1, train_dice_t2, train_dice_t3 = [], [], []
# test_dice_t1, test_dice_t2, test_dice_t3 = [], [], []

# # Standard Deviation (SD) Lists for Residuals
# train_sd_t1, train_sd_t2, train_sd_t3 = [], [], []
# test_sd_t1, test_sd_t2, test_sd_t3 = [], [], []

# # --- MASK HISTORY LISTS (Scores and SDs) ---
# train_mask_t1, train_mask_t2, train_mask_t3 = [], [], []
# test_mask_t1, test_mask_t2, test_mask_t3 = [], [], []

# # Standard Deviation (SD) Lists for Masks
# train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3 = [], [], []
# test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3 = [], [], []

# # --- TRAINING SETUP (Reconfirming) ---
# # Loss functions (re-instantiated if necessary, matching previous definitions)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss()
# loss_fn_l2 = nn.MSELoss()
# ACCUMULATION_STEPS = 8
# soft_dice = False

# lr = 1E-4
# # --- OPTIMIZER DEFINITION ---
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-5)

# # --- NEW: LEARNING RATE SCHEDULER INITIALIZATION ---
# # Use ReduceLROnPlateau to decrease LR when training loss plateaus.
# scheduler = ReduceLROnPlateau(
#     optimizer, 
#     mode='min',         # Monitor minimum loss
#     factor=0.5,         # Reduce LR by 50%
#     patience=15,         # Wait 5 epochs for improvement before reducing
#     verbose=True, 
#     min_lr=1e-6         # Stop reducing LR at 1e-6
# )
# # ---------------------------------------------------

# # -----------------
# # FREEZE BATCH NORM
# freeze_batch_norm(model)
# # ----------------

# print(f"\nStarting SWAU_Net training on device: {args.device} for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
# #     # --- 1. Training Step ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, model, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, lambda_recon=0.5, use_augmentation=True
#     )
    
#     # Store iteration loss and count for plot_log_loss
#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
    
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- NEW: SCHEDULER STEP ---
#     # Tell the scheduler to check the current loss and adjust the LR if necessary
#     scheduler.step(mean_epoch_loss)
    
#     # --- 3. Logging and Checkpoint ---
    
#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary ---")
#     print(f"Mean Loss: {mean_epoch_loss:.6f}")

#     # --- 4. Per-Epoch Visualizations (MOVED INSIDE LOOP) ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # C. Plot a sample prediction clip (e.g., test sample index 20)
#     f_display_frames(current_train_data, model, args, sample_idx=20, T_total=4)
    
#      # EVALUATION
#     use_median = True
#     if use_median==True:
#         print('Using DICE Median.')
#     else:
#         print('Using DICE Aggregate Mean.')
        
#     if soft_dice==True:
#         print('Using Soft DICE.')
#     else:
#         print('Using hard DICE.')
        
#     # --- Evaluation Step ---
    
#     # 1. Unpack Test Results: Returns ((Res Scores, Res SDs), (Msk Scores, Msk SDs))
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#     f_eval_pred_dice_test_set(test_loader, model, args, soft_dice=soft_dice, use_median=use_median)

#     # 2. Unpack Train Results
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#     f_eval_pred_dice_train_set(current_train_data, model, args, args.batch_size, soft_dice=soft_dice, use_median=use_median)
    
#     # --- 1. Residual Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_dice_t1.append(res_train_scores[0]); train_dice_t2.append(res_train_scores[1]); train_dice_t3.append(res_train_scores[2])
#     test_dice_t1.append(res_test_scores[0]); test_dice_t2.append(res_test_scores[1]); test_dice_t3.append(res_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_sd_t1.append(res_train_sds[0]); train_sd_t2.append(res_train_sds[1]); train_sd_t3.append(res_train_sds[2])
#     test_sd_t1.append(res_test_sds[0]); test_sd_t2.append(res_test_sds[1]); test_sd_t3.append(res_test_sds[2])

#     # --- 2. Mask Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_mask_t1.append(msk_train_scores[0]); train_mask_t2.append(msk_train_scores[1]); train_mask_t3.append(msk_train_scores[2])
#     test_mask_t1.append(msk_test_scores[0]); test_mask_t2.append(msk_test_scores[1]); test_mask_t3.append(msk_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_mask_sd_t1.append(msk_test_sds[0]); train_mask_sd_t2.append(msk_test_sds[1]); train_mask_sd_t3.append(msk_test_sds[2])
#     test_mask_sd_t1.append(msk_test_sds[0]); test_mask_sd_t2.append(msk_test_sds[1]); test_mask_sd_t3.append(msk_test_sds[2])


#     # 1. Plot Residual History
#     plot_train_test_dice_history(
#         train_dice_t1, train_dice_t2, train_dice_t3,
#         test_dice_t1, test_dice_t2, test_dice_t3,
#         train_sd_t1, train_sd_t2, train_sd_t3,       # PASSING SDs HERE
#         test_sd_t1, test_sd_t2, test_sd_t3,         # PASSING SDs HERE
#         plot_title='Residual Mask Dice Score History (T=1, T=2, T=3)'
#     )

#     # 2. Plot Mask History (NEW PLOT)
#     plot_train_test_dice_history(
#         train_mask_t1, train_mask_t2, train_mask_t3,
#         test_mask_t1, test_mask_t2, test_mask_t3,
#         train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3, # PASSING SDs HERE
#         test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3,   # PASSING SDs HERE
#         plot_title='Full Mask Dice Score History (T=1, T=2, T=3)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# # ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//SWAU_Net')
# # ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//No Pretraining')
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//No Augmentation')
# FINAL_EPOCH = args.num_epochs

# # saved_path = save_model_weights(
# #     model=model, 
# #     final_epoch=FINAL_EPOCH, 
# #     save_dir=ckpt_save_dir,
# #     model_name = "SWAU_without pretraining" + '_k_fold_' + str(k) + '_'
# # )

# saved_path = save_final_experiment_data(
#     model=model, 
#     final_epoch=FINAL_EPOCH, 
#     base_save_dir=ckpt_save_dir, 
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=train_dice_t1, train_dice_t2=train_dice_t2, train_dice_t3=train_dice_t3,
#     test_dice_t1=test_dice_t1, test_dice_t2=test_dice_t2, test_dice_t3=test_dice_t3,
#     train_sd_t1=train_sd_t1, train_sd_t2=train_sd_t2, train_sd_t3=train_sd_t3,
#     test_sd_t1=test_sd_t1, test_sd_t2=test_sd_t2, test_sd_t3=test_sd_t3,
    
#     # Mask
#     train_mask_t1=train_mask_t1, train_mask_t2=train_mask_t2, train_mask_t3=train_mask_t3,
#     test_mask_t1=test_mask_t1, test_mask_t2=test_mask_t2, test_mask_t3=test_mask_t3,
#     train_mask_sd_t1=train_mask_sd_t1, train_mask_sd_t2=train_mask_sd_t2, train_mask_sd_t3=train_mask_sd_t3,
#     test_mask_sd_t1=test_mask_sd_t1, test_mask_sd_t2=test_mask_sd_t2, test_mask_sd_t3=test_mask_sd_t3,

#     # Set model name prefix
#     model_name_prefix="SWAU_WITHOUT_augmentation"
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

In [None]:
# if torch.cuda.is_available():
#     torch.cuda.empty_cache() # Clears unused cached memory
#     # Also clear any gradients potentially left from the interrupted batch
#     optimizer.zero_grad()

In [None]:
# # --- 4. Final Visualizations ---
# print("\n--- Generating Final Training Visualizations ---")

# # A. Plot Loss History
# plot_log_loss(all_iteration_losses, epoch_iteration_counts)

# # 1. Plot Residual History
# plot_train_test_dice_history(
#     train_dice_t1, train_dice_t2, train_dice_t3,
#     test_dice_t1, test_dice_t2, test_dice_t3,
#     train_sd_t1, train_sd_t2, train_sd_t3,       # PASSING SDs HERE
#     test_sd_t1, test_sd_t2, test_sd_t3,         # PASSING SDs HERE
#     plot_title='Residual Mask Dice Score History (T=1, T=2, T=3)'
# )

# # 2. Plot Mask History (NEW PLOT)
# plot_train_test_dice_history(
#     train_mask_t1, train_mask_t2, train_mask_t3,
#     test_mask_t1, test_mask_t2, test_mask_t3,
#     train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3, # PASSING SDs HERE
#     test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3,   # PASSING SDs HERE
#     plot_title='Full Mask Dice Score History (T=1, T=2, T=3)'
# )

# # C. Plot a sample prediction clip (e.g., test sample index 20)
# f_display_frames(current_train_data, model, args, sample_idx=20, T_total=4)

In [None]:
# for j in range(len(current_train_data)):
#     f_display_frames(current_train_data, model, args, sample_idx=j, T_total=4)

In [None]:
# # --- Setup for Plotting ---

# ## Plot with soft DICE (Residual and Mask)
# soft_dice = True
# metric_type_str = "Soft Dice"

# # 1. Evaluate: Unpack the score arrays and discard the Mean/SD tuples
# (res_scores_train, msk_scores_train), _ = f_get_individual_dice(
#     train_dataset, model, args, is_train_set=True, soft_dice=soft_dice
# )
# (res_scores_test, msk_scores_test), _ = f_get_individual_dice(
#     test_dataset, model, args, is_train_set=False, soft_dice=soft_dice
# )

# # 2. Plot Residuals (Soft) - (Plotting logic remains correct)
# f_plot_individual_dice(res_scores_train, res_scores_test, metric_type_str, channel_name='Residual Mask')

# # 3. Plot Masks (Soft)
# f_plot_individual_dice(msk_scores_train, msk_scores_test, metric_type_str, channel_name='Full Mask')


# # --- Horizontal Rule to Separate Soft/Hard Plots ---
# print("\n" + "="*50 + "\n")

# ## Plot with hard DICE (Residual and Mask)
# soft_dice = False
# metric_type_str = "Hard Dice"

# # 1. Evaluate: Unpack the score arrays and discard the Mean/SD tuples
# (res_scores_train, msk_scores_train), _ = f_get_individual_dice(
#     train_dataset, model, args, is_train_set=True, soft_dice=soft_dice
# )
# (res_scores_test, msk_scores_test), _ = f_get_individual_dice(
#     test_dataset, model, args, is_train_set=False, soft_dice=soft_dice
# )

# # 2. Plot Residuals (Hard)
# f_plot_individual_dice(res_scores_train, res_scores_test, metric_type_str, channel_name='Residual Mask')

# # 3. Plot Masks (Hard)
# f_plot_individual_dice(msk_scores_train, msk_scores_test, metric_type_str, channel_name='Full Mask')

In [None]:
# model.eval()

# is_train_set = 1

# # Create a DataLoader specifically for batch size 1 iteration
# data_loader = DataLoader(
#     train_dataset,
#     batch_size=1, # CRITICAL: Batch size of 1 for individual scores
#     shuffle=False,
#     num_workers=0
# )

# all_clip_dice_scores = []
# device = args.device
# T_pred = 3

# # Determine which metric tensor to use for evaluation
# if soft_dice:
#     desc_label = "Soft Dice (Individual)"
# else:
#     desc_label = "Hard Dice (Individual)"

# with torch.no_grad():
#     # tqdm description is based on whether we are processing the train tensor or test loader
#     if is_train_set:
#         iterator = tqdm(data_loader, desc=f"Evaluating Train Clips ({desc_label})")
#     else:
#         iterator = tqdm(data_loader, desc=f"Evaluating Test Clips ({desc_label})")

#     for data_batch in iterator:
#         # data_batch shape: [1, C, T, H, W]

#         predictions, targets, *discards = model(data_batch.to(device))
#         break
        
# predictions.size()

# targets.size()

# print(targets[0,0,2,:,:])

# print(predictions[0,0,2,:,:])

# plt.imshow(targets[0,2,2,:,:].detach().cpu().numpy())
# plt.show()
# plt.imshow(predictions[0,2,2,:,:].detach().cpu().numpy())
# plt.show()

In [None]:
##########################################################################################################################
##########################################################################################################################

##########################################################################################################################
##########################################################################################################################

## Train CFB Ablation

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# args.num_attn_layers = 2

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the full SWAU_Net model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net, which owns E1, CFB_enc, CFB_dec, SWA, P, and D1.
# model = SWAU_CFB_Ablation(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(model.E1)

# swa_params = count_parameters(model.SWA) 
# p_params = count_parameters(model.P)
# d1_params = count_parameters(model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator**", f"**{swa_params:,}**"], 
#     ["DynNet (P)", "Temporal Feature Predictor (M_t → Evolved_t)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full SWAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "CFB_Ablation_synthetic_pretrain_epoch50_20251120_145522.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION FOR LOGGING ---
# # Set to higher epochs to clearly observe overfitting
# args.num_epochs = 50

# # # --- OVERFITTING TEST ACTIVATED ---
# # # Create an overfit dataset by replicating the first sample (index 0) 100 times
# # overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 100, dim=0)
# # current_train_data = overfit_data 
# # # ----------------------------------

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # History lists
# all_iteration_losses = [] 
# epoch_iteration_counts = [] # Stores number of iterations in each epoch

# # --- RESIDUAL HISTORY LISTS (Scores and SDs) ---
# train_dice_t1, train_dice_t2, train_dice_t3 = [], [], []
# test_dice_t1, test_dice_t2, test_dice_t3 = [], [], []

# # Standard Deviation (SD) Lists for Residuals
# train_sd_t1, train_sd_t2, train_sd_t3 = [], [], []
# test_sd_t1, test_sd_t2, test_sd_t3 = [], [], []

# # --- MASK HISTORY LISTS (Scores and SDs) ---
# train_mask_t1, train_mask_t2, train_mask_t3 = [], [], []
# test_mask_t1, test_mask_t2, test_mask_t3 = [], [], []

# # Standard Deviation (SD) Lists for Masks
# train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3 = [], [], []
# test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3 = [], [], []

# # --- TRAINING SETUP (Reconfirming) ---
# # Loss functions (re-instantiated if necessary, matching previous definitions)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss()
# loss_fn_l2 = nn.MSELoss()
# ACCUMULATION_STEPS = 8
# soft_dice = False

# lr = 1E-4
# # --- OPTIMIZER DEFINITION ---
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-5)

# # --- NEW: LEARNING RATE SCHEDULER INITIALIZATION ---
# # Use ReduceLROnPlateau to decrease LR when training loss plateaus.
# scheduler = ReduceLROnPlateau(
#     optimizer, 
#     mode='min',         # Monitor minimum loss
#     factor=0.5,         # Reduce LR by 50%
#     patience=15,         # Wait 5 epochs for improvement before reducing
#     verbose=True, 
#     min_lr=1e-6         # Stop reducing LR at 1e-6
# )
# # ---------------------------------------------------

# # -----------------
# # FREEZE BATCH NORM
# freeze_batch_norm(model)
# # ----------------

# print(f"\nStarting SWAU_Net training on device: {args.device} for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
# #     # --- 1. Training Step ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, model, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, lambda_recon=0.5, use_augmentation=True
#     )
    
#     # Store iteration loss and count for plot_log_loss
#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
    
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- NEW: SCHEDULER STEP ---
#     # Tell the scheduler to check the current loss and adjust the LR if necessary
#     scheduler.step(mean_epoch_loss)
    
#     # --- 3. Logging and Checkpoint ---
    
#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary ---")
#     print(f"Mean Loss: {mean_epoch_loss:.6f}")

#     # --- 4. Per-Epoch Visualizations (MOVED INSIDE LOOP) ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # C. Plot a sample prediction clip (e.g., test sample index 20)
#     f_display_frames(current_train_data, model, args, sample_idx=20, T_total=4)
    
#      # EVALUATION
#     use_median = True
#     if use_median==True:
#         print('Using DICE Median.')
#     else:
#         print('Using DICE Aggregate Mean.')
        
#     if soft_dice==True:
#         print('Using Soft DICE.')
#     else:
#         print('Using hard DICE.')
        
#     # --- Evaluation Step ---
    
#     # 1. Unpack Test Results: Returns ((Res Scores, Res SDs), (Msk Scores, Msk SDs))
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#     f_eval_pred_dice_test_set(test_loader, model, args, soft_dice=soft_dice, use_median=use_median)

#     # 2. Unpack Train Results
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#     f_eval_pred_dice_train_set(current_train_data, model, args, args.batch_size, soft_dice=soft_dice, use_median=use_median)
    
#     # --- 1. Residual Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_dice_t1.append(res_train_scores[0]); train_dice_t2.append(res_train_scores[1]); train_dice_t3.append(res_train_scores[2])
#     test_dice_t1.append(res_test_scores[0]); test_dice_t2.append(res_test_scores[1]); test_dice_t3.append(res_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_sd_t1.append(res_train_sds[0]); train_sd_t2.append(res_train_sds[1]); train_sd_t3.append(res_train_sds[2])
#     test_sd_t1.append(res_test_sds[0]); test_sd_t2.append(res_test_sds[1]); test_sd_t3.append(res_test_sds[2])

#     # --- 2. Mask Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_mask_t1.append(msk_train_scores[0]); train_mask_t2.append(msk_train_scores[1]); train_mask_t3.append(msk_train_scores[2])
#     test_mask_t1.append(msk_test_scores[0]); test_mask_t2.append(msk_test_scores[1]); test_mask_t3.append(msk_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_mask_sd_t1.append(msk_test_sds[0]); train_mask_sd_t2.append(msk_test_sds[1]); train_mask_sd_t3.append(msk_test_sds[2])
#     test_mask_sd_t1.append(msk_test_sds[0]); test_mask_sd_t2.append(msk_test_sds[1]); test_mask_sd_t3.append(msk_test_sds[2])


#     # 1. Plot Residual History
#     plot_train_test_dice_history(
#         train_dice_t1, train_dice_t2, train_dice_t3,
#         test_dice_t1, test_dice_t2, test_dice_t3,
#         train_sd_t1, train_sd_t2, train_sd_t3,       # PASSING SDs HERE
#         test_sd_t1, test_sd_t2, test_sd_t3,         # PASSING SDs HERE
#         plot_title='Residual Mask Dice Score History (T=1, T=2, T=3)'
#     )

#     # 2. Plot Mask History (NEW PLOT)
#     plot_train_test_dice_history(
#         train_mask_t1, train_mask_t2, train_mask_t3,
#         test_mask_t1, test_mask_t2, test_mask_t3,
#         train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3, # PASSING SDs HERE
#         test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3,   # PASSING SDs HERE
#         plot_title='Full Mask Dice Score History (T=1, T=2, T=3)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# # ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//SWAU_Net')
# # ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//No Pretraining')
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//CFB Ablation')
# FINAL_EPOCH = args.num_epochs

# # saved_path = save_model_weights(
# #     model=model, 
# #     final_epoch=FINAL_EPOCH, 
# #     save_dir=ckpt_save_dir,
# #     model_name = "SWAU_without pretraining" + '_k_fold_' + str(k) + '_'
# # )

# saved_path = save_final_experiment_data(
#     model=model, 
#     final_epoch=FINAL_EPOCH, 
#     base_save_dir=ckpt_save_dir, 
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=train_dice_t1, train_dice_t2=train_dice_t2, train_dice_t3=train_dice_t3,
#     test_dice_t1=test_dice_t1, test_dice_t2=test_dice_t2, test_dice_t3=test_dice_t3,
#     train_sd_t1=train_sd_t1, train_sd_t2=train_sd_t2, train_sd_t3=train_sd_t3,
#     test_sd_t1=test_sd_t1, test_sd_t2=test_sd_t2, test_sd_t3=test_sd_t3,
    
#     # Mask
#     train_mask_t1=train_mask_t1, train_mask_t2=train_mask_t2, train_mask_t3=train_mask_t3,
#     test_mask_t1=test_mask_t1, test_mask_t2=test_mask_t2, test_mask_t3=test_mask_t3,
#     train_mask_sd_t1=train_mask_sd_t1, train_mask_sd_t2=train_mask_sd_t2, train_mask_sd_t3=train_mask_sd_t3,
#     test_mask_sd_t1=test_mask_sd_t1, test_mask_sd_t2=test_mask_sd_t2, test_mask_sd_t3=test_mask_sd_t3,

#     # Set model name prefix
#     model_name_prefix="SWAU_WITHOUT_augmentation"
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Train DynNet Ablation

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Base channel width
# args.d_attn1 = 128 # Feed-forward dim for L3
# args.d_attn2 = 256 # Feed-forward dim for L4/L5

# args.num_attn_layers = 2 # Number of SWA layers

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the SWAU_DynNet_Ablation model with the updated configuration...")

# # Instantiate the ablation model and move it to the device
# # NOTE: The CFB components are now included inside E1 and D1 or explicitly defined.
# # Using the SWAU_Net_NoDynNet_Decoupled class.
# swau_dynnet_ablation_model = SWAU_DynNet_Ablation(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_dynnet_ablation_model.E1)
# cfb_enc_params = count_parameters(swau_dynnet_ablation_model.CFB_enc)
# cfb_dec_params = count_parameters(swau_dynnet_ablation_model.CFB_dec)
# swa_params = count_parameters(swau_dynnet_ablation_model.SWA)

# # DynNet is removed, so its parameter count is 0.
# p_params = 0 

# d1_params = count_parameters(swau_dynnet_ablation_model.D1)

# # Ensure all components are summed up for the total count (including CFB blocks)
# total_params = e1_params + cfb_enc_params + cfb_dec_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["CFB_enc", "Pre-Aggregation Channel Refinement", f"{cfb_enc_params:,}"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Estimator**", f"**{swa_params:,}**"],
#     ["DynNet (P)", "Temporal Evolution Module", f"**{p_params:,} (Removed)**"], # P_params = 0
#     ["CFB_dec", "Post-Aggregation Channel Refinement", f"{cfb_dec_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**SWAU_DynNet_Ablation Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_DynNet_Ablation Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "DynNet_Ablation_synthetic_pretrain_epoch50_20251123_123558.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=swau_dynnet_ablation_model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION FOR LOGGING ---
# # Set to higher epochs to clearly observe overfitting
# args.num_epochs = 50

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # History lists
# all_iteration_losses = [] 
# epoch_iteration_counts = [] # Stores number of iterations in each epoch

# # --- RESIDUAL HISTORY LISTS (Scores and SDs) - USING NEW NAMING CONVENTION ---
# # These lists are now redundant as the main training loop (from the previous prompt) 
# # uses the 'swau_' prefixes. We initialize the *final* lists here.
# swau_train_t1, swau_train_t2, swau_train_t3 = [], [], []
# swau_test_t1, swau_test_t2, swau_test_t3 = [], [], []
# swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3 = [], [], []
# swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3 = [], [], []

# # --- MASK HISTORY LISTS (Scores and SDs) - USING NEW NAMING CONVENTION ---
# swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3 = [], [], []
# swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3 = [], [], []
# swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3 = [], [], []
# swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3 = [], [], []


# # --- TRAINING SETUP (Reconfirming) ---
# # Loss functions (re-instantiated if necessary, matching previous definitions)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss()
# loss_fn_l2 = nn.MSELoss()
# ACCUMULATION_STEPS = 8
# soft_dice = False

# lr = 1E-4
# # --- OPTIMIZER DEFINITION (Using the working model instance: swau_dynnet_ablation_model) ---
# optimizer = torch.optim.Adam(swau_dynnet_ablation_model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-5)

# # --- LEARNING RATE SCHEDULER INITIALIZATION ---
# scheduler = ReduceLROnPlateau(
#     optimizer, 
#     mode='min',         # Monitor minimum loss
#     factor=0.5,         
#     patience=15,        
#     verbose=True, 
#     min_lr=1e-6         
# )
# # ---------------------------------------------------

# # -----------------
# # FREEZE BATCH NORM (Using the working model instance: swau_dynnet_ablation_model)
# freeze_batch_norm(swau_dynnet_ablation_model)
# # ----------------

# print(f"\nStarting SWAU_DynNet_Ablation training on device: {args.device} for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
# #     # --- 1. Training Step ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_dynnet_ablation_model, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, lambda_recon=0.5, use_augmentation=True
#     )
    
#     # Store iteration loss and count for plot_log_loss
#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
    
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- NEW: SCHEDULER STEP ---
#     scheduler.step(mean_epoch_loss)
    
#     # --- 3. Logging and Checkpoint ---
    
#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary ---")
#     print(f"Mean Loss: {mean_epoch_loss:.6f}")

#     # --- 4. Per-Epoch Visualizations (MOVED INSIDE LOOP) ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # C. Plot a sample prediction clip 
#     f_display_frames(current_train_data, swau_dynnet_ablation_model, args, sample_idx=20, T_total=4)
    
#      # EVALUATION
#     use_median = True
#     if use_median==True:
#         print('Using DICE Median.')
#     else:
#         print('Using DICE Aggregate Mean.')
        
#     if soft_dice==True:
#         print('Using Soft DICE.')
#     else:
#         print('Using hard DICE.')
        
#     # --- Evaluation Step (Using the working model instance: swau_dynnet_ablation_model) ---
    
#     # 1. Unpack Test Results: Returns ((Res Scores, Res SDs), (Msk Scores, Msk SDs))
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#     f_eval_pred_dice_test_set(test_loader, swau_dynnet_ablation_model, args, soft_dice=soft_dice, use_median=use_median)

#     # 2. Unpack Train Results
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#     f_eval_pred_dice_train_set(current_train_data, swau_dynnet_ablation_model, args, args.batch_size, soft_dice=soft_dice, use_median=use_median)
    
#     # --- 1. Residual Scores and SDs Accumulation (Using swau_ prefixes) ---

#     # Accumulate Scores
#     swau_train_t1.append(res_train_scores[0]); swau_train_t2.append(res_train_scores[1]); swau_train_t3.append(res_train_scores[2])
#     swau_test_t1.append(res_test_scores[0]); swau_test_t2.append(res_test_scores[1]); swau_test_t3.append(res_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     swau_train_sd_t1.append(res_train_sds[0]); swau_train_sd_t2.append(res_train_sds[1]); swau_train_sd_t3.append(res_train_sds[2])
#     swau_test_sd_t1.append(res_test_sds[0]); swau_test_sd_t2.append(res_test_sds[1]); swau_test_sd_t3.append(res_test_sds[2])

#     # --- 2. Mask Scores and SDs Accumulation (Using swau_ prefixes) ---

#     # Accumulate Scores
#     swau_train_mask_t1.append(msk_train_scores[0]); swau_train_mask_t2.append(msk_train_scores[1]); swau_train_mask_t3.append(msk_train_scores[2])
#     swau_test_mask_t1.append(msk_test_scores[0]); swau_test_mask_t2.append(msk_test_scores[1]); swau_test_mask_t3.append(msk_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     swau_train_mask_sd_t1.append(msk_train_sds[0]); swau_train_mask_sd_t2.append(msk_train_sds[1]); swau_train_mask_sd_t3.append(msk_train_sds[2])
#     swau_test_mask_sd_t1.append(msk_test_sds[0]); swau_test_mask_sd_t2.append(msk_test_sds[1]); swau_test_mask_sd_t3.append(msk_test_sds[2])


#     # 1. Plot Residual History
#     plot_train_test_dice_history(
#         swau_train_t1, swau_train_t2, swau_train_t3,
#         swau_test_t1, swau_test_t2, swau_test_t3,
#         swau_train_sd_t1, swau_train_sd_t2, swau_train_sd_t3,       
#         swau_test_sd_t1, swau_test_sd_t2, swau_test_sd_t3,         
#         plot_title='SWAU_DynNet_Ablation Residual Dice History (Median ± SD)'
#     )

#     # 2. Plot Mask History 
#     plot_train_test_dice_history(
#         swau_train_mask_t1, swau_train_mask_t2, swau_train_mask_t3,
#         swau_test_mask_t1, swau_test_mask_t2, swau_test_mask_t3,
#         swau_train_mask_sd_t1, swau_train_mask_sd_t2, swau_train_mask_sd_t3, 
#         swau_test_mask_sd_t1, swau_test_mask_sd_t2, swau_test_mask_sd_t3,   
#         plot_title='SWAU_DynNet_Ablation Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# # --- DIRECTORY PATH (Change this for your execution environment) ---
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//DynNet_Ablation')
# FINAL_EPOCH = args.num_epochs

# # --- SAVE FINAL EXPERIMENT DATA ---

# # NOTE: We use the working model instance 'swau_dynnet_ablation_model' for saving.
# saved_path = save_final_experiment_data(
#     model=swau_dynnet_ablation_model, # Use the model instance from the training loop
#     final_epoch=FINAL_EPOCH, 
#     base_save_dir=ckpt_save_dir, 
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop (using 'swau_' prefixes) ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=swau_train_t1, train_dice_t2=swau_train_t2, train_dice_t3=swau_train_t3,
#     test_dice_t1=swau_test_t1, test_dice_t2=swau_test_t2, test_dice_t3=swau_test_t3,
#     train_sd_t1=swau_train_sd_t1, train_sd_t2=swau_train_sd_t2, train_sd_t3=swau_train_sd_t3,
#     test_sd_t1=swau_test_sd_t1, test_sd_t2=swau_test_sd_t2, test_sd_t3=swau_test_sd_t3,
    
#     # Mask
#     train_mask_t1=swau_train_mask_t1, train_mask_t2=swau_train_mask_t2, train_mask_t3=swau_train_mask_t3,
#     test_mask_t1=swau_test_mask_t1, test_mask_t2=swau_test_mask_t2, test_mask_t3=swau_test_mask_t3,
#     train_mask_sd_t1=swau_train_mask_sd_t1, train_mask_sd_t2=swau_train_mask_sd_t2, train_mask_sd_t3=swau_train_mask_sd_t3,
#     test_mask_sd_t1=swau_test_mask_sd_t1, test_mask_sd_t2=swau_test_mask_sd_t2, test_mask_sd_t3=swau_test_mask_sd_t3,

#     # Set model name prefix to reflect the specific ablation
#     model_name_prefix="SWAU_DYNNET_ABLATION" 
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Train Axial U-Net

In [None]:
# ## Train Axial U-Net

# args.num_attn_layers = 2

# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the full AxialU_Net model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # We use AxialU_Net for the instantiation but maintain SWAU_Net names in comments/output
# model = AxialU_Net(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(model.E1)

# # CORRECTED: Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(model.CFB_enc) 
# cfb_dec_params = count_parameters(model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # *** CRITICAL FIX: Assigning the Axial Aggregator params to the SWA variable name ***
# swa_params = count_parameters(model.Axial_Aggregator) 

# p_params = count_parameters(model.P)
# d1_params = count_parameters(model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (Time t)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"], # NEW LINE: Aggregate CFB
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator**", f"**{swa_params:,}**"], # Using SWA name, but value is Axial Aggregator
#     ["DynNet (P)", "Temporal Feature Predictor (M_t → Evolved_t)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor (Time t+1)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**Full SWAU_Net Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net Component Parameter Summary\n")
# # CRITICAL FIX: Using the custom 'tabulate' replacement function for compliance
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "AxialU_Net_pretrain_epoch50_20251118_223001.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION FOR LOGGING ---
# # Set to higher epochs to clearly observe overfitting
# args.num_epochs = 60

# # # --- OVERFITTING TEST ACTIVATED ---
# # # Create an overfit dataset by replicating the first sample (index 0) 100 times
# # overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 100, dim=0)
# # current_train_data = overfit_data 
# # # ----------------------------------

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # History lists
# all_iteration_losses = [] 
# epoch_iteration_counts = [] # Stores number of iterations in each epoch

# # --- RESIDUAL HISTORY LISTS (Scores and SDs) ---
# train_dice_t1, train_dice_t2, train_dice_t3 = [], [], []
# test_dice_t1, test_dice_t2, test_dice_t3 = [], [], []

# # Standard Deviation (SD) Lists for Residuals
# train_sd_t1, train_sd_t2, train_sd_t3 = [], [], []
# test_sd_t1, test_sd_t2, test_sd_t3 = [], [], []

# # --- MASK HISTORY LISTS (Scores and SDs) ---
# train_mask_t1, train_mask_t2, train_mask_t3 = [], [], []
# test_mask_t1, test_mask_t2, test_mask_t3 = [], [], []

# # Standard Deviation (SD) Lists for Masks
# train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3 = [], [], []
# test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3 = [], [], []

# # --- TRAINING SETUP (Reconfirming) ---
# # Loss functions (re-instantiated if necessary, matching previous definitions)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss()
# loss_fn_l2 = nn.MSELoss()
# ACCUMULATION_STEPS = 8
# soft_dice = False

# lr = 1E-4
# # --- OPTIMIZER DEFINITION ---
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-5)

# # --- NEW: LEARNING RATE SCHEDULER INITIALIZATION ---
# # Use ReduceLROnPlateau to decrease LR when training loss plateaus.
# scheduler = ReduceLROnPlateau(
#     optimizer, 
#     mode='min',         # Monitor minimum loss
#     factor=0.5,         # Reduce LR by 50%
#     patience=15,         # Wait 5 epochs for improvement before reducing
#     verbose=True, 
#     min_lr=1e-6         # Stop reducing LR at 1e-6
# )
# # ---------------------------------------------------

# # -----------------
# # FREEZE BATCH NORM
# freeze_batch_norm(model)
# # ----------------

# print(f"\nStarting AxialU_Net training on device: {args.device} for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
# #     # --- 1. Training Step ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, model, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, lambda_recon=0.5, use_augmentation=True
#     )
    
#     # Store iteration loss and count for plot_log_loss
#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
    
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- NEW: SCHEDULER STEP ---
#     # Tell the scheduler to check the current loss and adjust the LR if necessary
#     scheduler.step(mean_epoch_loss)
    
#     # --- 3. Logging and Checkpoint ---
    
#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary ---")
#     print(f"Mean Loss: {mean_epoch_loss:.6f}")

#     # --- 4. Per-Epoch Visualizations (MOVED INSIDE LOOP) ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # C. Plot a sample prediction clip (e.g., test sample index 20)
#     f_display_frames(current_train_data, model, args, sample_idx=20, T_total=4)
    
#      # EVALUATION
#     use_median = True
#     if use_median==True:
#         print('Using DICE Median.')
#     else:
#         print('Using DICE Aggregate Mean.')
        
#     if soft_dice==True:
#         print('Using Soft DICE.')
#     else:
#         print('Using hard DICE.')
        
#     # --- Evaluation Step ---
    
#     # 1. Unpack Test Results: Returns ((Res Scores, Res SDs), (Msk Scores, Msk SDs))
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#     f_eval_pred_dice_test_set(test_loader, model, args, soft_dice=soft_dice, use_median=use_median)

#     # 2. Unpack Train Results
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#     f_eval_pred_dice_train_set(current_train_data, model, args, args.batch_size, soft_dice=soft_dice, use_median=use_median)
    
#     # --- 1. Residual Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_dice_t1.append(res_train_scores[0]); train_dice_t2.append(res_train_scores[1]); train_dice_t3.append(res_train_scores[2])
#     test_dice_t1.append(res_test_scores[0]); test_dice_t2.append(res_test_scores[1]); test_dice_t3.append(res_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_sd_t1.append(res_train_sds[0]); train_sd_t2.append(res_train_sds[1]); train_sd_t3.append(res_train_sds[2])
#     test_sd_t1.append(res_test_sds[0]); test_sd_t2.append(res_test_sds[1]); test_sd_t3.append(res_test_sds[2])

#     # --- 2. Mask Scores and SDs Accumulation ---

#     # Accumulate Scores
#     train_mask_t1.append(msk_train_scores[0]); train_mask_t2.append(msk_train_scores[1]); train_mask_t3.append(msk_train_scores[2])
#     test_mask_t1.append(msk_test_scores[0]); test_mask_t2.append(msk_test_scores[1]); test_mask_t3.append(msk_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     train_mask_sd_t1.append(msk_test_sds[0]); train_mask_sd_t2.append(msk_test_sds[1]); train_mask_sd_t3.append(msk_test_sds[2])
#     test_mask_sd_t1.append(msk_test_sds[0]); test_mask_sd_t2.append(msk_test_sds[1]); test_mask_sd_t3.append(msk_test_sds[2])


#     # 1. Plot Residual History
#     plot_train_test_dice_history(
#         train_dice_t1, train_dice_t2, train_dice_t3,
#         test_dice_t1, test_dice_t2, test_dice_t3,
#         train_sd_t1, train_sd_t2, train_sd_t3,       # PASSING SDs HERE
#         test_sd_t1, test_sd_t2, test_sd_t3,         # PASSING SDs HERE
#         plot_title='Residual Mask Dice Score History (T=1, T=2, T=3)'
#     )

#     # 2. Plot Mask History (NEW PLOT)
#     plot_train_test_dice_history(
#         train_mask_t1, train_mask_t2, train_mask_t3,
#         test_mask_t1, test_mask_t2, test_mask_t3,
#         train_mask_sd_t1, train_mask_sd_t2, train_mask_sd_t3, # PASSING SDs HERE
#         test_mask_sd_t1, test_mask_sd_t2, test_mask_sd_t3,   # PASSING SDs HERE
#         plot_title='Full Mask Dice Score History (T=1, T=2, T=3)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Axial_UNet')
# FINAL_EPOCH = args.num_epochs

# saved_path = save_final_experiment_data(
#     model=model, 
#     final_epoch=FINAL_EPOCH, 
#     base_save_dir=ckpt_save_dir, 
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=train_dice_t1, train_dice_t2=train_dice_t2, train_dice_t3=train_dice_t3,
#     test_dice_t1=test_dice_t1, test_dice_t2=test_dice_t2, test_dice_t3=test_dice_t3,
#     train_sd_t1=train_sd_t1, train_sd_t2=train_sd_t2, train_sd_t3=train_sd_t3,
#     test_sd_t1=test_sd_t1, test_sd_t2=test_sd_t2, test_sd_t3=test_sd_t3,
    
#     # Mask
#     train_mask_t1=train_mask_t1, train_mask_t2=train_mask_t2, train_mask_t3=train_mask_t3,
#     test_mask_t1=test_mask_t1, test_mask_t2=test_mask_t2, test_mask_t3=test_mask_t3,
#     train_mask_sd_t1=train_mask_sd_t1, train_mask_sd_t2=train_mask_sd_t2, train_mask_sd_t3=train_mask_sd_t3,
#     test_mask_sd_t1=test_mask_sd_t1, test_mask_sd_t2=test_mask_sd_t2, test_mask_sd_t3=test_mask_sd_t3,

#     # Set model name prefix
#     model_name_prefix="Axial_UNet"
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Ablate Spatial Attention

In [None]:
# ## Ablate Spatial Attention


# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the **SWAU_Net-CNN Ablation** model with the updated configuration...")

# # Instantiate the full model and move it to the device
# # The model is SWAU_Net_CNN, which uses CNN_Unet_Enc/Dec and CNN_DynNet.
# swau_model = SWAU_Net_CNN(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(swau_model.E1)

# # CORRECTED: Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(swau_model.CFB_enc) 
# cfb_dec_params = count_parameters(swau_model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# swa_params = count_parameters(swau_model.SWA) 
# p_params = count_parameters(swau_model.P)
# d1_params = count_parameters(swau_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["CNN_Unet_Enc (E1)", "Feature Extractor (No Spatial Attention)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"],
#     ["SlidingWindowAttention (SWA)", "**Feature Aggregator/Integrator (Temporal Axial Only)**", f"**{swa_params:,}**"], 
#     ["CNN_DynNet (P)", "Temporal Feature Predictor (No Attention)", f"{p_params:,}"],
#     ["CNN_Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**SWAU_Net-CNN Ablation**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### SWAU_Net-CNN Ablation Component Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "SWAU_CNN_pretrain_epoch50_20251115_205035.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=swau_model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )


In [None]:
# # --- INITIALIZATION AND HYPERPARAMETER SETUP ---

# # # --- OVERFITTING TEST ACTIVATED ---
# # # Create an overfit dataset by replicating the first sample (index 0) 100 times
# # overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 400, dim=0)
# # current_train_data = overfit_data 
# # # ----------------------------------

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 60
# ACCUMULATION_STEPS = 8 
# soft_dice = False # Use Soft Dice for stability
# lr = 1E-4 # Initial LR

# # --- MODEL SWAP: Instantiate SWAU_Net_CNN instead of model_baseline ---
# # This is the SWAU_Net Ablation (SWA temporal regularization, no spatial attention)
# swau_cnn_model = SWAU_Net_CNN(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# use_median = True
# if use_median==True:
#     print('Using DICE Median.')
# else:
#     print('Using DICE Aggregate Mean.')

# if soft_dice==True:
#     print('Using Soft DICE.')
# else:
#     print('Using hard DICE.')

# # Update optimizer and scheduler to use the new model instance
# optimizer_swau = torch.optim.Adam(swau_cnn_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_swau = ReduceLROnPlateau(
#     optimizer_swau, 
#     mode='min', factor=0.5, patience=15, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# # LLR_WEIGHT and BOTTLENECK_L2_WEIGHT are correctly used below
# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(swau_cnn_model)

# # --- BASELINE HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median) - Use generic 'ablated' naming now
# ablated_train_residual_t1, ablated_train_residual_t2, ablated_train_residual_t3 = [], [], []
# ablated_test_residual_t1, ablated_test_residual_t2, ablated_test_residual_t3 = [], [], []
# # Residual SDs
# ablated_train_residual_sd_t1, ablated_train_residual_sd_t2, ablated_train_residual_sd_t3 = [], [], []
# ablated_test_residual_sd_t1, ablated_test_residual_sd_t2, ablated_test_residual_sd_t3 = [], [], []

# # Mask Scores (Mean/Median)
# ablated_train_mask_t1, ablated_train_mask_t2, ablated_train_mask_t3 = [], [], []
# ablated_test_mask_t1, ablated_test_mask_t2, ablated_test_mask_t3 = [], [], []
# # Mask SDs
# ablated_train_mask_sd_t1, ablated_train_mask_sd_t2, ablated_train_mask_sd_t3 = [], [], []
# ablated_test_mask_sd_t1, ablated_test_mask_sd_t2, ablated_test_mask_sd_t3 = [], [], []


# print(f"\n Starting SWAU-Net-CNN (No Spatial Attention) Training for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP (100 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, swau_cnn_model, optimizer_swau, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, 
#         lambda_recon=0.2, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, swau_cnn_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, swau_cnn_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation ---
#     # Residual Scores
#     ablated_train_residual_t1.append(res_train_scores[0]); ablated_train_residual_t2.append(res_train_scores[1]); ablated_train_residual_t3.append(res_train_scores[2])
#     ablated_test_residual_t1.append(res_test_scores[0]); ablated_test_residual_t2.append(res_test_scores[1]); ablated_test_residual_t3.append(res_test_scores[2])
#     # Residual SDs
#     ablated_train_residual_sd_t1.append(res_train_sds[0]); ablated_train_residual_sd_t2.append(res_train_sds[1]); ablated_train_residual_sd_t3.append(res_train_sds[2])
#     ablated_test_residual_sd_t1.append(res_test_sds[0]); ablated_test_residual_sd_t2.append(res_test_sds[1]); ablated_test_residual_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     ablated_train_mask_t1.append(msk_train_scores[0]); ablated_train_mask_t2.append(msk_train_scores[1]); ablated_train_mask_t3.append(msk_train_scores[2])
#     ablated_test_mask_t1.append(msk_test_scores[0]); ablated_test_mask_t2.append(msk_test_scores[1]); ablated_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     ablated_train_mask_sd_t1.append(msk_train_sds[0]); ablated_train_mask_sd_t2.append(msk_train_sds[1]); ablated_train_mask_sd_t3.append(msk_train_sds[2])
#     ablated_test_mask_sd_t1.append(msk_test_sds[0]); ablated_test_mask_sd_t2.append(msk_test_sds[1]); ablated_test_mask_sd_t3.append(msk_test_sds[2]) 

#     # --- 4. Scheduler & Logging ---
#     scheduler_swau.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_swau.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, swau_cnn_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History
#     plot_train_test_dice_history(
#         ablated_train_residual_t1, ablated_train_residual_t2, ablated_train_residual_t3,
#         ablated_test_residual_t1, ablated_test_residual_t2, ablated_test_residual_t3,
#         ablated_train_residual_sd_t1, ablated_train_residual_sd_t2, ablated_train_residual_sd_t3,
#         ablated_test_residual_sd_t1, ablated_test_residual_sd_t2, ablated_test_residual_sd_t3,
#         plot_title='SWAU-Net-CNN (No Spatial Attn) Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History
#     plot_train_test_dice_history(
#         ablated_train_mask_t1, ablated_train_mask_t2, ablated_train_mask_t3,
#         ablated_test_mask_t1, ablated_test_mask_t2, ablated_test_mask_t3,
#         ablated_train_mask_sd_t1, ablated_train_mask_sd_t2, ablated_train_mask_sd_t3,
#         ablated_test_mask_sd_t1, ablated_test_mask_sd_t2, ablated_test_mask_sd_t3,
#         plot_title='SWAU-Net-CNN (No Spatial Attn) Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//CNN_Ablation')
# FINAL_EPOCH = args.num_epochs

# saved_path = save_final_experiment_data(
#     model=swau_cnn_model,  # <-- Using the ablated model instance
#     final_epoch=FINAL_EPOCH,  
#     base_save_dir=ckpt_save_dir,  # <-- Updated save directory name
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual (Using ablated history variables)
#     train_dice_t1=ablated_train_residual_t1, train_dice_t2=ablated_train_residual_t2, train_dice_t3=ablated_train_residual_t3,
#     test_dice_t1=ablated_test_residual_t1, test_dice_t2=ablated_test_residual_t2, test_dice_t3=ablated_test_residual_t3,
#     train_sd_t1=ablated_train_residual_sd_t1, train_sd_t2=ablated_train_residual_sd_t2, train_sd_t3=ablated_train_residual_sd_t3,
#     test_sd_t1=ablated_test_residual_sd_t1, test_sd_t2=ablated_test_residual_sd_t2, test_sd_t3=ablated_test_residual_sd_t3,
    
#     # Mask (Using ablated history variables)
#     train_mask_t1=ablated_train_mask_t1, train_mask_t2=ablated_train_mask_t2, train_mask_t3=ablated_train_mask_t3,
#     test_mask_t1=ablated_test_mask_t1, test_mask_t2=ablated_test_mask_t2, test_mask_t3=ablated_test_mask_t3,
#     train_mask_sd_t1=ablated_train_mask_sd_t1, train_mask_sd_t2=ablated_train_mask_sd_t2, train_mask_sd_t3=ablated_train_mask_sd_t3,
#     test_mask_sd_t1=ablated_test_mask_sd_t1, test_mask_sd_t2=ablated_test_mask_sd_t2, test_mask_sd_t3=ablated_test_mask_sd_t3,

#     # Set model name prefix
#     model_name_prefix="SWAUNet_CNN_ablation" # <-- Updated prefix
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Train Conv LSTM Baseline

In [None]:
# --- Configuration Update (Reconfirmed) ---
BASE_CHANNELS = 16 # 16 channels
args.img_channels = 3 # 3 channels (FAF, Mask, Residual)

# Function to count trainable parameters (Provided in setup)
def count_parameters(model):
    """Counts the total number of trainable parameters in a PyTorch model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- Instantiation and Parameter Calculation ---

print("\nInstantiating the ConvLSTMBaseline model...")

# Instantiate the baseline model (assumes Unet_Enc, Unet_Dec, ConvLSTMCore are defined)
model_baseline = ConvLSTMBaseline(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# Calculate parameters for each main component
e1_params = count_parameters(model_baseline.E1)
p_lstm_params = count_parameters(model_baseline.P_LSTM) # ConvLSTM Core
d1_params = count_parameters(model_baseline.D1)
total_params = e1_params + p_lstm_params + d1_params

# --- Create Table Data ---
param_data = [
    ["Unet_Enc (E1)", "Feature Extractor", f"{e1_params:,}"],
    ["ConvLSTMCore (P_LSTM)", "**Recurrent Temporal Core**", f"**{p_lstm_params:,}**"],
    ["Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
    ["", "", ""], # Separator
    ["**TOTAL**", "**ConvLSTMBaseline Model**", f"**{total_params:,}**"],
]

# --- Print Table ---
print("\n### ConvLSTMBaseline Parameter Summary\n")
print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
## Load Pretrained Model

ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# Load the model
# MODEL_FILENAME = "ConvLSTM_baseline_pretrain_epoch50_20251107_105643.pth" 
# MODEL_FILENAME = "ConvLSTM_baseline_pretrain_epoch50_20251109_130141.pth"
MODEL_FILENAME = "ConvLSTM_baseline_pretrain_epoch50_20251112_235650.pth"
MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

loaded_model, loaded_epoch = load_model(
    model=model_baseline, 
    model_path=MODEL_PATH, 
    device=args.device
)

In [None]:
# --- INITIALIZATION AND HYPERPARAMETER SETUP ---

# # --- OVERFITTING TEST ACTIVATED ---
# # Create an overfit dataset by replicating the first sample (index 0) 100 times
# overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 400, dim=0)
# current_train_data = overfit_data 
# # ----------------------------------

# --- TRAIN WITH FULL DATASET ---
current_train_data = full_clean_data_tensor_cpu
# ----------------------------------

# HYPERPARAMETERS
args.num_epochs = 100
ACCUMULATION_STEPS = 8 
soft_dice = False # Use Soft Dice for stability
lr = 1E-4 # Initial LR

use_median = True
if use_median==True:
    print('Using DICE Median.')
else:
    print('Using DICE Aggregate Mean.')

if soft_dice==True:
    print('Using Soft DICE.')
else:
    print('Using hard DICE.')

optimizer_baseline = torch.optim.Adam(model_baseline.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

scheduler_baseline = ReduceLROnPlateau(
    optimizer_baseline, 
    mode='min', factor=0.5, patience=15, verbose=True, min_lr=1e-6
)

# Loss functions (Ensure these are correctly instantiated elsewhere)
loss_fn_bce = nn.BCELoss(reduction='mean')
loss_fn_l1 = nn.L1Loss(reduction='mean') 
loss_fn_l2 = nn.MSELoss(reduction='mean')
loss_fn_dice = dice_loss # This relies on your custom dice_loss function
loss_fn_gdl = GDLoss(alpha=1, beta=1)

# LLR_WEIGHT and BOTTLENECK_L2_WEIGHT are correctly used below
BOTTLENECK_L2_WEIGHT = 1e-6 

# Freeze Batch Norm layers (essential for small batches)
freeze_batch_norm(model_baseline)

# --- BASELINE HISTORY INITIALIZATION (REQUIRED FOR THIS SCOPE) ---

# Loss/Iteration Tracking
all_iteration_losses = [] 
epoch_iteration_counts = []

# Residual Scores (Mean/Median)
baseline_train_residual_t1, baseline_train_residual_t2, baseline_train_residual_t3 = [], [], []
baseline_test_residual_t1, baseline_test_residual_t2, baseline_test_residual_t3 = [], [], []
# Residual SDs
baseline_train_residual_sd_t1, baseline_train_residual_sd_t2, baseline_train_residual_sd_t3 = [], [], []
baseline_test_residual_sd_t1, baseline_test_residual_sd_t2, baseline_test_residual_sd_t3 = [], [], []

# Mask Scores (Mean/Median)
baseline_train_mask_t1, baseline_train_mask_t2, baseline_train_mask_t3 = [], [], []
baseline_test_mask_t1, baseline_test_mask_t2, baseline_test_mask_t3 = [], [], []
# Mask SDs
baseline_train_mask_sd_t1, baseline_train_mask_sd_t2, baseline_train_mask_sd_t3 = [], [], []
baseline_test_mask_sd_t1, baseline_test_mask_sd_t2, baseline_test_mask_sd_t3 = [], [], []


print(f"\n Starting ConvLSTM Baseline Training for {args.num_epochs} epoch(s)...")

# --- TRAINING LOOP (100 Epochs) ---
for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
    # --- 1. Training Step (Using SWAU_Net's accumulated loss function) ---
    epoch_losses = f_single_epoch_spatiotemporal_accumulated(
        current_train_data, model_baseline, optimizer_baseline, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
        accumulation_steps=ACCUMULATION_STEPS, 
        lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, 
        lambda_recon=0.2, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
    )

    all_iteration_losses.extend(epoch_losses.tolist())
    epoch_iteration_counts.append(len(epoch_losses))
    mean_epoch_loss = np.mean(epoch_losses)
    
    # --- 2. Evaluation Step (Median/SD) ---
    (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
        f_eval_pred_dice_test_set(test_loader, model_baseline, args, soft_dice=soft_dice, use_median=True)
    (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
        f_eval_pred_dice_train_set(current_train_data, model_baseline, args, args.batch_size, soft_dice=soft_dice, use_median=True)

    # --- 3. Accumulation ---
    # Residual Scores
    baseline_train_residual_t1.append(res_train_scores[0]); baseline_train_residual_t2.append(res_train_scores[1]); baseline_train_residual_t3.append(res_train_scores[2])
    baseline_test_residual_t1.append(res_test_scores[0]); baseline_test_residual_t2.append(res_test_scores[1]); baseline_test_residual_t3.append(res_test_scores[2])
    # Residual SDs
    baseline_train_residual_sd_t1.append(res_train_sds[0]); baseline_train_residual_sd_t2.append(res_train_sds[1]); baseline_train_residual_sd_t3.append(res_train_sds[2])
    baseline_test_residual_sd_t1.append(res_test_sds[0]); baseline_test_residual_sd_t2.append(res_test_sds[1]); baseline_test_residual_sd_t3.append(res_test_sds[2])
    
    # Mask Scores
    baseline_train_mask_t1.append(msk_train_scores[0]); baseline_train_mask_t2.append(msk_train_scores[1]); baseline_train_mask_t3.append(msk_train_scores[2])
    baseline_test_mask_t1.append(msk_test_scores[0]); baseline_test_mask_t2.append(msk_test_scores[1]); baseline_test_mask_t3.append(msk_test_scores[2])
    # Mask SDs
    baseline_train_mask_sd_t1.append(msk_train_sds[0]); baseline_train_mask_sd_t2.append(msk_train_sds[1]); baseline_train_mask_sd_t3.append(msk_train_sds[2])
    baseline_test_mask_sd_t1.append(msk_test_sds[0]); baseline_test_mask_sd_t2.append(msk_test_sds[1]); baseline_test_mask_sd_t3.append(msk_test_sds[2]) # Corrected logic using test_sds

    # --- 4. Scheduler & Logging ---
    scheduler_baseline.step(mean_epoch_loss)

    print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_baseline.param_groups[0]['lr']:.2e}) ---")
    print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
    print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
    # --- Per-Epoch Visualizations ---
    print("\n--- Generating Per-Epoch Visualizations ---")
    
    # A. Plot Loss History
    plot_log_loss(all_iteration_losses, epoch_iteration_counts)

    # B. Plot Sample Prediction
    f_display_frames(current_train_data, model_baseline, args, sample_idx=20, T_total=4)
    
    # C. Plot Residual History
    plot_train_test_dice_history(
        baseline_train_residual_t1, baseline_train_residual_t2, baseline_train_residual_t3,
        baseline_test_residual_t1, baseline_test_residual_t2, baseline_test_residual_t3,
        baseline_train_residual_sd_t1, baseline_train_residual_sd_t2, baseline_train_residual_sd_t3,
        baseline_test_residual_sd_t1, baseline_test_residual_sd_t2, baseline_test_residual_sd_t3,
        plot_title='ConvLSTM Baseline Residual Dice History (Median ± SD)'
    )

    # D. Plot Mask History
    plot_train_test_dice_history(
        baseline_train_mask_t1, baseline_train_mask_t2, baseline_train_mask_t3,
        baseline_test_mask_t1, baseline_test_mask_t2, baseline_test_mask_t3,
        baseline_train_mask_sd_t1, baseline_train_mask_sd_t2, baseline_train_mask_sd_t3,
        baseline_test_mask_sd_t1, baseline_test_mask_sd_t2, baseline_test_mask_sd_t3,
        plot_title='ConvLSTM Baseline Full Mask Dice History (Median ± SD)'
    )

# --- Final Message ---
print("\n--- Training Complete ---")

In [None]:
ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Conv LSTM Baseline')
FINAL_EPOCH = args.num_epochs

saved_path = save_final_experiment_data(
    model=model_baseline, 
    final_epoch=FINAL_EPOCH, 
    base_save_dir=ckpt_save_dir, 
    k_fold_index=k,
    
    # --- Pass all history lists from the main training loop ---
    all_iteration_losses=all_iteration_losses,
    epoch_iteration_counts=epoch_iteration_counts,
    
    # Residual
    train_dice_t1=baseline_train_residual_t1, train_dice_t2=baseline_train_residual_t2, train_dice_t3=baseline_train_residual_t3,
    test_dice_t1=baseline_test_residual_t1, test_dice_t2=baseline_test_residual_t2, test_dice_t3=baseline_test_residual_t3,
    train_sd_t1=baseline_train_residual_sd_t1, train_sd_t2=baseline_train_residual_sd_t2, train_sd_t3=baseline_train_residual_sd_t3,
    test_sd_t1=baseline_test_residual_sd_t1, test_sd_t2=baseline_test_residual_sd_t2, test_sd_t3=baseline_test_residual_sd_t3,
    
    # Mask
    train_mask_t1=baseline_train_mask_t1, train_mask_t2=baseline_train_mask_t2, train_mask_t3=baseline_train_mask_t3,
    test_mask_t1=baseline_test_mask_t1, test_mask_t2=baseline_test_mask_t2, test_mask_t3=baseline_test_mask_t3,
    train_mask_sd_t1=baseline_train_mask_sd_t1, train_mask_sd_t2=baseline_train_mask_sd_t2, train_mask_sd_t3=baseline_train_mask_sd_t3,
    test_mask_sd_t1=baseline_test_mask_sd_t1, test_mask_sd_t2=baseline_test_mask_sd_t2, test_mask_sd_t3=baseline_test_mask_sd_t3,

    # Set model name prefix
    model_name_prefix="ConvLSTM_baseline"
)

if saved_path:
    print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Train Simple ConvLSTM

In [None]:
# # --- Configuration Update (Reconfirmed) ---
# BASE_CHANNELS = 16 # 16 channels
# args.img_channels = 3 # 3 channels (FAF, Mask, Residual)

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the ConvLSTM_Simple model...")

# # Instantiate the simple CNN-based model
# # NOTE: This model uses CNN_Unet_Enc and CNN_Unet_Dec
# model_simple = ConvLSTM_Simple(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(model_simple.E1)
# p_lstm_params = count_parameters(model_simple.P_LSTM) # ConvLSTM Core
# d1_params = count_parameters(model_simple.D1)
# total_params = e1_params + p_lstm_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["CNN_Unet_Enc (E1)", "Feature Extractor (Ablated CNN)", f"{e1_params:,}"],
#     ["ConvLSTMCore (P_LSTM)", "**Recurrent Temporal Core**", f"**{p_lstm_params:,}**"],
#     ["CNN_Unet_Dec (D1)", "Frame Reconstructor (Ablated CNN)", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**ConvLSTM_Simple Model**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### ConvLSTM_Simple Parameter Summary\n")
# print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "ConvLSTM_simple_pretrain_epoch50_20251202_103250.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=model_simple, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION FOR LOGGING ---
# # Set to higher epochs to clearly observe overfitting
# args.num_epochs = 50

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # History lists
# all_iteration_losses = [] 
# epoch_iteration_counts = [] # Stores number of iterations in each epoch

# # --- RESIDUAL HISTORY LISTS (Scores and SDs) - USING NEW NAMING CONVENTION (simple_) ---
# # The lists below are initialized with the 'simple_' prefix to store ConvLSTM_Simple results.
# simple_train_t1, simple_train_t2, simple_train_t3 = [], [], []
# simple_test_t1, simple_test_t2, simple_test_t3 = [], [], []
# simple_train_sd_t1, simple_train_sd_t2, simple_train_sd_t3 = [], [], []
# simple_test_sd_t1, simple_test_sd_t2, simple_test_sd_t3 = [], [], []

# # --- MASK HISTORY LISTS (Scores and SDs) - USING NEW NAMING CONVENTION (simple_) ---
# simple_train_mask_t1, simple_train_mask_t2, simple_train_mask_t3 = [], [], []
# simple_test_mask_t1, simple_test_mask_t2, simple_test_mask_t3 = [], [], []
# simple_train_mask_sd_t1, simple_train_mask_sd_t2, simple_train_mask_sd_t3 = [], [], []
# simple_test_mask_sd_t1, simple_test_mask_sd_t2, simple_test_mask_sd_t3 = [], [], []


# # --- TRAINING SETUP (Reconfirming) ---
# # Loss functions (re-instantiated if necessary, matching previous definitions)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_dice = dice_loss
# loss_fn_gdl = GDLoss(alpha=1, beta=1)
# loss_fn_l1 = nn.L1Loss()
# loss_fn_l2 = nn.MSELoss()
# ACCUMULATION_STEPS = 8
# soft_dice = False # Hard Dice

# lr = 1E-4
# # --- OPTIMIZER DEFINITION (Using the working model instance: model_simple) ---
# optimizer = torch.optim.Adam(model_simple.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-5)

# # --- LEARNING RATE SCHEDULER INITIALIZATION ---
# scheduler = ReduceLROnPlateau(
#     optimizer, 
#     mode='min',         # Monitor minimum loss
#     factor=0.5,         
#     patience=15,        
#     verbose=True, 
#     min_lr=1e-6         
# )
# # ---------------------------------------------------

# # -----------------
# # FREEZE BATCH NORM (Using the working model instance: model_simple)
# freeze_batch_norm(model_simple)
# # ----------------

# print(f"\nStarting ConvLSTM Simple training on device: {args.device} for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
# #    # --- 1. Training Step ---
#     # Update model instance to model_simple
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, model_simple, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, lambda_recon=0.5, use_augmentation=True
#     )
    
#     # Store iteration loss and count for plot_log_loss
#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
    
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- NEW: SCHEDULER STEP ---
#     scheduler.step(mean_epoch_loss)
    
#     # --- 3. Logging and Checkpoint ---
    
#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary ---")
#     print(f"Mean Loss: {mean_epoch_loss:.6f}")

#     # --- 4. Per-Epoch Visualizations (MOVED INSIDE LOOP) ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # C. Plot a sample prediction clip 
#     # Update model instance to model_simple
#     f_display_frames(current_train_data, model_simple, args, sample_idx=20, T_total=4)
    
#      # EVALUATION
#     use_median = True
#     if use_median==True:
#         print('Using DICE Median.')
#     else:
#         print('Using DICE Aggregate Mean.')
        
#     if soft_dice==True:
#         print('Using Soft DICE.')
#     else:
#         print('Using hard DICE.')
        
#     # --- Evaluation Step (Using the working model instance: model_simple) ---
    
#     # 1. Unpack Test Results: Returns ((Res Scores, Res SDs), (Msk Scores, Msk SDs))
#     # Update model instance to model_simple
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#     f_eval_pred_dice_test_set(test_loader, model_simple, args, soft_dice=soft_dice, use_median=use_median)

#     # 2. Unpack Train Results
#     # Update model instance to model_simple
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#     f_eval_pred_dice_train_set(current_train_data, model_simple, args, args.batch_size, soft_dice=soft_dice, use_median=use_median)
    
#     # --- 1. Residual Scores and SDs Accumulation (Using simple_ prefixes) ---

#     # Accumulate Scores
#     simple_train_t1.append(res_train_scores[0]); simple_train_t2.append(res_train_scores[1]); simple_train_t3.append(res_train_scores[2])
#     simple_test_t1.append(res_test_scores[0]); simple_test_t2.append(res_test_scores[1]); simple_test_t3.append(res_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     simple_train_sd_t1.append(res_train_sds[0]); simple_train_sd_t2.append(res_train_sds[1]); simple_train_sd_t3.append(res_train_sds[2])
#     simple_test_sd_t1.append(res_test_sds[0]); simple_test_sd_t2.append(res_test_sds[1]); simple_test_sd_t3.append(res_test_sds[2])

#     # --- 2. Mask Scores and SDs Accumulation (Using simple_ prefixes) ---

#     # Accumulate Scores
#     simple_train_mask_t1.append(msk_train_scores[0]); simple_train_mask_t2.append(msk_train_scores[1]); simple_train_mask_t3.append(msk_train_scores[2])
#     simple_test_mask_t1.append(msk_test_scores[0]); simple_test_mask_t2.append(msk_test_scores[1]); simple_test_mask_t3.append(msk_test_scores[2])

#     # Accumulate Standard Deviations (SDs)
#     simple_train_mask_sd_t1.append(msk_train_sds[0]); simple_train_mask_sd_t2.append(msk_train_sds[1]); simple_train_mask_sd_t3.append(msk_train_sds[2])
#     simple_test_mask_sd_t1.append(msk_test_sds[0]); simple_test_mask_sd_t2.append(msk_test_sds[1]); simple_test_mask_sd_t3.append(msk_test_sds[2])


#     # 1. Plot Residual History
#     # Update variable names and plot title
#     plot_train_test_dice_history(
#         simple_train_t1, simple_train_t2, simple_train_t3,
#         simple_test_t1, simple_test_t2, simple_test_t3,
#         simple_train_sd_t1, simple_train_sd_t2, simple_train_sd_t3,          
#         simple_test_sd_t1, simple_test_sd_t2, simple_test_sd_t3,            
#         plot_title='ConvLSTM Simple Residual Dice History (Median ± SD)'
#     )

#     # 2. Plot Mask History 
#     # Update variable names and plot title
#     plot_train_test_dice_history(
#         simple_train_mask_t1, simple_train_mask_t2, simple_train_mask_t3,
#         simple_test_mask_t1, simple_test_mask_t2, simple_test_mask_t3,
#         simple_train_mask_sd_t1, simple_train_mask_sd_t2, simple_train_mask_sd_t3, 
#         simple_test_mask_sd_t1, simple_test_mask_sd_t2, simple_test_mask_sd_t3,    
#         plot_title='ConvLSTM Simple Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# # --- DIRECTORY PATH (Change this for your execution environment) ---
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//ConvLSTM_Simple_Baseline') # Updated Path Name
# FINAL_EPOCH = args.num_epochs

# # --- SAVE FINAL EXPERIMENT DATA ---

# # NOTE: Using the working model instance 'model_simple' for saving.
# saved_path = save_final_experiment_data(
#     model=model_simple,
#     final_epoch=FINAL_EPOCH, 
#     base_save_dir=ckpt_save_dir, 
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop (using 'simple_' prefixes) ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=simple_train_t1, train_dice_t2=simple_train_t2, train_dice_t3=simple_train_t3,
#     test_dice_t1=simple_test_t1, test_dice_t2=simple_test_t2, test_dice_t3=simple_test_t3,
#     train_sd_t1=simple_train_sd_t1, train_sd_t2=simple_train_sd_t2, train_sd_t3=simple_train_sd_t3,
#     test_sd_t1=simple_test_sd_t1, test_sd_t2=simple_test_sd_t2, test_sd_t3=simple_test_sd_t3,
    
#     # Mask
#     train_mask_t1=simple_train_mask_t1, train_mask_t2=simple_train_mask_t2, train_mask_t3=simple_train_mask_t3,
#     test_mask_t1=simple_test_mask_t1, test_mask_t2=simple_test_mask_t2, test_mask_t3=simple_test_mask_t3,
#     train_mask_sd_t1=simple_train_mask_sd_t1, train_mask_sd_t2=simple_train_mask_sd_t2, train_mask_sd_t3=simple_train_mask_sd_t3,
#     test_mask_sd_t1=simple_test_mask_sd_t1, test_mask_sd_t2=simple_test_mask_sd_t2, test_mask_sd_t3=simple_test_mask_sd_t3,

#     # Set model name prefix to reflect the specific ablation
#     model_name_prefix="CONVLSTM_SIMPLE" # <--- Changed prefix
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Ablate Spatiotemporal Attention

In [None]:
# # --- Configuration Update for Memory Reduction (Confirmed from previous turn) ---
# BASE_CHANNELS = 16 # Reduced from 24 to 16
# args.d_attn1 = 128 # Reduced from 192 to 128
# args.d_attn2 = 256 # Reduced from 384 to 256

# # Function to count trainable parameters (Provided in setup)
# def count_parameters(model):
#     """Counts the total number of trainable parameters in a PyTorch model."""
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # --- Instantiation and Parameter Calculation ---

# print("\nInstantiating the **UPredNet (SWA Ablation)** model with the updated configuration...")

# # Instantiate the UPredNet model (UPredNet uses E1, CFB_enc, P=DynNet, CFB_dec, D1)
# # We assume UPredNet is accessible, and base_channels defaults to 16 if not set.
# upred_model = UPredNet(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # Calculate parameters for each main component
# e1_params = count_parameters(upred_model.E1)

# # Calculate parameters for both CFB modules separately
# cfb_enc_params = count_parameters(upred_model.CFB_enc) 
# cfb_dec_params = count_parameters(upred_model.CFB_dec) 
# cfb_total_params = cfb_enc_params + cfb_dec_params

# # The UPredNet model does NOT have an 'SWA' module. This parameter should be 0.
# swa_params = 0 

# # P is the DynNet
# p_params = count_parameters(upred_model.P)
# d1_params = count_parameters(upred_model.D1)

# # Ensure all components are summed up for the total count
# total_params = e1_params + cfb_total_params + swa_params + p_params + d1_params

# # --- Create Table Data ---
# param_data = [
#     ["Unet_Enc (E1)", "Feature Extractor (No Spatial Attention)", f"{e1_params:,}"],
#     ["CFB (Total, 2x Modules)", "**Pre/Post-Dynamics Mixer**", f"**{cfb_total_params:,}**"],
#     ["**SWA Module**", "**Ablated**", f"**{swa_params:,}**"], # SWA is 0
#     ["DynNet (P)", "Temporal Feature Predictor (Evolution)", f"{p_params:,}"],
#     ["Unet_Dec (D1)", "Frame Reconstructor", f"{d1_params:,}"],
#     ["", "", ""], # Separator
#     ["**TOTAL**", "**UPredNet (SWA Ablation)**", f"**{total_params:,}**"],
# ]

# # --- Print Table ---
# print("\n### UPredNet (SWA Ablation) Component Parameter Summary\n")
# try:
#     from tabulate import tabulate
#     print(tabulate(param_data, headers=["Component", "Role", "Parameters (Trainable)"], tablefmt="fancy_grid", colalign=("left", "left", "right")))
# except ImportError:
#     print("Tabulate library not available. Printing raw data.")
#     for row in param_data:
#         print(row)

In [None]:
# ## Load Pretrained Model

# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//Pretrained Models')
# # Load the model
# MODEL_FILENAME = "UPredNet_pretrain_epoch50_20251121_033305.pth"
# MODEL_PATH = ckpt_save_dir / MODEL_FILENAME

# loaded_model, loaded_epoch = load_model(
#     model=upred_model, 
#     model_path=MODEL_PATH, 
#     device=args.device
# )

In [None]:
# # --- INITIALIZATION AND HYPERPARAMETER SETUP ---

# # # --- OVERFITTING TEST ACTIVATED ---
# # # Create an overfit dataset by replicating the first sample (index 0) 100 times
# # overfit_data = torch.cat([full_clean_data_tensor_cpu[0].unsqueeze(0)] * 400, dim=0)
# # current_train_data = overfit_data 
# # # ----------------------------------

# # --- TRAIN WITH FULL DATASET ---
# current_train_data = full_clean_data_tensor_cpu
# # ----------------------------------

# # HYPERPARAMETERS
# args.num_epochs = 60
# ACCUMULATION_STEPS = 8 
# soft_dice = False # Use Soft Dice for stability
# lr = 1E-4 # Initial LR

# use_median = True
# if use_median==True:
#     print('Using DICE Median.')
# else:
#     print('Using DICE Aggregate Mean.')

# if soft_dice==True:
#     print('Using Soft DICE.')
# else:
#     print('Using hard DICE.')

# # --- MODEL INSTANTIATION (REPLACED CONVLSTM with UPredNet) ---
# uprednet_model = UPredNet(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)

# # --- OPTIMIZER, SCHEDULER RENAMING (Corrected all references) ---
# optimizer_uprednet = torch.optim.Adam(uprednet_model.parameters(), lr=lr, betas=(0.95, 0.999), weight_decay=1E-5)

# scheduler_uprednet = ReduceLROnPlateau(
#     optimizer_uprednet, 
#     mode='min', factor=0.5, patience=15, verbose=True, min_lr=1e-6
# )

# # Loss functions (Ensure these are correctly instantiated elsewhere)
# loss_fn_bce = nn.BCELoss(reduction='mean')
# loss_fn_l1 = nn.L1Loss(reduction='mean') 
# loss_fn_l2 = nn.MSELoss(reduction='mean')
# loss_fn_dice = dice_loss # This relies on your custom dice_loss function
# loss_fn_gdl = GDLoss(alpha=1, beta=1)

# # LLR_WEIGHT and BOTTLENECK_L2_WEIGHT are correctly used below
# BOTTLENECK_L2_WEIGHT = 1e-6 

# # Freeze Batch Norm layers (essential for small batches)
# freeze_batch_norm(uprednet_model)

# # --- UPREDNET HISTORY INITIALIZATION (Corrected all references) ---

# # Loss/Iteration Tracking
# all_iteration_losses = [] 
# epoch_iteration_counts = []

# # Residual Scores (Mean/Median)
# uprednet_train_residual_t1, uprednet_train_residual_t2, uprednet_train_residual_t3 = [], [], []
# uprednet_test_residual_t1, uprednet_test_residual_t2, uprednet_test_residual_t3 = [], [], []
# # Residual SDs
# uprednet_train_residual_sd_t1, uprednet_train_residual_sd_t2, uprednet_train_residual_sd_t3 = [], [], []
# uprednet_test_residual_sd_t1, uprednet_test_residual_sd_t2, uprednet_test_residual_sd_t3 = [], [], []

# # Mask Scores (Mean/Median)
# uprednet_train_mask_t1, uprednet_train_mask_t2, uprednet_train_mask_t3 = [], [], []
# uprednet_test_mask_t1, uprednet_test_mask_t2, uprednet_test_mask_t3 = [], [], []
# # Mask SDs
# uprednet_train_mask_sd_t1, uprednet_train_mask_sd_t2, uprednet_train_mask_sd_t3 = [], [], []
# uprednet_test_mask_sd_t1, uprednet_test_mask_sd_t2, uprednet_test_mask_sd_t3 = [], [], []


# print(f"\n Starting **UPredNet (SWA Ablation)** Training for {args.num_epochs} epoch(s)...")

# # --- TRAINING LOOP (100 Epochs) ---
# for epoch in tqdm(np.arange(args.num_epochs), desc="Epoch Progress"):
    
#     # --- 1. Training Step (Using UPredNet's accumulated loss function) ---
#     epoch_losses = f_single_epoch_spatiotemporal_accumulated(
#         current_train_data, uprednet_model, optimizer_uprednet, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size, 
#         accumulation_steps=ACCUMULATION_STEPS, 
#         lambda_gdl=0, lambda_faf=0.5, lambda_mask=1.0, lambda_residual=5.0, 
#         lambda_recon=0.2, lambda_bottleneck=BOTTLENECK_L2_WEIGHT, use_augmentation=True
#     )

#     all_iteration_losses.extend(epoch_losses.tolist())
#     epoch_iteration_counts.append(len(epoch_losses))
#     mean_epoch_loss = np.mean(epoch_losses)
    
#     # --- 2. Evaluation Step (Median/SD) ---
#     (res_test_scores, res_test_sds), (msk_test_scores, msk_test_sds) = \
#         f_eval_pred_dice_test_set(test_loader, uprednet_model, args, soft_dice=soft_dice, use_median=True)
#     (res_train_scores, res_train_sds), (msk_train_scores, msk_train_sds) = \
#         f_eval_pred_dice_train_set(current_train_data, uprednet_model, args, args.batch_size, soft_dice=soft_dice, use_median=True)

#     # --- 3. Accumulation (Corrected all history variable assignment) ---
#     # Residual Scores
#     uprednet_train_residual_t1.append(res_train_scores[0]); uprednet_train_residual_t2.append(res_train_scores[1]); uprednet_train_residual_t3.append(res_train_scores[2])
#     uprednet_test_residual_t1.append(res_test_scores[0]); uprednet_test_residual_t2.append(res_test_scores[1]); uprednet_test_residual_t3.append(res_test_scores[2])
#     # Residual SDs
#     uprednet_train_residual_sd_t1.append(res_train_sds[0]); uprednet_train_residual_sd_t2.append(res_train_sds[1]); uprednet_train_residual_sd_t3.append(res_train_sds[2])
#     uprednet_test_residual_sd_t1.append(res_test_sds[0]); uprednet_test_residual_sd_t2.append(res_test_sds[1]); uprednet_test_residual_sd_t3.append(res_test_sds[2])
    
#     # Mask Scores
#     uprednet_train_mask_t1.append(msk_train_scores[0]); uprednet_train_mask_t2.append(msk_train_scores[1]); uprednet_train_mask_t3.append(msk_train_scores[2])
#     uprednet_test_mask_t1.append(msk_test_scores[0]); uprednet_test_mask_t2.append(msk_test_scores[1]); uprednet_test_mask_t3.append(msk_test_scores[2])
#     # Mask SDs
#     uprednet_train_mask_sd_t1.append(msk_train_sds[0]); uprednet_train_mask_sd_t2.append(msk_train_sds[1]); uprednet_train_mask_sd_t3.append(msk_train_sds[2])
#     uprednet_test_mask_sd_t1.append(msk_test_sds[0]); uprednet_test_mask_sd_t2.append(msk_test_sds[1]); uprednet_test_mask_sd_t3.append(msk_test_sds[2]) 

#     # --- 4. Scheduler & Logging ---
#     scheduler_uprednet.step(mean_epoch_loss)

#     print(f"\n--- Epoch {epoch+1}/{args.num_epochs} Summary (LR: {optimizer_uprednet.param_groups[0]['lr']:.2e}) ---")
#     print(f"Mean Loss: **{mean_epoch_loss:.6f}**")
    
#     print("\nResidual T=3 Test Median Dice: {:.4f} (SD: {:.4f})".format(res_test_scores[2], res_test_sds[2]))
    
#     # --- Per-Epoch Visualizations ---
#     print("\n--- Generating Per-Epoch Visualizations ---")
    
#     # A. Plot Loss History
#     plot_log_loss(all_iteration_losses, epoch_iteration_counts)

#     # B. Plot Sample Prediction
#     f_display_frames(current_train_data, uprednet_model, args, sample_idx=20, T_total=4)
    
#     # C. Plot Residual History (Corrected all history variable references and title)
#     plot_train_test_dice_history(
#         uprednet_train_residual_t1, uprednet_train_residual_t2, uprednet_train_residual_t3,
#         uprednet_test_residual_t1, uprednet_test_residual_t2, uprednet_test_residual_t3,
#         uprednet_train_residual_sd_t1, uprednet_train_residual_sd_t2, uprednet_train_residual_sd_t3,
#         uprednet_test_residual_sd_t1, uprednet_test_residual_sd_t2, uprednet_test_residual_sd_t3,
#         plot_title='UPredNet (SWA Ablation) Residual Dice History (Median ± SD)'
#     )

#     # D. Plot Mask History (Corrected all history variable references and title)
#     plot_train_test_dice_history(
#         uprednet_train_mask_t1, uprednet_train_mask_t2, uprednet_train_mask_t3,
#         uprednet_test_mask_t1, uprednet_test_mask_t2, uprednet_test_mask_t3,
#         uprednet_train_mask_sd_t1, uprednet_train_mask_sd_t2, uprednet_train_mask_sd_t3,
#         uprednet_test_mask_sd_t1, uprednet_test_mask_sd_t2, uprednet_test_mask_sd_t3,
#         plot_title='UPredNet (SWA Ablation) Full Mask Dice History (Median ± SD)'
#     )

# # --- Final Message ---
# print("\n--- Training Complete ---")

In [None]:
# ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models//UPredNet')
# FINAL_EPOCH = args.num_epochs

# saved_path = save_final_experiment_data(
#     model=uprednet_model,  # Model reference changed to uprednet_model
#     final_epoch=FINAL_EPOCH,  
#     base_save_dir=ckpt_save_dir,  
#     k_fold_index=k,
    
#     # --- Pass all history lists from the main training loop (Variables changed from 'baseline' to 'uprednet') ---
#     all_iteration_losses=all_iteration_losses,
#     epoch_iteration_counts=epoch_iteration_counts,
    
#     # Residual
#     train_dice_t1=uprednet_train_residual_t1, train_dice_t2=uprednet_train_residual_t2, train_dice_t3=uprednet_train_residual_t3,
#     test_dice_t1=uprednet_test_residual_t1, test_dice_t2=uprednet_test_residual_t2, test_dice_t3=uprednet_test_residual_t3,
#     train_sd_t1=uprednet_train_residual_sd_t1, train_sd_t2=uprednet_train_residual_sd_t2, train_sd_t3=uprednet_train_residual_sd_t3,
#     test_sd_t1=uprednet_test_residual_sd_t1, test_sd_t2=uprednet_test_residual_sd_t2, test_sd_t3=uprednet_test_residual_sd_t3,
    
#     # Mask
#     train_mask_t1=uprednet_train_mask_t1, train_mask_t2=uprednet_train_mask_t2, train_mask_t3=uprednet_train_mask_t3,
#     test_mask_t1=uprednet_test_mask_t1, test_mask_t2=uprednet_test_mask_t2, test_mask_t3=uprednet_test_mask_t3,
#     train_mask_sd_t1=uprednet_train_mask_sd_t1, train_mask_sd_t2=uprednet_train_mask_sd_t2, train_mask_sd_t3=uprednet_train_mask_sd_t3,
#     test_mask_sd_t1=uprednet_test_mask_sd_t1, test_mask_sd_t2=uprednet_test_mask_sd_t2, test_mask_sd_t3=uprednet_test_mask_sd_t3,

#     # Set model name prefix
#     model_name_prefix="UPredNet"
# )

# if saved_path:
#     print(f"\n All experiment data saved successfully to: {saved_path.name}")

## Experiment Analysis

In [None]:
## Experiment Analysis (SWAU NET)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\SWAU_Net'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (CONV LSTM)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\Conv LSTM Baseline'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2 

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (AXIAL UNet, i.e. ablate SWA)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\Axial_UNet'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2 

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (Ablate spatial attention)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\CNN_Ablation'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2 

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (UPredNet, i.e. ablate spatiotemporal attention)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\UPredNet'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2 

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (SWAU NET, No Pretraining)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\No Pretraining'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (SWAU NET, No Augmentations)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\No Augmentation'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (SWAU NET, No Augmentations)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\CFB Ablation'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (SWAU NET, No Augmentations)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\CFB Ablation'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
## Experiment Analysis (SWAU NET, No Augmentations)

# --- CONFIGURATION ---
# Base directory where all K-fold folders (e.g., 'Fold_0', 'Fold_1') are saved.
BASE_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\ConvLSTM_Simple_Baseline'

# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
MEAN_SCORE_KEY = 'test_t3' # Contains the median score time series
STD_SCORE_KEY = 'test_sd_t3' # Contains the standard deviation time series

# Number of last epochs to average over for stability
LAST_EPOCH = 50
N_EPOCHS_TO_AVERAGE = 10

# Index corresponding to the critical T=3 prediction step (T=1 is 0, T=2 is 1, T=3 is 2)
# NOTE: This index is only used if the underlying array is not sliced for T3 already,
# but since the saved arrays are single time series lists, we won't use it for slicing.
T_CRITICAL_INDEX = 2

# --- MAIN ANALYSIS FUNCTION ---
if __name__ == "__main__":
    analyze_kfold_results(BASE_DIR, RESIDUAL_HISTORY_FILE, MASK_HISTORY_FILE, MEAN_SCORE_KEY, STD_SCORE_KEY, N_EPOCHS_TO_AVERAGE, T_CRITICAL_INDEX, LAST_EPOCH)


In [None]:
# Expected filenames for the NumPy array dictionaries containing the TEST dice history.
RESIDUAL_HISTORY_FILE = 'residual_dice_history.npy'
MASK_HISTORY_FILE = 'mask_dice_history.npy'

# Keys used to access the relevant arrays inside the saved dictionaries
# These keys are assumed to hold the scores for the critical T=3 prediction step.
MEAN_SCORE_KEY = 'test_t3' 
STD_SCORE_KEY = 'test_sd_t3' 

# Index corresponding to the critical T=3 prediction step (index 2 for a sequence [T1, T2, T3])
T_CRITICAL_INDEX = 2 

# --- 2. MODEL PATHS (Replace these with your actual directories) ---

# Model A: Full SWAU-Net (The proposed best model)
MODEL_A_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\SWAU_Net' 

# Model B: UPredNet (The Causal CNN Aggregator Baseline)
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\CFB Ablation'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\No Augmentation'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\No Pretraining'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\UPredNet'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\CNN_Ablation'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\Axial_UNet'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\Conv LSTM Baseline'
# MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\DynNet_Ablation'
MODEL_B_DIR = r'C:\Users\Pracioppo\Desktop\GA Forecasting\Saved Models\ConvLSTM_Simple_Baseline'

# --- 3. EXECUTION OF COMPARISON ---

print("Starting statistical analysis...")
compare_models(
    model_a_dir=MODEL_A_DIR, 
    model_b_dir=MODEL_B_DIR,
    # Pass configuration constants to the function (fixes the NameError)
    RESIDUAL_HISTORY_FILE=RESIDUAL_HISTORY_FILE, 
    MASK_HISTORY_FILE=MASK_HISTORY_FILE,
    MEAN_SCORE_KEY=MEAN_SCORE_KEY, 
    STD_SCORE_KEY=STD_SCORE_KEY,
    T_CRITICAL_INDEX=T_CRITICAL_INDEX,
    # Fixed statistical parameters based on the study design
    N_TOTAL_SAMPLES=66, 
    K_FOLDS=5, 
    n_epochs=10, 
    last_epoch=50
)