In [None]:
# CURR V2 V2

#!/usr/bin/env python3
# CURRENT
"""
Fully‑GPU Batched Sculpt3DEnv + DQN Agent (Hybrid State & Noisy Nets)


• State Representation: Hybrid - CNN processes Stock/Mask grids, concatenates
                      with normalized XYZ coordinates.
• Agent Network: Uses a 3D CNN feature extractor + Dense head with Noisy Layers.
• Exploration: Uses Noisy Networks instead of epsilon-greedy.
• N envs stepped in parallel during training.
• Batched replay buffer insertion.
• Includes periodic evaluation during training and final evaluation/rendering.
• Plots carving performance trend at the end of training.
• **NEW:** Periodic saving of model weights during training.


NOTE: Monitor memory usage. Tune hyperparameters (LR, network, etc.).
     Periodic evaluation adds runtime overhead.
"""


import os
import tensorflow as tf
import numpy as np
import random
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import time
import math # For noisy layer initialization
import io   # For TensorBoard image logging


# --- Configure GPU Memory Growth ---
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
 try:
   for gpu in gpus:
     tf.config.experimental.set_memory_growth(gpu, True)
   logical_gpus = tf.config.experimental.list_logical_devices('GPU')
   print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs configured with Memory Growth.")
 except RuntimeError as e:
   print(f"Memory growth error: {e}") # Might happen if GPU already initialized
else:
   print("No GPU detected by TensorFlow.")




# -----------------------------------------------------------------------------
# 1) Batched Sculpt3DEnvTF (Hybrid Observation)
# -----------------------------------------------------------------------------
class BatchedSculpt3DEnvTF:
   def __init__(self, grid_size=16, max_steps=200, n_envs=16):
       G, N = grid_size, n_envs
       if N <= 0: raise ValueError("n_envs must be positive.")
       self.G, self.N, self.max_steps = G, N, max_steps
       self.flat_dim = G*G*G
       self.grid_obs_shape = (G, G, G, 2) # Channels: Stock, ShapeMask
       self.coord_obs_shape = (3,)        # Channels: X, Y, Z (normalized)


       # Use tf operations for shape generation to keep it within TF graph if possible
       coords_range = tf.range(G, dtype=tf.float32)
       coords = tf.stack(tf.meshgrid(coords_range, coords_range, coords_range, indexing='ij'), axis=-1) # [G,G,G,3]
       center = tf.constant([G/2 - 0.5, G/2 - 0.5, G/2 - 0.5], tf.float32) # Center point
       dist2 = tf.reduce_sum(tf.square(coords - center), axis=-1) # [G,G,G] squared distance from center
       radius_sq = tf.square(tf.cast(G // 2 - 1, tf.float32)) # Define radius (slightly smaller than half grid)
       mask3d = dist2 <= radius_sq # Boolean mask [G,G,G]
       mask_flat = tf.reshape(mask3d, [-1]) # [G^3]


       # Initialize state variables
       self.shape_mask = tf.Variable(tf.tile(mask_flat[None, :], [N, 1]), trainable=False, dtype=tf.bool, name="shape_mask")
       self.stock = tf.Variable(tf.ones([N, self.flat_dim], dtype=tf.bool), trainable=False, name="stock") # Start with full stock
       self.pos = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="pos")
       self.steps = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="steps")
       self.done = tf.Variable(tf.zeros([N], dtype=tf.bool), trainable=False, name="done")


       # Calculate shifts (outside tf.function, Python is fine)
       G_py = grid_size # Use Python int for calculation
       def to_flat_py(dx, dy, dz): return dx*G_py*G_py + dy*G_py + dz
       moves = [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)] # dx, dy, dz
       shifts_py = [to_flat_py(*m) for m in moves]
       self.shifts = tf.constant(shifts_py, dtype=tf.int32, name="shifts")


   @tf.function
   def reset(self):
       # Reset stock to full, steps/done to zero
       self.stock.assign(tf.ones_like(self.stock))
       self.steps.assign(tf.zeros_like(self.steps))
       self.done.assign(tf.zeros_like(self.done))


       # Find safe starting positions (outside the target shape)
       # Use shape_mask[0] as reference (they are identical initially)
       safe_indices = tf.where(tf.logical_not(self.shape_mask[0]))[:, 0] # Get flat indices
       num_safe = tf.shape(safe_indices)[0]
       tf.Assert(num_safe >= self.N, ["Not enough safe starting positions available.", num_safe, self.N])


       # Shuffle and select N starting positions
       shuffled_safe_indices = tf.random.shuffle(safe_indices)[:self.N]
       self.pos.assign(tf.cast(shuffled_safe_indices, tf.int32))


       return self._get_obs()


   @tf.function
   def step(self, actions):
       # Calculate potential new positions based on actions
       action_shifts = tf.gather(self.shifts, actions) # [N] shifts based on actions
       new_pos = self.pos + action_shifts             # [N] potential new flat indices


       # Check boundaries and collisions
       in_bounds = tf.logical_and(new_pos >= 0, new_pos < self.flat_dim) # [N] boolean
       # Clip new_pos for safe gathering, even if OOB
       safe_new_pos = tf.clip_by_value(new_pos, 0, self.flat_dim - 1)


       # Gather shape mask and stock at the potential new positions
       shape_mask_at_new = tf.gather(self.shape_mask, safe_new_pos, axis=1, batch_dims=1) # [N] bool
       stock_at_new = tf.gather(self.stock, safe_new_pos, axis=1, batch_dims=1)      # [N] bool


       # Determine invalid moves (hit shape or out of bounds)
       hit_shape_or_oob = tf.logical_or(tf.logical_not(in_bounds), shape_mask_at_new) # [N] bool


       # Determine if stock can be removed (valid move AND stock exists at new pos)
       can_remove = tf.logical_and(tf.logical_not(hit_shape_or_oob), stock_at_new) # [N] bool


       # Calculate rewards
       reward = tf.where(hit_shape_or_oob, -5.0, 0.0)   # Penalty for invalid move
       reward = tf.where(can_remove, reward + 1.0, reward) # Reward for removing stock
       reward = reward - 0.1                             # Step penalty


       # Update stock (remove material where applicable)
       remove_indices = tf.where(can_remove) # Indices [k, 0] where can_remove is True
       num_removals = tf.shape(remove_indices)[0]


       # Conditional update to avoid empty tensor issues if num_removals is 0
       def perform_update():
           # Indices of environments where removal happened
           env_indices_to_update = tf.squeeze(tf.cast(remove_indices, tf.int32), axis=1) # [num_removals]
           # Corresponding positions where removal happened
           pos_to_remove = tf.gather(new_pos, env_indices_to_update) # [num_removals]
           # Combine env and pos indices for scatter_nd_update: [[env0, pos0], [env1, pos1], ...]
           scatter_indices = tf.stack([env_indices_to_update, pos_to_remove], axis=1) # [num_removals, 2]
           # Values to update with (False, meaning remove stock)
           updates = tf.zeros(num_removals, dtype=tf.bool)
           return tf.tensor_scatter_nd_update(self.stock, scatter_indices, updates)


       maybe_updated_stock = tf.cond(num_removals > 0,
                                     true_fn=perform_update,
                                     false_fn=lambda: self.stock) # No update if no removals
       self.stock.assign(maybe_updated_stock)


       # Update agent position (only if the move was valid)
       is_valid_move = tf.logical_not(hit_shape_or_oob) # [N] bool
       next_pos = tf.where(is_valid_move, new_pos, self.pos) # Stay if invalid, move if valid
       self.pos.assign(next_pos)


       # Update steps and check for done state
       self.steps.assign_add(tf.ones_like(self.steps))
       newly_done = (self.steps >= self.max_steps)
       self.done.assign(tf.logical_or(self.done, newly_done)) # Mark as done if max steps reached


       # Get next observation
       next_obs = self._get_obs()
       return next_obs, tf.cast(reward, tf.float32), self.done


   @tf.function
   def _get_obs(self):
       G = self.G; N = self.N
       # Reshape flat stock/shape masks to 3D grids
       stock_grid = tf.reshape(self.stock, [N, G, G, G])
       shape_mask_grid = tf.reshape(self.shape_mask, [N, G, G, G])


       # Convert to float and stack as channels
       stock_float = tf.cast(stock_grid, tf.float32)
       shape_mask_float = tf.cast(shape_mask_grid, tf.float32)
       grid_obs = tf.stack([stock_float, shape_mask_float], axis=-1) # [N, G, G, G, 2]


       # Calculate normalized coordinates
       g_tf = tf.constant(G, dtype=tf.int32)
       z = self.pos % g_tf
       y = (self.pos // g_tf) % g_tf
       x = self.pos // (g_tf * g_tf)


       # Normalize coordinates to [0, 1] range
       g_minus_1_float = tf.cast(tf.maximum(1, G - 1), tf.float32) # Avoid division by zero if G=1
       x_norm = tf.cast(x, tf.float32) / g_minus_1_float
       y_norm = tf.cast(y, tf.float32) / g_minus_1_float
       z_norm = tf.cast(z, tf.float32) / g_minus_1_float


       coord_obs = tf.stack([x_norm, y_norm, z_norm], axis=-1) # [N, 3]


       return (grid_obs, coord_obs) # Return as a tuple


# -----------------------------------------------------------------------------
# 2) Replay Buffer (Handling Tuple State)
# -----------------------------------------------------------------------------
class BatchedReplayBuffer:
   def __init__(self, capacity=50000):
       self.cap = capacity
       # Use a deque for efficient pop from left when capacity is reached
       self.buf = deque(maxlen=capacity)


   def add_batch(self, S_tuple, A, R, S2_tuple, D):
       # Store the tuple state directly along with A, R, D
       self.buf.append((S_tuple, A, R, S2_tuple, D))


   def sample(self, batch_size=32):
       """Samples a batch of transitions, handling the tuple state."""
       num_stored = len(self.buf)
       if num_stored < batch_size:
           # Not enough samples yet
           return None


       # Randomly sample indices
       indices = random.sample(range(num_stored), batch_size)
       # Retrieve the selected transitions (which are batches themselves)
       batch = [self.buf[i] for i in indices]


       # Unzip the components of the sampled transitions
       S_tuple_list, A_list, R_list, S2_tuple_list, D_list = zip(*batch)


       # Unzip the state tuples
       S_grid_list, S_coord_list = zip(*S_tuple_list)
       S2_grid_list, S2_coord_list = zip(*S2_tuple_list)


       # Concatenate components into final batch tensors
       # Axis 0 is the batch dimension to concatenate along
       return (
           tf.concat(S_grid_list, axis=0), tf.concat(S_coord_list, axis=0), # State S
           tf.concat(A_list, axis=0),                                      # Action A
           tf.concat(R_list, axis=0),                                      # Reward R
           tf.concat(S2_grid_list, axis=0), tf.concat(S2_coord_list, axis=0),# Next state S2
           tf.concat(D_list, axis=0)                                       # Done D
       )


   def __len__(self):
       # Returns the number of *batches* stored, not total transitions
       return len(self.buf)




# -----------------------------------------------------------------------------
# 3) Noisy Dense Layer (Factorized Gaussian Noise)
# -----------------------------------------------------------------------------
class NoisyDense(tf.keras.layers.Layer):
   def __init__(self, units, activation=None, sigma0=0.5, **kwargs):
       super().__init__(**kwargs)
       self.units = units
       self.activation = tf.keras.activations.get(activation)
       self.sigma0 = sigma0 # Initial standard deviation parameter


   def build(self, input_shape):
       in_features = input_shape[-1]
       out_features = self.units
       dtype = tf.float32 # Assuming float32


       # Weight parameters (mean and standard deviation)
       sigma_init_val = self.sigma0 / math.sqrt(float(in_features))
       sigma_initializer = tf.constant_initializer(sigma_init_val)
       self.kernel_mean = self.add_weight(name="kernel_mean", shape=(in_features, out_features),
                                          initializer="he_uniform", trainable=True, dtype=dtype)
       self.kernel_sigma = self.add_weight(name="kernel_sigma", shape=(in_features, out_features),
                                           initializer=sigma_initializer, trainable=True, dtype=dtype)


       # Bias parameters (mean and standard deviation)
       self.bias_mean = self.add_weight(name="bias_mean", shape=(out_features,),
                                        initializer="zeros", trainable=True, dtype=dtype)
       self.bias_sigma = self.add_weight(name="bias_sigma", shape=(out_features,),
                                         initializer=sigma_initializer, trainable=True, dtype=dtype)


       super().build(input_shape)


   def call(self, inputs, training=None):
       # Determine if in training mode (for applying noise)
       if training is None:
           training = tf.keras.backend.learning_phase()


       if training:
           # Sample noise for weights and biases using Factorized Gaussian noise
           # Generate noise for input and output dimensions
           noise_in = self._factorized_noise(tf.shape(inputs)[-1])  # Shape [in_features]
           noise_out = self._factorized_noise(self.units)           # Shape [out_features]


           # Combine noise for weight matrix: outer product
           # noise_in: [in_features], noise_out: [out_features]
           # Need noise_in as [in_features, 1] and noise_out as [1, out_features] for matmul/tensordot
           # Correct approach is outer product, can achieve with broadcasting or einsum
           # Equivalent to: kernel_noise = noise_in[:, None] * noise_out[None, :]
           kernel_noise = tf.tensordot(tf.expand_dims(noise_in, -1), tf.expand_dims(noise_out, 0), axes=1) # Shape [in_features, out_features]


           # Noise for bias is just the output noise
           bias_noise = noise_out # Shape [out_features]


           # Apply noise: W = W_mu + W_sigma * noise_W, b = b_mu + b_sigma * noise_b
           kernel = self.kernel_mean + self.kernel_sigma * kernel_noise
           bias = self.bias_mean + self.bias_sigma * bias_noise
       else:
           # In inference mode, use only the mean weights and biases
           kernel = self.kernel_mean
           bias = self.bias_mean


       # Standard dense layer calculation: output = input * kernel + bias
       output = tf.matmul(inputs, kernel) + bias


       # Apply activation function if specified
       if self.activation is not None:
           output = self.activation(output)
       return output


   def _factorized_noise(self, num_elements):
       """Generates noise based on the Factorized Gaussian noise formula."""
       # Sample standard normal noise
       noise = tf.random.normal(shape=[num_elements])
       # Apply transformation: sign(x) * sqrt(|x|)
       return tf.sign(noise) * tf.sqrt(tf.abs(noise))


   def compute_output_shape(self, input_shape):
       # Output shape is same as input shape except for the last dimension (units)
       return tuple(input_shape[:-1]) + (self.units,)




# -----------------------------------------------------------------------------
# 4) Batched DQN Agent (Hybrid State CNN + Noisy Nets)
# -----------------------------------------------------------------------------
class BatchedDQNAgentTF:
   def __init__(self, grid_shape, coord_shape, action_dim=6,
                lr=1e-4, gamma=0.99, tau=0.005):
       self.grid_shape = grid_shape
       self.coord_shape = coord_shape
       self.action_dim = action_dim
       self.gamma = gamma
       self.tau = tau


       # --- Build Model Function ---
       def build_hybrid_noisy_model(name="Model"):
           grid_input = tf.keras.layers.Input(shape=self.grid_shape, name="grid_input")
           coord_input = tf.keras.layers.Input(shape=self.coord_shape, name="coord_input")


           # CNN part for grid observations
           x_cnn = tf.keras.layers.Conv3D(filters=32, kernel_size=5, strides=2, activation='relu', padding='same')(grid_input)
           x_cnn = tf.keras.layers.Conv3D(filters=64, kernel_size=3, strides=2, activation='relu', padding='same')(x_cnn)
           x_cnn = tf.keras.layers.Conv3D(filters=64, kernel_size=3, strides=1, activation='relu', padding='same')(x_cnn)
           cnn_features = tf.keras.layers.Flatten()(x_cnn)


           # Concatenate CNN features with coordinate observations
           concat_features = tf.keras.layers.Concatenate()([cnn_features, coord_input])


           # Dense part with Noisy Layers
           x = NoisyDense(512, activation='relu')(concat_features)
           outputs = NoisyDense(action_dim, activation='linear')(x) # Linear output for Q-values


           return tf.keras.Model(inputs=[grid_input, coord_input], outputs=outputs, name=name)
       # --- End Build Model Function ---


       # Initialize online and target models
       self.model = build_hybrid_noisy_model(name="Online_Model")
       self.target = build_hybrid_noisy_model(name="Target_Model")
       # Ensure target network starts with the same weights as the online network
       self.target.set_weights(self.model.get_weights())


       # Optimizer
       self.opt = tf.keras.optimizers.Adam(learning_rate=lr)


       # Replay Buffer
       self.buffer = BatchedReplayBuffer()


       # TensorBoard Writer
       logdir = f"runs/hybrid_noisy_dqn_{datetime.datetime.now():%Y%m%d_%H%M%S}"
       self.writer = tf.summary.create_file_writer(logdir)
       print(f"TensorBoard log directory: {logdir}")


       # Training step counter
       self.train_step_count = tf.Variable(0, dtype=tf.int64, trainable=False, name="train_steps")


   @tf.function
   def train_step(self, S_grid, S_coord, A, R, S2_grid, S2_coord, D):
       """Performs a single training step using Double DQN update."""


       # --- Target Q-value Calculation (Double DQN) ---
       # 1. Get next actions from the *online* model for S2
       Q2_online = self.model([S2_grid, S2_coord], training=True) # Pass training=True for noisy layers
       best_actions_next = tf.argmax(Q2_online, axis=1, output_type=tf.int32) # Shape [batch_size]


       # 2. Get Q-values from the *target* model for S2
       Q2_target = self.target([S2_grid, S2_coord], training=True) # Pass training=True for noisy layers


       # 3. Select the Q-value from the target network corresponding to the best action selected by the online network
       # Create indices: [[0, action0], [1, action1], ...]
       action_indices = tf.stack([tf.range(tf.shape(best_actions_next)[0], dtype=tf.int32), best_actions_next], axis=1)
       Q2_best_target = tf.gather_nd(Q2_target, action_indices) # Shape [batch_size]


       # 4. Calculate the TD target: R + gamma * Q_target(S', argmax_a Q_online(S', a)) * (1 - D)
       target_Q = R + self.gamma * Q2_best_target * (1.0 - tf.cast(D, tf.float32))
       # --- End Target Q-value Calculation ---


       # --- Loss Calculation and Gradient Update ---
       with tf.GradientTape() as tape:
           # Predict Q-values for the original states (S) using the online model
           Q_online = self.model([S_grid, S_coord], training=True) # Pass training=True for noisy layers


           # Select the Q-values corresponding to the actions actually taken (A)
           # Create indices: [[0, A0], [1, A1], ...]
           action_indices_taken = tf.stack([tf.range(tf.shape(A)[0], dtype=tf.int32), A], axis=1)
           Q_online_taken = tf.gather_nd(Q_online, action_indices_taken) # Shape [batch_size]


           # Calculate loss (e.g., Mean Squared Error) between target_Q and Q_online_taken
           loss = tf.keras.losses.MeanSquaredError()(target_Q, Q_online_taken) # Use instantiated loss


       # Compute and apply gradients
       grads = tape.gradient(loss, self.model.trainable_variables)
       # Optional: Gradient clipping
       # grads, _ = tf.clip_by_global_norm(grads, 1.0)
       self.opt.apply_gradients(zip(grads, self.model.trainable_variables))
       # --- End Loss Calculation and Gradient Update ---


       # --- Soft Update Target Network ---
       # Using tf.keras.layers.Layer.set_weights is generally safer and simpler than manual assignment loop
       updated_target_weights = []
       for w_online, w_target in zip(self.model.weights, self.target.weights):
           updated_target_weights.append(self.tau * w_online + (1.0 - self.tau) * w_target)
       self.target.set_weights(updated_target_weights)
       # --- End Soft Update Target Network ---


       return loss


   @tf.function
   def act_batch(self, S_tuple, deterministic=False):
       """Selects actions for a batch of states."""
       # Pass training=not deterministic to enable/disable noise in NoisyDense layers
       # If deterministic (evaluation), training=False -> use mean weights, no noise
       # If not deterministic (training), training=True -> use noisy weights
       q_values = self.model(S_tuple, training=not deterministic)
       actions = tf.argmax(q_values, axis=1, output_type=tf.int32)
       return actions


   def remember_batch(self, S_tuple, A, R, S2_tuple, D):
       # Simply add the batch of transitions to the buffer
       self.buffer.add_batch(S_tuple, A, R, S2_tuple, D)


   def learn(self, batch_size=32):
       """Samples from buffer and performs a training step."""
       # Check buffer size before sampling
       if len(self.buffer) < batch_size:
           return None # Not enough samples to learn


       # Sample data
       sampled_data = self.buffer.sample(batch_size)
       if sampled_data is None: # Should not happen if len check passes, but safe check
           return None


       # Unpack sampled data
       S_grid_s, S_coord_s, A_s, R_s, S2_grid_s, S2_coord_s, D_s = sampled_data


       # Perform the training step
       loss = self.train_step(S_grid_s, S_coord_s, A_s, R_s, S2_grid_s, S2_coord_s, D_s)


       # Log loss to TensorBoard
       step = self.train_step_count.numpy() # Get current step value
       with self.writer.as_default(step=step):
           tf.summary.scalar("Train/Loss", loss)


       # Increment training step counter
       self.train_step_count.assign_add(1)


       return loss.numpy() # Return scalar loss value


# -----------------------------------------------------------------------------
# 5) Evaluation Function (Moved Before Training Loop & Returns Stats)
# -----------------------------------------------------------------------------
def evaluate_agent_performance(agent, grid_size, max_steps, num_eval_episodes=10, render=True, render_env_index=0):
   """
   Evaluates a trained agent deterministically, calculates performance statistics,
   and optionally renders the final state of one specified environment.
   Handles hybrid state (grid, coords). Returns key statistics.
   """
   print(f"\n--- Running Evaluation ({num_eval_episodes} episodes) ---")
   eval_start_time = time.time()


   # Create a separate environment instance for evaluation
   eval_env = BatchedSculpt3DEnvTF(grid_size=grid_size, max_steps=max_steps, n_envs=num_eval_episodes)


   # Calculate initial carvable count (do this once)
   initial_shape_mask_flat_gpu = eval_env.shape_mask[0] # Use env 0 as reference
   initial_shape_mask_flat_np = initial_shape_mask_flat_gpu.numpy()
   initial_carvable_mask_flat = ~initial_shape_mask_flat_np # Material outside the shape
   initial_carvable_count = np.sum(initial_carvable_mask_flat)
   print(f"  Initial number of carvable voxels: {initial_carvable_count}")
   if initial_carvable_count == 0: print("  Warning: No carvable material defined.")


   # Lists to store results across all evaluation episodes
   all_ep_rewards = []
   all_ep_lengths = []
   all_ep_removed_counts = []
   all_ep_incorrect_removed_counts = []


   # Get the final stock state variable from the eval env *before* the loop
   # This variable will be updated in-place by env.step
   final_stock_variable_eval = eval_env.stock


   # Run evaluation episodes
   obs_tuple = eval_env.reset()
   done = tf.zeros([num_eval_episodes], tf.bool)
   ep_rewards = tf.zeros([num_eval_episodes], tf.float32)
   ep_steps = tf.zeros([num_eval_episodes], tf.int32)


   for _ in range(max_steps): # Limit loop by max_steps
        if tf.reduce_all(done): break # Exit early if all envs are done

        # Get actions deterministically (no noise)
        A = agent.act_batch(obs_tuple, deterministic=True)

        # Step the environment
        S2_tuple, R, next_done = eval_env.step(A)

        # Update rewards and steps only for environments not yet done
        active_mask = ~done
        ep_rewards += R * tf.cast(active_mask, tf.float32)
        ep_steps += tf.cast(active_mask, tf.int32)

        # Update observations and done status
        obs_tuple = S2_tuple
        done = next_done # Use the done flags returned by step


   # After the loop, collect final results
   final_stock_batch_np = final_stock_variable_eval.numpy() # Get final stock state
   batch_rewards = ep_rewards.numpy()
   batch_lengths = ep_steps.numpy()


   all_ep_rewards.extend(batch_rewards.tolist())
   all_ep_lengths.extend(batch_lengths.tolist())


   # Calculate removed/incorrect counts per episode
   for i in range(num_eval_episodes):
       final_stock_flat_np = final_stock_batch_np[i]


       # Correctly removed: Initially carvable AND now gone
       removed_mask = initial_carvable_mask_flat & (~final_stock_flat_np)
       removed_count = np.sum(removed_mask)
       all_ep_removed_counts.append(removed_count)


       # Incorrectly removed: Initially part of shape AND now gone
       incorrectly_removed_mask = initial_shape_mask_flat_np & (~final_stock_flat_np)
       incorrectly_removed_count = np.sum(incorrectly_removed_mask)
       all_ep_incorrect_removed_counts.append(incorrectly_removed_count)


   # Aggregate and Print Statistics
   avg_reward = np.mean(all_ep_rewards); std_reward = np.std(all_ep_rewards)
   avg_length = np.mean(all_ep_lengths)
   avg_removed_count = np.mean(all_ep_removed_counts)
   avg_incorrect_removed = np.mean(all_ep_incorrect_removed_counts)


   # Calculate removal percentage safely
   if initial_carvable_count > 0:
       removal_percentages = [(c / initial_carvable_count) * 100.0 for c in all_ep_removed_counts]
       avg_removal_percentage = np.mean(removal_percentages)
       std_removal_percentage = np.std(removal_percentages)
   else:
       avg_removal_percentage = 0.0
       std_removal_percentage = 0.0


   print(f"\n--- Evaluation Results ---")
   print(f"  Avg Reward : {avg_reward:.2f} (+/- {std_reward:.2f})")
   print(f"  Avg Length : {avg_length:.1f}")
   print(f"  Avg Removed: {avg_removed_count:.1f} / {initial_carvable_count} ({avg_removal_percentage:.2f}% +/- {std_removal_percentage:.2f}%)")
   print(f"  Avg Incorrect: {avg_incorrect_removed:.1f}")
   eval_duration = time.time() - eval_start_time
   print(f"  Evaluation Duration: {eval_duration:.2f}s")


   # Render Final State
   if render:
       print(f"\n--- Rendering final state for Eval Env Index: {render_env_index} ---")
       if render_env_index < 0 or render_env_index >= num_eval_episodes:
            print(f"Error: render_env_index ({render_env_index}) out of bounds for {num_eval_episodes} eval envs.")
       elif num_eval_episodes > 0:
           try:
               # Get final stock and shape mask for the specific environment
               final_stock_flat_np = final_stock_batch_np[render_env_index]
               shape_mask_flat_np = initial_shape_mask_flat_np # Same for all


               G = grid_size
               final_stock_3d = final_stock_flat_np.reshape((G, G, G))
               shape_mask_3d = shape_mask_flat_np.reshape((G, G, G))


               # Masks for plotting
               shape_to_plot = shape_mask_3d # Target shape
               initial_carvable_mask_render = ~shape_mask_3d # Initially carvable
               removed_mask_render = initial_carvable_mask_render & (~final_stock_3d) # Correctly removed
               incorrectly_removed_mask_render = shape_mask_3d & (~final_stock_3d) # Incorrectly removed


               # Plotting
               fig = plt.figure(figsize=(9, 7)); ax = fig.add_subplot(111, projection='3d')
               ax.set_facecolor('whitesmoke')
               # Coordinates for voxels (need +1 for boundaries)
               x_vox, y_vox, z_vox = np.indices(np.array(shape_to_plot.shape) + 1)


               # Plot target shape (transparent blue)
               ax.voxels(x_vox, y_vox, z_vox, shape_to_plot, facecolors='blue', alpha=0.1, edgecolor=None)


               # Plot correctly removed material (solid red)
               ax.voxels(x_vox, y_vox, z_vox, removed_mask_render, facecolors='red', alpha=0.6, edgecolor=None)


               # Plot incorrectly removed material (yellow, if any)
               if np.sum(incorrectly_removed_mask_render) > 0:
                    ax.voxels(x_vox, y_vox, z_vox, incorrectly_removed_mask_render, facecolors='yellow', alpha=0.7, edgecolor='orange', label='Incorrect Removal')
                    ax.legend()


               # Set title and labels
               rendered_env_reward = all_ep_rewards[render_env_index]
               rendered_env_removed = all_ep_removed_counts[render_env_index]
               rendered_env_perc = (rendered_env_removed / initial_carvable_count) * 100.0 if initial_carvable_count > 0 else 0.0
               ax.set_title(f"Eval Render Env #{render_env_index} (R={rendered_env_reward:.1f}, Removed={rendered_env_perc:.1f}%)")
               ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
               ax.set_xlim(0, G); ax.set_ylim(0, G); ax.set_zlim(0, G); ax.set_aspect('auto') # Use 'auto' or 'equal'
               plt.tight_layout()


               # Save the plot
               render_dir = "renders_eval"
               os.makedirs(render_dir, exist_ok=True) # Ensure directory exists
               save_path = os.path.join(render_dir, f"eval_render_env_{render_env_index}_final.png")
               plt.savefig(save_path); print(f"Saved evaluation render to {save_path}"); plt.close(fig)
           except Exception as e:
               print(f"Error during rendering: {e}")
               import traceback; traceback.print_exc()


   # Return aggregated statistics
   return {
       "avg_reward": avg_reward,
       "std_reward": std_reward,
       "avg_length": avg_length,
       "avg_removed_count": avg_removed_count,
       "avg_incorrect_removed": avg_incorrect_removed,
       "avg_removal_percentage": avg_removal_percentage,
       "std_removal_percentage": std_removal_percentage,
       "initial_carvable_count": initial_carvable_count
   }




# -----------------------------------------------------------------------------
# 6) Training Loop (Using Hybrid State & Noisy Agent) with Periodic Saving
# -----------------------------------------------------------------------------
def train_gpu_batched(
       grid_size=16, max_steps=300, n_envs=32, episodes=10000,
       buffer_capacity=100000, learn_batch_size=32, learn_freq=4,
       gamma=0.99, lr=1e-4, tau=0.005, log_every=50,
       evaluate_every=100,
       num_eval_episodes_periodic=10,
       render_intermediate_eval=False,
       # <<< NEW Parameters for Saving >>>
       save_every_episodes=5,
       checkpoint_dir="checkpoints"
       ):


   print(f"--- Training Hybrid Noisy DQN Agent ---")
   print(f"Params: Grid={grid_size}, N_Envs={n_envs}, MaxSteps={max_steps}, Episodes={episodes}")
   print(f"Learn Batch Size (B): {learn_batch_size}, Total samples/learn step (B*N): {learn_batch_size * n_envs}")
   print(f"Periodic Evaluation every {evaluate_every} episodes ({num_eval_episodes_periodic} eps each).")
   print(f"Periodic Weight Saving every {save_every_episodes} episodes to '{checkpoint_dir}'.") # <<< Log saving info
   print(f"Memory Warning: Ensure sufficient CPU RAM and GPU VRAM.")


   env = BatchedSculpt3DEnvTF(grid_size, max_steps, n_envs)
   agent = BatchedDQNAgentTF(grid_shape=env.grid_obs_shape, coord_shape=env.coord_obs_shape,
                             action_dim=6, lr=lr, gamma=gamma, tau=tau)
   agent.buffer.cap = buffer_capacity # Set buffer capacity


   total_steps_taken = 0
   episode_rewards_history = []
   episode_lengths_history = []
   start_time = time.time()


   # Lists to store evaluation results for plotting trend
   eval_episodes_list = []
   eval_avg_rewards_list = []
   eval_avg_removal_perc_list = []


   # --- Main Training Loop ---
   for ep in range(1, episodes + 1):
       obs_tuple = env.reset()
       done = tf.zeros([n_envs], tf.bool)
       ep_rewards = tf.zeros([n_envs], tf.float32)
       ep_steps = tf.zeros([n_envs], tf.int32)


       # --- Episode Step Loop ---
       # Limit by max_steps to prevent infinite loops if agent gets stuck
       for current_ep_step in range(max_steps):
           if tf.reduce_all(done): break # Exit if all envs finished


           # Act (use noise during training)
           A = agent.act_batch(obs_tuple, deterministic=False)


           # Step environment
           S2_tuple, R, next_done = env.step(A)


           # Remember transition
           agent.remember_batch(obs_tuple, A, R, S2_tuple, next_done)


           # Update episode stats for active environments
           active_mask = ~done
           ep_rewards += R * tf.cast(active_mask, tf.float32)
           ep_steps += tf.cast(active_mask, tf.int32)


           # Update state and done flags
           obs_tuple = S2_tuple
           done = next_done # Use done flags from the step


           # Update total steps taken
           total_steps_taken += n_envs


           # Learn periodically based on steps taken within the episode (or total steps)
           # Using total_steps_taken might be more common
           if total_steps_taken > 0 and total_steps_taken % (learn_freq * n_envs) == 0: # Learn every 'learn_freq' agent steps approx
               loss_val = agent.learn(learn_batch_size)
       # --- End Episode Step Loop ---


       # Calculate and store average batch stats
       avg_reward_batch = tf.reduce_mean(ep_rewards).numpy()
       avg_steps_batch = tf.reduce_mean(tf.cast(ep_steps, tf.float32)).numpy()
       episode_rewards_history.append(avg_reward_batch)
       episode_lengths_history.append(avg_steps_batch)


       # --- Logging ---
       if ep % log_every == 0 or ep == 1:
           elapsed_time = time.time() - start_time
           # Calculate rolling averages for smoother logging
           avg_r = np.mean(episode_rewards_history[-log_every:]) if len(episode_rewards_history) >= log_every else np.mean(episode_rewards_history)
           avg_l = np.mean(episode_lengths_history[-log_every:]) if len(episode_lengths_history) >= log_every else np.mean(episode_lengths_history)
           print(f"Ep {ep}/{episodes} | Avg R (last {log_every}): {avg_r:.2f} | Avg Len: {avg_l:.1f} | Steps: {total_steps_taken} | TrainSteps: {agent.train_step_count.numpy()} | Time: {elapsed_time:.1f}s")


           # Log scalars to TensorBoard
           agent_step = agent.train_step_count.numpy()
           with agent.writer.as_default(step=agent_step):
                tf.summary.scalar("Episode/AvgReward_Roll", avg_r)
                tf.summary.scalar("Episode/AvgLength_Roll", avg_l)
                tf.summary.scalar("System/TotalEnvSteps", total_steps_taken)


           # --- TensorBoard Image Logging (Render one environment) ---
           try:
                render_env_index = 0
                # Get current stock and shape for the specific env as numpy arrays
                stock_np = env.stock.numpy()[render_env_index].reshape((grid_size, grid_size, grid_size))
                shape_np = env.shape_mask.numpy()[render_env_index].reshape((grid_size, grid_size, grid_size))


                # Calculate masks for rendering
                removed_mask = (~shape_np) & (~stock_np) # Initially carvable and now gone
                incorrect_mask = shape_np & (~stock_np)  # Initially shape and now gone


                # Create plot
                fig = plt.figure(figsize=(6, 5))
                ax = fig.add_subplot(111, projection='3d')
                x_vox, y_vox, z_vox = np.indices(np.array(stock_np.shape) + 1)


                # Plot volumes
                ax.voxels(x_vox, y_vox, z_vox, shape_np, facecolors='blue', alpha=0.1) # Target shape
                ax.voxels(x_vox, y_vox, z_vox, removed_mask, facecolors='red', alpha=0.6) # Correctly removed
                if np.sum(incorrect_mask) > 0:
                    ax.voxels(x_vox, y_vox, z_vox, incorrect_mask, facecolors='yellow', alpha=0.7) # Incorrectly removed


                ax.set_title(f"Ep {ep} - Render Env #{render_env_index}")
                ax.set_axis_off() # Clean look for TensorBoard
                fig.tight_layout()


                # Convert plot to PNG image bytes
                buf = io.BytesIO()
                plt.savefig(buf, format='png')
                buf.seek(0)
                # Decode PNG and add batch dimension for TensorBoard
                image_tensor = tf.image.decode_png(buf.getvalue(), channels=4)
                image_tensor = tf.expand_dims(image_tensor, 0)
                buf.close()
                plt.close(fig) # Close plot to free memory


                # Write image to TensorBoard
                with agent.writer.as_default(step=agent.train_step_count.numpy()):
                    tf.summary.image("EnvRender/Env0_Train", image_tensor)
           except Exception as e:
                print(f"Render log error at ep {ep}: {e}")
       # --- End Logging ---


       # <<< --- Periodic Model Saving --- >>>
       if save_every_episodes > 0 and ep % save_every_episodes == 0 and ep > 0:
           try:
               # Ensure the checkpoint directory exists
               os.makedirs(checkpoint_dir, exist_ok=True)
               # Construct save path
               save_path = os.path.join(checkpoint_dir, f"model_ep{ep}_g{grid_size}.weights.h5")
               # Save the weights of the online model
               agent.model.save_weights(save_path)
               print(f"\n--- Saved model weights at episode {ep} to {save_path} ---")
           except Exception as e:
               print(f"\n--- Error saving weights at episode {ep}: {e} ---")
       # <<< --- End Periodic Model Saving --- >>>


       # --- Periodic Evaluation ---
       if evaluate_every > 0 and ep % evaluate_every == 0:
           eval_stats = evaluate_agent_performance(
               agent=agent,
               grid_size=grid_size,
               max_steps=max_steps,
               num_eval_episodes=num_eval_episodes_periodic,
               render=render_intermediate_eval,
               render_env_index=0 # Render env 0 if rendering is enabled
           )
           # Store results for final trend plot
           if eval_stats:
               eval_episodes_list.append(ep)
               eval_avg_rewards_list.append(eval_stats["avg_reward"])
               eval_avg_removal_perc_list.append(eval_stats["avg_removal_percentage"])


               # Log evaluation stats to TensorBoard
               agent_step = agent.train_step_count.numpy()
               with agent.writer.as_default(step=agent_step):
                   tf.summary.scalar("Evaluate/AvgReward", eval_stats["avg_reward"])
                   tf.summary.scalar("Evaluate/AvgRemovalPercentage", eval_stats["avg_removal_percentage"])
                   tf.summary.scalar("Evaluate/AvgIncorrectRemoved", eval_stats["avg_incorrect_removed"])


           print("-" * 60) # Separator after evaluation output
       # --- End Periodic Evaluation ---


   # --- End of Training Loop ---
   agent.writer.close() # Close the TensorBoard writer
   total_training_time = time.time() - start_time
   print(f"\nTraining finished. Total steps: {total_steps_taken}, Total Time: {total_training_time:.2f}s")


   # --- Plot Evaluation Trend ---
   if eval_episodes_list:
       print("\n--- Plotting Evaluation Trend ---")
       try:
           fig, ax1 = plt.subplots(figsize=(12, 6))


           color = 'tab:red'
           ax1.set_xlabel('Training Episode')
           ax1.set_ylabel('Avg Carvable Material Removed (%)', color=color)
           ax1.plot(eval_episodes_list, eval_avg_removal_perc_list, color=color, marker='o', linestyle='-', label='Removal %')
           ax1.tick_params(axis='y', labelcolor=color)
           ax1.grid(True, axis='y', linestyle=':')


           ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
           color = 'tab:blue'
           ax2.set_ylabel('Avg Evaluation Reward', color=color)
           ax2.plot(eval_episodes_list, eval_avg_rewards_list, color=color, marker='x', linestyle='--', label='Avg Reward')
           ax2.tick_params(axis='y', labelcolor=color)


           fig.suptitle('Agent Evaluation Performance During Training')
           fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout


           # Create plot directory if it doesn't exist
           plot_dir = "plots"
           os.makedirs(plot_dir, exist_ok=True)
           plot_path = os.path.join(plot_dir, f"evaluation_trend_g{grid_size}_n{n_envs}.png")
           plt.savefig(plot_path)
           print(f"Saved evaluation trend plot to {plot_path}")
           plt.close(fig) # Close plot


       except Exception as e:
           print(f"Error plotting evaluation trend: {e}")


   return agent




# -----------------------------------------------------------------------------
# 7) Main Execution Block (Adjusted Parameters)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
   # --- Parameters (Adjust based on your hardware!) ---
   GRID_SIZE_RUN = 8         # Smaller grid for potentially faster runs
   N_ENVS_RUN = 16           # Number of parallel environments
   MAX_STEPS_RUN = 200       # Max steps per episode
   EPISODES_RUN = 5000       # Total training episodes
   BUFFER_CAP_RUN = 50000    # Replay buffer capacity (in batches)
   LEARN_BATCH_RUN = 32      # Number of *batches* to sample for learning
   LEARNING_RATE = 1e-4      # Learning rate for Adam
   EVAL_FREQ_RUN = 50       # Evaluate every N training episodes
   NUM_EVAL_EPISODES_RUN = 10 # Number of episodes per evaluation run
   SAVE_FREQ_RUN = 1000       # <<< Save weights every N episodes >>>
   CHECKPOINT_DIR_RUN = "checkpoints_hybrid_noisy" # <<< Directory for saved weights >>>


   print(f"Starting run with Hybrid State + Noisy Nets")
   print(f"Params: Grid={GRID_SIZE_RUN}, N_Envs={N_ENVS_RUN}, LearnBatch(B)={LEARN_BATCH_RUN}")
   print(f"Total samples per learn step (B*N): {LEARN_BATCH_RUN * N_ENVS_RUN}")


   # Train the agent, with periodic evaluation and saving
   trained_agent = train_gpu_batched(
       grid_size=GRID_SIZE_RUN,
       max_steps=MAX_STEPS_RUN,
       n_envs=N_ENVS_RUN,
       episodes=EPISODES_RUN,
       buffer_capacity=BUFFER_CAP_RUN,
       learn_batch_size=LEARN_BATCH_RUN,
       learn_freq=4, # Learn approx every 4 agent steps
       gamma=0.99,
       lr=LEARNING_RATE,
       tau=0.005, # Target network update rate
       log_every=5,
       evaluate_every=EVAL_FREQ_RUN,
       num_eval_episodes_periodic=NUM_EVAL_EPISODES_RUN,
       render_intermediate_eval=False, # Keep intermediate rendering off by default
       # <<< Pass saving parameters >>>
       save_every_episodes=SAVE_FREQ_RUN,
       checkpoint_dir=CHECKPOINT_DIR_RUN
   )


   # --- Run Final Evaluation After Training ---
   print("\n" + "="*70)
   print("      RUNNING FINAL EVALUATION ON TRAINED AGENT")
   print("="*70)
   if trained_agent:
        evaluate_agent_performance(
            agent=trained_agent,
            grid_size=GRID_SIZE_RUN,
            max_steps=MAX_STEPS_RUN,
            num_eval_episodes=50, # Evaluate over more episodes for final assessment
            render=True,          # Render the final plot for one episode
            render_env_index=0
        )


   # ---  Save Final Model Weights ---
   if trained_agent:
       # Use the same checkpoint directory for the final save
       final_save_path = os.path.join(CHECKPOINT_DIR_RUN, f"FINAL_model_ep{EPISODES_RUN}_g{GRID_SIZE_RUN}.weights.h5")
       try:
           # Ensure directory exists for final save too
           os.makedirs(CHECKPOINT_DIR_RUN, exist_ok=True)
           trained_agent.model.save_weights(final_save_path)
           print(f"\nFinal model weights saved to {final_save_path}")
       except Exception as e:
           print(f"\nError saving final model weights: {e}")
