## Setup

In [None]:
## Setup

%cd /Users/Pracioppo/Desktop/GA Forecasting

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import functools
from torch.nn import init
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau 

import numpy as np
import argparse
import random
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from pathlib import Path
from datetime import datetime
import scipy.io as sio
from PIL import Image
import cv2
from tabulate import tabulate
from collections import defaultdict

# Assuming these utilities are available as imported
from torch.utils.data import Dataset, DataLoader, random_split, Subset 

# ---------------------------------------

from preprocessing_utils import f_rescale_dataset, f_Residuals, f_reshape_training_data, f_rotate_and_zoom, f_random_crop, f_rotate_and_zoom_all, f_crop_all, f_flip_all, f_augment_dataset2

from data_utils import DataWrapper, visualize_sample, compare_split_masks

from visualization import f_display_autoencoder, plot_log_loss, f_display_frames

from models import init_weights, count_parameters
from models import rotate_half, RotaryPositionalEmbedding, RoPEMultiheadAttention, RoPETransformerEncoderLayer, ResidualBlock, ChannelReducer, Unet_Enc, Unet_Dec, U_Net_AE

from augmentation_utils import f_augment_spatial_and_intensity
from training_utils import dsc, dice_loss, GDLoss
from training_utils import f_single_epoch_AE, f_single_epoch_UPredNet, calculate_total_loss, f_single_epoch_UPredNet_Accumulated

# ---------------------------------------

ckpt_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')
tensorboard_save_dir = Path('/Users/Pracioppo/Desktop//GA Forecasting//Saved Models')
                            
start_epoch = 0

resume_ckpt = None

summary_writer = SummaryWriter(tensorboard_save_dir.absolute().as_posix())


# --- SETUP ---
# Define a simple placeholder for command-line arguments and configuration
parser = argparse.ArgumentParser('AE Model Args')
args = parser.parse_args(args=[])

# Defining essential arguments (set to match your intended lightweight AE setup)
args.N = 4                       # CRITICAL FIX: Batch size is 4 for memory safety
args.nhead = 4                   # CRITICAL FIX: Reduced heads from 8 to 4
args.d_attn1 = 192               # FFN dimension for L3 (112 channels)
args.d_attn2 = 384               # FFN dimension for L4 (224 channels)
args.img_channels = 3            # Three grayscale images (FAF, masks, growth masks)
args.img_sz = 256                # Image size 256x256
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BASE_CHANNELS = 24               # Reduced from 28 to 24 for extra parameter savings

# Training loop arguments
args.num_epochs = 10             # UPDATED: Set to 1 epoch as requested
args.show_example_epochs = 5
args.batch_size = args.N        # Batch size for iteration is args.N
args.num_t_steps = 4            # Time steps (used only for data simulation/flattening)

# Initialize paths (for saving checkpoints)
resume_AE_ckpt = Path('./ae_checkpoints')

print(torch.__version__)

torch.cuda.is_available()

args.device = torch.device('cuda:0')
print(f"Using {args.device} device")

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(2025)

## Load the Data

In [None]:
# --- Data Loading and Global Preprocessing ---

# WARNING: Replace this with your actual data directory
DATA_DIR = Path('/Users/Pracioppo/Desktop/GA Forecasting/GA_Zubens_data') 
print("Loading data...")

FAFS_PATH = DATA_DIR / 'fafs_reg3.mat'
MASKS_PATH = DATA_DIR / 'masks_reg3.mat'

FAFs = torch.from_numpy(sio.loadmat(FAFS_PATH)["fafs_reg3"].astype(np.float32)).reshape(660,1,4,256,256)
masks = torch.from_numpy(sio.loadmat(MASKS_PATH)["masks_reg3"].astype(np.float32)).reshape(660,1,4,256,256)

# Global normalization
masks /= torch.max(masks) if torch.max(masks) > 0 else 1.0
FAFs /= torch.max(FAFs) if torch.max(FAFs) > 0 else 1.0

# Residual calculation and normalization (applied globally to all 660 samples)
all_residuals = f_Residuals(masks)

# --- 3. K-Fold Split using Subsets ---

# Instantiate the Dataset with all pre-processed feature tensors
all_data = DataWrapper(FAFs, masks, all_residuals)
print(f"\nTotal Dataset size (samples): {len(all_data)}")

# K-FOLD GROUP SETUP (To avoid data leakage from augmented samples)
SAMPLES_PER_GROUP = 10
N_SAMPLES = len(all_data) # N_SAMPLES = 660
N_GROUPS = N_SAMPLES // SAMPLES_PER_GROUP # N_GROUPS = 66
K_FOLDS = 5 

# Calculate fold sizes based on groups (66 groups in 5 folds: 4 folds of 13 groups, 1 fold of 14 groups)
group_fold_size_base = N_GROUPS // K_FOLDS # 13
group_fold_remainder = N_GROUPS % K_FOLDS # 1

k = 0 # The current fold index (fixed at 0 for this script)

# Calculate group indices for the current fold (k=0 gets the larger remainder group)
G_all_indices = np.arange(N_GROUPS)

# Determine the size of the first 'remainder' folds
# For k=0, this will be 13 + 1 = 14 groups
current_group_fold_size = group_fold_size_base + (1 if k < group_fold_remainder else 0) 

G_start_idx = k * group_fold_size_base + min(k, group_fold_remainder) # 0
G_end_idx = G_start_idx + current_group_fold_size # 14

G_test_indices = G_all_indices[G_start_idx:G_end_idx]
G_train_indices = np.concatenate([G_all_indices[:G_start_idx], G_all_indices[G_end_idx:]])

# Map Group indices back to Sample indices (0-659)
def map_group_to_sample_indices(group_indices, samples_per_group):
    # For each group index i, generate indices [i*10, i*10 + 1, ..., i*10 + 9]
    # Using np.concatenate for efficiency over Python loops for large arrays
    sample_indices = np.concatenate([
        np.arange(g * samples_per_group, (g + 1) * samples_per_group)
        for g in group_indices
    ])
    return sample_indices

train_indices = map_group_to_sample_indices(G_train_indices, SAMPLES_PER_GROUP)
test_indices = map_group_to_sample_indices(G_test_indices, SAMPLES_PER_GROUP)

sz = len(test_indices) # Should be 140 for k=0

# Create Subsets for training and testing
train_dataset = Subset(all_data, train_indices)
test_dataset = Subset(all_data, test_indices)

print("--- K-Fold Split Result (Group-Aware) ---")
print(f"Total Groups: {N_GROUPS}, Samples/Group: {SAMPLES_PER_GROUP}")
print(f"Fold {k} Groups: {len(G_test_indices)} (Fold Size: {sz})")
print(f"Train Dataset Samples: {len(train_dataset)} (Expected 520)")
print(f"Test Dataset Samples: {len(test_dataset)} (Expected 140)")

# --- 4. DATA ASSEMBLY FOR MANUAL ITERATION ---

# We use a DataLoader only to efficiently stack the training Subset items into one tensor.
N_TRAIN_SAMPLES = len(train_dataset)

temp_loader = DataLoader(
    train_dataset,
    batch_size=N_TRAIN_SAMPLES,
    shuffle=False, # Must be False for sequential index fetching
    num_workers=0 
)

# Fetch the entire training dataset into a single CPU tensor
# Shape: [N_TRAIN_SAMPLES, C, T, H, W]
for batch in temp_loader:
    full_clean_data_tensor_cpu = batch 
    break

print(f"\nASSEMBLED TENSOR: full_clean_data_tensor_cpu size: {full_clean_data_tensor_cpu.size()}")

# --- DataLoader Setup (Only for the Test/Validation Set) ---

test_loader = DataLoader(
    test_dataset,
    batch_size=args.N, # Use args.N (batch_size for GPU)
    shuffle=False,
    num_workers=0,
    pin_memory=True
)


# --- UPDATED VISUALIZATION EXAMPLE (Side-by-Side) ---

# Example function calls
visualize_sample(train_dataset, test_dataset, sample_idx=20, dataset_name='test')
# visualize_sample(sample_idx=500, dataset_name='train')

In [None]:
# for idx in range(52):
#     visualize_sample(sample_idx=idx*10, dataset_name='train')
    
# for idx in range(14):
#     visualize_sample(sample_idx=idx*10, dataset_name='test')

# Run the full visual split comparison
compare_split_masks(train_dataset, test_dataset, channel_type='mask')

In [None]:
compare_split_masks(train_dataset, test_dataset, channel_type='residual', time_step=1)

## Train the autoencoder

In [None]:
# --- TRAINING EXECUTION ---

# Create necessary path for checkpointing
resume_AE_ckpt.mkdir(exist_ok=True)
print(f"\nCheckpoints will be saved to: {resume_AE_ckpt.resolve()}")

# Model Initialization
AE_model = U_Net_AE(args, img_channels=args.img_channels, base_channels=BASE_CHANNELS).to(args.device)


total_params = count_parameters(AE_model)
encoder_params = count_parameters(AE_model.E1)
decoder_params = count_parameters(AE_model.D1)

# Create a table for clean output
param_data = [
    ["Total AE Model (U_Net_AE)", f"{total_params:,}"],
    ["Encoder (E1)", f"{encoder_params:,}"],
    ["Decoder (D1)", f"{decoder_params:,}"],
]

print(tabulate(param_data, headers=["Component", "Parameters (Trainable)"], tablefmt="fancy_grid"))

print(f"\nModel Initialized with a total of {total_params:,} trainable parameters.")

# -------------------------------------

# Training setup
log_mean_epoch_losses = np.zeros(args.num_epochs)

# Loss functions 
loss_fn_bce = nn.BCELoss(reduction='mean') # FAF BCE Loss
loss_fn_l1 = nn.L1Loss(reduction='mean') # Used for LLR-Loss
loss_fn_l2 = nn.MSELoss(reduction='mean') # Used for Bottleneck Regularization
loss_fn_dice = dice_loss # Custom Dice Loss (includes BCE)
loss_fn_gdl = GDLoss(alpha=1, beta=1) # Instantiating GDL with defaults

lr = 1E-3
optimizer = torch.optim.Adam(AE_model.parameters(), lr=lr, betas=(0.9, 0.999))

print(f"\nStarting training on device: {args.device} for {args.num_epochs} epochs using real data...")

LLR_WEIGHT = 1e-5 # L1 Penalty to encourage sparsity
BOTTLENECK_L2_WEIGHT = 1e-6 # Bottleneck L2 weight

# Training loop
for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):
    # CORRECTION: Replaced 'train_loader' with 'full_clean_data_tensor_cpu' 
    # and added 'args.batch_size' as the required argument.
    epoch_losses = f_single_epoch_AE(
        full_clean_data_tensor_cpu, AE_model, optimizer, loss_fn_bce, loss_fn_dice, loss_fn_gdl, loss_fn_l1, loss_fn_l2, args, args.batch_size,
        lambda_gdl=1e-3, lambda_residual=5.0, lambda_llr=LLR_WEIGHT, lambda_bottleneck=BOTTLENECK_L2_WEIGHT
    )
    
    # Calculate and log the mean epoch loss
    mean_epoch_loss = np.mean(epoch_losses)
    if mean_epoch_loss > 0 and not np.isnan(mean_epoch_loss) and np.isfinite(mean_epoch_loss):
        log_mean_epoch_losses[epoch] = np.log(mean_epoch_loss)
    else:
        log_mean_epoch_losses[epoch] = -100 

    # Visualization and Plotting (at set intervals)
    if np.mod(epoch, args.show_example_epochs) == 0:
        print(f"\n--- Epoch {epoch}: Mean Loss (log): {log_mean_epoch_losses[epoch]:.4f} ---")
        # Visualizing the first sample (index 0) at the first time step (time_step 0).
        for j in range(4):
            f_display_autoencoder(train_dataset, AE_model, args, time_step=j)
        
        plt.figure()
        plt.plot(log_mean_epoch_losses[0:epoch+1][log_mean_epoch_losses[0:epoch+1] > -99]) 
        plt.title('Log Mean Epoch Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Log Loss')
        plt.grid()
        plt.show()

print("\nTraining complete. Saving final checkpoints...")

# Final Checkpoint Saving
F_Enc_path_save = resume_AE_ckpt.joinpath('U_F_Enc_real_data_ckpt1.pth')
F_Dec_path_save = resume_AE_ckpt.joinpath('U_F_Dec_real_data_ckpt1.pth')

# Save the full Unet_Enc and Unet_Dec modules
torch.save(AE_model.E1.state_dict(), F_Enc_path_save)
torch.save(AE_model.D1.state_dict(), F_Dec_path_save)

print(f"Checkpoints saved: {F_Enc_path_save.name} and {F_Dec_path_save.name}")
print("Script finished execution.")

for j in range(4):
    f_display_autoencoder(train_dataset, AE_model, args, time_step=j)
