In [None]:
#!/usr/bin/env python3
"""
Fully‑GPU Batched Sculpt3DEnv + DQN Agent
• N envs stepped in parallel (no Python loops per step)
• Paths pre‑seeded with starting pos to avoid empty unpack errors  <-- NOTE: Path tracking modified
• Batched replay buffer insertion
• Periodic 3D rendering, guarded against empty paths <-- NOTE: Rendering needs adaptation
"""

import os
# Suppress TensorFlow INFO/WARNING messages
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Temporarily comment out for more verbose TF logs if needed

import tensorflow as tf
import numpy as np
import random
# from torch.utils.tensorboard import SummaryWriter # Replaced with tf.summary
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import time # For basic timing

# Optional: Explicitly check for GPU and log device placement
# print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
# tf.debugging.set_log_device_placement(True) # Uncomment for verbose device placement logs

# -----------------------------------------------------------------------------
# 1) Batched Sculpt3DEnvTF
# -----------------------------------------------------------------------------
class BatchedSculpt3DEnvTF:
    # Increased default grid size slightly for potentially more interesting tasks
    def __init__(self, grid_size=32, max_steps=200, n_envs=16):
        G, N = grid_size, n_envs # G is a Python int here
        self.G, self.N, self.max_steps = G, N, max_steps
        self.flat_dim = G*G*G

        # Use float32 for coords calculation to avoid potential type issues downstream
        coords_range = tf.range(G, dtype=tf.float32)
        coords = tf.stack(tf.meshgrid(
            coords_range, coords_range, coords_range,
            indexing='ij'
        ), axis=-1)  # [G,G,G,3]

        # Ensure center calculation handles odd/even G correctly
        center = tf.constant([G/2 - 0.5, G/2 - 0.5, G/2 - 0.5], tf.float32)
        # Calculate squared distance
        dist2 = tf.reduce_sum(tf.square(coords - center), axis=-1)
        # Use float comparison for radius
        radius_sq = tf.square(tf.cast(G // 2 - 1, tf.float32))
        mask3d = dist2 <= radius_sq
        mask_flat = tf.reshape(mask3d, [-1])  # [G^3] boolean

        # Place variables on GPU if available, TF handles this by default
        # Precompute protected‑shape mask (flat)
        self.shape_mask = tf.Variable(
            tf.tile(mask_flat[None, :], [N, 1]), trainable=False, dtype=tf.bool,
            name="shape_mask" # Add name for clarity in graph/debugging
        )
        # Represents the material currently present
        self.stock = tf.Variable(
            tf.ones([N, self.flat_dim], dtype=tf.bool),
            trainable=False, name="stock"
        )

        # Router positions (flat‑index), steps taken, done flags
        self.pos = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="pos")
        self.steps = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="steps")
        self.done = tf.Variable(tf.zeros([N], dtype=tf.bool), trainable=False, name="done")

        # --- SHIFTS CALCULATION (Using Python ints) ---
        G_py = grid_size # Use Python int G directly
        def to_flat_py(dx, dy, dz):
            # Pure Python integer calculation
            return dx * G_py * G_py + dy * G_py + dz

        moves = [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]
        shifts_py = [to_flat_py(*m) for m in moves]
        self.shifts = tf.constant(shifts_py, dtype=tf.int32, name="shifts")  # [6]
        # --- END SHIFTS CALCULATION ---


    @tf.function # Compile reset logic
    def reset(self):
        # Reset stock (material), steps, done flags
        self.stock.assign(tf.ones_like(self.stock, dtype=tf.bool))
        self.steps.assign(tf.zeros_like(self.steps, dtype=tf.int32))
        self.done.assign(tf.zeros_like(self.done, dtype=tf.bool))

        # Pick a random *safe* start flat‑index for each env
        safe_indices = tf.where(~self.shape_mask[0])[:, 0] # Get flat indices where mask is False [#safe]
        num_safe = tf.shape(safe_indices)[0]

        # Ensure we don't request more starting positions than available safe spots
        num_envs_to_start = tf.minimum(self.N, num_safe)
        tf.Assert(num_safe >= self.N, ["Not enough safe starting positions available."])

        shuffled_safe_indices = tf.random.shuffle(safe_indices)[:num_envs_to_start]

        self.pos.assign(tf.cast(shuffled_safe_indices, tf.int32))
        return self._get_obs()

    @tf.function # Compile step logic
    def step(self, actions):
        # actions: [N] ints ∈ [0..5]
        action_shifts = tf.gather(self.shifts, actions)     # [N]
        new_pos = self.pos + action_shifts                 # [N], Potential new positions

        # --- Boundary Checks ---
        in_bounds = (new_pos >= 0) & (new_pos < self.flat_dim) # [N], boolean

        # --- Clamp Indices for Safe Gathering ---
        safe_new_pos = tf.clip_by_value(new_pos, 0, self.flat_dim - 1) # Shape [N]

        # --- Gather State Information using tf.gather ---
        shape_mask_at_old = tf.gather(self.shape_mask, self.pos, axis=1, batch_dims=1) # [N] bool
        shape_mask_at_new = tf.gather(self.shape_mask, safe_new_pos, axis=1, batch_dims=1) # [N] bool
        stock_at_old = tf.gather(self.stock, self.pos, axis=1, batch_dims=1) # [N] bool
        stock_at_new = tf.gather(self.stock, safe_new_pos, axis=1, batch_dims=1) # [N] bool

        # --- Collision and Removal Logic ---
        hit_shape_or_oob = ~in_bounds | shape_mask_at_new # [N] bool
        can_remove = ~hit_shape_or_oob & stock_at_new # [N] bool

        # --- Calculate Rewards ---
        reward = tf.where(hit_shape_or_oob, -5.0, 0.0) # [N] float32
        reward = tf.where(can_remove, reward + 1.0, reward)
        reward = reward - 0.1 # Constant step cost

        # --- Update State ---
        # 1. Update Stock
        remove_indices = tf.where(can_remove) # Shape [num_removals, 1], dtype=int64
        num_removals = tf.shape(remove_indices)[0]

        def perform_update():
            # Gather the actual valid new positions where removal occurred
            # Squeeze remove_indices to 1D for gathering from 1D new_pos
            pos_to_remove = tf.gather(new_pos, tf.squeeze(remove_indices, axis=1)) # Shape [num_removals]
            # Create scatter indices [env_idx, pos_idx]
            scatter_indices = tf.concat([
                tf.cast(remove_indices, tf.int32), # Env indices [num_removals, 1]
                tf.expand_dims(pos_to_remove, axis=1) # Pos indices [num_removals, 1]
            ], axis=1) # Shape [num_removals, 2]
            updates = tf.zeros(num_removals, dtype=tf.bool)
            return tf.tensor_scatter_nd_update(self.stock, scatter_indices, updates)

        maybe_updated_stock = tf.cond(
            num_removals > 0,
            true_fn=perform_update,
            false_fn=lambda: self.stock # No-op, return current stock tensor
        )
        self.stock.assign(maybe_updated_stock)

        # 2. Update Position: Move only if the move is valid
        is_valid_move = ~hit_shape_or_oob # [N] bool
        next_pos = tf.where(is_valid_move, new_pos, self.pos)
        self.pos.assign(next_pos)

        # 3. Update Steps and Done Flags
        self.steps.assign_add(tf.ones_like(self.steps))
        newly_done = (self.steps >= self.max_steps)
        self.done.assign(self.done | newly_done)

        return self._get_obs(), tf.cast(reward, tf.float32), self.done

    @tf.function # Compile observation calculation
    def _get_obs(self):
        G_int = tf.cast(self.G, tf.int32)
        z = self.pos % G_int
        y = (self.pos // G_int) % G_int
        x = self.pos // (G_int * G_int)
        xyz = tf.stack([x, y, z], axis=1) # int32
        center_coords = tf.constant([self.G/2 - 0.5, self.G/2 - 0.5, self.G/2 - 0.5], tf.float32)
        center_tiled = tf.tile(center_coords[None, :], [self.N, 1]) # [N, 3] float32
        obs = tf.concat([tf.cast(xyz, tf.float32), center_tiled], axis=1) # [N, 6] float32
        return obs


# -----------------------------------------------------------------------------
# 2) Batched DQN Agent + Replay Buffer
# -----------------------------------------------------------------------------
class BatchedReplayBuffer:
    # Increased default capacity
    def __init__(self, capacity=100000):
        self.cap = capacity
        self.buf = []  # list of (S, A, R, S2, D) tuples, each element is a Tensor [N, ...]

    def add_batch(self, S, A, R, S2, D):
        if len(self.buf) >= self.cap:
            self.buf.pop(0) # Remove oldest batch
        self.buf.append((S, A, R, S2, D))

    def sample(self, batch_size=128):
        num_batches_in_buffer = len(self.buf)
        if num_batches_in_buffer == 0: # Handle empty buffer case
             return None

        actual_batch_size = min(batch_size, num_batches_in_buffer)

        if num_batches_in_buffer < batch_size:
            indices = np.random.choice(num_batches_in_buffer, actual_batch_size, replace=True)
            batch = [self.buf[i] for i in indices]
        else:
            batch = random.sample(self.buf, actual_batch_size)

        if not batch:
             return None

        S_list, A_list, R_list, S2_list, D_list = zip(*batch)

        S_sampled = tf.concat(S_list, axis=0)
        A_sampled = tf.concat(A_list, axis=0)
        R_sampled = tf.concat(R_list, axis=0)
        S2_sampled = tf.concat(S2_list, axis=0)
        D_sampled = tf.concat(D_list, axis=0)

        return (S_sampled, A_sampled, R_sampled, S2_sampled, D_sampled)

    def __len__(self):
        return len(self.buf)


class BatchedDQNAgentTF:
    def __init__(self, state_dim=6, action_dim=6,
                 lr=1e-4, gamma=0.99, tau=0.005,  # Adjusted defaults
                 n_envs=16): # Need n_envs for buffer size calculation
        self.gamma, self.tau = gamma, tau
        self.action_dim = action_dim
        self.n_envs = n_envs # Store n_envs used in training

        self.model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(state_dim,)),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(action_dim, activation='linear') # Linear output for Q-values
        ], name="Q_Model")
        self.target = tf.keras.models.clone_model(self.model)
        self.target.set_weights(self.model.get_weights()) # Initial sync

        self.opt = tf.keras.optimizers.Adam(learning_rate=lr)
        self.buffer = BatchedReplayBuffer()

        logdir = f"runs/batched_tf_{datetime.datetime.now():%Y%m%d_%H%M%S}"
        self.writer = tf.summary.create_file_writer(logdir)
        self.train_step_count = tf.Variable(0, dtype=tf.int64, trainable=False, name="train_step_count")

    @tf.function # Compile training step
    def train_step(self, S, A, R, S2, D):
        # S, A, R, S2, D shapes: [B*N, state], [B*N], [B*N], [B*N, state], [B*N]

        # Double DQN implementation
        Q2_target = self.target(S2) # Q'(s', a) from target network
        best_actions_next = tf.argmax(self.model(S2), axis=1, output_type=tf.int32) # a' = argmax_a Q(s', a) from online model

        # Get Q'(s', a') using tf.gather_nd
        action_indices = tf.stack([
            tf.range(tf.shape(best_actions_next)[0], dtype=tf.int32), # Use int32 for indices
            best_actions_next
        ], axis=1)
        Q2_best_target = tf.gather_nd(Q2_target, action_indices) # Q'(s', argmax_a Q(s', a))

        # Compute TD target: y = R + gamma * Q'(s', a') * (1 - D)
        target_Q = R + self.gamma * Q2_best_target * (1.0 - tf.cast(D, tf.float32))

        with tf.GradientTape() as tape:
            Q_online = self.model(S) # Q(s, a) from online network
            # Get Q(s, A) where A is the action actually taken
            action_indices_taken = tf.stack([
                 tf.range(tf.shape(A)[0], dtype=tf.int32), # Use int32 for indices
                 A # Action A should already be int32
            ], axis=1)
            Q_online_taken = tf.gather_nd(Q_online, action_indices_taken) # Q(s, A)

            # --- CORRECTED LOSS CALCULATION ---
            # Compute loss (MSE using the Keras class)
            loss = tf.keras.losses.MeanSquaredError()(target_Q, Q_online_taken)
            # The MeanSquaredError class instance handles reduction over the batch.
            # --- END CORRECTION ---

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.model.trainable_variables))

        # Polyak update target network weights
        updated_target_weights = []
        for w_online, w_target in zip(self.model.weights, self.target.weights):
            updated_target_weights.append(self.tau * w_online + (1.0 - self.tau) * w_target)
        self.target.set_weights(updated_target_weights)

        return loss

    # Pass eps as a tf.Tensor to avoid retracing
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, 6], dtype=tf.float32), # S
        tf.TensorSpec(shape=[], dtype=tf.float32)         # eps
    ])
    def act_batch(self, S, eps_tf):
        # S shape: [N, state_dim]
        batch_size = tf.shape(S)[0] # Should be N
        q_values = self.model(S) # [N, action_dim]

        # Epsilon-greedy
        random_actions = tf.random.uniform(shape=[batch_size], minval=0, maxval=self.action_dim, dtype=tf.int32)
        greedy_actions = tf.argmax(q_values, axis=1, output_type=tf.int32) # [N]
        # Use tf.random.uniform for comparison, ensuring graph mode compatibility
        choose_random = tf.random.uniform(shape=[batch_size], minval=0.0, maxval=1.0) < eps_tf # Compare with tensor eps_tf
        actions = tf.where(choose_random, random_actions, greedy_actions)

        return actions # [N] int32

    def remember_batch(self, S, A, R, S2, D):
        self.buffer.add_batch(S, A, R, S2, D)

    def learn(self, batch_size=128):
        # Sample B batches -> B*N transitions
        sampled_data = self.buffer.sample(batch_size)
        if sampled_data is None:
             # tf.print("DEBUG: Buffer too small, skipping learn") # Add TF print for debug inside @tf.function if needed
             return None # Not enough data

        S_sampled, A_sampled, R_sampled, S2_sampled, D_sampled = sampled_data

        # Perform one training step
        loss = self.train_step(S_sampled, A_sampled, R_sampled, S2_sampled, D_sampled)

        # Log loss and increment step counter
        step = self.train_step_count.numpy() # Read step value for logging
        with self.writer.as_default(step=step):
             tf.summary.scalar("Train/Loss", loss) # Log the scalar loss tensor
        self.train_step_count.assign_add(1) # Increment TF variable
        return loss.numpy() # Return loss value for potential printing

# -----------------------------------------------------------------------------
# 3) Training Loop
# -----------------------------------------------------------------------------
# import time # Ensure time is imported (should be done in main block)

def train_gpu_batched(
        grid_size=32, max_steps=400,       # Env params
        n_envs=64, episodes=1000,          # Batching/Loop params
        buffer_capacity=100000,             # Buffer param (passed to agent/buffer)
        learn_batch_size=64, learn_freq=4, # Learning control (learn every X env steps)
        eps0=1.0, eps_end=0.05, eps_steps=200000, # Epsilon decay over agent steps
        gamma=0.99, lr=1e-4, tau=0.005,     # Agent HPs
        log_every=20, render_every=100      # Logging/Rendering frequency (episodes)
        ):

    print(f"Starting training with N_Envs={n_envs}, Grid={grid_size}, MaxSteps={max_steps}")
    env = BatchedSculpt3DEnvTF(grid_size, max_steps, n_envs)
    print(f"DEBUG: Environment initialized.") # DEBUG PRINT
    agent = BatchedDQNAgentTF(state_dim=6, action_dim=6, lr=lr, gamma=gamma, tau=tau, n_envs=n_envs)
    print(f"DEBUG: Agent initialized.") # DEBUG PRINT
    agent.buffer.cap = buffer_capacity

    total_steps_taken = 0 # Tracks total interactions (N * parallel steps)
    eps = eps0 # Python float for calculation

    episode_rewards_history = [] # Store average reward per episode batch
    episode_lengths_history = [] # Store average length per episode batch
    start_time = time.time()

    for ep in range(1, episodes + 1):
        #print(f"\nDEBUG: Starting Episode Batch {ep}/{episodes}") # DEBUG PRINT
        ep_start_time = time.time() # Time each episode batch

        obs = env.reset()                     # [N, 6]
        #print(f"DEBUG: Episode Batch {ep} - env.reset() completed.") # DEBUG PRINT
        done = tf.zeros([n_envs], tf.bool)    # [N]
        ep_rewards = tf.zeros([n_envs], tf.float32) # Track rewards per env within this episode batch
        ep_steps = tf.zeros([n_envs], tf.int32)     # Track steps per env

        current_ep_step = 0 # Counter for steps within the current parallel episode batch
        inner_loop_start_time = time.time() # Time the inner loop

        while not tf.reduce_all(done):
            step_start_time = time.time() # Time each parallel step

            # Epsilon calculation based on total agent interactions (using Python float)
            eps = max(eps_end, eps0 - (eps0 - eps_end) * (total_steps_taken / eps_steps))

            # --- FIX FOR RETRACING ---
            # Convert Python float eps to a Tensor before passing to act_batch
            eps_tf = tf.constant(eps, dtype=tf.float32)
            # Choose action batch using epsilon-greedy
            A = agent.act_batch(obs, eps_tf) # Pass Tensor eps_tf
            # --- END FIX ---

            # Step the environment batch
            S2, R, next_done = env.step(A) # [N, 6], [N], [N]

            # Store the batch of transitions in the replay buffer
            agent.remember_batch(obs, A, R, S2, next_done)

            # Update rewards and steps for environments that were active this step
            active_mask = ~done
            ep_rewards += R * tf.cast(active_mask, tf.float32)
            ep_steps += tf.cast(active_mask, tf.int32)

            # Update observations and done flags for the next iteration
            obs = S2
            done = next_done # Use the new done flags from the env step

            # Increment counters
            total_steps_taken += n_envs
            current_ep_step += 1

            # --- SIMPLIFIED LEARNING FREQUENCY LOGIC ---
            learn_loss = None
            if current_ep_step > 0 and current_ep_step % learn_freq == 0:
                learn_start_time = time.time() # Time learning step
                learn_loss = agent.learn(learn_batch_size)
                learn_duration = time.time() - learn_start_time
                # Optional: Print learn duration if it seems slow
                # if learn_duration > 0.5: # If learn takes > 0.5 sec
                #    print(f"DEBUG: Ep {ep}, Step {current_ep_step} - Learn duration: {learn_duration:.3f}s")

            # --- DEBUG PRINT inside inner loop ---
            #step_duration = time.time() - step_start_time
            #if current_ep_step % 100 == 0: # Print every 100 steps
                 # Display learn_loss which might be None if learning didn't happen this step
                 #print(f"DEBUG: Ep {ep}, Step {current_ep_step}/{max_steps} | Step Time: {step_duration:.4f}s | Learn Loss: {learn_loss}")
            # --- END DEBUG PRINT ---


        # --- Episode Batch Finished ---
        inner_loop_duration = time.time() - inner_loop_start_time
        #print(f"DEBUG: Episode Batch {ep} finished inner loop ({current_ep_step} steps) in {inner_loop_duration:.2f}s.") # DEBUG PRINT

        avg_reward_batch = tf.reduce_mean(ep_rewards).numpy()
        avg_steps_batch = tf.reduce_mean(tf.cast(ep_steps, tf.float32)).numpy()
        episode_rewards_history.append(avg_reward_batch)
        episode_lengths_history.append(avg_steps_batch)

        # Log episode statistics
        if ep % log_every == 0 or ep == 1: # Log on first episode too
            elapsed_time = time.time() - start_time
            # Calculate stats over the last 'log_every' episode batches
            avg_r = np.mean(episode_rewards_history[-log_every:]) if episode_rewards_history else 0.0
            avg_l = np.mean(episode_lengths_history[-log_every:]) if episode_lengths_history else 0.0
            print(f"Ep {ep}/{episodes} | Avg R: {avg_r:.2f} | Avg Len: {avg_l:.1f} | Eps: {eps:.3f} | Steps: {total_steps_taken} | Time: {elapsed_time:.1f}s")

            # Log to TensorBoard (using agent's writer and agent step count)
            agent_step = agent.train_step_count.numpy()
            with agent.writer.as_default(step=agent_step):
                 tf.summary.scalar("Episode/AvgReward", avg_r)
                 tf.summary.scalar("Episode/AvgLength", avg_l)
                 tf.summary.scalar("Params/Epsilon", eps) # Log the python float eps value


        # --- Rendering ---
        if ep % render_every == 0:
            print(f"--- Rendering Placeholder at episode {ep} ---")
            # Find the environment index with the highest reward in the completed batch
            best_env_index = int(tf.argmax(ep_rewards).numpy())
            best_reward = ep_rewards[best_env_index].numpy()
            print(f"Rendering requires path reconstruction (currently disabled for performance).")
            print(f"Best env in last batch: Index {best_env_index}, Reward {best_reward:.2f}")
            # --- RENDER CODE NEEDS REWORK ---


    # Cleanup
    agent.writer.close()
    print(f"Training finished. Total steps: {total_steps_taken}")
    return agent


if __name__ == "__main__":
    # Ensure necessary libraries are imported at the top
    import os
    # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Keep commented out to see TF logs
    import tensorflow as tf
    import numpy as np
    import random
    import datetime
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
    import time

    # Set random seeds for reproducibility (optional)
    # seed = 42
    # np.random.seed(seed)
    # random.seed(seed)
    # tf.random.set_seed(seed)

    # --- Parameters for training ---
    # Consider reducing n_envs significantly first if suspecting OOM or slowness
    N_ENVS_TO_RUN = 128 # Original: 128. Try 32 or 16 for debugging.
    GRID_SIZE_TO_RUN = 32 # Original: 32. Try 20 for debugging.

    agent = train_gpu_batched(
        grid_size=GRID_SIZE_TO_RUN,     # Use variable
        max_steps=500,
        n_envs=N_ENVS_TO_RUN,           # Use variable
        episodes=1000,
        buffer_capacity=200000,
        learn_batch_size=64,
        learn_freq=4,
        eps_steps=500000,
        eps_end=0.02,
        lr=5e-5,
        tau=0.005,
        log_every=20,
        render_every=500
    )
    # Optional: Save the trained model
    # save_path = "sculpt_dqn_model.keras"
    # agent.model.save(save_path)
    # print(f"Model saved to {save_path}")

In [None]:
# %% imports (ensure these are imported in your notebook)
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import time
import os

# Assume BatchedSculpt3DEnvTF and BatchedDQNAgentTF classes are defined
# Assume trained_agent, GRID_SIZE_TO_RUN, MAX_STEPS_TRAIN are available
# from the previous training cell.

# %% Evaluation Function Definition

def evaluate_agent_stats(agent, grid_size, max_steps, num_test_episodes=10, render_env_index=0):
    """
    Evaluates a trained agent greedily, calculates performance statistics,
    and renders the final state of one specified environment.

    Args:
        agent: The trained BatchedDQNAgentTF object.
        grid_size: The grid size used during training.
        max_steps: The max steps per episode used during training.
        num_test_episodes: How many episodes to run for evaluation.
        render_env_index: The index of the environment (0 to num_test_episodes-1)
                           whose final state will be rendered.
    """
    print(f"\n--- Starting Agent Evaluation ---")
    print(f"Running {num_test_episodes} evaluation episodes...")

    # Create a test environment batch
    test_env = BatchedSculpt3DEnvTF(grid_size=grid_size, max_steps=max_steps, n_envs=num_test_episodes)

    # --- Get Initial State Information ---
    # We need the shape mask to know what *can* be carved.
    # Since reset() sets stock to all True, initial carvable material is just ~shape_mask.
    initial_shape_mask_flat_gpu = test_env.shape_mask[0] # Mask is same for all envs
    initial_shape_mask_flat_np = initial_shape_mask_flat_gpu.numpy()
    # Calculate initial carvable material count (where stock is True and mask is False)
    initial_carvable_mask_flat = ~initial_shape_mask_flat_np # Initially stock is True everywhere
    initial_carvable_count = np.sum(initial_carvable_mask_flat)
    print(f"Initial number of carvable voxels: {initial_carvable_count}")
    if initial_carvable_count == 0:
        print("Warning: No carvable material initially defined by the shape mask.")

    # --- Run Evaluation Episodes ---
    all_ep_rewards = []
    all_ep_lengths = []
    all_ep_removed_counts = []
    all_ep_removal_percentages = []

    # References to final state variables
    final_stock_variable_test = test_env.stock
    final_shape_mask_variable_test = test_env.shape_mask # Although constant, get reference

    for ep in range(num_test_episodes):
        obs = test_env.reset()
        done = tf.zeros([num_test_episodes], tf.bool)
        ep_rewards = tf.zeros([num_test_episodes], tf.float32)
        ep_steps = tf.zeros([num_test_episodes], tf.int32)
        current_ep_step = 0

        while not tf.reduce_all(done):
            # Act greedily (epsilon = 0)
            eps_tf_zero = tf.constant(0.0, dtype=tf.float32)
            A = agent.act_batch(obs, eps_tf_zero)
            S2, R, next_done = test_env.step(A)

            active_mask = ~done
            ep_rewards += R * tf.cast(active_mask, tf.float32)
            ep_steps += tf.cast(active_mask, tf.int32)

            obs = S2
            done = next_done
            current_ep_step += 1

        # --- Calculate Stats for this Batch ---
        final_stock_batch_np = final_stock_variable_test.numpy() # Get final stock for all envs in batch
        batch_rewards = ep_rewards.numpy()
        batch_lengths = ep_steps.numpy()

        all_ep_rewards.extend(batch_rewards.tolist())
        all_ep_lengths.extend(batch_lengths.tolist())

        # Calculate removal stats per environment
        for i in range(num_test_episodes):
            final_stock_flat_np = final_stock_batch_np[i]
            # Material that was initially carvable and is now removed (~final_stock)
            removed_mask = initial_carvable_mask_flat & (~final_stock_flat_np)
            removed_count = np.sum(removed_mask)
            all_ep_removed_counts.append(removed_count)

            if initial_carvable_count > 0:
                removal_percentage = (removed_count / initial_carvable_count) * 100.0
                all_ep_removal_percentages.append(removal_percentage)
            else:
                all_ep_removal_percentages.append(0.0) # Avoid division by zero

        print(f"Eval Episode {ep+1}/{num_test_episodes} finished. Steps: {current_ep_step}") # Removed rewards print here

    # --- Aggregate and Print Statistics ---
    if not all_ep_rewards:
        print("\nNo evaluation episodes were run.")
        return

    avg_reward = np.mean(all_ep_rewards)
    avg_length = np.mean(all_ep_lengths)
    avg_removed_count = np.mean(all_ep_removed_counts)
    avg_removal_percentage = np.mean(all_ep_removal_percentages)
    std_reward = np.std(all_ep_rewards)
    std_removal_percentage = np.std(all_ep_removal_percentages)


    print(f"\n--- Evaluation Results ({num_test_episodes} episodes) ---")
    print(f"Average Reward: {avg_reward:.2f} (+/- {std_reward:.2f})")
    print(f"Average Length: {avg_length:.1f}")
    print(f"Average Carvable Voxels Removed: {avg_removed_count:.1f} / {initial_carvable_count}")
    print(f"Average Removal Percentage: {avg_removal_percentage:.2f}% (+/- {std_removal_percentage:.2f}%)")

    # --- Render Final State ---
    print(f"\n--- Rendering final state for Eval Env Index: {render_env_index} ---")
    if render_env_index < 0 or render_env_index >= num_test_episodes:
         print(f"Error: render_env_index ({render_env_index}) out of bounds for {num_test_episodes} test environments.")
         return

    try:
        # Get final stock and shape mask for the specific env index from the *last* batch run
        final_stock_flat_np = final_stock_batch_np[render_env_index]
        # Mask is constant, can re-use from initial calculation
        # final_shape_mask_flat_np = final_shape_mask_variable_test[0].numpy()

        # Reshape to 3D grid
        G = grid_size
        final_stock_3d = final_stock_flat_np.reshape((G, G, G))
        final_shape_mask_3d = initial_shape_mask_flat_np.reshape((G, G, G))

        # Create boolean arrays for plotting:
        shape_to_plot = final_shape_mask_3d
        removed_material_to_plot = ~final_stock_3d & ~final_shape_mask_3d

        # --- Matplotlib Visualization ---
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111, projection='3d')
        x, y, z = np.indices(np.array(shape_to_plot.shape) + 1) # Edges

        # Plot the target shape (slightly transparent blue)
        ax.voxels(x, y, z, shape_to_plot, facecolors='blue', alpha=0.15, edgecolor=None)
        # Plot the removed material (more opaque red)
        ax.voxels(x, y, z, removed_material_to_plot, facecolors='red', alpha=0.5, edgecolor=None)

        # Use the reward from the specific rendered env
        rendered_env_reward = all_ep_rewards[render_env_index]
        rendered_env_perc = all_ep_removal_percentages[render_env_index]
        title_reward = f"{rendered_env_reward:.2f}"
        ax.set_title(f"Eval Render Env #{render_env_index} (R={title_reward}, Removed={rendered_env_perc:.1f}%)")

        ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
        ax.set_xlim(0, G); ax.set_ylim(0, G); ax.set_zlim(0, G)
        ax.set_aspect('auto')

        plt.tight_layout()
        # Save figure
        render_dir = "renders_eval"
        if not os.path.exists(render_dir):
            os.makedirs(render_dir)
        save_path = os.path.join(render_dir, f"eval_render_env_{render_env_index}.png")
        plt.savefig(save_path)
        print(f"Saved evaluation render to {save_path}")
        plt.close(fig) # Close figure

    except Exception as e:
        print(f"Error during rendering: {e}")
        import traceback
        traceback.print_exc()


# %% Run Evaluation (Example Call)
# Make sure trained_agent, GRID_SIZE_TO_RUN, and MAX_STEPS_TRAIN are defined
# from your training process before running this cell.

NUM_EVAL_EPISODES = 10 # Number of episodes to average over for evaluation
RENDER_INDEX = 0       # Which episode's final state to render (0 to NUM_EVAL_EPISODES-1)

if 'trained_agent' in locals() or 'trained_agent' in globals():
     evaluate_agent_stats(
         agent=trained_agent,
         grid_size=GRID_SIZE_TO_RUN,
         max_steps=MAX_STEPS_TRAIN,
         num_test_episodes=NUM_EVAL_EPISODES,
         render_env_index=RENDER_INDEX
     )
else:
     print("Variable 'trained_agent' not found. Please train the agent first.")

