In [None]:
#!/usr/bin/env python3 # CURRENT
"""
Fully‑GPU Batched Sculpt3DEnv + DQN Agent (Hybrid State & Noisy Nets)

• State Representation: Hybrid - CNN processes Stock/Mask grids, concatenates
                       with normalized XYZ coordinates.
• Agent Network: Uses a 3D CNN feature extractor + Dense head with Noisy Layers.
• Exploration: Uses Noisy Networks instead of epsilon-greedy.
• N envs stepped in parallel during training.
• Batched replay buffer insertion.
• Includes periodic evaluation during training and final evaluation/rendering.
• Plots carving performance trend at the end of training.

NOTE: Monitor memory usage. Tune hyperparameters (LR, network, etc.).
      Periodic evaluation adds runtime overhead.
"""

import os
# Suppress TensorFlow INFO/WARNING messages
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
import numpy as np
import random
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import time
import math # For noisy layer initialization

# --- Configure GPU Memory Growth ---
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs configured with Memory Growth.")
  except RuntimeError as e:
    print(f"Memory growth error: {e}") # Might happen if GPU already initialized
else:
    print("No GPU detected by TensorFlow.")


# -----------------------------------------------------------------------------
# 1) Batched Sculpt3DEnvTF (Hybrid Observation)
# -----------------------------------------------------------------------------
class BatchedSculpt3DEnvTF:
    def __init__(self, grid_size=16, max_steps=200, n_envs=16):
        G, N = grid_size, n_envs
        if N <= 0: raise ValueError("n_envs must be positive.")
        self.G, self.N, self.max_steps = G, N, max_steps
        self.flat_dim = G*G*G
        self.grid_obs_shape = (G, G, G, 2) # Channels: Stock, ShapeMask
        self.coord_obs_shape = (3,)        # Channels: X, Y, Z (normalized)

        coords_range = tf.range(G, dtype=tf.float32)
        coords = tf.stack(tf.meshgrid(coords_range, coords_range, coords_range, indexing='ij'), axis=-1)
        center = tf.constant([G/2 - 0.5, G/2 - 0.5, G/2 - 0.5], tf.float32)
        dist2 = tf.reduce_sum(tf.square(coords - center), axis=-1)
        radius_sq = tf.square(tf.cast(G // 2 - 1, tf.float32))
        mask3d = dist2 <= radius_sq
        mask_flat = tf.reshape(mask3d, [-1])

        self.shape_mask = tf.Variable(tf.tile(mask_flat[None, :], [N, 1]), trainable=False, dtype=tf.bool, name="shape_mask")
        self.stock = tf.Variable(tf.ones([N, self.flat_dim], dtype=tf.bool), trainable=False, name="stock")
        self.pos = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="pos")
        self.steps = tf.Variable(tf.zeros([N], dtype=tf.int32), trainable=False, name="steps")
        self.done = tf.Variable(tf.zeros([N], dtype=tf.bool), trainable=False, name="done")

        G_py = grid_size
        def to_flat_py(dx, dy, dz): return dx*G_py*G_py + dy*G_py + dz
        moves = [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]
        shifts_py = [to_flat_py(*m) for m in moves]
        self.shifts = tf.constant(shifts_py, dtype=tf.int32, name="shifts")

    @tf.function
    def reset(self):
        self.stock.assign(tf.ones_like(self.stock))
        self.steps.assign(tf.zeros_like(self.steps))
        self.done.assign(tf.zeros_like(self.done))
        safe_indices = tf.where(~self.shape_mask[0])[:, 0]
        num_safe = tf.shape(safe_indices)[0]
        tf.Assert(num_safe >= self.N, ["Not enough safe starting positions available."])
        shuffled_safe_indices = tf.random.shuffle(safe_indices)[:self.N]
        self.pos.assign(tf.cast(shuffled_safe_indices, tf.int32))
        return self._get_obs()

    @tf.function
    def step(self, actions):
        action_shifts = tf.gather(self.shifts, actions)
        new_pos = self.pos + action_shifts
        in_bounds = (new_pos >= 0) & (new_pos < self.flat_dim)
        safe_new_pos = tf.clip_by_value(new_pos, 0, self.flat_dim - 1)
        shape_mask_at_new = tf.gather(self.shape_mask, safe_new_pos, axis=1, batch_dims=1)
        stock_at_new = tf.gather(self.stock, safe_new_pos, axis=1, batch_dims=1)
        hit_shape_or_oob = ~in_bounds | shape_mask_at_new
        can_remove = ~hit_shape_or_oob & stock_at_new
        reward = tf.where(hit_shape_or_oob, -5.0, 0.0)
        reward = tf.where(can_remove, reward + 1.0, reward)
        reward = reward - 0.1
        remove_indices = tf.where(can_remove)
        num_removals = tf.shape(remove_indices)[0]
        def perform_update():
            pos_to_remove = tf.gather(new_pos, tf.squeeze(remove_indices, axis=1))
            scatter_indices = tf.concat([tf.cast(remove_indices, tf.int32), tf.expand_dims(pos_to_remove, axis=1)], axis=1)
            updates = tf.zeros(num_removals, dtype=tf.bool)
            return tf.tensor_scatter_nd_update(self.stock, scatter_indices, updates)
        maybe_updated_stock = tf.cond(num_removals > 0, true_fn=perform_update, false_fn=lambda: self.stock)
        self.stock.assign(maybe_updated_stock)
        is_valid_move = ~hit_shape_or_oob
        next_pos = tf.where(is_valid_move, new_pos, self.pos)
        self.pos.assign(next_pos)
        self.steps.assign_add(tf.ones_like(self.steps))
        newly_done = (self.steps >= self.max_steps)
        self.done.assign(self.done | newly_done)
        next_obs = self._get_obs()
        return next_obs, tf.cast(reward, tf.float32), self.done

    @tf.function
    def _get_obs(self):
        G = self.G; N = self.N
        stock_grid = tf.reshape(self.stock, [N, G, G, G])
        shape_mask_grid = tf.reshape(self.shape_mask, [N, G, G, G])
        stock_float = tf.cast(stock_grid, tf.float32)
        shape_mask_float = tf.cast(shape_mask_grid, tf.float32)
        grid_obs = tf.stack([stock_float, shape_mask_float], axis=-1)
        g_tf = tf.constant(G, dtype=tf.int32)
        z = self.pos % g_tf
        y = (self.pos // g_tf) % g_tf
        x = self.pos // (g_tf * g_tf)
        g_minus_1_float = tf.cast(tf.maximum(1, G - 1), tf.float32)
        x_norm = tf.cast(x, tf.float32) / g_minus_1_float
        y_norm = tf.cast(y, tf.float32) / g_minus_1_float
        z_norm = tf.cast(z, tf.float32) / g_minus_1_float
        coord_obs = tf.stack([x_norm, y_norm, z_norm], axis=-1)
        return (grid_obs, coord_obs)

# -----------------------------------------------------------------------------
# 2) Replay Buffer (Handling Tuple State)
# -----------------------------------------------------------------------------
class BatchedReplayBuffer:
    def __init__(self, capacity=50000):
        self.cap = capacity; self.buf = []
    def add_batch(self, S_tuple, A, R, S2_tuple, D):
        if len(self.buf) >= self.cap: self.buf.pop(0)
        self.buf.append((S_tuple, A, R, S2_tuple, D))
    def sample(self, batch_size=32):
        num = len(self.buf)
        if num < batch_size: return None
        indices = random.sample(range(num), batch_size)
        batch = [self.buf[i] for i in indices]
        S_tuple_list, A_list, R_list, S2_tuple_list, D_list = zip(*batch)
        S_grid_list, S_coord_list = zip(*S_tuple_list)
        S2_grid_list, S2_coord_list = zip(*S2_tuple_list)
        return (tf.concat(S_grid_list, axis=0), tf.concat(S_coord_list, axis=0),
                tf.concat(A_list, axis=0), tf.concat(R_list, axis=0),
                tf.concat(S2_grid_list, axis=0), tf.concat(S2_coord_list, axis=0),
                tf.concat(D_list, axis=0))
    def __len__(self): return len(self.buf)

# -----------------------------------------------------------------------------
# 3) Noisy Dense Layer (Factorized Gaussian Noise)
# -----------------------------------------------------------------------------
class NoisyDense(tf.keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.sigma0 = 0.5
    def build(self, input_shape):
        in_features = input_shape[-1]; out_features = self.units; dtype = tf.float32
        sigma_init_val = self.sigma0 / math.sqrt(float(in_features))
        sigma_init = tf.constant_initializer(sigma_init_val)
        self.kernel_mean = self.add_weight(name="kernel_mean", shape=(in_features, out_features), initializer="he_uniform", trainable=True, dtype=dtype)
        self.bias_mean = self.add_weight(name="bias_mean", shape=(out_features,), initializer="zeros", trainable=True, dtype=dtype)
        self.kernel_sigma = self.add_weight(name="kernel_sigma", shape=(in_features, out_features), initializer=sigma_init, trainable=True, dtype=dtype)
        self.bias_sigma = self.add_weight(name="bias_sigma", shape=(out_features,), initializer=sigma_init, trainable=True, dtype=dtype)
        super().build(input_shape)
    def call(self, inputs, training=None):
        if training is None: training = tf.keras.backend.learning_phase()
        if training:
            noise_in = self._get_noise(tf.shape(inputs)[-1])
            noise_out = self._get_noise(self.units)
            kernel_noise = tf.tensordot(tf.expand_dims(noise_in, -1), tf.expand_dims(noise_out, 0), axes=1)
            bias_noise = noise_out
            kernel = self.kernel_mean + self.kernel_sigma * kernel_noise
            bias = self.bias_mean + self.bias_sigma * bias_noise
        else:
            kernel = self.kernel_mean; bias = self.bias_mean
        output = tf.matmul(inputs, kernel) + bias
        if self.activation is not None: output = self.activation(output)
        return output
    def _get_noise(self, num_elements):
        noise = tf.random.normal(shape=[num_elements]) # Use rank-1 shape
        return tf.sign(noise) * tf.sqrt(tf.abs(noise))
    def compute_output_shape(self, input_shape): return tuple(input_shape[:-1]) + (self.units,)

# -----------------------------------------------------------------------------
# 4) Batched DQN Agent (Hybrid State CNN + Noisy Nets)
# -----------------------------------------------------------------------------
class BatchedDQNAgentTF:
    def __init__(self, grid_shape, coord_shape, action_dim=6,
                 lr=1e-4, gamma=0.99, tau=0.005):
        self.grid_shape = grid_shape; self.coord_shape = coord_shape
        self.action_dim = action_dim; self.gamma = gamma; self.tau = tau
        def build_hybrid_noisy_model():
            grid_input = tf.keras.layers.Input(shape=self.grid_shape, name="grid_input")
            coord_input = tf.keras.layers.Input(shape=self.coord_shape, name="coord_input")
            x_cnn = tf.keras.layers.Conv3D(filters=32, kernel_size=5, strides=2, activation='relu', padding='same')(grid_input)
            x_cnn = tf.keras.layers.Conv3D(filters=64, kernel_size=3, strides=2, activation='relu', padding='same')(x_cnn)
            x_cnn = tf.keras.layers.Conv3D(filters=64, kernel_size=3, strides=1, activation='relu', padding='same')(x_cnn)
            cnn_features = tf.keras.layers.Flatten()(x_cnn)
            concat_features = tf.keras.layers.Concatenate()([cnn_features, coord_input])
            x = NoisyDense(512, activation='relu')(concat_features)
            outputs = NoisyDense(action_dim, activation='linear')(x)
            return tf.keras.Model(inputs=[grid_input, coord_input], outputs=outputs)
        self.model = build_hybrid_noisy_model()
        self.target = build_hybrid_noisy_model()
        self.target.set_weights(self.model.get_weights())
        self.opt = tf.keras.optimizers.Adam(learning_rate=lr)
        self.buffer = BatchedReplayBuffer()
        logdir = f"runs/hybrid_noisy_dqn_{datetime.datetime.now():%Y%m%d_%H%M%S}"
        self.writer = tf.summary.create_file_writer(logdir)
        self.train_step_count = tf.Variable(0, dtype=tf.int64, trainable=False)

    @tf.function
    def train_step(self, S_grid, S_coord, A, R, S2_grid, S2_coord, D):
        Q2_target = self.target([S2_grid, S2_coord], training=True)
        Q2_online = self.model([S2_grid, S2_coord], training=True)
        best_actions_next = tf.argmax(Q2_online, axis=1, output_type=tf.int32)
        action_indices = tf.stack([tf.range(tf.shape(best_actions_next)[0], dtype=tf.int32), best_actions_next], axis=1)
        Q2_best_target = tf.gather_nd(Q2_target, action_indices)
        target_Q = R + self.gamma * Q2_best_target * (1.0 - tf.cast(D, tf.float32))
        with tf.GradientTape() as tape:
            Q_online = self.model([S_grid, S_coord], training=True)
            action_indices_taken = tf.stack([tf.range(tf.shape(A)[0], dtype=tf.int32), A], axis=1)
            Q_online_taken = tf.gather_nd(Q_online, action_indices_taken)
            loss = tf.keras.losses.MeanSquaredError()(target_Q, Q_online_taken)
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.model.trainable_variables))
        updated_target_weights = []
        for w_online, w_target in zip(self.model.weights, self.target.weights):
            updated_target_weights.append(self.tau * w_online + (1.0 - self.tau) * w_target)
        self.target.set_weights(updated_target_weights)
        return loss

    @tf.function
    def act_batch(self, S_tuple, deterministic=False):
        q_values = self.model(S_tuple, training=not deterministic)
        actions = tf.argmax(q_values, axis=1, output_type=tf.int32)
        return actions

    def remember_batch(self, S_tuple, A, R, S2_tuple, D):
        self.buffer.add_batch(S_tuple, A, R, S2_tuple, D)

    def learn(self, batch_size=32):
        if len(self.buffer) < batch_size: return None
        sampled_data = self.buffer.sample(batch_size)
        if sampled_data is None: return None
        S_grid_s, S_coord_s, A_s, R_s, S2_grid_s, S2_coord_s, D_s = sampled_data
        loss = self.train_step(S_grid_s, S_coord_s, A_s, R_s, S2_grid_s, S2_coord_s, D_s)
        step = self.train_step_count.numpy()
        with self.writer.as_default(step=step): tf.summary.scalar("Train/Loss", loss)
        self.train_step_count.assign_add(1)
        return loss.numpy()

# -----------------------------------------------------------------------------
# 5) Evaluation Function (Moved Before Training Loop & Returns Stats)
# -----------------------------------------------------------------------------
def evaluate_agent_performance(agent, grid_size, max_steps, num_eval_episodes=10, render=True, render_env_index=0):
    """
    Evaluates a trained agent deterministically, calculates performance statistics,
    and optionally renders the final state of one specified environment.
    Handles hybrid state (grid, coords). Returns key statistics.
    """
    print(f"\n--- Running Evaluation ({num_eval_episodes} episodes) ---")
    eval_start_time = time.time()

    eval_env = BatchedSculpt3DEnvTF(grid_size=grid_size, max_steps=max_steps, n_envs=num_eval_episodes)

    initial_shape_mask_flat_gpu = eval_env.shape_mask[0]
    initial_shape_mask_flat_np = initial_shape_mask_flat_gpu.numpy()
    initial_carvable_mask_flat = ~initial_shape_mask_flat_np
    initial_carvable_count = np.sum(initial_carvable_mask_flat)
    print(f"  Initial number of carvable voxels: {initial_carvable_count}")
    if initial_carvable_count == 0: print("  Warning: No carvable material defined.")

    all_ep_rewards = []
    all_ep_lengths = []
    all_ep_removed_counts = []
    all_ep_incorrect_removed_counts = []

    final_stock_variable_eval = eval_env.stock

    obs_tuple = eval_env.reset()
    done = tf.zeros([num_eval_episodes], tf.bool)
    ep_rewards = tf.zeros([num_eval_episodes], tf.float32)
    ep_steps = tf.zeros([num_eval_episodes], tf.int32)
    current_ep_step = 0

    while not tf.reduce_all(done):
        if current_ep_step >= max_steps: break
        A = agent.act_batch(obs_tuple, deterministic=True) # Use deterministic actions
        S2_tuple, R, next_done = eval_env.step(A)
        active_mask = ~done
        ep_rewards += R * tf.cast(active_mask, tf.float32)
        ep_steps += tf.cast(active_mask, tf.int32)
        obs_tuple = S2_tuple
        done = next_done
        current_ep_step += 1

    final_stock_batch_np = final_stock_variable_eval.numpy()
    batch_rewards = ep_rewards.numpy()
    batch_lengths = ep_steps.numpy()
    all_ep_rewards.extend(batch_rewards.tolist())
    all_ep_lengths.extend(batch_lengths.tolist())

    for i in range(num_eval_episodes):
        final_stock_flat_np = final_stock_batch_np[i]
        removed_mask = initial_carvable_mask_flat & (~final_stock_flat_np)
        removed_count = np.sum(removed_mask)
        all_ep_removed_counts.append(removed_count)
        incorrectly_removed_mask = initial_shape_mask_flat_np & (~final_stock_flat_np)
        incorrectly_removed_count = np.sum(incorrectly_removed_mask)
        all_ep_incorrect_removed_counts.append(incorrectly_removed_count)

    # Aggregate and Print Statistics
    avg_reward = np.mean(all_ep_rewards); std_reward = np.std(all_ep_rewards)
    avg_length = np.mean(all_ep_lengths)
    avg_removed_count = np.mean(all_ep_removed_counts)
    avg_incorrect_removed = np.mean(all_ep_incorrect_removed_counts)
    avg_removal_percentage = (avg_removed_count / initial_carvable_count) * 100.0 if initial_carvable_count > 0 else 0.0
    std_removal_percentage = np.std([(c / initial_carvable_count)*100.0 if initial_carvable_count > 0 else 0.0 for c in all_ep_removed_counts])

    print(f"\n--- Evaluation Results ---")
    print(f"  Avg Reward : {avg_reward:.2f} (+/- {std_reward:.2f})")
    print(f"  Avg Length : {avg_length:.1f}")
    print(f"  Avg Removed: {avg_removed_count:.1f} / {initial_carvable_count} ({avg_removal_percentage:.2f}% +/- {std_removal_percentage:.2f}%)")
    print(f"  Avg Incorrect: {avg_incorrect_removed:.1f}")
    eval_duration = time.time() - eval_start_time
    print(f"  Evaluation Duration: {eval_duration:.2f}s")

    # Render Final State
    if render:
        print(f"\n--- Rendering final state for Eval Env Index: {render_env_index} ---")
        if render_env_index < 0 or render_env_index >= num_eval_episodes:
             print(f"Error: render_env_index ({render_env_index}) out of bounds.")
        elif num_eval_episodes > 0:
            try:
                final_stock_flat_np = final_stock_batch_np[render_env_index]
                shape_mask_flat_np = initial_shape_mask_flat_np
                G = grid_size
                final_stock_3d = final_stock_flat_np.reshape((G, G, G))
                shape_mask_3d = shape_mask_flat_np.reshape((G, G, G))
                shape_to_plot = shape_mask_3d
                initial_carvable_mask_render = ~shape_mask_3d
                removed_mask_render = initial_carvable_mask_render & (~final_stock_3d)
                incorrectly_removed_mask_render = shape_mask_3d & (~final_stock_3d)

                fig = plt.figure(figsize=(9, 7)); ax = fig.add_subplot(111, projection='3d')
                ax.set_facecolor('whitesmoke')
                x_vox, y_vox, z_vox = np.indices(np.array(shape_to_plot.shape) + 1)
                ax.voxels(x_vox, y_vox, z_vox, shape_to_plot, facecolors='blue', alpha=0.1, edgecolor=None)
                ax.voxels(x_vox, y_vox, z_vox, removed_mask_render, facecolors='red', alpha=0.6, edgecolor=None)
                if np.sum(incorrectly_removed_mask_render) > 0:
                     ax.voxels(x_vox, y_vox, z_vox, incorrectly_removed_mask_render, facecolors='yellow', alpha=0.7, edgecolor='orange', label='Incorrect Removal')
                     ax.legend()

                rendered_env_reward = all_ep_rewards[render_env_index]
                rendered_env_perc = (all_ep_removed_counts[render_env_index] / initial_carvable_count) * 100.0 if initial_carvable_count > 0 else 0.0
                ax.set_title(f"Eval Render Env #{render_env_index} (R={rendered_env_reward:.1f}, Removed={rendered_env_perc:.1f}%)")
                ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
                ax.set_xlim(0, G); ax.set_ylim(0, G); ax.set_zlim(0, G); ax.set_aspect('auto')
                plt.tight_layout()
                render_dir = "renders_eval"
                if not os.path.exists(render_dir): os.makedirs(render_dir)
                save_path = os.path.join(render_dir, f"eval_render_env_{render_env_index}_final.png")
                plt.savefig(save_path); print(f"Saved evaluation render to {save_path}"); plt.close(fig)
            except Exception as e: print(f"Error during rendering: {e}"); import traceback; traceback.print_exc()

    # <<< MODIFIED: Return statistics >>>
    return {
        "avg_reward": avg_reward,
        "std_reward": std_reward,
        "avg_length": avg_length,
        "avg_removed_count": avg_removed_count,
        "avg_incorrect_removed": avg_incorrect_removed,
        "avg_removal_percentage": avg_removal_percentage,
        "std_removal_percentage": std_removal_percentage,
        "initial_carvable_count": initial_carvable_count
    }


# -----------------------------------------------------------------------------
# 6) Training Loop (Using Hybrid State & Noisy Agent)
# -----------------------------------------------------------------------------
def train_gpu_batched(
        grid_size=16, max_steps=300, n_envs=32, episodes=10000,
        buffer_capacity=100000, learn_batch_size=32, learn_freq=4,
        gamma=0.99, lr=1e-4, tau=0.005, log_every=50,
        evaluate_every=100,
        num_eval_episodes_periodic=10,
        render_intermediate_eval=False
        ):

    print(f"--- Training Hybrid Noisy DQN Agent ---")
    print(f"Params: Grid={grid_size}, N_Envs={n_envs}, MaxSteps={max_steps}, Episodes={episodes}")
    print(f"Learn Batch Size (B): {learn_batch_size}, Total samples/learn step (B*N): {learn_batch_size * n_envs}")
    print(f"Periodic Evaluation every {evaluate_every} episodes ({num_eval_episodes_periodic} eps each).")
    print(f"Memory Warning: Ensure sufficient CPU RAM and GPU VRAM.")

    env = BatchedSculpt3DEnvTF(grid_size, max_steps, n_envs)
    agent = BatchedDQNAgentTF(grid_shape=env.grid_obs_shape, coord_shape=env.coord_obs_shape,
                              action_dim=6, lr=lr, gamma=gamma, tau=tau)
    agent.buffer.cap = buffer_capacity

    total_steps_taken = 0
    episode_rewards_history = []
    episode_lengths_history = []
    start_time = time.time()

    # <<< New lists to store evaluation results >>>
    eval_episodes_list = []
    eval_avg_rewards_list = []
    eval_avg_removal_perc_list = []

    for ep in range(1, episodes + 1):
        obs_tuple = env.reset()
        done = tf.zeros([n_envs], tf.bool)
        ep_rewards = tf.zeros([n_envs], tf.float32)
        ep_steps = tf.zeros([n_envs], tf.int32)
        current_ep_step = 0

        while not tf.reduce_all(done):
            A = agent.act_batch(obs_tuple, deterministic=False)
            S2_tuple, R, next_done = env.step(A)
            agent.remember_batch(obs_tuple, A, R, S2_tuple, next_done)
            active_mask = ~done
            ep_rewards += R * tf.cast(active_mask, tf.float32)
            ep_steps += tf.cast(active_mask, tf.int32)
            obs_tuple = S2_tuple
            done = next_done
            total_steps_taken += n_envs
            current_ep_step += 1
            if current_ep_step > 0 and current_ep_step % learn_freq == 0:
                loss_val = agent.learn(learn_batch_size)

        avg_reward_batch = tf.reduce_mean(ep_rewards).numpy()
        avg_steps_batch = tf.reduce_mean(tf.cast(ep_steps, tf.float32)).numpy()
        episode_rewards_history.append(avg_reward_batch)
        episode_lengths_history.append(avg_steps_batch)

        if ep % log_every == 0 or ep == 1:
            elapsed_time = time.time() - start_time
            avg_r = np.mean(episode_rewards_history[-log_every:]) if episode_rewards_history else 0.0
            avg_l = np.mean(episode_lengths_history[-log_every:]) if episode_lengths_history else 0.0
            print(f"Ep {ep}/{episodes} | Avg R: {avg_r:.2f} | Avg Len: {avg_l:.1f} | Steps: {total_steps_taken} | Time: {elapsed_time:.1f}s")
            agent_step = agent.train_step_count.numpy()
            with agent.writer.as_default(step=agent_step):
                 tf.summary.scalar("Episode/AvgReward", avg_r)
                 tf.summary.scalar("Episode/AvgLength", avg_l)

        # --- Periodic Evaluation Call & Data Storage ---
        if evaluate_every > 0 and ep % evaluate_every == 0:
            eval_stats = evaluate_agent_performance(
                agent=agent,
                grid_size=grid_size,
                max_steps=max_steps,
                num_eval_episodes=num_eval_episodes_periodic,
                render=render_intermediate_eval,
                render_env_index=0
            )
            if eval_stats: # Check if evaluation ran successfully
                eval_episodes_list.append(ep)
                eval_avg_rewards_list.append(eval_stats["avg_reward"])
                eval_avg_removal_perc_list.append(eval_stats["avg_removal_percentage"])
                # Log eval stats to TensorBoard too
                agent_step = agent.train_step_count.numpy() # Use agent step for consistency
                with agent.writer.as_default(step=agent_step):
                    tf.summary.scalar("Evaluate/AvgReward", eval_stats["avg_reward"])
                    tf.summary.scalar("Evaluate/AvgRemovalPercentage", eval_stats["avg_removal_percentage"])
                    tf.summary.scalar("Evaluate/AvgIncorrectRemoved", eval_stats["avg_incorrect_removed"])

            print("-" * 60) # Separator after evaluation

    # --- End of Training ---
    agent.writer.close()
    print(f"Training finished. Total steps: {total_steps_taken}")

    # --- Plot Evaluation Trend ---
    if eval_episodes_list:
        print("\n--- Plotting Evaluation Trend ---")
        try:
            fig, ax1 = plt.subplots(figsize=(12, 6))

            color = 'tab:red'
            ax1.set_xlabel('Training Episode')
            ax1.set_ylabel('Avg Carvable Material Removed (%)', color=color)
            ax1.plot(eval_episodes_list, eval_avg_removal_perc_list, color=color, marker='o', linestyle='-', label='Removal %')
            ax1.tick_params(axis='y', labelcolor=color)
            ax1.grid(True, axis='y', linestyle=':')

            # Optional: Plot average reward on second y-axis
            ax2 = ax1.twinx()
            color = 'tab:blue'
            ax2.set_ylabel('Avg Evaluation Reward', color=color)
            ax2.plot(eval_episodes_list, eval_avg_rewards_list, color=color, marker='x', linestyle='--', label='Avg Reward')
            ax2.tick_params(axis='y', labelcolor=color)

            fig.suptitle('Agent Evaluation Performance During Training')
            fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
            # Add legend if both axes are plotted
            # lines, labels = ax1.get_legend_handles_labels()
            # lines2, labels2 = ax2.get_legend_handles_labels()
            # ax2.legend(lines + lines2, labels + labels2, loc='center right')

            plot_dir = "plots"
            if not os.path.exists(plot_dir): os.makedirs(plot_dir)
            plot_path = os.path.join(plot_dir, f"evaluation_trend_g{grid_size}_n{n_envs}.png")
            plt.savefig(plot_path)
            print(f"Saved evaluation trend plot to {plot_path}")
            plt.close(fig)

        except Exception as e:
            print(f"Error plotting evaluation trend: {e}")

    return trained_agent

# -----------------------------------------------------------------------------
# 7) Main Execution Block (Adjusted Parameters)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # --- Parameters (Adjust based on your hardware!) ---
    GRID_SIZE_RUN = 8
    N_ENVS_RUN = 16
    MAX_STEPS_RUN = 500
    EPISODES_RUN = 5000
    BUFFER_CAP_RUN = 100000
    LEARN_BATCH_RUN = 32
    LEARNING_RATE = 1e-4
    EVAL_FREQ_RUN = 50 # Evaluate every 100 training episodes
    NUM_EVAL_EPISODES_RUN = 10 # Run 10 episodes for each evaluation

    print(f"Starting run with Hybrid State + Noisy Nets")
    print(f"ACTUAL Params: Grid={GRID_SIZE_RUN}, N_Envs={N_ENVS_RUN}, LearnBatch(B)={LEARN_BATCH_RUN}")
    print(f"Total samples per learn step (B*N): {LEARN_BATCH_RUN * N_ENVS_RUN}")

    # Train the agent, with periodic evaluation
    trained_agent = train_gpu_batched(
        grid_size=GRID_SIZE_RUN,
        max_steps=MAX_STEPS_RUN,
        n_envs=N_ENVS_RUN,
        episodes=EPISODES_RUN,
        buffer_capacity=BUFFER_CAP_RUN,
        learn_batch_size=LEARN_BATCH_RUN,
        learn_freq=4,
        gamma=0.99,
        lr=LEARNING_RATE,
        tau=0.005,
        log_every=50,
        evaluate_every=EVAL_FREQ_RUN,
        num_eval_episodes_periodic=NUM_EVAL_EPISODES_RUN,
        render_intermediate_eval=False # Keep intermediate rendering off by default
    )

    # --- Run Final Evaluation After Training ---
    print("\n" + "="*70)
    print("      RUNNING FINAL EVALUATION ON TRAINED AGENT")
    print("="*70)
    if trained_agent:
         evaluate_agent_performance(
             agent=trained_agent,
             grid_size=GRID_SIZE_RUN,
             max_steps=MAX_STEPS_RUN,
             num_eval_episodes=100, # Evaluate over more episodes for final assessment
             render=True,          # Render the final plot for one episode
             render_env_index=0
         )

    # --- Optional: Save Model Weights ---
    # if trained_agent:
    #     save_path = f"sculpt_hybrid_noisy_dqn_g{GRID_SIZE_RUN}_weights.h5"
    #     try:
    #         trained_agent.model.save_weights(save_path)
    #         print(f"\nModel weights saved to {save_path}")
    #     except Exception as e:
    #         print(f"\nError saving model weights: {e}")



In [None]:
# %% imports (ensure these are imported in your notebook)
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import time
import os

# Assume BatchedSculpt3DEnvTF and BatchedDQNAgentTF classes are defined
# Assume trained_agent, GRID_SIZE_TO_RUN, MAX_STEPS_TRAIN are available
# from the previous training cell.

# %% Evaluation Function Definition

def evaluate_agent_stats(agent, grid_size, max_steps, num_test_episodes=10, render_env_index=0):
    """
    Evaluates a trained agent greedily, calculates performance statistics,
    and renders the final state of one specified environment.

    Args:
        agent: The trained BatchedDQNAgentTF object.
        grid_size: The grid size used during training.
        max_steps: The max steps per episode used during training.
        num_test_episodes: How many episodes to run for evaluation.
        render_env_index: The index of the environment (0 to num_test_episodes-1)
                           whose final state will be rendered.
    """
    print(f"\n--- Starting Agent Evaluation ---")
    print(f"Running {num_test_episodes} evaluation episodes...")

    # Create a test environment batch
    test_env = BatchedSculpt3DEnvTF(grid_size=grid_size, max_steps=max_steps, n_envs=num_test_episodes)

    # --- Get Initial State Information ---
    # We need the shape mask to know what *can* be carved.
    # Since reset() sets stock to all True, initial carvable material is just ~shape_mask.
    initial_shape_mask_flat_gpu = test_env.shape_mask[0] # Mask is same for all envs
    initial_shape_mask_flat_np = initial_shape_mask_flat_gpu.numpy()
    # Calculate initial carvable material count (where stock is True and mask is False)
    initial_carvable_mask_flat = ~initial_shape_mask_flat_np # Initially stock is True everywhere
    initial_carvable_count = np.sum(initial_carvable_mask_flat)
    print(f"Initial number of carvable voxels: {initial_carvable_count}")
    if initial_carvable_count == 0:
        print("Warning: No carvable material initially defined by the shape mask.")

    # --- Run Evaluation Episodes ---
    all_ep_rewards = []
    all_ep_lengths = []
    all_ep_removed_counts = []
    all_ep_removal_percentages = []

    # References to final state variables
    final_stock_variable_test = test_env.stock
    final_shape_mask_variable_test = test_env.shape_mask # Although constant, get reference

    for ep in range(num_test_episodes):
        obs = test_env.reset()
        done = tf.zeros([num_test_episodes], tf.bool)
        ep_rewards = tf.zeros([num_test_episodes], tf.float32)
        ep_steps = tf.zeros([num_test_episodes], tf.int32)
        current_ep_step = 0

        while not tf.reduce_all(done):
            # Act greedily (epsilon = 0)
            eps_tf_zero = tf.constant(0.0, dtype=tf.float32)
            A = agent.act_batch(obs, eps_tf_zero)
            S2, R, next_done = test_env.step(A)

            active_mask = ~done
            ep_rewards += R * tf.cast(active_mask, tf.float32)
            ep_steps += tf.cast(active_mask, tf.int32)

            obs = S2
            done = next_done
            current_ep_step += 1

        # --- Calculate Stats for this Batch ---
        final_stock_batch_np = final_stock_variable_test.numpy() # Get final stock for all envs in batch
        batch_rewards = ep_rewards.numpy()
        batch_lengths = ep_steps.numpy()

        all_ep_rewards.extend(batch_rewards.tolist())
        all_ep_lengths.extend(batch_lengths.tolist())

        # Calculate removal stats per environment
        for i in range(num_test_episodes):
            final_stock_flat_np = final_stock_batch_np[i]
            # Material that was initially carvable and is now removed (~final_stock)
            removed_mask = initial_carvable_mask_flat & (~final_stock_flat_np)
            removed_count = np.sum(removed_mask)
            all_ep_removed_counts.append(removed_count)

            if initial_carvable_count > 0:
                removal_percentage = (removed_count / initial_carvable_count) * 100.0
                all_ep_removal_percentages.append(removal_percentage)
            else:
                all_ep_removal_percentages.append(0.0) # Avoid division by zero

        print(f"Eval Episode {ep+1}/{num_test_episodes} finished. Steps: {current_ep_step}") # Removed rewards print here

    # --- Aggregate and Print Statistics ---
    if not all_ep_rewards:
        print("\nNo evaluation episodes were run.")
        return

    avg_reward = np.mean(all_ep_rewards)
    avg_length = np.mean(all_ep_lengths)
    avg_removed_count = np.mean(all_ep_removed_counts)
    avg_removal_percentage = np.mean(all_ep_removal_percentages)
    std_reward = np.std(all_ep_rewards)
    std_removal_percentage = np.std(all_ep_removal_percentages)


    print(f"\n--- Evaluation Results ({num_test_episodes} episodes) ---")
    print(f"Average Reward: {avg_reward:.2f} (+/- {std_reward:.2f})")
    print(f"Average Length: {avg_length:.1f}")
    print(f"Average Carvable Voxels Removed: {avg_removed_count:.1f} / {initial_carvable_count}")
    print(f"Average Removal Percentage: {avg_removal_percentage:.2f}% (+/- {std_removal_percentage:.2f}%)")

    # --- Render Final State ---
    print(f"\n--- Rendering final state for Eval Env Index: {render_env_index} ---")
    if render_env_index < 0 or render_env_index >= num_test_episodes:
         print(f"Error: render_env_index ({render_env_index}) out of bounds for {num_test_episodes} test environments.")
         return

    try:
        # Get final stock and shape mask for the specific env index from the *last* batch run
        final_stock_flat_np = final_stock_batch_np[render_env_index]
        # Mask is constant, can re-use from initial calculation
        # final_shape_mask_flat_np = final_shape_mask_variable_test[0].numpy()

        # Reshape to 3D grid
        G = grid_size
        final_stock_3d = final_stock_flat_np.reshape((G, G, G))
        final_shape_mask_3d = initial_shape_mask_flat_np.reshape((G, G, G))

        # Create boolean arrays for plotting:
        shape_to_plot = final_shape_mask_3d
        removed_material_to_plot = ~final_stock_3d & ~final_shape_mask_3d

        # --- Matplotlib Visualization ---
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111, projection='3d')
        x, y, z = np.indices(np.array(shape_to_plot.shape) + 1) # Edges

        # Plot the target shape (slightly transparent blue)
        ax.voxels(x, y, z, shape_to_plot, facecolors='blue', alpha=0.15, edgecolor=None)
        # Plot the removed material (more opaque red)
        ax.voxels(x, y, z, removed_material_to_plot, facecolors='red', alpha=0.5, edgecolor=None)

        # Use the reward from the specific rendered env
        rendered_env_reward = all_ep_rewards[render_env_index]
        rendered_env_perc = all_ep_removal_percentages[render_env_index]
        title_reward = f"{rendered_env_reward:.2f}"
        ax.set_title(f"Eval Render Env #{render_env_index} (R={title_reward}, Removed={rendered_env_perc:.1f}%)")

        ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
        ax.set_xlim(0, G); ax.set_ylim(0, G); ax.set_zlim(0, G)
        ax.set_aspect('auto')

        plt.tight_layout()
        # Save figure
        render_dir = "renders_eval"
        if not os.path.exists(render_dir):
            os.makedirs(render_dir)
        save_path = os.path.join(render_dir, f"eval_render_env_{render_env_index}.png")
        plt.savefig(save_path)
        print(f"Saved evaluation render to {save_path}")
        plt.close(fig) # Close figure

    except Exception as e:
        print(f"Error during rendering: {e}")
        import traceback
        traceback.print_exc()


# %% Run Evaluation (Example Call)
# Make sure trained_agent, GRID_SIZE_TO_RUN, and MAX_STEPS_TRAIN are defined
# from your training process before running this cell.

NUM_EVAL_EPISODES = 10 # Number of episodes to average over for evaluation
RENDER_INDEX = 0       # Which episode's final state to render (0 to NUM_EVAL_EPISODES-1)

if 'trained_agent' in locals() or 'trained_agent' in globals():
     evaluate_agent_stats(
         agent=trained_agent,
         grid_size=GRID_SIZE_TO_RUN,
         max_steps=MAX_STEPS_TRAIN,
         num_test_episodes=NUM_EVAL_EPISODES,
         render_env_index=RENDER_INDEX
     )
else:
     print("Variable 'trained_agent' not found. Please train the agent first.")



In [None]:
# %%capture
# # Install widgets if needed (uncomment and run once)
# !pip install ipywidgets

# %% Imports and Setup
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import numpy as np
import tensorflow as tf
import time
import os

# Assume BatchedSculpt3DEnvTF and BatchedDQNAgentTF classes are defined above
# Assume trained_agent, GRID_SIZE_TO_RUN, MAX_STEPS_TRAIN are available

print("Imports and setup complete.")

# %% --- Helper Functions for Predefined Shapes ---
# These functions create the 1D flat boolean mask for different shapes

def create_sphere_mask(grid_size):
    """Generates a flat boolean mask for a centered sphere."""
    G = grid_size
    coords_range = tf.range(G, dtype=tf.float32)
    coords = tf.stack(tf.meshgrid(
        coords_range, coords_range, coords_range, indexing='ij'
    ), axis=-1)
    center = tf.constant([G/2 - 0.5, G/2 - 0.5, G/2 - 0.5], tf.float32)
    dist2 = tf.reduce_sum(tf.square(coords - center), axis=-1)
    radius_sq = tf.square(tf.cast(G // 2 - 1, tf.float32)) # Sphere radius G/2 - 1
    mask3d = dist2 <= radius_sq
    mask_flat = tf.reshape(mask3d, [-1])
    return mask_flat.numpy() # Return as numpy array

def create_cube_mask(grid_size, cube_ratio=0.5):
    """Generates a flat boolean mask for a centered cube."""
    G = grid_size
    cube_size = int(G * cube_ratio)
    start = (G - cube_size) // 2
    end = start + cube_size
    mask3d = np.zeros((G, G, G), dtype=bool)
    mask3d[start:end, start:end, start:end] = True
    mask_flat = mask3d.reshape([-1])
    return mask_flat

# Add more shape functions here if desired (e.g., cylinder, multiple objects)

# %% --- Simulation, Statistics, and Plotting Function ---

def run_simulation_and_display(selected_shape_name, grid_size, max_steps, agent):
    """
    Runs the agent on a selected shape, calculates stats, and displays results.
    """
    global output_widget # Use the globally defined output widget

    # --- 1. Clear Output & Generate Shape Mask ---
    with output_widget:
        clear_output(wait=True) # Clear previous results
        print(f"Selected Target Shape: {selected_shape_name}")
        print(f"Grid Size: {grid_size}x{grid_size}x{grid_size}, Max Steps: {max_steps}")
        print("Generating shape mask...")
        time.sleep(0.1) # Tiny pause for UI update

        if selected_shape_name == 'Sphere':
            shape_mask_flat = create_sphere_mask(grid_size)
        elif selected_shape_name == 'Cube':
            shape_mask_flat = create_cube_mask(grid_size, cube_ratio=0.6) # Example: 60% cube
        # Add elif for other shapes here
        else:
             print(f"Error: Shape generation function for '{selected_shape_name}' not found.")
             return

        try:
            shape_mask_3d = shape_mask_flat.reshape((grid_size, grid_size, grid_size))
            print("Shape mask generated.")
        except ValueError as e:
            print(f"Error reshaping mask: {e}")
            return

    # --- 2. Setup Environment (n_envs=1) ---
    # Create a new environment instance for this specific run
    # This ensures state isolation but might incur tf.function retracing on first run per shape
    try:
        # print("DEBUG: Creating test environment...") # Temporary debug
        test_env = BatchedSculpt3DEnvTF(grid_size=grid_size, max_steps=max_steps, n_envs=1)
        # Set the custom shape mask *after* initialization
        # This assumes the shape_mask variable exists and can be assigned to.
        test_env.shape_mask.assign(tf.tile(tf.constant(shape_mask_flat, dtype=tf.bool)[None,:], [1, 1]))
        # print("DEBUG: Test environment created and mask assigned.") # Temporary debug
    except Exception as e:
         with output_widget:
              print(f"Error creating environment: {e}")
              import traceback
              traceback.print_exc()
         return

    # --- 3. Run Greedy Episode & Track Path ---
    with output_widget:
        print("Running agent simulation (greedy)...")
        time.sleep(0.1)

    sim_start_time = time.time()
    obs = test_env.reset()
    done = tf.zeros([1], tf.bool) # Batch size is 1
    path_coords = []
    total_reward = 0.0
    steps_taken = 0
    eps_tf_zero = tf.constant(0.0, dtype=tf.float32)

    # Store initial position for path plotting
    try:
        initial_pos_flat = test_env.pos.numpy()[0]
        x_i = initial_pos_flat // (grid_size*grid_size)
        y_i = (initial_pos_flat // grid_size) % grid_size
        z_i = initial_pos_flat % grid_size
        path_coords.append([x_i, y_i, z_i])
    except Exception as e:
        with output_widget:
            print(f"Error getting initial position: {e}")
        # Continue without initial point if needed

    # Simulation loop
    while not tf.reduce_all(done):
        if steps_taken >= max_steps: # Safety break
             print(f"Warning: Reached max_steps ({max_steps}) during simulation loop.")
             break
        try:
            A = agent.act_batch(obs, eps_tf_zero)
            S2, R, next_done = test_env.step(A)

            # Get current position for path tracking (from observation S2)
            current_pos_xyz = S2.numpy()[0, :3] # Get numpy array for the single env [x,y,z]
            path_coords.append(current_pos_xyz.tolist())

            total_reward += R.numpy()[0] # Get reward for the single env
            steps_taken += 1
            obs = S2
            done = next_done

        except Exception as e:
            with output_widget:
                print(f"\nError during simulation step {steps_taken}: {e}")
                import traceback
                traceback.print_exc()
            return # Stop simulation on error

    sim_duration = time.time() - sim_start_time
    with output_widget:
        print(f"Simulation finished in {sim_duration:.2f} seconds.")

    # --- 4. Get Final State ---
    try:
        final_stock_flat = test_env.stock.numpy()[0] # Get stock for the single env
        final_stock_3d = final_stock_flat.reshape((grid_size, grid_size, grid_size))
    except Exception as e:
         with output_widget:
              print(f"Error getting final stock state: {e}")
         return

    # --- 5. Calculate Statistics ---
    initial_carvable_mask = ~shape_mask_3d # Stock is initially True everywhere outside mask
    initial_carvable_count = np.sum(initial_carvable_mask)

    removed_mask = initial_carvable_mask & (~final_stock_3d) # Initially carvable AND now removed
    removed_count = np.sum(removed_mask)

    incorrectly_removed_mask = shape_mask_3d & (~final_stock_3d) # Target shape AND now removed (error)
    incorrectly_removed_count = np.sum(incorrectly_removed_mask)

    removal_percentage = (removed_count / initial_carvable_count) * 100.0 if initial_carvable_count > 0 else 0.0

    # --- 6. Plotting ---
    with output_widget:
        print("\n--- Simulation Results ---")
        print(f"  Total Reward: {total_reward:.2f}")
        print(f"  Steps Taken: {steps_taken}")
        print(f"  Carvable Voxels Removed: {removed_count} / {initial_carvable_count} ({removal_percentage:.1f}%)")
        print(f"  Target Voxels Incorrectly Removed: {incorrectly_removed_count}")
        print("\nPlotting final state and path...")

        try:
            fig = plt.figure(figsize=(10, 8)) # Slightly wider figure
            ax = fig.add_subplot(111, projection='3d')
            ax.set_facecolor('whitesmoke') # Nicer background

            # Voxel coordinates (for edges)
            x_vox, y_vox, z_vox = np.indices(np.array(shape_mask_3d.shape) + 1)

            # Plot target shape (transparent blue) - Plotting only surface might be faster
            ax.voxels(x_vox, y_vox, z_vox, shape_mask_3d, facecolors='blue', alpha=0.1, edgecolor=None)

            # Plot removed material (red) - Only where originally carvable
            ax.voxels(x_vox, y_vox, z_vox, removed_mask, facecolors='red', alpha=0.6, edgecolor=None)

            # Plot incorrectly removed material (target shape voxels removed - yellow/warning)
            if incorrectly_removed_count > 0:
                 ax.voxels(x_vox, y_vox, z_vox, incorrectly_removed_mask, facecolors='yellow', alpha=0.7, edgecolor='orange')


            # Plot path
            if path_coords and len(path_coords) > 1:
                path_np = np.array(path_coords)
                # Offset path slightly for visibility if needed, or keep as is
                ax.plot(path_np[:, 0]+0.5, path_np[:, 1]+0.5, path_np[:, 2]+0.5, color='green', linewidth=1.5, label='Agent Path', alpha=0.8)
                # Start/End markers
                ax.scatter(path_np[0, 0]+0.5, path_np[0, 1]+0.5, path_np[0, 2]+0.5, color='lime', s=80, edgecolor='black', depthshade=False, label='Start', alpha=1.0)
                ax.scatter(path_np[-1, 0]+0.5, path_np[-1, 1]+0.5, path_np[-1, 2]+0.5, color='magenta', s=80, edgecolor='black', depthshade=False, label='End', alpha=1.0)
                ax.legend()
            elif path_coords: # Only start point
                 ax.scatter(path_coords[0][0]+0.5, path_coords[0][1]+0.5, path_coords[0][2]+0.5, color='lime', s=80, edgecolor='black', depthshade=False, label='Start', alpha=1.0)
                 ax.legend()


            ax.set_title(f"Agent Result: '{selected_shape_name}' | R={total_reward:.1f} | Removed={removal_percentage:.1f}%")
            ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
            ax.set_xlim(0, grid_size); ax.set_ylim(0, grid_size); ax.set_zlim(0, grid_size)
            # Setting aspect ratio can be tricky in 3D, 'auto' is often best default
            ax.set_aspect('auto') # Try 'equal' if axes look distorted, but might fail

            plt.tight_layout()
            plt.show() # Display plot directly in Colab output cell
            print("Plotting complete.")

        except Exception as e:
            print(f"\nError during plotting: {e}")
            import traceback
            traceback.print_exc()


# %% --- Create and Display Widgets ---

# Check if agent variable exists (replace 'trained_agent' if yours has a different name)
agent_variable_name = 'trained_agent' # <<< MAKE SURE THIS MATCHES YOUR TRAINED AGENT VARIABLE
if agent_variable_name not in locals() and agent_variable_name not in globals():
     print(f"ERROR: Trained agent variable '{agent_variable_name}' not found.")
     print("Please ensure the agent is trained and the variable is available before running this cell.")
     # You might want to stop execution here in a real notebook
     # raise NameError(f"Variable '{agent_variable_name}' not defined.")
else:
    print("Trained agent found.")
    # Define Widgets
    shape_selector = widgets.Dropdown(
        options=['Sphere', 'Cube'], # Add more names if functions are defined above
        value='Sphere',
        description='Target Shape:',
        style={'description_width': 'initial'}, # Prevent label truncation
        disabled=False,
    )

    run_button = widgets.Button(
        description='Run Simulation & Render',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Run the trained agent on the selected shape and display results',
        icon='cube' # Example icon
    )

    # Output widget to capture prints and plots
    output_widget = widgets.Output()

    # Define button click handler
    def on_run_button_clicked(button_instance):
        # Disable button during run
        run_button.disabled = True
        run_button.icon = 'spinner'
        shape_name = shape_selector.value
        agent_obj = globals().get(agent_variable_name) or locals().get(agent_variable_name)

        if agent_obj is None:
             with output_widget:
                  clear_output(wait=True)
                  print(f"Error: '{agent_variable_name}' not found.")
             run_button.disabled = False # Re-enable button
             run_button.icon = 'cube'
             return

        # Run the simulation function
        run_simulation_and_display(
            selected_shape_name=shape_name,
            grid_size=GRID_SIZE_TO_RUN,     # Assumes this exists from training
            max_steps=MAX_STEPS_TRAIN,      # Assumes this exists from training
            agent=agent_obj
        )
        # Re-enable button after run
        run_button.disabled = False
        run_button.icon = 'cube'


    run_button.on_click(on_run_button_clicked)

    # Display the UI Layout
    print("\n--- Agent Evaluation Interface ---")
    ui_box = widgets.VBox([
        widgets.HTML("<h3>Select a target shape and run the simulation:</h3>"),
        shape_selector,
        run_button,
        widgets.HTML("<hr><h4>Results:</h4>"), # Separator
        output_widget
    ])
    display(ui_box)