In [1]:
# =====================================
# 🧩 BLOCCO 1 — FIX PATH IMPORTS + LOGGER SETUP
# =====================================
import sys, os
from pathlib import Path
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.logging import RichHandler
import logging

# === Rich console e logging ===
console = Console(record=True)
LOG_DIR = Path("artifacts/logs")
LOG_DIR.mkdir(parents=True, exist_ok=True)

# configura il logging Python con Rich
logging.basicConfig(
    level="INFO",
    format="%(message)s",
    datefmt="[%X]",
    handlers=[
        RichHandler(console=console, markup=True),
        logging.FileHandler(LOG_DIR / "rich_log.txt", encoding="utf-8"),
    ],
)
log = logging.getLogger("snake_rl")

# === PATH SETUP ===
BASE_DIR = Path().resolve()
ENV_DIR = BASE_DIR / "envs"
ENV_DIR.mkdir(exist_ok=True)

if str(BASE_DIR) not in sys.path:
    sys.path.append(str(BASE_DIR))

log.info(f"🔧 [bold green]Path aggiunto:[/bold green] {BASE_DIR}")

# Tabellina riassuntiva
table = Table(title="📁 Directory di lavoro")
table.add_column("Chiave", style="cyan", no_wrap=True)
table.add_column("Percorso", style="magenta")
table.add_row("BASE_DIR", str(BASE_DIR))
table.add_row("ENV_DIR", str(ENV_DIR))
table.add_row("LOG_DIR", str(LOG_DIR))
console.print(Panel(table, title="[bold yellow]Setup Path Completato[/bold yellow]", expand=False))


In [2]:
# =====================================
# 🐍 SNAKE ENVIRONMENT — versione configurabile con Rich logging
# =====================================
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from rich.console import Console

console = Console()


class SnakeEnv(gym.Env):
    metadata = {"render_modes": ["rgb_array"], "render_fps": 10}

    def __init__(self, size=10, render_mode=None, verbose=False, reward_cfg=None):
        super().__init__()
        self.size = size
        self.render_mode = render_mode
        self.verbose = verbose

        # === CONFIGURAZIONE REWARD ===
        default_rewards = {
            # Movimento base: leggero incentivo a risolvere il task velocemente
            "step_penalty": -0.02,
        
            # Direzione relativa rispetto alla mela (dot-product shaping)
            "good_direction": +0.05,
            "bad_direction": -0.02,
        
            # Ricompensa mela e bonus streak molto leggero
            "apple_reward": +1.0,
            "streak_bonus": 0.03,
        
            # Penalità terminale (muro o corpo)
            "death_penalty": -1.0,
        
            # Anti-idle: chiude episodi senza progresso
            "idle_penalty": -0.3,
            "idle_steps_factor": 4,  # max_idle = size * 4
        
            # NON utilizzare più: lasciati a 0 per sicurezza
            "wall_penalty": 0.0,
            "self_penalty": 0.0,
            "timeout_penalty": 0.0,
            "approach_bonus": 0.0,
            "retreat_penalty": 0.0,
            "combo_bonus": 0.0,
}

        # unisci default + eventuali override passati
        self.rewards = {**default_rewards, **(reward_cfg or {})}

        # === SPAZI GYM ===
        self.action_space = spaces.Discrete(4)   # 0:UP, 1:RIGHT, 2:DOWN, 3:LEFT
        self.observation_space = spaces.Box(0, 1, (4, size, size), np.float32)
        self._dirs = np.array([[0, -1], [1, 0], [0, 1], [-1, 0]])  # U,R,D,L

        self.reset()

    # ------------------------------
    # Utility base
    # ------------------------------
    def _get_obs(self):
        obs = np.zeros((4, self.size, self.size), dtype=np.float32)
        head = self.snake[0]
        obs[0, head[0], head[1]] = 1.0
        for x, y in self.snake[1:]:
            obs[1, x, y] = 1.0
        obs[2, self.apple[0], self.apple[1]] = 1.0
        obs[3, head[0], head[1]] = (self.dir_idx + 1) / 4
        return obs

    def _spawn_apple(self):
        free = set((x, y) for x in range(self.size) for y in range(self.size)) - set(map(tuple, self.snake))
        self.apple = np.array(list(free)[np.random.randint(len(free))])
        if self.verbose:
            console.log(f"🍎 [bold green]Nuova mela a:[/bold green] {tuple(self.apple)}")

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        mid = self.size // 2
        self.snake = [np.array([mid, mid])]
        self.dir_idx = 1  # start right
        self._spawn_apple()
        self.steps = 0
        self.apples = 0
        self.steps_since_last_apple = 0

        if self.verbose:
            console.log(f"🔄 [cyan]Reset ambiente Snake[/cyan] (size={self.size})")
        return self._get_obs(), {}

    # ------------------------------
    # STEP CON REWARD SHAPING CONFIGURABILE
    # ------------------------------
        # ------------------------------
    # STEP CON REWARD SHAPING DEFINITIVO (PPO-OPTIMIZED)
    # ------------------------------
    def step(self, action):
        # evita inversioni a 180°
        if (action + 2) % 4 == self.dir_idx:
            action = self.dir_idx
        self.dir_idx = action

        prev_head = self.snake[0]
        head = prev_head + self._dirs[self.dir_idx]
        self.steps += 1
        self.steps_since_last_apple += 1

        # =====================================================
        # === COLLISIONI (episodio TERMINALE) ================
        # =====================================================
        hit_wall = (head < 0).any() or (head >= self.size).any()
        hit_self = any((head == s).all() for s in self.snake)

        if hit_wall or hit_self:
            reward = self.rewards.get("death_penalty", -1.0)
            cause = "wall" if hit_wall else "self"

            if self.verbose:
                console.log(
                    f"💥 [red]Morte per collisione {cause}[/red] | passi={self.steps}, mele={self.apples}"
                )

            info = {
                "episode": {"r": reward, "l": self.steps},
                "apples": self.apples,
                "cause": cause,
            }
            return self._get_obs(), reward, True, False, info

        # =====================================================
        # === REWARD SHAPING BASE =============================
        # =====================================================
        reward = 0.0

        reward += self.rewards.get("step_penalty", -0.02)

        # === Direzione (dot product) ===
        vec_to_apple = self.apple - prev_head
        vec_move = head - prev_head
        dot = np.dot(vec_to_apple, vec_move)

        if dot > 0:
            reward += self.rewards.get("good_direction", +0.05)
        else:
            reward += self.rewards.get("bad_direction", -0.02)

        # =====================================================
        # === MELA MANGIATA ==================================
        # =====================================================
        if (head == self.apple).all():
            self.snake.insert(0, head)
            self._spawn_apple()
            self.apples += 1
            self.steps_since_last_apple = 0

            reward += self.rewards.get("apple_reward", 1.0)
            reward += self.apples * self.rewards.get("streak_bonus", 0.03)

            if self.verbose:
                console.log(
                    f"🍏 [bold green]Mela mangiata![/bold green] Totale={self.apples}"
                )
        else:
            self.snake.insert(0, head)
            self.snake.pop()

        # =====================================================
        # === IDLE TIMEOUT (episodio TERMINALE) ===============
        # =====================================================
        max_idle = self.size * self.rewards.get("idle_steps_factor", 4)

        if self.steps_since_last_apple > max_idle:
            reward += self.rewards.get("idle_penalty", -0.3)

            if self.verbose:
                console.log(
                    f"⏳ [yellow]Idle timeout[/yellow]: nessuna mela trovata da {self.steps_since_last_apple} step"
                )

            info = {
                "episode": {"r": reward, "l": self.steps},
                "apples": self.apples,
                "cause": "idle_timeout",
            }
            return self._get_obs(), reward, True, False, info

        # =====================================================
        # === STEP NORMALE (episodio NON terminale) ===========
        # =====================================================
        info = {
            "apples": self.apples,
            "dot": float(dot),
        }
        return self._get_obs(), float(reward), False, False, info



    # ------------------------------
    # Render grafico
    # ------------------------------
    def render(self):
        grid = np.zeros((self.size, self.size, 3), dtype=np.uint8)
        grid[self.apple[0], self.apple[1]] = (200, 30, 30)
        for (x, y) in self.snake[1:]:
            grid[x, y] = (40, 160, 40)
        hx, hy = self.snake[0]
        grid[hx, hy] = (20, 220, 20)
        return np.kron(grid, np.ones((16, 16, 1), dtype=np.uint8))


In [3]:
from stable_baselines3.common.callbacks import BaseCallback
import json
from pathlib import Path
from collections import defaultdict

class TrajectoryCallbackByApples(BaseCallback):
    def __init__(
        self,
        save_dir: Path,
        grid_size: int,
        max_per_class: int = 2,   # <-- 2 episodi per ogni numero di mele
        max_apples: int = 200     # <-- limite massimo (solo per sicurezza)
    ):
        super().__init__()
        self.save_dir = Path(save_dir)
        self.grid_size = grid_size
        self.max_per_class = max_per_class
        self.max_apples = max_apples

        # contatore episodi salvati per mela
        self.class_counts = defaultdict(int)

        # lista step episodio corrente
        self.current = []

        self.save_dir.mkdir(parents=True, exist_ok=True)

    def _unwrap(self, env):
        while hasattr(env, "env"):
            env = env.env
        return env

    def _on_step(self):
        env_raw = self._unwrap(self.locals["env"].envs[0])
        info = self.locals["infos"][0]
        reward = float(self.locals["rewards"][0])
        step_n = env_raw.steps

        # registra ogni step
        self.current.append({
            "step": step_n,
            "snake": [s.tolist() for s in env_raw.snake],
            "apple": env_raw.apple.tolist(),
            "reward": reward,
            "event": None
        })

        # episodio terminato?
        if "episode" in info:
            apples = info.get("apples", 0)

            # salva SOLO se non abbiamo raggiunto 2 episodi per quella mela
            if apples <= self.max_apples and self.class_counts[apples] < self.max_per_class:

                # marca ultimo evento come terminale
                self.current[-1]["event"] = info.get("cause", "end")

                data = {
                    "episode_idx": self.class_counts[apples],
                    "apples": apples,
                    "grid_size": self.grid_size,
                    "steps": self.current,
                    "end": {
                        "cause": info.get("cause", None),
                        "total_steps": int(info["episode"]["l"]),
                        "total_apples": int(apples),
                    }
                }

                # naming coerente
                out_path = self.save_dir / f"apples_{apples:03d}_ep_{self.class_counts[apples]:02d}.json"
                with open(out_path, "w") as f:
                    json.dump(data, f, indent=2)

                self.class_counts[apples] += 1

            # reset episodio
            self.current = []

        return True


In [4]:
# =====================================
# 🧩 BLOCCO TEST AMBIENTE SNAKE
# =====================================
from rich.console import Console
from rich.table import Table
from time import sleep

console = Console()

env = SnakeEnv(size=8, verbose=True)  # usa verbose=True per logging interno
obs, _ = env.reset()

table = Table(title="🐍 Test Ambiente Snake", show_header=True, header_style="bold cyan")
table.add_column("Step", justify="right")
table.add_column("Action", justify="center")
table.add_column("Reward", justify="center")
table.add_column("Apples", justify="center")
table.add_column("Terminated", justify="center")

for step in range(1, 6):
    action = env.action_space.sample()
    obs, r, term, trunc, info = env.step(action)
    table.add_row(str(step), str(action), f"{r:.2f}", str(info["apples"]), str(term or trunc))

    sleep(0.2)  # solo per vedere i log scorrere lentamente (puoi rimuoverlo)
    if term or trunc:
        console.log("[red bold]⚠️ Episodio terminato prematuramente![/red bold]")
        break

console.print(table)


In [5]:
# =====================================
# ⚙️ BLOCCO 3 — CONFIGURAZIONE ESPERIMENTO (Rich-enhanced)
# =====================================
from pathlib import Path
import torch
from rich.console import Console
from rich.table import Table
from rich.panel import Panel

console = Console()

# ==== IDENTITÀ ESPERIMENTO ====
run_name = "ppo_snake_v1"
project_name = "snake_rl"
algo = "PPO"

# ==== SEED & DEVICE ====
seed = 123
device = "cuda" if torch.cuda.is_available() else "cpu"
console.log(f"🧠 [bold cyan]Device:[/bold cyan] {device}")

# ==== AMBIENTE ====
env_cfg = {
    "grid_size": 15,
    "frame_stack": 4,
    "max_steps": 400,
    "num_envs": 8,      # ⬅️ AGGIUNTO
}


# ==== TRAINING ====
train_cfg = {
    "total_timesteps": 5_000_000,
    "eval_freq": 20_000,
    "n_eval_episodes": 10,
    "save_freq": 100_000,
}

# ==== IPERPARAMETRI PPO ====
ppo_cfg = {
    "learning_rate": 2.5e-4,
    "n_steps": 2048,
    "batch_size": 512,
    "n_epochs": 4,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_range": 0.1,
    "vf_coef": 0.5,
    "ent_coef": 0.01,
    "max_grad_norm": 0.5,
}

# ==== POLICY ====
policy_cfg = {
    "policy": "MlpPolicy",
    "net_arch": [dict(pi=[256, 256], vf=[256, 256])],
}

# ==== LOGGING ====
base_dir = Path("artifacts")
paths = {
    "logs": base_dir / "logs" / run_name,
    "checkpoints": base_dir / "checkpoints" / run_name,
    "trajectories": base_dir / "trajectories" / run_name,
    "videos": base_dir / "videos" / run_name,
}
for p in paths.values():
    p.mkdir(parents=True, exist_ok=True)

logging_cfg = {
    "tensorboard": True,
    "csv": True,
    "json": True,
    "episode_log": True,
    "trajectory_log": True,
    "episode_log_freq": 1,
    "record_videos_every": 500_000,
}

# ==== CONFIGURAZIONE COMPLETA ====
config = {
    "run_name": run_name,
    "algo": algo,
    "seed": seed,
    "device": device,
    "env": env_cfg,
    "train": train_cfg,
    "ppo": ppo_cfg,
    "policy": policy_cfg,
    "paths": paths,
    "logging": logging_cfg,
}

# ==== VISUALIZZA CONFIG CON RICH ====
def show_dict_as_table(title, d):
    t = Table(title=title, show_header=True, header_style="bold magenta")
    t.add_column("Chiave", style="cyan", no_wrap=True)
    t.add_column("Valore", style="green")
    for k, v in d.items():
        t.add_row(str(k), str(v))
    console.print(t)

console.rule(f"[bold yellow]🏁 Configurazione Esperimento — {run_name}[/bold yellow]")
show_dict_as_table("🧩 ENV", env_cfg)
show_dict_as_table("🎯 TRAIN", train_cfg)
show_dict_as_table("🧠 PPO", ppo_cfg)
show_dict_as_table("🏗️ POLICY", policy_cfg)
show_dict_as_table("🗂️ PATHS", {k: str(v) for k,v in paths.items()})

console.print(Panel.fit(
    f"[bold green]Esperimento pronto:[/bold green] [cyan]{run_name}[/cyan]\n"
    f"Algoritmo: [magenta]{algo}[/magenta] • Seed: [yellow]{seed}[/yellow]",
    title="✅ Setup Completato",
    border_style="green"
))


In [6]:
# =====================================
# 🧩 BLOCCO 2 — SETUP ENV + LOGGER + CALLBACK
# =====================================
import os
import json
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import (
    EvalCallback,
    CheckpointCallback,
    CallbackList,
)
from stable_baselines3.common.logger import (
    Logger,
    JSONOutputFormat,
    CSVOutputFormat,
    TensorBoardOutputFormat,
)

# === 2.1 Ambiente ===

def make_env(seed: int, grid_size: int):
    """
    Factory per creare una nuova istanza di SnakeEnv wrappata con Monitor.
    Monitor:
      - traccia reward/lunghezza episodica
      - inserisce info['episode'] per SB3 (r, l)
    """
    def _init():
        env = SnakeEnv(size=grid_size, render_mode="rgb_array")
        env = Monitor(env)          # ✅ fondamentale per avere info["episode"]
        env.reset(seed=seed)
        return env
    return _init

# Env di training e di evaluation (seed diversi, stessa configurazione)
from stable_baselines3.common.vec_env import DummyVecEnv

num_envs = 8
env_train = DummyVecEnv([
    make_env(config["seed"] + i, config["env"]["grid_size"])
    for i in range(num_envs)
])

env_eval  = DummyVecEnv([make_env(config["seed"] + 1, config["env"]["grid_size"])])

print("✅ Environment initialized — grid:", config["env"]["grid_size"])


# === 2.2 Logger custom ===

def make_logger(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    formats = []
    if config["logging"]["json"]:
        formats.append(JSONOutputFormat(os.path.join(log_dir, "progress.json")))
    if config["logging"]["csv"]:
        formats.append(CSVOutputFormat(os.path.join(log_dir, "progress.csv")))
    if config["logging"]["tensorboard"]:
        formats.append(TensorBoardOutputFormat(log_dir))
    return Logger(folder=log_dir, output_formats=formats)

logger = make_logger(config["paths"]["logs"])
print("✅ Logger set up in:", config["paths"]["logs"])


# === 2.3 Callback: Checkpoint + Eval ===

checkpoint_cb = CheckpointCallback(
    save_freq=config["train"]["save_freq"] // env_train.num_envs,
    save_path=config["paths"]["checkpoints"],
    name_prefix="ppo_snake",
)

eval_cb = EvalCallback(
    env_eval,
    best_model_save_path=config["paths"]["checkpoints"],
    log_path=config["paths"]["logs"],
    eval_freq=config["train"]["eval_freq"],
    deterministic=True,
    render=False,
)

# === 2.4 Callback list ===

callbacks = CallbackList([checkpoint_cb, eval_cb])
print("✅ Callbacks ready")

traj_cb = TrajectoryCallbackByApples(
    save_dir=config["paths"]["trajectories"],
    grid_size=config["env"]["grid_size"],
    max_per_class=2,       # 👈 due episodi per mela
    max_apples=200         # 👈 salvi fino a 200 mele
)



callbacks.callbacks.append(traj_cb)



# === 2.5 Istanzia modello PPO ===

model = PPO(
    policy=config["policy"]["policy"],
    env=env_train,
    verbose=0,
    seed=config["seed"],
    tensorboard_log=str(config["paths"]["logs"]),
    device=config["device"],
    **config["ppo"],
    policy_kwargs=dict(net_arch=config["policy"]["net_arch"]),
)

# assegna logger personalizzato
model.set_logger(logger)
print("✅ PPO model ready on", config["device"])


# === 2.6 (Facoltativo) salva snapshot iniziale di config ===

cfg_copy = config.copy()
cfg_copy["paths"] = {k: str(v) for k, v in cfg_copy["paths"].items()}
with open(config["paths"]["logs"] / "config.json", "w") as f:
    json.dump(cfg_copy, f, indent=2)

print("💾 Config saved to", config["paths"]["logs"] / "config.json")


✅ Environment initialized — grid: 15
✅ Logger set up in: artifacts\logs\ppo_snake_v1
✅ Callbacks ready




✅ PPO model ready on cpu
💾 Config saved to artifacts\logs\ppo_snake_v1\config.json


In [None]:
# =====================================
# 🧩 BLOCCO 3 — TRAINING + EPISODE LOGGER + EXPORT (Rich-enhanced + persistent model)
# =====================================
import json
import numpy as np
import pandas as pd
from pathlib import Path
from time import time
from rich.console import Console
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, TimeElapsedColumn
from rich.table import Table
from stable_baselines3.common.callbacks import BaseCallback, CallbackList
from stable_baselines3 import PPO

console = Console()

# === 3.0 Caricamento policy precedente se disponibile ===
latest_checkpoint = None
checkpoint_dir = config["paths"]["checkpoints"]
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# trova l’ultimo modello salvato
ckpts = sorted(checkpoint_dir.glob("ppo_snake_*.zip"), key=os.path.getmtime)
if ckpts:
    latest_checkpoint = ckpts[-1]
elif (checkpoint_dir / "final_model.zip").exists():
    latest_checkpoint = checkpoint_dir / "final_model.zip"

if latest_checkpoint:
    console.rule(f"[bold yellow]♻️ Ripristino modello precedente[/bold yellow]")
    console.log(f"📂 Caricamento da: [cyan]{latest_checkpoint.name}[/cyan]")
    model = PPO.load(latest_checkpoint, env=env_train, device=config["device"])
else:
    console.rule(f"[bold yellow]🆕 Nessun modello trovato — training da zero[/bold yellow]")

# === 3.1 Callback custom per loggare ogni episodio ===
class EpisodeLoggerCallback(BaseCallback):
    def __init__(self, log_dir, trajectory_dir, log_freq=1, console=None, num_envs=1):
        super().__init__()
        self.num_envs = num_envs
        self.log_dir = Path(log_dir)
        self.trajectory_dir = Path(trajectory_dir)
        self.log_freq = log_freq
        self.episode_count = 0
        self.episode_summaries = []
        self.console = console or Console()
        self.start_time = None

        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.trajectory_dir.mkdir(parents=True, exist_ok=True)

    def _on_training_start(self) -> None:
        self.start_time = time()
        self.console.rule("[bold yellow]🚀 Inizio Training PPO Snake[/bold yellow]")

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])

        for env_i, info in enumerate(infos):
            if "episode" in info:
                ep_r = info["episode"]["r"]
                ep_l = info["episode"]["l"]
                apples = info.get("apples", 0)
    
                record = {
                    "env_id": env_i,
                    "episode": self.episode_count,
                    "reward": float(ep_r),
                    "steps": int(ep_l),
                    "apples": int(apples),
                }
    
                self.episode_summaries.append(record)
    
                if self.episode_count % self.log_freq == 0:
                    ep_path = self.log_dir / f"episode_{self.episode_count:06d}.json"
                    with open(ep_path, "w") as f:
                        json.dump(record, f, indent=2)
    
                self.console.log(
                    f"🏁 [env {env_i}] Episodio {self.episode_count:<6} | "
                    f"R: [green]{ep_r:.2f}[/green] | Steps: {ep_l:<4} | Apples: {apples}"
                )
    
                self.episode_count += 1
        return True


    def _on_training_end(self) -> None:
        duration = time() - self.start_time
        all_path = self.log_dir / "episode_log.json"
        with open(all_path, "w") as f:
            json.dump(self.episode_summaries, f, indent=2)
        self.console.rule("[bold green]✅ Training Completato[/bold green]")
        self.console.log(f"📘 Episodi totali: {len(self.episode_summaries)}")
        self.console.log(f"🕒 Durata: {duration/60:.1f} min")
        self.console.log(f"🗂️ Log episodi → {all_path}")

# === 3.2 Callback aggiuntiva ===
episode_logger = EpisodeLoggerCallback(
    log_dir=config["paths"]["logs"],
    trajectory_dir=config["paths"]["trajectories"],
    log_freq=config["logging"]["episode_log_freq"],
    console=console,
)
callbacks.callbacks.append(episode_logger)
console.log("✅ EpisodeLogger aggiunto alle callback.")

# === 3.3 Avvio training ===
total_steps = config["train"]["total_timesteps"]
console.log(f"🚀 Training per [bold cyan]{total_steps:,}[/bold cyan] timesteps")

with Progress(
    TextColumn("[bold blue]{task.description}"),
    BarColumn(),
    "[progress.percentage]{task.percentage:>3.1f}%",
    "•",
    TimeElapsedColumn(),
    "•",
    TimeRemainingColumn(),
    console=console,
    transient=True,
) as progress:
    task = progress.add_task("🏋️‍♂️ PPO Snake Training", total=total_steps)
    model.learn(total_timesteps=total_steps, callback=callbacks, progress_bar=False)
    progress.update(task, advance=total_steps)

console.rule("[bold green]🏁 Training completato[/bold green]")

# === 3.4 Salvataggio finale ===
final_path = checkpoint_dir / "final_model.zip"
model.save(final_path)
console.log(f"💾 Modello finale salvato in: [green]{final_path}[/green]")

# === 3.5 Metriche riassuntive ===
ep_data = episode_logger.episode_summaries
if ep_data:
    rewards = [e["reward"] for e in ep_data]
    lengths = [e["steps"] for e in ep_data]
    apples = [e["apples"] for e in ep_data]

    metrics = {
        "algo": config["algo"],
        "run_name": config["run_name"],
        "episodes": len(rewards),
        "mean_reward": float(np.mean(rewards)),
        "std_reward": float(np.std(rewards)),
        "mean_length": float(np.mean(lengths)),
        "mean_apples": float(np.mean(apples)),
        "best_reward": float(np.max(rewards)),
        "config": {
            "seed": config["seed"],
            "grid_size": config["env"]["grid_size"],
            "total_timesteps": config["train"]["total_timesteps"],
            "learning_rate": config["ppo"]["learning_rate"],
        },
    }

    pd.DataFrame([metrics]).to_csv(config["paths"]["logs"] / "metrics.csv", index=False)
    with open(config["paths"]["logs"] / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    table = Table(title="📊 Metriche Globali", show_header=True, header_style="bold magenta")
    for k, v in metrics.items():
        if k != "config":
            table.add_row(str(k), str(v))
    console.print(table)
    console.log(f"📈 Metriche salvate in: {config['paths']['logs']}")
else:
    console.log("[red]⚠️ Nessun episodio loggato![/red]")


Output()

In [None]:
# =====================================
# 🧩 BLOCCO 4 — ANALISI + REPLAY (Rich-enhanced)
# =====================================
import json, numpy as np, pandas as pd, matplotlib.pyplot as plt
from IPython.display import HTML, clear_output
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import track
import time

console = Console()

# === 4.1 Carica progress logs (dal training SB3) ===
log_csv = config["paths"]["logs"] / "progress.csv"

if not log_csv.exists():
    console.log(f"[red]❌ Nessun file di log trovato in {log_csv}[/red]")
else:
    df = pd.read_csv(log_csv)
    console.rule("[bold cyan]📈 Analisi training SB3[/bold cyan]")
    console.log(f"✅ Log caricato: [green]{len(df)}[/green] step di training")

    # === 4.2 Grafici base ===
    def plot_training_curves(df):
        fig, ax = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
        df["time/total_timesteps"] = df["time/total_timesteps"].fillna(method="ffill")

        ax[0].plot(df["time/total_timesteps"], df["rollout/ep_rew_mean"], label="Reward medio", color="green")
        ax[0].set_ylabel("Reward medio"); ax[0].legend()

        ax[1].plot(df["time/total_timesteps"], df["rollout/ep_len_mean"], color="orange", label="Lunghezza media")
        ax[1].set_ylabel("Lunghezza"); ax[1].legend()

        metric = "train/value_loss" if "train/value_loss" in df.columns else "train/entropy_loss"
        label = "Value Loss" if metric == "train/value_loss" else "Entropy"
        ax[2].plot(df["time/total_timesteps"], df[metric], color="purple", label=label)
        ax[2].set_ylabel(label); ax[2].set_xlabel("Timesteps"); ax[2].legend()

        plt.tight_layout()
        plt.show()

    console.log("📊 Generazione grafici di training…")
    plot_training_curves(df)

# === 4.3 Analisi aggregata episodio-per-episodio ===
ep_log_path = config["paths"]["logs"] / "episode_log.json"

if not ep_log_path.exists():
    console.log(f"[red]⚠️ Nessun file episodio trovato in {ep_log_path}[/red]")
else:
    with open(ep_log_path) as f:
        ep_data = json.load(f)

    ep_rewards = [e["reward"] for e in ep_data]
    ep_lengths = [e["steps"] for e in ep_data]
    ep_apples = [e["apples"] for e in ep_data]

    console.rule("[bold yellow]📘 Analisi Episodi[/bold yellow]")
    table = Table(title="📈 Metriche Episodio", show_header=True, header_style="bold magenta")
    table.add_column("Totale episodi", justify="right")
    table.add_column("Reward medio", justify="center")
    table.add_column("Lunghezza media", justify="center")
    table.add_column("Mele medie", justify="center")

    table.add_row(
        str(len(ep_data)),
        f"{np.mean(ep_rewards):.3f}",
        f"{np.mean(ep_lengths):.2f}",
        f"{np.mean(ep_apples):.2f}",
    )
    console.print(table)

    plt.figure(figsize=(10,4))
    plt.plot(ep_rewards, label="Reward episodio", color="green")
    plt.xlabel("Episodio"); plt.ylabel("Reward")
    plt.title("Andamento Reward per Episodio")
    plt.legend()
    plt.tight_layout()
    plt.show()


# === 4.4 Replay testuale (con Rich) ===
# === 4.4 Replay testuale aggiornato ===
def replay_episode(file_path, fps=5):
    """
    Mostra un episodio salvato come JSON (versione SmartSampler).
    'file_path' è il percorso completo al file .json
    """
    ep_file = Path(file_path)
    if not ep_file.exists():
        console.log(f"[red]❌ File non trovato: {ep_file}[/red]")
        return

    with open(ep_file) as f:
        traj = json.load(f)

    steps = traj.get("steps", [])
    grid_size = traj.get("grid_size", config["env"]["grid_size"])
    apples = traj.get("apples", "?")
    episode_idx = traj.get("episode_idx", "?")

    console.rule(
        f"[bold cyan]▶️ Replay episodio {episode_idx} — {apples} mele[/bold cyan]"
    )

    for s in track(steps, description="🎬 Riproduzione", console=console):

        grid = np.full((grid_size, grid_size), ".", dtype=str)

        # MELA
        ax, ay = s["apple"]
        grid[ax, ay] = "[red]A[/red]"

        # SERPENTE
        snake = s["snake"]
        hx, hy = snake[0]
        grid[hx, hy] = "[bold yellow]H[/bold yellow]"
        for (x, y) in snake[1:]:
            grid[x, y] = "[green]o[/green]"

        clear_output(wait=True)
        length = len(snake)

        console.log(
            f"[blue]Step {s['step']}[/blue] — Reward: "
            f"[green]{s['reward']:+.2f}[/green] | Length: {length}"
        )

        for row in grid:
            console.print(" ".join(row))

        time.sleep(1 / fps)

    end = traj.get("end", {})
    cause = end.get("cause", "unknown")
    console.print(
        f"[bold green]🏁 Fine episodio — "
        f"Cause: [yellow]{cause}[/yellow], Steps: {end.get('total_steps')}, "
        f"Mele: {end.get('total_apples')}[/bold green]"
    )



In [None]:
# =====================================
# 🧩 BLOCCO 4.5 — REPORT RIASSUNTIVO (Rich-enhanced)
# =====================================
import json, numpy as np
from rich.console import Console
from rich.table import Table
from rich.panel import Panel

console = Console()

# === 4.5 Salva report riassuntivo ===
if 'ep_data' not in locals() or len(ep_data) == 0:
    console.log("[red]⚠️ Nessun dato episodio disponibile, impossibile generare report.[/red]")
else:
    report = {
        "run_name": config["run_name"],
        "algo": config["algo"],
        "episodes": len(ep_data),
        "reward_mean": float(np.mean(ep_rewards)),
        "reward_std": float(np.std(ep_rewards)),
        "length_mean": float(np.mean(ep_lengths)),
        "apples_mean": float(np.mean(ep_apples)),
    }

    report_path = config["paths"]["logs"] / "summary_report.json"
    with open(report_path, "w") as f:
        json.dump(report, f, indent=2)

    # === stampa con rich ===
    console.rule("[bold yellow]📄 Report Riassuntivo Esperimento[/bold yellow]")
    table = Table(show_header=True, header_style="bold magenta")
    table.add_column("Parametro", style="cyan", justify="right")
    table.add_column("Valore", style="green")

    for k, v in report.items():
        table.add_row(str(k), f"{v}")

    console.print(table)
    console.print(Panel.fit(f"✅ Report salvato in [bold green]{report_path}[/bold green]", border_style="green"))

    # preview JSON formattato
    console.log("🗂️ Anteprima JSON:")
    console.print_json(data=report)


In [None]:
# =====================================
# 🧩 BLOCCO 5 — REPLAY GRAFICO + EXPORT VIDEO (Compatibile SmartSampler)
# =====================================
import json, numpy as np, matplotlib.pyplot as plt, matplotlib.patches as patches
from rich.console import Console
from rich.progress import track
from rich.panel import Panel
from pathlib import Path
import imageio.v2 as imageio

console = Console()

# === 5.1 Funzione per disegnare un singolo frame ===
def draw_snake_frame(snake, apple, grid_size, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(5,5))

    ax.clear()
    ax.set_xlim(0, grid_size)
    ax.set_ylim(0, grid_size)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_aspect("equal")
    ax.set_facecolor("#0f0f0f")

    # Griglia
    for i in range(grid_size):
        for j in range(grid_size):
            ax.add_patch(
                patches.Rectangle((i,j), 1,1, fill=False, edgecolor="#1c1c1c", lw=0.4)
            )

    # Mela
    ax.add_patch(patches.Rectangle((apple[0], apple[1]), 1,1, color="#e74c3c"))

    # Corpo serpente
    for seg in snake[1:]:
        ax.add_patch(patches.Rectangle((seg[0], seg[1]), 1,1, color="#27ae60"))

    # Testa serpente
    head = snake[0]
    ax.add_patch(patches.Rectangle((head[0], head[1]), 1,1, color="#2ecc71"))

    return ax


# === Utility per trovare un file trajectory ===
def find_episode_file(pattern="snake_ep_0apples_000.json"):
    """Restituisce path completo di un file trajectory."""
    traj_dir = config["paths"]["trajectories"]
    matches = list(Path(traj_dir).glob(pattern))
    if not matches:
        console.log(f"[red]❌ Nessun file trovato con pattern: {pattern}[/red]")
        return None
    return matches[0]


# === 5.2 Replay grafico interattivo ===
def replay_episode_plot(file_path, fps=6):
    file_path = Path(file_path)
    if not file_path.exists():
        console.log(f"[red]❌ File trajectory non trovato:[/red] {file_path}")
        return

    with open(file_path) as f:
        traj = json.load(f)

    steps = traj["steps"]
    grid_size = traj["grid_size"]
    apples = traj["apples"]

    console.rule(f"[bold cyan]🎬 Replay — {file_path.name} — Mele: {apples}[/bold cyan]")

    fig, ax = plt.subplots(figsize=(5,5))

    for s in track(steps, description="🎥 Riproduzione", console=console):
        snake = s["snake"]
        apple = s["apple"]

        draw_snake_frame(snake, apple, grid_size, ax)
        ax.set_title(
            f"Step {s['step']} | Len {len(snake)} | r={s['reward']:+.2f}"
        )
        plt.pause(1/fps)

    plt.close(fig)
    console.log("[green]🏁 Fine episodio[/green]")


# === 5.3 Esporta video ===
def export_episode_video(file_path, fps=8):
    file_path = Path(file_path)
    if not file_path.exists():
        console.log(f"[red]❌ File trajectory non trovato:[/red] {file_path}")
        return

    with open(file_path) as f:
        traj = json.load(f)

    steps = traj["steps"]
    grid_size = traj["grid_size"]

    out_path = config["paths"]["videos"] / (file_path.stem + ".mp4")

    console.rule(f"[bold yellow]📽️ Esportazione video: {file_path.name}[/bold yellow]")

    frames = []
    fig, ax = plt.subplots(figsize=(4,4))

    for s in track(steps, description="🎞️ Frame", console=console):
        snake = s["snake"]
        apple = s["apple"]

        ax = draw_snake_frame(snake, apple, grid_size, ax)
        ax.set_title(
            f"Step {s['step']} | Len {len(snake)} | r={s['reward']:+.2f}"
        )

        fig.canvas.draw()
        frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        frames.append(frame.copy())

    plt.close(fig)
    imageio.mimsave(out_path, frames, fps=fps)
    console.print(Panel.fit(
        f"🎉 Video salvato in [green]{out_path}[/green]", border_style="green"
    ))
