In [None]:
#!/usr/bin/env python3
"""
Fully TensorFlow‑Accelerated Sculpt3DEnv + DQN Agent
• Environment ops in @tf.function
• All tensors stay in TF until logging, where we convert to Python floats/arrays
• Extensive TensorBoard logging
• Periodic 3D rendering of path + object

Usage:
  python sculpt_tf_dqn.py
  tensorboard --logdir=runs
"""

import tensorflow as tf
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from collections import deque

# --- 1) Sculpt3DEnvTF ---
class Sculpt3DEnvTF:
    def __init__(self, grid_size=20, max_steps=200):
        self.grid_size = grid_size
        self.max_steps = max_steps

        # stock & shape on CPU to leverage gather_nd there
        with tf.device('/CPU:0'):
            self.stock = tf.Variable(tf.ones([grid_size,grid_size,grid_size], dtype=tf.bool), trainable=False)
            coords = tf.stack(tf.meshgrid(
                tf.range(grid_size), tf.range(grid_size), tf.range(grid_size),
                indexing='ij'), axis=-1)
            center = tf.constant([grid_size//2]*3, dtype=tf.int32)
            dist2 = tf.reduce_sum(tf.square(tf.cast(coords - center, tf.int32)), axis=-1)
            r = tf.cast(grid_size//2 - 1, tf.int32)
            self.shape = tf.Variable(dist2 <= r*r, trainable=False)

        # router & control on default device (GPU)
        self.router_pos = tf.Variable([0,0,0], dtype=tf.int32)
        self.steps      = tf.Variable(0,          dtype=tf.int32)
        self.done       = tf.Variable(False,      dtype=tf.bool)

        # Pre‑defined moves
        self.moves = tf.constant(
            [[1,0,0],[-1,0,0],[0,1,0],[0,-1,0],[0,0,1],[0,0,-1]],
            dtype=tf.int32
        )

    def reset(self):
        # reset stock on CPU
        with tf.device('/CPU:0'):
            self.stock.assign(tf.ones_like(self.stock))

        self.steps.assign(0)
        self.done.assign(False)

        # pick random start outside protected shape
        flat_mask = (~self.shape).numpy().reshape(-1)
        choices = np.nonzero(flat_mask)[0]
        idx = random.choice(choices)
        z = idx % self.grid_size
        y = (idx // self.grid_size) % self.grid_size
        x = idx // (self.grid_size * self.grid_size)
        self.router_pos.assign([x,y,z])

        return self._get_obs()

    @tf.function
    def step(self, action):
        move   = tf.gather(self.moves, action)
        newpos = self.router_pos + move
        inb    = tf.reduce_all((newpos>=0)&(newpos<self.grid_size))

        def _oob():
            self.done.assign(True)
            return tf.constant(-5.0)
        def _carve():
            idxs = tf.stack([self.router_pos, newpos], axis=0)
            with tf.device('/CPU:0'):
                shape_vals = tf.gather_nd(self.shape, idxs)
            def _hit():
                self.done.assign(True)
                return tf.constant(-5.0)
            def _remove():
                with tf.device('/CPU:0'):
                    stock_vals = tf.gather_nd(self.stock, idxs)
                    updated    = tf.tensor_scatter_nd_update(self.stock, idxs, tf.zeros([2],dtype=tf.bool))
                    self.stock.assign(updated)
                self.router_pos.assign(newpos)
                return tf.reduce_sum(tf.cast(stock_vals, tf.float32))
            return tf.cond(tf.reduce_any(shape_vals), _hit, _remove)

        reward = tf.cond(inb, _carve, _oob)
        self.steps.assign_add(1)
        reward = reward - 0.1
        self.done.assign(tf.logical_or(self.done, self.steps>=self.max_steps))

        return self._get_obs(), reward, self.done

    @tf.function
    def _get_obs(self):
        rx,ry,rz = tf.unstack(self.router_pos)
        c = self.grid_size // 2
        return tf.cast(tf.stack([rx,ry,rz, c,c,c]), tf.float32)


# --- 2) TF Replay Buffer ---
class ReplayBufferTF:
    def __init__(self, capacity=10000):
        self.buf = []
        self.cap = capacity

    def add(self, s,a,r,ns,d):
        if len(self.buf) >= self.cap:
            self.buf.pop(0)
        self.buf.append((s,a,r,ns,d))

    def sample(self, bs):
        batch = random.sample(self.buf, bs)
        s,a,r,ns,d = zip(*batch)
        return (
            tf.stack(s),
            tf.convert_to_tensor(a,  tf.int32),
            tf.convert_to_tensor(r,  tf.float32),
            tf.stack(ns),
            tf.convert_to_tensor(d,  tf.bool)
        )

    def __len__(self):
        return len(self.buf)


# --- 3) TF DQN Agent with Extensive Logging ---
class DQNAgentTF:
    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99, tau=0.01):
        self.gamma, self.tau = gamma, tau

        # build model & target
        inputs = tf.keras.Input(shape=(state_dim,))
        x = tf.keras.layers.Dense(128, activation='relu')(inputs)
        x = tf.keras.layers.Dense(128, activation='relu')(x)
        outputs = tf.keras.layers.Dense(action_dim)(x)
        self.model  = tf.keras.Model(inputs, outputs)
        self.target = tf.keras.models.clone_model(self.model)
        self.target.set_weights(self.model.get_weights())

        self.optimizer = tf.keras.optimizers.Adam(lr)
        self.buffer    = ReplayBufferTF()

        # TensorBoard
        logdir = f"runs/tf_sculpt_full_{datetime.datetime.now():%Y%m%d_%H%M%S}"
        self.writer = SummaryWriter(logdir)
        self.step   = 0

    @tf.function
    def train_step(self, states, actions, rewards, next_states, dones):
        with tf.GradientTape() as tape:
            q     = self.model(states, training=True)
            q_sa  = tf.reduce_sum(q * tf.one_hot(actions, tf.shape(q)[1]), axis=1)
            qn    = self.target(next_states, training=False)
            max_n = tf.reduce_max(qn, axis=1)
            target= rewards + self.gamma * max_n * (1 - tf.cast(dones,tf.float32))
            loss  = tf.reduce_mean(tf.square(q_sa - target))

        grads = tape.gradient(loss, self.model.trainable_variables)
        grad_norm = tf.linalg.global_norm(grads)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

        # soft update
        for w, tw in zip(self.model.weights, self.target.weights):
            tw.assign(self.tau*w + (1-self.tau)*tw)

        return loss, grad_norm

    def remember(self, *args):
        self.buffer.add(*args)

    def act(self, state, eps=0.1):
        if random.random() < eps:
            return random.randrange(self.model.output_shape[-1])
        qv = self.model(tf.expand_dims(state,0), training=False)[0]
        noise = tf.random.normal(tf.shape(qv), stddev=eps)
        return int(tf.argmax(qv + noise).numpy())

    def learn(self, batch_size=64):
        if len(self.buffer) < batch_size:
            return
        s,a,r,ns,d = self.buffer.sample(batch_size)
        loss, grad_norm = self.train_step(s,a,r,ns,d)

        # log scalars
        self.step += 1
        self.writer.add_scalar("Train/Loss",        loss.numpy(),      self.step)
        self.writer.add_scalar("Train/GradNorm",    grad_norm.numpy(), self.step)
        #self.writer.add_scalar("Episode/LearningRate",
        #               float(agent.optimizer.learning_rate.numpy()), ep)

        # log histograms
        for var in self.model.trainable_variables:
            self.writer.add_histogram(var.name.replace(':','_'),
                                      var.numpy(), self.step)


# --- 4) Training Loop with Episode‐level Logging & Rendering ---
def train_tf(env, episodes=500, eps_start=1.0, eps_end=0.05, eps_decay=0.995,
             render_every=100):
    state_dim, action_dim = 6, 6
    agent = DQNAgentTF(state_dim, action_dim)
    eps   = eps_start
    recent_rewards = deque(maxlen=20)

    for ep in range(1, episodes+1):
        state = env.reset()
        done  = False
        total = 0.0
        path  = []

        while not done:
            path.append(state[:3].numpy().tolist())
            a = agent.act(state, eps)
            ns, r, done = env.step(a)
            agent.remember(state, a, r, ns, done)
            agent.learn()
            state, total = ns, total + r

        recent_rewards.append(total)
        avg20 = float(np.mean(recent_rewards))

        # episode scalars
        agent.writer.add_scalar("Episode/Reward",     float(total), ep)
        agent.writer.add_scalar("Episode/Avg20Reward", avg20,        ep)
        agent.writer.add_scalar("Episode/Epsilon",     eps,          ep)
        agent.writer.add_scalar("Episode/LearningRate",
                       float(agent.optimizer.learning_rate.numpy()), ep)

        # periodic 3D render
        if ep % render_every == 0:
            fig = plt.figure()
            ax  = fig.add_subplot(111, projection='3d')
            xs, ys, zs = zip(*path)
            ax.plot(xs, ys, zs, '-o', label=f"Ep{ep} R={total:.2f}")

            # plot protected shape
            mask = env.shape.numpy()
            pts = np.argwhere(mask)
            ax.scatter(pts[:,0], pts[:,1], pts[:,2], s=1, alpha=0.1, color='red')

            ax.set_title(f"Episode {ep}")
            ax.legend()
            plt.show()

        eps = max(eps_end, eps * eps_decay)
        print(f"Ep {ep}/{episodes}  R={total:.2f}  Avg20={avg20:.2f}  Eps={eps:.3f}")

    agent.writer.close()
    return agent


if __name__ == "__main__":
    env = Sculpt3DEnvTF(grid_size=20, max_steps=200)
    train_tf(env, episodes=1000, render_every=100)