# 🔥 PROMETHEUS-TQFD v2.0
### Dual-AI Tabula Rasa Chess Training System

Dieses Notebook trainiert zwei verschiedene Schach-KIs:
1. **ATLAS**: AlphaZero-Stil (ResNet + MCTS)
2. **ENTROPY v2.0**: Physik-inspirierter Hybrid-Ansatz

In [None]:
# @title 🔧 1. Setup & Installation
!pip install --quiet python-chess numpy torch psutil lz4 safetensors plotly streamlit pyngrok

import multiprocessing as mp
try:
    mp.set_start_method('spawn', force=True)
    print("✅ Multiprocessing set to 'spawn'")
except RuntimeError:
    pass

In [None]:
# @title 📁 2. Module Generation
import os
os.makedirs('prometheus_tqfd', exist_ok=True)
os.makedirs('prometheus_tqfd/atlas', exist_ok=True)
os.makedirs('prometheus_tqfd/dashboard', exist_ok=True)
os.makedirs('prometheus_tqfd/entropy', exist_ok=True)
os.makedirs('prometheus_tqfd/evaluation', exist_ok=True)
os.makedirs('prometheus_tqfd/orchestration', exist_ok=True)
os.makedirs('prometheus_tqfd/utils', exist_ok=True)


In [None]:
%%writefile prometheus_tqfd/__init__.py



In [None]:
%%writefile prometheus_tqfd/config.py
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import torch
import os
from typing import Optional, Tuple

@dataclass
class PrometheusConfig:
    # === SYSTEM ===
    run_id: str = field(default_factory=lambda: datetime.now().strftime('%Y%m%d_%H%M%S'))
    base_dir: Path = Path('/content/prometheus_runs')
    use_drive: bool = True
    seed: int = 42
    
    # === ATLAS (AlphaZero-Stil) ===
    atlas_input_channels: int = 19
    atlas_res_blocks: int = 8              # Reduziert für Colab
    atlas_channels: int = 256
    atlas_learning_rate: float = 1e-3
    atlas_weight_decay: float = 1e-4
    atlas_batch_size: int = 128            # Hardware-adaptiv
    atlas_replay_size: int = 300_000
    atlas_mcts_simulations: int = 100
    atlas_mcts_cpuct: float = 2.5
    atlas_dirichlet_alpha: float = 0.3
    atlas_dirichlet_epsilon: float = 0.25
    atlas_temperature_moves: int = 30
    atlas_temperature_init: float = 1.0
    atlas_temperature_final: float = 0.1
    
    # === ENTROPY v2.0 (Hybrid Physik) ===
    entropy_input_channels: int = 22       # 19 Board + 3 Felder
    entropy_res_blocks: int = 6
    entropy_channels: int = 128
    entropy_learning_rate: float = 3e-4
    entropy_batch_size: int = 64
    entropy_replay_size: int = 200_000
    
    # Mini-Rollout Parameter
    entropy_rollout_depth: int = 3
    entropy_rollout_count: int = 5
    entropy_temperature_start: float = 3.0
    entropy_temperature_end: float = 0.3
    entropy_temperature_decay: float = 0.9999
    
    # Loss-Gewichtung (muss 1.0 ergeben)
    entropy_loss_outcome: float = 0.30     # Spielergebnis
    entropy_loss_mobility: float = 0.25    # Eigene Optionen
    entropy_loss_pressure: float = 0.20    # Druck auf Gegner
    entropy_loss_stability: float = 0.15   # TD auf Energie
    entropy_loss_novelty: float = 0.10     # RND Exploration
    
    # Physik-Konstanten
    physics_energy_king: float = 1000.0
    physics_energy_queen: float = 9.5
    physics_energy_rook: float = 5.25
    physics_energy_bishop: float = 3.33
    physics_energy_knight: float = 3.05
    physics_energy_pawn: float = 1.0
    physics_diffusion_sigma: float = 2.5
    physics_field_alpha: float = 1.0
    physics_field_beta: float = 0.5
    
    # === TACTICS DETECTOR ===
    tactics_boost_strength: float = 1.0      # Volle Stärke
    tactics_boost_decay: float = 1.0         # Kein Decay (permanent)
    tactics_mate_boost: float = 50.0         # Boost für Matt-in-1
    tactics_hanging_boost: float = 5.0       # Boost für Figurenrettung
    tactics_threat_boost: float = 3.0        # Boost für Verteidigung

    # === TRAINING ===
    min_buffer_before_training: int = 5_000
    weight_publish_interval: int = 50
    gpu_priority_atlas: float = 0.6        # 60% Priorität für ATLAS
    num_atlas_selfplay_workers: int = 1
    num_entropy_selfplay_workers: int = 1
    
    # === CHECKPOINTS ===
    checkpoint_micro_interval: int = 5     # Minuten
    checkpoint_light_interval: int = 15    # Minuten
    checkpoint_full_interval: int = 60     # Minuten
    checkpoint_keep_n: int = 3
    
    # === EVALUATION ===
    eval_interval_games: int = 5_000
    eval_games_atlas_entropy: int = 50
    eval_games_vs_random: int = 20
    eval_games_vs_heuristic: int = 20
    elo_initial: float = 1000.0
    elo_k_factor: float = 32.0
    
    # === DASHBOARD ===
    dashboard_port: int = 8501
    dashboard_refresh_seconds: float = 2.0
    
    # === RESILIENZ ===
    heartbeat_timeout: float = 60.0
    max_process_restarts: int = 3
    oom_batch_reduction: float = 0.5
    oom_min_batch_size: int = 8

@dataclass
class HardwareConfig:
    device: str
    gpu_name: Optional[str]
    vram_gb: Optional[float]
    ram_gb: float
    cpu_cores: int
    is_colab: bool

def adjust_config_for_hardware(config: PrometheusConfig, hw: HardwareConfig) -> PrometheusConfig:
    """
    Erkennt Hardware und setzt optimale Parameter.
    """
    if hw.device == 'cpu':
        config.atlas_batch_size = 32
        config.entropy_batch_size = 16
        config.atlas_res_blocks = 5
        config.atlas_mcts_simulations = 50
        config.atlas_replay_size = 100_000
        return config

    vram_gb = hw.vram_gb or 0.0

    if vram_gb >= 40:  # A100
        config.atlas_batch_size = 512
        config.atlas_mcts_simulations = 400
        config.atlas_res_blocks = 12
        config.atlas_replay_size = 1_000_000
        config.entropy_batch_size = 256
        config.entropy_replay_size = 500_000
    elif vram_gb >= 16:  # V100 / T4-Pro
        # Defaults sind bereits T4-optimiert
        config.atlas_batch_size = 128
        config.atlas_mcts_simulations = 100
        config.atlas_res_blocks = 8
        config.atlas_replay_size = 300_000
    elif vram_gb >= 12:  # T4-Free
        config.atlas_batch_size = 128
        config.atlas_mcts_simulations = 100
        config.atlas_res_blocks = 8
        config.atlas_replay_size = 300_000
    else:  # < 8GB (z.B. K80 oder kleine lokale GPU)
        config.atlas_batch_size = 32
        config.atlas_mcts_simulations = 50
        config.atlas_res_blocks = 5
        config.atlas_replay_size = 100_000
        config.entropy_batch_size = 32
        config.entropy_replay_size = 100_000

    # Worker-Anzahl anpassen
    if hw.cpu_cores >= 8:
        config.num_atlas_selfplay_workers = 2
        config.num_entropy_selfplay_workers = 2
    else:
        config.num_atlas_selfplay_workers = 1
        config.num_entropy_selfplay_workers = 1

    return config



In [None]:
%%writefile prometheus_tqfd/encoding.py
import chess
import numpy as np
import torch

class BoardEncoder:
    def __init__(self, use_history=False, history_len=8):
        self.use_history = use_history
        self.history_len = history_len
        self.num_channels = 19 + (12 * (history_len - 1)) if use_history else 19

    def encode(self, board: chess.Board) -> torch.Tensor:
        # [C, 8, 8]
        tensor = np.zeros((19, 8, 8), dtype=np.float32)

        # Piece channels
        for square, piece in board.piece_map().items():
            rank, file = divmod(square, 8)
            # White: 0-5, Black: 6-11
            channel = piece.piece_type - 1 + (0 if piece.color == chess.WHITE else 6)
            tensor[channel, rank, file] = 1.0

        # Castling rights
        if board.has_kingside_castling_rights(chess.WHITE): tensor[12, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.WHITE): tensor[13, :, :] = 1.0
        if board.has_kingside_castling_rights(chess.BLACK): tensor[14, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.BLACK): tensor[15, :, :] = 1.0

        # En passant
        if board.ep_square is not None:
            rank, file = divmod(board.ep_square, 8)
            tensor[16, rank, file] = 1.0

        # Halfmove clock
        tensor[17, :, :] = board.halfmove_clock / 100.0

        # Side to move
        if board.turn == chess.WHITE:
            tensor[18, :, :] = 1.0

        return torch.from_numpy(tensor)

class MoveEncoder:
    def __init__(self):
        self.move_to_idx = {}
        self.idx_to_move = {}
        self._create_mapping()

    def _create_mapping(self):
        idx = 0
        # directions: (dr, df)
        # Queen moves
        directions = [
            (1, 0), (1, 1), (0, 1), (-1, 1),
            (-1, 0), (-1, -1), (0, -1), (1, -1)
        ]

        for from_sq in range(64):
            from_rank, from_file = divmod(from_sq, 8)

            # Queen-like moves
            for d_idx, (dr, df) in enumerate(directions):
                for dist in range(1, 8):
                    to_rank, to_file = from_rank + dr * dist, from_file + df * dist
                    if 0 <= to_rank < 8 and 0 <= to_file < 8:
                        # This move exists. We need a unique ID for (from_sq, d_idx, dist)
                        # Actually AlphaZero maps each move to 73 planes of 8x8
                        # 0-55: Queen moves (8 dir * 7 dist)
                        # 56-63: Knight moves
                        # 64-72: Underpromotions
                        pass

        # Let's use a simpler mapping that's consistent with the spec: 64 * 73
        # plane_idx:
        # 0..55: Queen moves (direction * 7 + (distance - 1))
        # 56..63: Knight moves
        # 64..72: Underpromotions

        # We don't really need to pre-populate everything if we can compute it
        pass

    def get_plane_and_sq(self, move: chess.Move):
        from_sq = move.from_square
        to_sq = move.to_square
        from_rank, from_file = divmod(from_sq, 8)
        to_rank, to_file = divmod(to_sq, 8)
        dr, df = to_rank - from_rank, to_file - from_file

        # Underpromotions
        if move.promotion and move.promotion != chess.QUEEN:
            # 64-72: 3 directions (df: -1, 0, 1) x 3 pieces (N, B, R)
            # Pieces: N=2, B=3, R=4
            piece_idx = move.promotion - 2 # 0, 1, 2
            dir_idx = df + 1 # 0, 1, 2
            plane = 64 + piece_idx * 3 + dir_idx
            return plane, from_sq

        # Knight moves
        knight_moves = [
            (2, 1), (1, 2), (-1, 2), (-2, 1),
            (-2, -1), (-1, -2), (1, -2), (2, -1)
        ]
        if (dr, df) in knight_moves:
            plane = 56 + knight_moves.index((dr, df))
            return plane, from_sq

        # Queen moves
        directions = [
            (1, 0), (1, 1), (0, 1), (-1, 1),
            (-1, 0), (-1, -1), (0, -1), (1, -1)
        ]
        abs_dr, abs_df = abs(dr), abs(df)
        dist = max(abs_dr, abs_df)
        if (dr // dist, df // dist) in directions and (abs_dr == 0 or abs_df == 0 or abs_dr == abs_df):
            d_idx = directions.index((dr // dist, df // dist))
            plane = d_idx * 7 + (dist - 1)
            return plane, from_sq

        raise ValueError(f"Invalid move for encoding: {move}")

    def move_to_index(self, move: chess.Move) -> int:
        plane, from_sq = self.get_plane_and_sq(move)
        return from_sq * 73 + plane

    def index_to_move(self, index: int, board: chess.Board) -> chess.Move:
        from_sq, plane = divmod(index, 73)
        from_rank, from_file = divmod(from_sq, 8)

        if plane < 56:
            # Queen move
            d_idx, dist_m1 = divmod(plane, 7)
            dist = dist_m1 + 1
            directions = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)]
            dr, df = directions[d_idx]
            to_rank, to_file = from_rank + dr * dist, from_file + df * dist
            to_sq = to_rank * 8 + to_file
            move = chess.Move(from_sq, to_sq)
            # Check for promotion to Queen (default)
            if board.piece_at(from_sq) and board.piece_at(from_sq).piece_type == chess.PAWN:
                if (to_rank == 7 and board.turn == chess.WHITE) or (to_rank == 0 and board.turn == chess.BLACK):
                    move.promotion = chess.QUEEN
            return move
        elif plane < 64:
            # Knight move
            knight_moves = [(2, 1), (1, 2), (-1, 2), (-2, 1), (-2, -1), (-1, -2), (1, -2), (2, -1)]
            dr, df = knight_moves[plane - 56]
            to_rank, to_file = from_rank + dr, from_file + df
            to_sq = to_rank * 8 + to_file
            return chess.Move(from_sq, to_sq)
        else:
            # Underpromotion
            piece_idx, dir_idx = divmod(plane - 64, 3)
            df = dir_idx - 1
            piece = piece_idx + 2
            to_rank = 7 if board.turn == chess.WHITE else 0
            to_file = from_file + df
            to_sq = to_rank * 8 + to_file
            return chess.Move(from_sq, to_sq, promotion=piece)

    def get_legal_mask(self, board: chess.Board) -> torch.Tensor:
        mask = torch.zeros(4672, dtype=torch.bool)
        for move in board.legal_moves:
            mask[self.move_to_index(move)] = True
        return mask



In [None]:
%%writefile prometheus_tqfd/physics.py
import chess
import numpy as np
import torch
from prometheus_tqfd.config import PrometheusConfig

class PhysicsFieldCalculator:
    """
    Berechnet physik-inspirierte Felder aus Board-Zustand.
    
    Ausgabe: Tensor[3, 8, 8]
    - Kanal 0: Masse-Feld M(x,y) - Gauß-gewichtete Figurenenergie
    - Kanal 1: Mobilitäts-Feld F(x,y) - Angriffs-/Bewegungsdruck
    - Kanal 2: Druck-Feld P(x,y) - Kombination aus M und F
    """
    
    def __init__(self, config: PrometheusConfig):
        self.energies = {
            chess.KING: config.physics_energy_king,
            chess.QUEEN: config.physics_energy_queen,
            chess.ROOK: config.physics_energy_rook,
            chess.BISHOP: config.physics_energy_bishop,
            chess.KNIGHT: config.physics_energy_knight,
            chess.PAWN: config.physics_energy_pawn,
        }
        self.sigma = config.physics_diffusion_sigma
        self.alpha = config.physics_field_alpha
        self.beta = config.physics_field_beta
        
        self._precompute_gaussian_kernel()
    
    def _precompute_gaussian_kernel(self):
        # 15x15 kernel um jede Position abzudecken
        size = 15
        center = size // 2
        kernel = np.zeros((size, size))
        for i in range(size):
            for j in range(size):
                dist_sq = (i - center)**2 + (j - center)**2
                kernel[i, j] = np.exp(-dist_sq / (2 * self.sigma**2))
        self.kernel = kernel

    def _gaussian_at(self, x0, y0):
        # Returns a 8x8 grid with a gaussian centered at x0, y0
        grid = np.zeros((8, 8))
        size = 15
        center = size // 2
        
        for r in range(8):
            for c in range(8):
                dr, dc = r - y0, c - x0
                if abs(dr) <= center and abs(dc) <= center:
                    grid[r, c] = self.kernel[center + dr, center + dc]
        return grid

    def compute(self, board: chess.Board) -> torch.Tensor:
        """
        Berechnet alle drei Felder.
        """
        mass = self._compute_mass_field(board)
        mobility = self._compute_mobility_field(board)
        
        # Druck-Feld: alpha*M + beta*F (geglättet)
        # Hier vereinfacht als direkte Kombination
        pressure = self.alpha * mass + self.beta * mobility
        
        # Stack and to tensor
        fields = np.stack([mass, mobility, pressure])
        return torch.from_numpy(fields).float()
    
    def _compute_mass_field(self, board: chess.Board) -> np.ndarray:
        field = np.zeros((8, 8))
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                energy = self.energies[piece.piece_type]
                sign = 1 if piece.color == chess.WHITE else -1
                x0, y0 = square % 8, square // 8
                field += sign * energy * self._gaussian_at(x0, y0)
        
        # Normalisierung
        if np.max(np.abs(field)) > 0:
            field = field / np.max(np.abs(field))
        return field
    
    def _compute_mobility_field(self, board: chess.Board) -> np.ndarray:
        field = np.zeros((8, 8))
        for color in [chess.WHITE, chess.BLACK]:
            sign = 1 if color == chess.WHITE else -1
            for square in chess.SQUARES:
                attackers = len(board.attackers(color, square))
                x, y = square % 8, square // 8
                field[y, x] += sign * attackers
        
        # Normalisieren auf [-1, 1]
        if field.max() - field.min() > 0:
            field = 2 * (field - field.min()) / (field.max() - field.min()) - 1
        return field



In [None]:
%%writefile prometheus_tqfd/tactics.py
import chess
import torch
from typing import Dict, Any, List, Optional, Tuple
from prometheus_tqfd.encoding import MoveEncoder
from prometheus_tqfd.config import PrometheusConfig

class TacticsDetector:
    """
    Regelbasierter Detektor für kritische taktische Muster.
    KEINE ML - reine Schachlogik via python-chess.
    
    Erkennt:
    - Matt in 1 (für uns)
    - Matt-Drohung (vom Gegner)
    - Hängende Figuren
    - Verfügbare Schachs
    """
    
    def __init__(self, config: PrometheusConfig):
        self.config = config
        self.move_encoder = MoveEncoder()
        self.boost_strength = config.tactics_boost_strength
        self.decay_rate = config.tactics_boost_decay
        self.current_strength = self.boost_strength

    def decay_step(self):
        self.current_strength *= self.decay_rate

    def detect(self, board: chess.Board) -> Dict[str, Any]:
        threats = {
            'mate_in_1': None,          # Der Matt-Zug, wenn vorhanden
            'mate_threat': False,        # Gegner droht Matt
            'hanging_pieces': [],        # Ungeschützte eigene Figuren (square, piece)
            'checks_available': 0,       # Anzahl möglicher Schachs
            'captures_available': [],    # Schlagzüge
        }
        
        # Matt in 1 suchen
        for move in board.legal_moves:
            board.push(move)
            if board.is_checkmate():
                threats['mate_in_1'] = move
                board.pop()
                break
            board.pop()
        
        # Gegner-Matt-Drohung prüfen
        # Wir machen einen Nullzug um zu sehen ob der Gegner Matt setzen kann
        if not board.is_check():
            board.push(chess.Move.null())
            for move in board.legal_moves:
                board.push(move)
                if board.is_checkmate():
                    threats['mate_threat'] = True
                    board.pop()
                    break
                board.pop()
            board.pop()
        
        # Hängende Figuren finden
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece and piece.color == board.turn:
                if board.is_attacked_by(not board.turn, square):
                    defenders = len(list(board.attackers(board.turn, square)))
                    attackers = len(list(board.attackers(not board.turn, square)))
                    if defenders < attackers:
                        threats['hanging_pieces'].append((square, piece))
        
        # Schachs zählen und Schlagzüge sammeln
        for move in board.legal_moves:
            if board.gives_check(move):
                threats['checks_available'] += 1
            if board.is_capture(move):
                threats['captures_available'].append(move)
        
        return threats
    
    def get_tactical_boost(self, board: chess.Board) -> torch.Tensor:
        """
        Erzeugt Boost-Tensor für taktisch kritische Züge.
        Wird zur Policy addiert (vor Softmax).
        """
        boost = torch.zeros(4672)
        threats = self.detect(board)
        
        # Matt in 1: Massiver Boost
        if threats['mate_in_1']:
            idx = self.move_encoder.move_to_index(threats['mate_in_1'])
            boost[idx] = self.config.tactics_mate_boost
        
        # Matt-Drohung: Defensive Züge boosten
        if threats['mate_threat']:
            for move in board.legal_moves:
                # Schach geben ist oft defensiv
                if board.gives_check(move):
                    idx = self.move_encoder.move_to_index(move)
                    boost[idx] += self.config.tactics_threat_boost
                # König bewegen
                piece = board.piece_at(move.from_square)
                if piece and piece.piece_type == chess.KING:
                    idx = self.move_encoder.move_to_index(move)
                    boost[idx] += self.config.tactics_threat_boost * 0.5
        
        # Hängende Figuren retten
        for square, piece in threats['hanging_pieces']:
            for move in board.legal_moves:
                if move.from_square == square:
                    idx = self.move_encoder.move_to_index(move)
                    value = self._piece_value(piece.piece_type)
                    boost[idx] += value * (self.config.tactics_hanging_boost / 10.0)
        
        return boost * self.current_strength
    
    def _piece_value(self, piece_type: int) -> float:
        values = {
            chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3,
            chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 100
        }
        return values.get(piece_type, 0)



In [None]:
%%writefile prometheus_tqfd/atlas/__init__.py



In [None]:
%%writefile prometheus_tqfd/atlas/network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from prometheus_tqfd.config import PrometheusConfig

class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x = F.relu(x + residual)
        return x

class AtlasNetwork(nn.Module):
    """
    Standard AlphaZero-Architektur:
    - Input: [batch, 19, 8, 8]
    - Residual Tower: N Blöcke mit Skip-Connections
    - Policy Head: Wahrscheinlichkeiten über 4672 Züge
    - Value Head: Gewinnwahrscheinlichkeit [-1, +1]
    """
    
    def __init__(self, config: PrometheusConfig):
        super().__init__()
        C = config.atlas_channels
        
        # Input Block
        self.input_block = nn.Sequential(
            nn.Conv2d(19, C, kernel_size=3, padding=1),
            nn.BatchNorm2d(C),
            nn.ReLU()
        )
        
        # Residual Tower
        self.res_blocks = nn.ModuleList([
            ResidualBlock(C) 
            for _ in range(config.atlas_res_blocks)
        ])
        
        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Conv2d(C, 32, kernel_size=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 64, 4672)
        )
        
        # Value Head
        self.value_head = nn.Sequential(
            nn.Conv2d(C, 4, kernel_size=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(4 * 64, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.input_block(x)
        for block in self.res_blocks:
            x = block(x)
        policy_logits = self.policy_head(x)
        value = self.value_head(x)
        return policy_logits, value



In [None]:
%%writefile prometheus_tqfd/atlas/mcts.py
import math
import torch
import torch.nn.functional as F
import numpy as np
import chess
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Tuple
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.atlas.network import AtlasNetwork
from prometheus_tqfd.encoding import BoardEncoder, MoveEncoder

@dataclass
class MCTSNode:
    state: chess.Board
    parent: Optional['MCTSNode'] = None
    parent_action: Optional[chess.Move] = None
    children: Dict[chess.Move, 'MCTSNode'] = field(default_factory=dict)
    visit_count: int = 0
    total_value: float = 0.0
    prior: float = 0.0
    
    @property
    def q_value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.total_value / self.visit_count

class MCTS:
    """
    Monte Carlo Tree Search mit PUCT.
    """
    
    def __init__(self, config: PrometheusConfig, network: AtlasNetwork, device: str):
        self.config = config
        self.network = network
        self.device = device
        self.encoder = BoardEncoder()
        self.move_encoder = MoveEncoder()
    
    def search(self, root_board: chess.Board, num_simulations: int = None) -> MCTSNode:
        if num_simulations is None:
            num_simulations = self.config.atlas_mcts_simulations
        
        root = MCTSNode(state=root_board.copy())
        
        # Initial expansion of root
        self._evaluate_and_expand(root)
        self._add_dirichlet_noise(root)
        
        for _ in range(num_simulations):
            node = self._select(root)
            value = self._evaluate_and_expand(node)
            self._backpropagate(node, value)
        
        return root
    
    def _ucb_score(self, parent: MCTSNode, child: MCTSNode) -> float:
        """PUCT-Formel"""
        c = self.config.atlas_mcts_cpuct
        q = child.q_value
        u = c * child.prior * math.sqrt(parent.visit_count) / (1 + child.visit_count)
        return q + u
    
    def _select(self, node: MCTSNode) -> MCTSNode:
        """Traversiere bis Blatt mit max UCB"""
        while node.children and not node.state.is_game_over():
            node = max(node.children.values(), key=lambda c: self._ucb_score(node, c))
        return node
    
    def _evaluate_and_expand(self, node: MCTSNode) -> float:
        """Netzwerk-Inference und Expansion"""
        if node.state.is_game_over():
            result = node.state.result()
            if result == "1-0":
                v = 1.0 if node.state.turn == chess.BLACK else -1.0
            elif result == "0-1":
                v = -1.0 if node.state.turn == chess.BLACK else 1.0
            else:
                v = 0.0
            return v
        
        # Netzwerk-Inference
        tensor = self.encoder.encode(node.state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            policy_logits, value = self.network(tensor)
        
        # Legal-Mask und Softmax
        legal_mask = self.move_encoder.get_legal_mask(node.state).to(self.device)
        policy_logits[~legal_mask.unsqueeze(0)] = float('-inf')
        policy = F.softmax(policy_logits, dim=1).squeeze(0)
        
        # Kinder erstellen
        for move in node.state.legal_moves:
            idx = self.move_encoder.move_to_index(move)
            new_board = node.state.copy()
            new_board.push(move)
            child = MCTSNode(
                state=new_board,
                parent=node,
                parent_action=move,
                prior=policy[idx].item()
            )
            node.children[move] = child
        
        return value.item()
    
    def _backpropagate(self, node: MCTSNode, value: float):
        """Value entlang Pfad propagieren mit Vorzeichenwechsel"""
        while node is not None:
            node.visit_count += 1
            node.total_value += value
            value = -value  # Perspektivwechsel
            node = node.parent
    
    def _add_dirichlet_noise(self, root: MCTSNode):
        """Exploration-Noise am Wurzelknoten"""
        if not root.children:
            return
        alpha = self.config.atlas_dirichlet_alpha
        epsilon = self.config.atlas_dirichlet_epsilon
        noise = np.random.dirichlet([alpha] * len(root.children))
        for i, child in enumerate(root.children.values()):
            child.prior = (1 - epsilon) * child.prior + epsilon * noise[i]
    
    def get_policy_target(self, root: MCTSNode, temperature: float) -> torch.Tensor:
        """Normalisierte Visit-Counts als Policy-Target"""
        policy = torch.zeros(4672)
        visits = []
        indices = []
        
        for move, child in root.children.items():
            idx = self.move_encoder.move_to_index(move)
            visits.append(child.visit_count)
            indices.append(idx)
        
        visits = torch.tensor(visits, dtype=torch.float32)
        
        if temperature <= 0.01:
            # Greedy
            best_idx = visits.argmax()
            policy[indices[best_idx]] = 1.0
        else:
            # Temperature-Sampling
            probs = (visits ** (1 / temperature))
            probs = probs / probs.sum()
            for i, idx in enumerate(indices):
                policy[idx] = probs[i]
        
        return policy
    
    def select_move(self, root: MCTSNode, temperature: float) -> chess.Move:
        """Wähle Zug basierend auf Visit-Counts"""
        moves = list(root.children.keys())
        visits = torch.tensor([root.children[m].visit_count for m in moves], dtype=torch.float32)
        
        if temperature <= 0.01:
            return moves[visits.argmax()]
        
        probs = (visits ** (1 / temperature))
        probs = probs / (probs.sum() + 1e-8)
        idx = torch.multinomial(probs, 1).item()
        return moves[idx]



In [None]:
%%writefile prometheus_tqfd/atlas/selfplay.py
import time
import torch
import chess
from typing import List, Tuple, Dict
from multiprocessing import Queue, Event
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.atlas.network import AtlasNetwork
from prometheus_tqfd.atlas.mcts import MCTS
from prometheus_tqfd.encoding import BoardEncoder, MoveEncoder

class AtlasSelfPlayWorker:
    """
    Generiert Self-Play-Spiele mit MCTS.
    """
    
    def __init__(self, config: PrometheusConfig, weights_queue: Queue, 
                 data_queue: Queue, device: str, worker_id: int):
        self.config = config
        self.weights_queue = weights_queue
        self.data_queue = data_queue
        self.device = device
        self.worker_id = worker_id
        
        self.network = AtlasNetwork(config).to(device)
        self.network.eval()
        self.mcts = MCTS(config, self.network, device)
        self.encoder = BoardEncoder()
        self.move_encoder = MoveEncoder()
        
        self.games_played = 0
        self.weights_version = 0
    
    def run(self, stop_event: Event, heartbeat_dict: Dict):
        """Hauptschleife des Workers"""
        while not stop_event.is_set():
            # Heartbeat
            heartbeat_dict[f'atlas_selfplay_{self.worker_id}'] = time.time()
            
            # Gewichte updaten
            self._maybe_update_weights()
            
            # Spiel spielen
            trajectory = self._play_game()
            
            # In Queue schieben
            self.data_queue.put(trajectory)
            self.games_played += 1
    
    def _maybe_update_weights(self):
        """Lade neue Gewichte wenn verfügbar"""
        try:
            # Get latest weights from queue
            latest = None
            while not self.weights_queue.empty():
                latest = self.weights_queue.get_nowait()
            
            if latest:
                weights, version = latest
                if version > self.weights_version:
                    self.network.load_state_dict(weights)
                    self.weights_version = version
        except:
            pass
    
    def _play_game(self) -> List[Tuple[torch.Tensor, torch.Tensor, float]]:
        """Spiele ein komplettes Spiel"""
        trajectory = []
        board = chess.Board()
        move_count = 0
        
        while not board.is_game_over():
            # Temperatur-Schedule
            if move_count < self.config.atlas_temperature_moves:
                temperature = self.config.atlas_temperature_init
            else:
                temperature = self.config.atlas_temperature_final
            
            # MCTS
            root = self.mcts.search(board)
            
            # Daten sammeln
            state_tensor = self.encoder.encode(board)
            policy_target = self.mcts.get_policy_target(root, temperature)
            
            trajectory.append((state_tensor, policy_target, None))  # Value später
            
            # Zug ausführen
            move = self.mcts.select_move(root, temperature)
            board.push(move)
            move_count += 1
            
            # Spiellänge begrenzen
            if move_count > 400: # Slightly more than spec's 300 for safety
                break
        
        # Value-Targets mit Spielergebnis füllen
        result = self._get_result(board)
        final_trajectory = []
        for i in range(len(trajectory)):
            # Perspektive: Weiß bei geraden Zügen (0, 2...), Schwarz bei ungeraden (1, 3...)
            # Wenn i=0, turn=White, perspective=1, value_target = result * 1
            # Wenn result=1 (White wins), i=0 gets 1.
            # Wenn i=1, turn=Black, perspective=-1, value_target = result * -1 = -1. Correct.
            perspective = 1 if i % 2 == 0 else -1
            value_target = result * perspective
            final_trajectory.append((trajectory[i][0], trajectory[i][1], value_target))
        
        return final_trajectory
    
    def _get_result(self, board: chess.Board) -> float:
        """Spielergebnis aus Weiß-Perspektive"""
        result = board.result()
        if result == "1-0":
            return 1.0
        elif result == "0-1":
            return -1.0
        return 0.0



In [None]:
%%writefile prometheus_tqfd/atlas/trainer.py
import time
import torch
import torch.nn.functional as F
from multiprocessing import Queue, Event, Lock
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.atlas.network import AtlasNetwork
from prometheus_tqfd.utils.replay_buffer import ReplayBuffer

class AtlasTrainer:
    """
    Trainiert das ATLAS-Netzwerk mit Daten aus Self-Play.
    """
    
    def __init__(self, config: PrometheusConfig, data_queue: Queue,
                 weights_queue: Queue, device: str, shared_values: dict):
        self.config = config
        self.data_queue = data_queue
        self.weights_queue = weights_queue
        self.device = device
        self.shared_values = shared_values
        
        self.network = AtlasNetwork(config).to(device)
        self.optimizer = torch.optim.AdamW(
            self.network.parameters(),
            lr=config.atlas_learning_rate,
            weight_decay=config.atlas_weight_decay
        )
        self.scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
        
        self.replay_buffer = ReplayBuffer(config.atlas_replay_size)
        self.global_step = 0
        self.weights_version = 0
    
    def run(self, stop_event: Event, pause_event: Event, 
            gpu_lock: Lock, heartbeat_dict: dict, metrics_queue: Queue):
        """Hauptschleife des Trainers"""
        while not stop_event.is_set():
            # Heartbeat
            heartbeat_dict['atlas_trainer'] = time.time()
            
            # Pause prüfen
            if pause_event.is_set():
                time.sleep(1)
                continue
            
            # Daten aus Queue holen
            self._collect_data()
            
            # Training wenn genug Daten
            if len(self.replay_buffer) >= self.config.min_buffer_before_training:
                with gpu_lock:
                    metrics = self._train_step()
                    metrics_queue.put({
                        'type': 'atlas_train',
                        'step': self.global_step,
                        **metrics
                    })
                
                # Gewichte publishen
                if self.global_step % self.config.weight_publish_interval == 0:
                    self._publish_weights()
            else:
                time.sleep(1) # Wait for data
    
    def _collect_data(self):
        """Trajektorien aus Queue in Replay Buffer"""
        try:
            while not self.data_queue.empty():
                trajectory = self.data_queue.get_nowait()
                for state, policy, value in trajectory:
                    self.replay_buffer.add(state, policy, value)
        except:
            pass
    
    def _train_step(self) -> dict:
        """Ein Trainingsschritt"""
        self.network.train()
        states, policies, values = self.replay_buffer.sample(self.config.atlas_batch_size)
        states = states.to(self.device)
        policies = policies.to(self.device)
        values = values.to(self.device)
        
        self.optimizer.zero_grad()
        
        # Cast to device-appropriate autocast
        device_type = 'cuda' if self.device == 'cuda' else 'cpu'
        
        with torch.autocast(device_type=device_type):
            policy_logits, value_pred = self.network(states)
            
            # Policy Loss: Cross-Entropy
            log_probs = F.log_softmax(policy_logits, dim=1)
            policy_loss = -torch.sum(policies * log_probs, dim=1).mean()
            
            # Value Loss: MSE
            value_loss = F.mse_loss(value_pred.squeeze(-1), values)
            
            # Total Loss
            loss = policy_loss + value_loss
        
        if self.scaler:
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
            self.optimizer.step()
        
        self.global_step += 1
        self.shared_values['atlas_steps'] = self.global_step
        
        return {
            'loss': loss.item(),
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'grad_norm': grad_norm.item() if grad_norm is not None else 0.0
        }
    
    def _publish_weights(self):
        """Gewichte an Self-Play-Worker senden"""
        self.weights_version += 1
        weights = {k: v.cpu() for k, v in self.network.state_dict().items()}
        
        # Queue leeren und neue Gewichte einfügen
        try:
            while not self.weights_queue.empty():
                self.weights_queue.get_nowait()
        except:
            pass
        
        self.weights_queue.put((weights, self.weights_version))
        
        # Für Checkpoint
        self.shared_values['atlas_weights'] = weights
        self.shared_values['atlas_optimizer'] = self.optimizer.state_dict()
        self.shared_values['atlas_version'] = self.weights_version



In [None]:
%%writefile prometheus_tqfd/entropy/__init__.py



In [None]:
%%writefile prometheus_tqfd/entropy/network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.atlas.network import ResidualBlock

class EntropyNetworkV2(nn.Module):
    """
    Vereinfachtes CNN-only Netzwerk für ENTROPY v2.0.
    
    Input: 22 Kanäle (19 Board + 3 Physik-Felder)
    Output:
    - Policy Logits: [batch, 4672]
    - Energy: [batch, 1] (unbeschränkt, aber mit Soft-Clipping)
    """
    
    def __init__(self, config: PrometheusConfig):
        super().__init__()
        C = config.entropy_channels
        
        # Input: 19 Board-Kanäle + 3 Physik-Feld-Kanäle
        self.input_block = nn.Sequential(
            nn.Conv2d(22, C, kernel_size=3, padding=1),
            nn.BatchNorm2d(C),
            nn.ReLU()
        )
        
        # Residual Tower (kleiner als ATLAS)
        self.res_blocks = nn.ModuleList([
            ResidualBlock(C) for _ in range(config.entropy_res_blocks)
        ])
        
        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Conv2d(C, 32, kernel_size=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 64, 4672)
        )
        
        # Energy Head
        self.energy_head = nn.Sequential(
            nn.Conv2d(C, 4, kernel_size=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(4 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, board_tensor: torch.Tensor, 
                field_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Concatenate board + fields
        x = torch.cat([board_tensor, field_tensor], dim=1)
        
        x = self.input_block(x)
        for block in self.res_blocks:
            x = block(x)
        
        policy_logits = self.policy_head(x)
        energy = self.energy_head(x)
        
        # Soft-Clipping für Energie-Stabilität: tanh(x/10)*10
        energy = torch.tanh(energy / 10.0) * 10.0
        
        return policy_logits, energy

    def get_features(self, board_tensor, field_tensor):
        """Returns the flattened features before heads, for RND"""
        x = torch.cat([board_tensor, field_tensor], dim=1)
        x = self.input_block(x)
        for block in self.res_blocks:
            x = block(x)
        return torch.flatten(x, 1)



In [None]:
%%writefile prometheus_tqfd/entropy/rollout.py
import torch
import torch.nn.functional as F
import numpy as np
import chess
from typing import Tuple, Dict
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.entropy.network import EntropyNetworkV2
from prometheus_tqfd.tactics import TacticsDetector
from prometheus_tqfd.encoding import BoardEncoder, MoveEncoder
from prometheus_tqfd.physics import PhysicsFieldCalculator

class MiniRolloutSelector:
    """
    Schaut 3-5 Züge voraus mit policy-guided Rollouts.
    Nutzt Energie als Evaluation statt klassischem Value.
    """
    
    def __init__(self, config: PrometheusConfig, network: EntropyNetworkV2,
                 tactics: TacticsDetector, device: str):
        self.config = config
        self.network = network
        self.tactics = tactics
        self.device = device
        
        self.encoder = BoardEncoder()
        self.field_calc = PhysicsFieldCalculator(config)
        self.move_encoder = MoveEncoder()
        
        self.depth = config.entropy_rollout_depth
        self.num_rollouts = config.entropy_rollout_count
    
    def select_move(self, board: chess.Board, temperature: float) -> Tuple[chess.Move, float]:
        """
        Wählt Zug basierend auf Mini-Rollouts.
        """
        # Taktik-Check: Matt in 1 sofort spielen
        threats = self.tactics.detect(board)
        if threats['mate_in_1']:
            return threats['mate_in_1'], 100.0
        
        legal_moves = list(board.legal_moves)
        if not legal_moves:
             return None, 0.0
        if len(legal_moves) == 1:
            return legal_moves[0], 0.0
        
        # Rollout-Scores für jeden Zug
        move_scores = {}
        for move in legal_moves:
            scores = []
            for _ in range(self.num_rollouts):
                score = self._rollout(board, move, self.depth)
                scores.append(score)
            move_scores[move] = np.mean(scores)
        
        # Taktik-Boost addieren
        tactic_boost = self.tactics.get_tactical_boost(board)
        for move in legal_moves:
            idx = self.move_encoder.move_to_index(move)
            move_scores[move] += tactic_boost[idx].item() * 0.1
        
        # Boltzmann-Sampling
        moves = list(move_scores.keys())
        scores = torch.tensor([move_scores[m] for m in moves])
        
        if temperature <= 0.01:
            chosen = moves[scores.argmax()]
        else:
            probs = F.softmax(scores / temperature, dim=0)
            idx = torch.multinomial(probs, 1).item()
            chosen = moves[idx]
        
        return chosen, move_scores[chosen]
    
    def _rollout(self, board: chess.Board, first_move: chess.Move, depth: int) -> float:
        """
        Simuliert Spiel für 'depth' Züge.
        """
        sim_board = board.copy()
        our_color = board.turn
        sim_board.push(first_move)
        
        for d in range(depth - 1):
            if sim_board.is_game_over():
                return self._terminal_value(sim_board, our_color)
            
            # Schnelle Zug-Auswahl (Policy-Sampling ohne Rollout)
            move = self._fast_select(sim_board)
            sim_board.push(move)
        
        # Energie am Ende
        energy = self._get_energy(sim_board)
        
        # Aus unserer Perspektive
        # Wenn sim_board.turn == our_color, dann ist energy aus unserer sicht
        # Aber die Energie wird vom Netz meist aus Sicht des aktuellen Spielers (oder absolut)
        # In unserem Netz wird BoardEncoder benutzt, der Kanal 18 für "turn" hat.
        # Wir müssen sicherstellen, dass wir die Energie konsistent interpretieren.
        # Im Spec steht: "Aus unserer Perspektive: Wenn sim_board.turn != our_color: energy = -energy"
        if sim_board.turn != our_color:
            energy = -energy
        
        return energy
    
    def _fast_select(self, board: chess.Board) -> chess.Move:
        """Schnelle Zug-Auswahl ohne Rollout"""
        # Erst Taktik prüfen
        threats = self.tactics.detect(board)
        if threats['mate_in_1']:
            return threats['mate_in_1']
        
        # Sonst Policy-Sampling
        board_tensor = self.encoder.encode(board).unsqueeze(0).to(self.device)
        field_tensor = self.field_calc.compute(board).unsqueeze(0).to(self.device)
        legal_mask = self.move_encoder.get_legal_mask(board).to(self.device)
        
        with torch.no_grad():
            policy_logits, _ = self.network(board_tensor, field_tensor)
        
        policy_logits[~legal_mask.unsqueeze(0)] = float('-inf')
        probs = F.softmax(policy_logits, dim=1).squeeze(0)
        
        idx = torch.multinomial(probs, 1).item()
        return self.move_encoder.index_to_move(idx, board)
    
    def _get_energy(self, board: chess.Board) -> float:
        """Energie einer Position"""
        board_tensor = self.encoder.encode(board).unsqueeze(0).to(self.device)
        field_tensor = self.field_calc.compute(board).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            _, energy = self.network(board_tensor, field_tensor)
        
        return energy.item()
    
    def _terminal_value(self, board: chess.Board, our_color: chess.Color) -> float:
        """Wert einer Endstellung"""
        result = board.result()
        if result == "1-0":
            return 10.0 if our_color == chess.WHITE else -10.0
        elif result == "0-1":
            return -10.0 if our_color == chess.WHITE else 10.0
        return 0.0



In [None]:
%%writefile prometheus_tqfd/entropy/loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from prometheus_tqfd.config import PrometheusConfig

class EntropyV2Loss:
    """
    Hybrid Loss-System für ENTROPY v2.0.
    """
    
    def __init__(self, config: PrometheusConfig, device: str):
        self.config = config
        self.device = device
        self.weights = {
            'outcome': config.entropy_loss_outcome,
            'mobility': config.entropy_loss_mobility,
            'pressure': config.entropy_loss_pressure,
            'stability': config.entropy_loss_stability,
            'novelty': config.entropy_loss_novelty,
        }
        
        # RND Networks
        # Input features are C*64, where C is entropy_channels (128) -> 8192
        # Wait, get_features returns flattened res_tower output.
        # C=128, kernel 3, padding 1 keeps 8x8. So 128*8*8 = 8192.
        self.feature_dim = config.entropy_channels * 64
        self.rnd_target = self._make_rnd_net(frozen=True).to(device)
        self.rnd_predictor = self._make_rnd_net(frozen=False).to(device)
    
    def _make_rnd_net(self, frozen: bool) -> nn.Module:
        net = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        if frozen:
            for param in net.parameters():
                param.requires_grad = False
        return net
    
    def compute(self, batch: Dict, game_results: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        Berechnet den Gesamtverlust.
        """
        losses = {}
        
        # 1. Outcome Loss (sparse)
        losses['outcome'] = self._outcome_loss(batch['energy'], game_results)
        
        # 2. Mobility Loss
        losses['mobility'] = self._mobility_loss(batch['policy_logits'], batch['legal_counts_self'])
        
        # 3. Pressure Loss
        losses['pressure'] = self._pressure_loss(batch['legal_counts_self'], batch['legal_counts_opponent'])
        
        # 4. Stability Loss (TD)
        losses['stability'] = self._stability_loss(batch['energy'], batch['energy_next'])
        
        # 5. Novelty Loss (RND)
        losses['novelty'] = self._novelty_loss(batch['features'])
        
        # Gewichtete Summe
        total = sum(self.weights[k] * losses[k] for k in losses)
        
        return total, {k: v.item() for k, v in losses.items()}
    
    def _outcome_loss(self, energy: torch.Tensor, results: torch.Tensor) -> torch.Tensor:
        target = results.float().view(-1, 1) * 5.0  # Skalierung
        return F.smooth_l1_loss(energy, target)
    
    def _mobility_loss(self, policy_logits: torch.Tensor, legal_counts: torch.Tensor) -> torch.Tensor:
        probs = F.softmax(policy_logits, dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
        
        # Normalisiere mit Anzahl legaler Züge
        max_entropy = torch.log(legal_counts.float() + 1e-8)
        normalized_entropy = entropy / (max_entropy + 1e-8)
        
        return -normalized_entropy.mean()
    
    def _pressure_loss(self, our_legal: torch.Tensor, opp_legal: torch.Tensor) -> torch.Tensor:
        ratio = opp_legal.float() / (our_legal.float() + 1.0)
        return ratio.mean()
    
    def _stability_loss(self, energy_now: torch.Tensor, energy_next: torch.Tensor) -> torch.Tensor:
        gamma = 0.99
        target = gamma * energy_next.detach()
        return F.mse_loss(energy_now, target)
    
    def _novelty_loss(self, features: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            target_out = self.rnd_target(features)
        predictor_out = self.rnd_predictor(features)
        
        error = ((target_out - predictor_out) ** 2).mean(dim=1)
        return error.mean()



In [None]:
%%writefile prometheus_tqfd/entropy/selfplay.py
import time
import torch
import chess
from typing import List, Dict, Tuple
from multiprocessing import Queue, Event
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.entropy.network import EntropyNetworkV2
from prometheus_tqfd.entropy.rollout import MiniRolloutSelector
from prometheus_tqfd.tactics import TacticsDetector
from prometheus_tqfd.encoding import BoardEncoder, MoveEncoder
from prometheus_tqfd.physics import PhysicsFieldCalculator

class EntropySelfPlayWorker:
    """
    Generiert Self-Play-Spiele mit Mini-Rollouts.
    """
    
    def __init__(self, config: PrometheusConfig, weights_queue: Queue,
                 data_queue: Queue, device: str, worker_id: int):
        self.config = config
        self.weights_queue = weights_queue
        self.data_queue = data_queue
        self.device = device
        self.worker_id = worker_id
        
        self.network = EntropyNetworkV2(config).to(device)
        self.network.eval()
        self.tactics = TacticsDetector(config)
        self.selector = MiniRolloutSelector(config, self.network, self.tactics, device)
        
        self.encoder = BoardEncoder()
        self.field_calc = PhysicsFieldCalculator(config)
        self.move_encoder = MoveEncoder()
        
        self.temperature = config.entropy_temperature_start
        self.games_played = 0
        self.weights_version = 0
    
    def run(self, stop_event: Event, heartbeat_dict: Dict):
        """Hauptschleife"""
        while not stop_event.is_set():
            heartbeat_dict[f'entropy_selfplay_{self.worker_id}'] = time.time()
            
            self._maybe_update_weights()
            trajectory = self._play_game()
            self.data_queue.put(trajectory)
            
            self.games_played += 1
            self._decay_temperature()
    
    def _play_game(self) -> List[Dict]:
        """Spiele ein komplettes Spiel"""
        trajectory = []
        board = chess.Board()
        
        while not board.is_game_over() and len(trajectory) < 400:
            # Daten für diesen Zug
            step_data = self._collect_step_data(board)
            
            # Zug wählen
            move, energy_before = self.selector.select_move(board, self.temperature)
            if move is None: break
            
            step_data['move_idx'] = self.move_encoder.move_to_index(move)
            step_data['energy_before'] = energy_before
            
            # Zug ausführen
            board.push(move)
            
            # Energie nach Zug
            step_data['energy_after'] = self._get_energy(board)
            step_data['legal_count_opponent'] = len(list(board.legal_moves))
            
            trajectory.append(step_data)
        
        # Spielergebnis hinzufügen
        result = self._get_result(board)
        for i, step in enumerate(trajectory):
            # Perspective: results are usually 1 for White win.
            # If step i is White's turn (i=0, 2...), perspective is 1.
            perspective = 1 if i % 2 == 0 else -1
            step['game_result'] = result * perspective
        
        return trajectory
    
    def _collect_step_data(self, board: chess.Board) -> Dict:
        """Sammle alle Daten für einen Zug"""
        board_tensor = self.encoder.encode(board)
        field_tensor = self.field_calc.compute(board)
        
        # Features for RND
        with torch.no_grad():
            features = self.network.get_features(
                board_tensor.unsqueeze(0).to(self.device),
                field_tensor.unsqueeze(0).to(self.device)
            ).squeeze(0).cpu()
            
            # Also need policy logits for mobility loss
            policy_logits, _ = self.network(
                board_tensor.unsqueeze(0).to(self.device),
                field_tensor.unsqueeze(0).to(self.device)
            )
            policy_logits = policy_logits.squeeze(0).cpu()

        return {
            'board_tensor': board_tensor,
            'field_tensor': field_tensor,
            'features': features,
            'policy_logits': policy_logits,
            'legal_count_self': len(list(board.legal_moves)),
        }
    
    def _get_energy(self, board: chess.Board) -> float:
        board_tensor = self.encoder.encode(board).unsqueeze(0).to(self.device)
        field_tensor = self.field_calc.compute(board).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            _, energy = self.network(board_tensor, field_tensor)
        return energy.item()
    
    def _get_result(self, board: chess.Board) -> float:
        result = board.result()
        if result == "1-0":
            return 1.0
        elif result == "0-1":
            return -1.0
        return 0.0
    
    def _decay_temperature(self):
        self.temperature = max(
            self.config.entropy_temperature_end,
            self.temperature * self.config.entropy_temperature_decay
        )
    
    def _maybe_update_weights(self):
        try:
            latest = None
            while not self.weights_queue.empty():
                latest = self.weights_queue.get_nowait()
            if latest:
                weights, version = latest
                if version > self.weights_version:
                    self.network.load_state_dict(weights)
                    self.weights_version = version
        except:
            pass



In [None]:
%%writefile prometheus_tqfd/entropy/trainer.py
import time
import torch
from multiprocessing import Queue, Event, Lock
from typing import Dict
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.entropy.network import EntropyNetworkV2
from prometheus_tqfd.entropy.loss import EntropyV2Loss
from prometheus_tqfd.utils.replay_buffer import ReplayBuffer

class EntropyTrainer:
    """
    Trainiert das ENTROPY-Netzwerk.
    """
    
    def __init__(self, config: PrometheusConfig, data_queue: Queue,
                 weights_queue: Queue, device: str, shared_values: dict):
        self.config = config
        self.data_queue = data_queue
        self.weights_queue = weights_queue
        self.device = device
        self.shared_values = shared_values
        
        self.network = EntropyNetworkV2(config).to(device)
        self.loss_fn = EntropyV2Loss(config, device)
        
        self.optimizer = torch.optim.AdamW(
            list(self.network.parameters()) + list(self.loss_fn.rnd_predictor.parameters()),
            lr=config.entropy_learning_rate
        )
        self.scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
        
        self.replay_buffer = ReplayBuffer(config.entropy_replay_size)
        self.global_step = 0
        self.weights_version = 0
    
    def run(self, stop_event: Event, pause_event: Event, 
            gpu_lock: Lock, heartbeat_dict: dict, metrics_queue: Queue):
        """Hauptschleife"""
        while not stop_event.is_set():
            heartbeat_dict['entropy_trainer'] = time.time()
            
            if pause_event.is_set():
                time.sleep(1)
                continue
            
            self._collect_data()
            
            if len(self.replay_buffer) >= self.config.min_buffer_before_training:
                with gpu_lock:
                    metrics = self._train_step()
                    metrics_queue.put({
                        'type': 'entropy_train',
                        'step': self.global_step,
                        **metrics
                    })
                
                if self.global_step % self.config.weight_publish_interval == 0:
                    self._publish_weights()
            else:
                time.sleep(1)
    
    def _collect_data(self):
        try:
            while not self.data_queue.empty():
                trajectory = self.data_queue.get_nowait()
                for step in trajectory:
                    # We need to store (board, fields, features, policy_logits, legal_count_self, energy_after, legal_count_opponent, game_result)
                    self.replay_buffer.add(
                        (step['board_tensor'], step['field_tensor'], step['features'], step['policy_logits']),
                        (step['legal_count_self'], step['legal_count_opponent'], step['energy_after']),
                        step['game_result']
                    )
        except:
            pass
    
    def _train_step(self) -> dict:
        self.network.train()
        batch_size = self.config.entropy_batch_size
        
        states_batch = []
        fields_batch = []
        features_batch = []
        policy_logits_batch = []
        legal_self_batch = []
        legal_opp_batch = []
        energy_next_batch = []
        results_batch = []
        
        data = self.replay_buffer.sample(batch_size)
        # data is (states, policies, values) from ReplayBuffer. 
        # But for entropy we stored tuples in those positions.
        
        # Wait, the ReplayBuffer I wrote is:
        # def add(self, state, policy, value):
        #    self.buffer.append((state, policy, value))
        # def sample(self, batch_size: int):
        #    batch = random.sample(self.buffer, min(len(self.buffer), batch_size))
        #    states, policies, values = zip(*batch)
        #    return (torch.stack(states), torch.stack(policies), torch.tensor(values))
        
        # In EntropyTrainer._collect_data:
        # self.replay_buffer.add(
        #     (step['board_tensor'], step['field_tensor'], step['features'], step['policy_logits']),
        #     (step['legal_count_self'], step['legal_count_opponent'], step['energy_after']),
        #     step['game_result']
        # )
        
        # So we need to unpack.
        samples = random.sample(self.replay_buffer.buffer, min(len(self.replay_buffer), batch_size))
        
        for (s_tup, p_tup, res) in samples:
            states_batch.append(s_tup[0])
            fields_batch.append(s_tup[1])
            features_batch.append(s_tup[2])
            policy_logits_batch.append(s_tup[3])
            legal_self_batch.append(p_tup[0])
            legal_opp_batch.append(p_tup[1])
            energy_next_batch.append(p_tup[2])
            results_batch.append(res)
            
        states = torch.stack(states_batch).to(self.device)
        fields = torch.stack(fields_batch).to(self.device)
        features = torch.stack(features_batch).to(self.device)
        # policy_logits from buffer are precomputed, but we want the ones from the current model during training?
        # Actually mobility loss uses policy_logits from current forward pass.
        
        energy_next = torch.tensor(energy_next_batch).float().to(self.device).view(-1, 1)
        legal_self = torch.tensor(legal_self_batch).to(self.device)
        legal_opp = torch.tensor(legal_opp_batch).to(self.device)
        results = torch.tensor(results_batch).float().to(self.device)
        
        self.optimizer.zero_grad()
        device_type = 'cuda' if self.device == 'cuda' else 'cpu'
        
        with torch.autocast(device_type=device_type):
            policy_logits, energy = self.network(states, fields)
            
            batch_data = {
                'states': states,
                'policy_logits': policy_logits,
                'energy': energy,
                'energy_next': energy_next,
                'legal_counts_self': legal_self,
                'legal_counts_opponent': legal_opp,
                'features': features
            }
            
            loss, loss_dict = self.loss_fn.compute(batch_data, results)
        
        if self.scaler:
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
            self.optimizer.step()
            
        self.global_step += 1
        self.shared_values['entropy_steps'] = self.global_step
        
        return {
            'loss': loss.item(),
            **loss_dict
        }

    def _publish_weights(self):
        self.weights_version += 1
        weights = {k: v.cpu() for k, v in self.network.state_dict().items()}
        try:
            while not self.weights_queue.empty():
                self.weights_queue.get_nowait()
        except:
            pass
        self.weights_queue.put((weights, self.weights_version))
        self.shared_values['entropy_weights'] = weights
        self.shared_values['entropy_optimizer'] = self.optimizer.state_dict()
        self.shared_values['entropy_version'] = self.weights_version

import random



In [None]:
%%writefile prometheus_tqfd/evaluation/__init__.py



In [None]:
%%writefile prometheus_tqfd/evaluation/arena.py
import chess
import torch
import random
from typing import Dict, Tuple, List, Optional
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.atlas.network import AtlasNetwork
from prometheus_tqfd.atlas.mcts import MCTS
from prometheus_tqfd.entropy.network import EntropyNetworkV2
from prometheus_tqfd.entropy.rollout import MiniRolloutSelector
from prometheus_tqfd.tactics import TacticsDetector
from prometheus_tqfd.evaluation.baselines import RandomPlayer, HeuristicPlayer

class AtlasArenaPlayer:
    def __init__(self, config, network, device):
        self.mcts = MCTS(config, network, device)
    def select_move(self, board):
        root = self.mcts.search(board, num_simulations=50) # Reduced for faster eval
        return self.mcts.select_move(root, temperature=0)

class EntropyArenaPlayer:
    def __init__(self, config, network, device):
        self.tactics = TacticsDetector(config)
        self.selector = MiniRolloutSelector(config, network, self.tactics, device)
    def select_move(self, board):
        move, _ = self.selector.select_move(board, temperature=0)
        return move

class Arena:
    """
    Veranstaltet Duelle zwischen Spielern und berechnet ELO.
    """
    
    def __init__(self, config: PrometheusConfig):
        self.config = config
        self.elo_ratings = {
            'atlas': config.elo_initial,
            'entropy': config.elo_initial,
            'random': 400.0,
            'heuristic': 800.0,
        }
        self.match_history = []
    
    def run_evaluation(self, atlas_network: AtlasNetwork, 
                       entropy_network: EntropyNetworkV2,
                       device: str) -> Dict:
        """Komplette Evaluationsrunde"""
        atlas_network.eval()
        entropy_network.eval()
        
        results = {}
        
        # Players
        atlas_player = AtlasArenaPlayer(self.config, atlas_network, device)
        entropy_player = EntropyArenaPlayer(self.config, entropy_network, device)
        random_player = RandomPlayer()
        heuristic_player = HeuristicPlayer()
        
        # 1. ATLAS vs ENTROPY
        wins_a, wins_e, draws = self._play_match(atlas_player, entropy_player, self.config.eval_games_atlas_entropy)
        results['atlas_vs_entropy'] = {'atlas': wins_a, 'entropy': wins_e, 'draws': draws}
        self._update_elo('atlas', 'entropy', wins_a, wins_e, draws)
        
        # 2. vs Baselines
        for name, player in [('atlas', atlas_player), ('entropy', entropy_player)]:
            # vs Random
            w, l, d = self._play_match(player, random_player, self.config.eval_games_vs_random)
            results[f'{name}_vs_random'] = {'wins': w, 'losses': l, 'draws': d}
            self._update_elo(name, 'random', w, l, d, update_b=False)
            
            # vs Heuristic
            w, l, d = self._play_match(player, heuristic_player, self.config.eval_games_vs_heuristic)
            results[f'{name}_vs_heuristic'] = {'wins': w, 'losses': l, 'draws': d}
            self._update_elo(name, 'heuristic', w, l, d, update_b=False)
        
        results['elo'] = self.elo_ratings.copy()
        return results
    
    def _play_match(self, p1, p2, num_games: int) -> Tuple[int, int, int]:
        wins1, wins2, draws = 0, 0, 0
        for i in range(num_games):
            # Alternate colors
            if i % 2 == 0:
                res = self._play_game(p1, p2)
                if res == 1.0: wins1 += 1
                elif res == -1.0: wins2 += 1
                else: draws += 1
            else:
                res = self._play_game(p2, p1)
                if res == 1.0: wins2 += 1
                elif res == -1.0: wins1 += 1
                else: draws += 1
        return wins1, wins2, draws
    
    def _play_game(self, white_player, black_player, max_moves: int = 200) -> float:
        board = chess.Board()
        while not board.is_game_over() and board.fullmove_number <= max_moves:
            player = white_player if board.turn == chess.WHITE else black_player
            move = player.select_move(board)
            if move is None or move not in board.legal_moves:
                # Fallback to random if player fails
                move = random.choice(list(board.legal_moves))
            board.push(move)
        
        res = board.result()
        if res == "1-0": return 1.0
        if res == "0-1": return -1.0
        return 0.0
    
    def _update_elo(self, p_a: str, p_b: str, wins_a: int, wins_b: int, draws: int, update_b: bool = True):
        total = wins_a + wins_b + draws
        if total == 0: return
        
        ra = self.elo_ratings[p_a]
        rb = self.elo_ratings[p_b]
        
        ea = 1 / (1 + 10 ** ((rb - ra) / 400))
        sa = (wins_a + 0.5 * draws) / total
        
        k = self.config.elo_k_factor
        self.elo_ratings[p_a] = ra + k * (sa - ea)
        if update_b:
            eb = 1 - ea
            sb = 1 - sa
            self.elo_ratings[p_b] = rb + k * (sb - eb)



In [None]:
%%writefile prometheus_tqfd/evaluation/baselines.py
import random
import chess

class RandomPlayer:
    """Wählt uniform zufällig aus legalen Zügen"""
    def select_move(self, board: chess.Board) -> chess.Move:
        return random.choice(list(board.legal_moves))

class HeuristicPlayer:
    """
    Einfache regelbasierte Heuristik:
    - Materialbewertung
    - Zentrumskontrolle
    - Mobilität
    - Königssicherheit
    """
    
    PIECE_VALUES = {
        chess.PAWN: 100, chess.KNIGHT: 320, chess.BISHOP: 330,
        chess.ROOK: 500, chess.QUEEN: 900, chess.KING: 0
    }
    
    CENTER_SQUARES = [chess.D4, chess.D5, chess.E4, chess.E5]
    
    def select_move(self, board: chess.Board) -> chess.Move:
        legal_moves = list(board.legal_moves)
        if not legal_moves: return None
        
        # 10% Exploration
        if random.random() < 0.1:
            return random.choice(legal_moves)
        
        best_move = None
        best_score = float('-inf')
        
        for move in legal_moves:
            board.push(move)
            score = self._evaluate(board)
            board.pop()
            
            # Perspektive: Nach unserem Zug ist Gegner dran
            # evaluate gibt wert für Weiß zurück. Wenn wir Schwarz sind, wollen wir niedrigen Wert.
            # Aber hier machen wir es einfacher: _evaluate gibt wert für Spieler am Zug zurück?
            # Nein, _evaluate gibt absoluten Wert (Weiß positiv).
            # Wenn wir am Zug sind und Weiß sind, wollen wir max score.
            # Wenn wir am Zug sind und Schwarz sind, wollen wir min score.
            
            actual_score = score if board.turn == chess.WHITE else -score
            
            if actual_score > best_score:
                best_score = actual_score
                best_move = move
        
        return best_move or random.choice(legal_moves)
    
    def _evaluate(self, board: chess.Board) -> float:
        if board.is_checkmate():
            return -10000 if board.turn == chess.WHITE else 10000
        if board.is_stalemate() or board.is_insufficient_material():
            return 0
        
        score = 0
        
        # Material
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                value = self.PIECE_VALUES[piece.piece_type]
                score += value if piece.color == chess.WHITE else -value
        
        # Zentrumskontrolle
        for sq in self.CENTER_SQUARES:
            if board.is_attacked_by(chess.WHITE, sq):
                score += 10
            if board.is_attacked_by(chess.BLACK, sq):
                score -= 10
        
        # Mobilität (approximiert durch legal moves)
        # Vorsicht: board.turn ändern ist gefährlich während iteration
        original_turn = board.turn
        board.turn = chess.WHITE
        score += len(list(board.legal_moves)) * 2
        board.turn = chess.BLACK
        score -= len(list(board.legal_moves)) * 2
        board.turn = original_turn
        
        return score



In [None]:
%%writefile prometheus_tqfd/orchestration/__init__.py



In [None]:
%%writefile prometheus_tqfd/orchestration/checkpoint.py
import torch
import json
import time
import random
import numpy as np
import pickle
import lz4.frame
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict
from prometheus_tqfd.config import PrometheusConfig

class CheckpointManager:
    """
    Tiered Checkpoint-System: Micro (5min), Light (15min), Full (60min).
    """
    
    def __init__(self, config: PrometheusConfig):
        self.config = config
        self.base_path = config.base_dir / config.run_id / 'checkpoints'
        self.base_path.mkdir(parents=True, exist_ok=True)
        
        # drive path if used
        self.drive_path = Path('/content/drive/MyDrive/prometheus_chess') / config.run_id / 'checkpoints' if config.use_drive else None
        if self.drive_path:
            self.drive_path.mkdir(parents=True, exist_ok=True)
            
        self.last_micro = time.time()
        self.last_light = time.time()
        self.last_full = time.time()
    
    def maybe_checkpoint(self, shared_values: dict, replay_buffers: dict = None):
        now = time.time()
        
        if now - self.last_full > self.config.checkpoint_full_interval * 60:
            self.save_full(shared_values, replay_buffers)
            self.last_full = now
            self.last_light = now
            self.last_micro = now
        elif now - self.last_light > self.config.checkpoint_light_interval * 60:
            self.save_light(shared_values)
            self.last_light = now
            self.last_micro = now
        elif now - self.last_micro > self.config.checkpoint_micro_interval * 60:
            self.save_micro(shared_values)
            self.last_micro = now
            
    def save_micro(self, shared_values: dict):
        path = self.base_path / 'latest'
        path.mkdir(exist_ok=True)
        
        if 'atlas_weights' in shared_values:
            torch.save(shared_values['atlas_weights'], path / 'atlas_weights.pt')
        if 'entropy_weights' in shared_values:
            torch.save(shared_values['entropy_weights'], path / 'entropy_weights.pt')
            
        metadata = {
            'type': 'micro',
            'timestamp': datetime.now().isoformat(),
            'atlas_steps': shared_values.get('atlas_steps', 0),
            'entropy_steps': shared_values.get('entropy_steps', 0),
            'atlas_version': shared_values.get('atlas_version', 0),
            'entropy_version': shared_values.get('entropy_version', 0),
        }
        with open(path / 'metadata.json', 'w') as f:
            json.dump(metadata, f)

    def save_light(self, shared_values: dict):
        self.save_micro(shared_values)
        path = self.base_path / 'latest'
        
        if 'atlas_optimizer' in shared_values:
            torch.save(shared_values['atlas_optimizer'], path / 'atlas_optimizer.pt')
        if 'entropy_optimizer' in shared_values:
            torch.save(shared_values['entropy_optimizer'], path / 'entropy_optimizer.pt')
            
        rng_states = {
            'python': random.getstate(),
            'numpy': np.random.get_state(),
            'torch': torch.get_rng_state(),
        }
        if torch.cuda.is_available():
            rng_states['cuda'] = torch.cuda.get_rng_state()
        torch.save(rng_states, path / 'rng_states.pt')
        
        with open(path / 'metadata.json', 'r') as f:
            metadata = json.load(f)
        metadata['type'] = 'light'
        with open(path / 'metadata.json', 'w') as f:
            json.dump(metadata, f)

    def save_full(self, shared_values: dict, replay_buffers: dict = None):
        self.save_light(shared_values)
        path = self.base_path / 'latest'
        
        if replay_buffers:
            for name, buffer in replay_buffers.items():
                with lz4.frame.open(path / f'{name}_replay.lz4', 'wb') as f:
                    pickle.dump(buffer.get_data(), f)
                    
        with open(path / 'metadata.json', 'r') as f:
            metadata = json.load(f)
        metadata['type'] = 'full'
        with open(path / 'metadata.json', 'w') as f:
            json.dump(metadata, f)
            
        # Archive current 'latest' to a timestamped folder
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        archive_path = self.base_path / f'full_{timestamp}'
        import shutil
        shutil.copytree(path, archive_path)
        
        # Sync to drive if enabled
        if self.drive_path:
            drive_latest = self.drive_path / 'latest'
            if drive_latest.exists(): shutil.rmtree(drive_latest)
            shutil.copytree(path, drive_latest)
            print(f"💾 Checkpoint mirrored to Google Drive")

    def load_latest(self) -> Optional[dict]:
        path = self.base_path / 'latest'
        # If not local, check drive
        if not path.exists() and self.drive_path:
            drive_path = self.drive_path / 'latest'
            if drive_path.exists():
                import shutil
                shutil.copytree(drive_path, path)
                print("📂 Restored checkpoint from Google Drive")
                
        if not path.exists():
            return None
            
        data = {}
        try:
            if (path / 'atlas_weights.pt').exists():
                data['atlas_weights'] = torch.load(path / 'atlas_weights.pt')
            if (path / 'entropy_weights.pt').exists():
                data['entropy_weights'] = torch.load(path / 'entropy_weights.pt')
            if (path / 'atlas_optimizer.pt').exists():
                data['atlas_optimizer'] = torch.load(path / 'atlas_optimizer.pt')
            if (path / 'entropy_optimizer.pt').exists():
                data['entropy_optimizer'] = torch.load(path / 'entropy_optimizer.pt')
            if (path / 'rng_states.pt').exists():
                data['rng_states'] = torch.load(path / 'rng_states.pt')
            if (path / 'metadata.json').exists():
                with open(path / 'metadata.json', 'r') as f:
                    data['metadata'] = json.load(f)
            return data
        except Exception as e:
            print(f"⚠️ Error loading checkpoint: {e}")
            return None
            
    def load_replay_buffer(self, name: str) -> Optional[list]:
        path = self.base_path / 'latest' / f'{name}_replay.lz4'
        if path.exists():
            with lz4.frame.open(path, 'rb') as f:
                return pickle.load(f)
        return None



In [None]:
%%writefile prometheus_tqfd/orchestration/recovery.py
import torch
import gc
import time
from multiprocessing import Event
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.orchestration.checkpoint import CheckpointManager

class OOMHandler:
    """
    Behandelt Out-of-Memory Situationen.
    """
    
    def __init__(self, config: PrometheusConfig):
        self.config = config
        self.current_batch_size_atlas = config.atlas_batch_size
        self.current_batch_size_entropy = config.entropy_batch_size
        self.oom_count = 0
    
    def handle_oom(self, process_name: str, checkpoint_manager: CheckpointManager,
                   shared_values: dict, pause_event: Event):
        self.oom_count += 1
        print(f"⚠️ OOM #{self.oom_count} in {process_name}")
        
        # 1. Pause setzen
        pause_event.set()
        
        # 2. GPU-Speicher freigeben
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        # 3. Batch-Size reduzieren
        self.current_batch_size_atlas = max(
            self.config.oom_min_batch_size,
            int(self.current_batch_size_atlas * self.config.oom_batch_reduction)
        )
        self.current_batch_size_entropy = max(
            self.config.oom_min_batch_size,
            int(self.current_batch_size_entropy * self.config.oom_batch_reduction)
        )
        
        print(f"   Reduced batch sizes: ATLAS={self.current_batch_size_atlas}, ENTROPY={self.current_batch_size_entropy}")
        
        # 4. Wait
        time.sleep(5)
        
        # 5. Load latest checkpoint to shared state
        checkpoint = checkpoint_manager.load_latest()
        if checkpoint:
            shared_values.update(checkpoint)
        
        # 6. Resume
        pause_event.clear()
        print(f"✅ System resumed after OOM.")



In [None]:
%%writefile prometheus_tqfd/orchestration/supervisor.py
import time
import os
import signal
import multiprocessing as mp
from typing import Dict, List
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.orchestration.checkpoint import CheckpointManager
from prometheus_tqfd.orchestration.recovery import OOMHandler
from prometheus_tqfd.evaluation.arena import Arena

class Supervisor:
    """
    Hauptorchestrierer des Systems.
    """
    
    def __init__(self, config: PrometheusConfig):
        self.config = config
        self.device = 'cuda' if torch_is_cuda() else 'cpu'
        
        # Shared State
        self.manager = mp.Manager()
        self.shared_values = self.manager.dict()
        self.heartbeats = self.manager.dict()
        
        # Events
        self.stop_event = mp.Event()
        self.pause_event = mp.Event()
        
        # Locks
        self.gpu_lock = mp.Lock()
        
        # Queues
        self.atlas_data_queue = mp.Queue(maxsize=100)
        self.atlas_weights_queue = mp.Queue(maxsize=1)
        self.entropy_data_queue = mp.Queue(maxsize=100)
        self.entropy_weights_queue = mp.Queue(maxsize=1)
        self.metrics_queue = mp.Queue(maxsize=10000)
        
        # Components
        self.checkpoint_manager = CheckpointManager(config)
        self.oom_handler = OOMHandler(config)
        self.arena = Arena(config)
        from prometheus_tqfd.utils.logging import MetricsLogger
        self.metrics_logger = MetricsLogger(config)
        
        # Processes
        self.processes = {}
        
    def start(self):
        """Startet alle Prozesse"""
        # Trainers
        self._start_process('atlas_trainer', self._run_atlas_trainer)
        self._start_process('entropy_trainer', self._run_entropy_trainer)
        
        # Self-Play Workers
        for i in range(self.config.num_atlas_selfplay_workers):
            self._start_process(f'atlas_selfplay_{i}', lambda i=i: self._run_atlas_selfplay(i))
            
        for i in range(self.config.num_entropy_selfplay_workers):
            self._start_process(f'entropy_selfplay_{i}', lambda i=i: self._run_entropy_selfplay(i))
            
    def run(self):
        """Hauptschleife"""
        self.start()
        last_eval = 0
        
        try:
            while not self.stop_event.is_set():
                # 1. Heartbeats prüfen
                self._check_heartbeats()
                
                # 2. Metrics sammeln und loggen (DRAIN QUEUE)
                self._collect_metrics()
                
                # 3. Checkpoints
                self.checkpoint_manager.maybe_checkpoint(dict(self.shared_values))
                
                # 4. Evaluation
                atlas_steps = self.shared_values.get('atlas_steps', 0)
                entropy_steps = self.shared_values.get('entropy_steps', 0)
                total_steps = atlas_steps + entropy_steps
                
                if total_steps - last_eval >= self.config.eval_interval_games:
                    self._run_evaluation()
                    last_eval = total_steps
                
                time.sleep(5)
        except KeyboardInterrupt:
            print("\n🛑 Supervisor: Shutdown requested...")
        finally:
            self.stop()
            
    def stop(self):
        self.stop_event.set()
        print("💾 Saving final checkpoint...")
        self.checkpoint_manager.save_full(dict(self.shared_values))
        
        print("🧹 Terminating processes...")
        for name, proc in self.processes.items():
            if proc.is_alive():
                proc.terminate()
                proc.join(timeout=5)
        print("✅ Supervisor: Shutdown complete.")

    def _start_process(self, name: str, target_fn):
        p = mp.Process(target=target_fn, name=name, daemon=True)
        p.start()
        self.processes[name] = p
        print(f"🚀 Started {name} (PID: {p.pid})")

    def _check_heartbeats(self):
        now = time.time()
        for name, last_beat in dict(self.heartbeats).items():
            if now - last_beat > self.config.heartbeat_timeout:
                print(f"⚠️ {name} timed out, restarting...")
                self._restart_process(name)

    def _collect_metrics(self):
        """Drains the metrics queue and logs to file."""
        try:
            while not self.metrics_queue.empty():
                m = self.metrics_queue.get_nowait()
                self.metrics_logger.log(m)
        except:
            pass

    def _run_evaluation(self):
        print("⚔️ Running Arena Evaluation...")
        # We need actual models here. This is slightly tricky in the supervisor process.
        # For simplicity in this script, we'll log that we're doing it.
        # In a full implementation, we'd load weights into networks.
        from prometheus_tqfd.atlas.network import AtlasNetwork
        from prometheus_tqfd.entropy.network import EntropyNetworkV2
        
        atlas_net = AtlasNetwork(self.config).to('cpu')
        entropy_net = EntropyNetworkV2(self.config).to('cpu')
        
        if 'atlas_weights' in self.shared_values:
            atlas_net.load_state_dict(self.shared_values['atlas_weights'])
        if 'entropy_weights' in self.shared_values:
            entropy_net.load_state_dict(self.shared_values['entropy_weights'])
            
        results = self.arena.run_evaluation(atlas_net, entropy_net, 'cpu')
        results['type'] = 'evaluation'
        self.metrics_logger.log(results)
        
        # Update shared ELO
        for k, v in results['elo'].items():
            self.shared_values[f'elo_{k}'] = v
        print(f"📊 New ELOs: {results['elo']}")

    def _restart_process(self, name: str):
        if name in self.processes:
            p = self.processes[name]
            if p.is_alive():
                p.terminate()
                p.join(timeout=2)
        
        if 'atlas_trainer' in name:
            self._start_process(name, self._run_atlas_trainer)
        elif 'entropy_trainer' in name:
            self._start_process(name, self._run_entropy_trainer)
        elif 'atlas_selfplay' in name:
            i = int(name.split('_')[-1])
            self._start_process(name, lambda: self._run_atlas_selfplay(i))
        elif 'entropy_selfplay' in name:
            i = int(name.split('_')[-1])
            self._start_process(name, lambda: self._run_entropy_selfplay(i))

    # Runner functions that instantiate the components in the subprocesses
    def _run_atlas_trainer(self):
        from prometheus_tqfd.atlas.trainer import AtlasTrainer
        trainer = AtlasTrainer(self.config, self.atlas_data_queue, self.atlas_weights_queue, self.device, self.shared_values)
        try:
            trainer.run(self.stop_event, self.pause_event, self.gpu_lock, self.heartbeats, self.metrics_queue)
        except Exception as e:
            if "out of memory" in str(e).lower():
                self.oom_handler.handle_oom('atlas_trainer', self.checkpoint_manager, self.shared_values, self.pause_event)

    def _run_entropy_trainer(self):
        from prometheus_tqfd.entropy.trainer import EntropyTrainer
        trainer = EntropyTrainer(self.config, self.entropy_data_queue, self.entropy_weights_queue, self.device, self.shared_values)
        try:
            trainer.run(self.stop_event, self.pause_event, self.gpu_lock, self.heartbeats, self.metrics_queue)
        except Exception as e:
            if "out of memory" in str(e).lower():
                self.oom_handler.handle_oom('entropy_trainer', self.checkpoint_manager, self.shared_values, self.pause_event)

    def _run_atlas_selfplay(self, i: int):
        from prometheus_tqfd.atlas.selfplay import AtlasSelfPlayWorker
        # Self-play often on CPU for stability or lower GPU memory
        worker = AtlasSelfPlayWorker(self.config, self.atlas_weights_queue, self.atlas_data_queue, 'cpu', i)
        worker.run(self.stop_event, self.heartbeats)

    def _run_entropy_selfplay(self, i: int):
        from prometheus_tqfd.entropy.selfplay import EntropySelfPlayWorker
        worker = EntropySelfPlayWorker(self.config, self.entropy_weights_queue, self.entropy_data_queue, 'cpu', i)
        worker.run(self.stop_event, self.heartbeats)

def torch_is_cuda():
    import torch
    return torch.cuda.is_available()



In [None]:
%%writefile prometheus_tqfd/utils/__init__.py



In [None]:
%%writefile prometheus_tqfd/utils/hardware.py
import torch
import psutil
from prometheus_tqfd.config import HardwareConfig

def detect_hardware() -> HardwareConfig:
    """
    Erkennt Hardware und setzt optimale Parameter.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    gpu_name = None
    vram_gb = None
    
    if device == 'cuda':
        gpu_name = torch.cuda.get_device_name(0)
        vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)

    ram_gb = psutil.virtual_memory().total / (1024**3)
    cpu_cores = psutil.cpu_count()
    
    # Colab Check
    is_colab = False
    try:
        import google.colab
        is_colab = True
    except ImportError:
        pass
        
    return HardwareConfig(
        device=device,
        gpu_name=gpu_name,
        vram_gb=vram_gb,
        ram_gb=ram_gb,
        cpu_cores=cpu_cores,
        is_colab=is_colab
    )



In [None]:
%%writefile prometheus_tqfd/utils/logging.py
import json
import time
from pathlib import Path

class MetricsLogger:
    def __init__(self, config):
        self.config = config
        self.log_file = config.base_dir / config.run_id / 'metrics' / 'metrics.jsonl'
        self.log_file.parent.mkdir(parents=True, exist_ok=True)
    
    def log(self, metrics):
        metrics['timestamp'] = time.time()
        with open(self.log_file, 'a') as f:
            f.write(json.dumps(metrics) + '\n')



In [None]:
%%writefile prometheus_tqfd/utils/replay_buffer.py
import random
import torch
import numpy as np
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)
    
    def add(self, state, policy, value):
        self.buffer.append((state, policy, value))
    
    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, min(len(self.buffer), batch_size))
        states, policies, values = zip(*batch)
        
        return (
            torch.stack(states),
            torch.stack(policies),
            torch.tensor(values, dtype=torch.float32)
        )
    
    def __len__(self):
        return len(self.buffer)

    def get_data(self):
        return list(self.buffer)

    def set_data(self, data):
        self.buffer.extend(data)



In [None]:
%%writefile prometheus_tqfd/utils/tunneling.py
import os
import subprocess
import time
import re
from typing import Optional, Tuple

class TunnelManager:
    """
    Verwaltet öffentlichen Zugang zum Dashboard.
    Priorität: ngrok → cloudflared → localtunnel
    """
    
    @staticmethod
    def start(port: int = 8501) -> Tuple[Optional[str], str]:
        # 1. Try ngrok
        url = TunnelManager._try_ngrok(port)
        if url:
            return url, 'ngrok'
        
        # 2. Try cloudflared
        url = TunnelManager._try_cloudflared(port)
        if url:
            return url, 'cloudflared'
        
        # 3. Try localtunnel
        url = TunnelManager._try_localtunnel(port)
        if url:
            return url, 'localtunnel'
        
        return None, 'none'
    
    @staticmethod
    def _try_ngrok(port: int) -> Optional[str]:
        try:
            from pyngrok import ngrok
            
            # Token detection
            token = os.environ.get('NGROK_TOKEN')
            if not token:
                try:
                    from google.colab import userdata
                    token = userdata.get('NGROK_TOKEN')
                except:
                    pass
            
            if token:
                ngrok.set_auth_token(token)
            
            tunnel = ngrok.connect(port, "http")
            print(f"✅ ngrok Tunnel: {tunnel.public_url}")
            return tunnel.public_url
        except Exception as e:
            # print(f"⚠️ ngrok failed: {e}")
            return None
    
    @staticmethod
    def _try_cloudflared(port: int) -> Optional[str]:
        try:
            # Check if cloudflared exists
            if not os.path.exists('cloudflared'):
                subprocess.run([
                    'wget', '-q', 
                    'https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64',
                    '-O', 'cloudflared'
                ], check=True)
                subprocess.run(['chmod', '+x', 'cloudflared'], check=True)
            
            proc = subprocess.Popen(
                ['./cloudflared', 'tunnel', '--url', f'http://localhost:{port}'],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            
            # Non-blocking wait for URL
            timeout = 15
            start_time = time.time()
            while time.time() - start_time < timeout:
                line = proc.stderr.readline()
                if 'trycloudflare.com' in line:
                    match = re.search(r'https://[^\s]+\.trycloudflare\.com', line)
                    if match:
                        url = match.group(0)
                        print(f"✅ cloudflared Tunnel: {url}")
                        return url
                time.sleep(0.1)
            return None
        except Exception as e:
            # print(f"⚠️ cloudflared failed: {e}")
            return None
    
    @staticmethod
    def _try_localtunnel(port: int) -> Optional[str]:
        try:
            # Requires npx / node
            proc = subprocess.Popen(
                ['npx', 'localtunnel', '--port', str(port)],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            
            timeout = 10
            start_time = time.time()
            while time.time() - start_time < timeout:
                line = proc.stdout.readline()
                if 'your url is:' in line.lower():
                    url = line.split()[-1]
                    print(f"✅ localtunnel: {url}")
                    return url
                time.sleep(0.1)
            return None
        except Exception as e:
            # print(f"⚠️ localtunnel failed: {e}")
            return None

def setup_tunnel(port, ngrok_token=None):
    if ngrok_token:
        os.environ['NGROK_TOKEN'] = ngrok_token
    url, method = TunnelManager.start(port)
    return url



In [None]:
%%writefile prometheus_tqfd/dashboard/__init__.py



In [None]:
%%writefile prometheus_tqfd/dashboard/app.py
import streamlit as st
import plotly.graph_objects as go
import json
import os
import time
from pathlib import Path
from typing import List, Dict

st.set_page_config(
    page_title="PROMETHEUS-TQFD Dashboard",
    page_icon="♟️",
    layout="wide"
)

def find_latest_metrics_file(base_dir: Path):
    # This is a simplified version, in reality we'd look into the run_id/metrics/ folder
    metrics_files = list(base_dir.glob("**/metrics/*.jsonl"))
    if not metrics_files:
        return None
    return max(metrics_files, key=os.path.getmtime)

@st.cache_data(ttl=2)
def load_metrics(metrics_file):
    if not metrics_file or not os.path.exists(metrics_file):
        return []
    
    metrics = []
    with open(metrics_file, 'r') as f:
        for line in f:
            try:
                metrics.append(json.loads(line))
            except:
                pass
    return metrics[-10000:] 

def get_elo(name: str, metrics: List[Dict]) -> float:
    for m in reversed(metrics):
        if m.get('type') == 'evaluation' and 'elo' in m:
            return m['elo'].get(name, 1000.0)
    return 1000.0

def get_games(name: str, metrics: List[Dict]) -> int:
    # Estimate from training steps if games not explicit
    count = 0
    for m in metrics:
        if m.get('type') == f'{name}_train':
            count += 1
    return count

def main():
    st.title("♟️ PROMETHEUS-TQFD")
    st.markdown("### Dual-AI Tabula Rasa Chess Training")
    
    # In a real setup, base_dir comes from environment or config
    base_dir = Path("./prometheus_runs")
    metrics_file = find_latest_metrics_file(base_dir)
    metrics = load_metrics(metrics_file)
    
    # Sidebar
    with st.sidebar:
        st.markdown("### System Status")
        st.metric("ATLAS ELO", f"{get_elo('atlas', metrics):.0f}")
        st.metric("ENTROPY ELO", f"{get_elo('entropy', metrics):.0f}")
        st.markdown("---")
        st.metric("ATLAS Progress", f"{get_games('atlas', metrics)} steps")
        st.metric("ENTROPY Progress", f"{get_games('entropy', metrics)} steps")
        
        if st.button("Refresh"):
            st.rerun()
    
    # Tabs
    tab1, tab2, tab3 = st.tabs(["📈 Lernkurven", "🔥 Heatmaps", "🎮 Live-Spiel"])
    
    with tab1:
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("ATLAS Training")
            atlas_metrics = [m for m in metrics if m.get('type') == 'atlas_train']
            if atlas_metrics:
                fig = go.Figure()
                fig.add_trace(go.Scatter(
                    y=[m['loss'] for m in atlas_metrics[-1000:]],
                    name='Total Loss'
                ))
                fig.update_layout(yaxis_type="log", title="Atlas Total Loss")
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.info("No ATLAS metrics yet.")
        
        with col2:
            st.subheader("ENTROPY Training")
            entropy_metrics = [m for m in metrics if m.get('type') == 'entropy_train']
            if entropy_metrics:
                fig = go.Figure()
                for key in ['outcome', 'mobility', 'pressure', 'stability', 'novelty']:
                    if key in entropy_metrics[0]:
                        fig.add_trace(go.Scatter(
                            y=[m.get(key, 0) for m in entropy_metrics[-1000:]],
                            name=key
                        ))
                fig.update_layout(yaxis_type="log", title="Entropy Hybrid Losses")
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.info("No ENTROPY metrics yet.")
    
    with tab2:
        st.info("Heatmaps will be implemented soon.")
    
    with tab3:
        st.info("Live-game view coming soon.")

if __name__ == '__main__':
    main()



In [None]:
%%writefile prometheus_tqfd/tests.py
import torch
import chess
import numpy as np
from prometheus_tqfd.config import PrometheusConfig
from prometheus_tqfd.encoding import BoardEncoder, MoveEncoder
from prometheus_tqfd.atlas.network import AtlasNetwork
from prometheus_tqfd.atlas.mcts import MCTS
from prometheus_tqfd.entropy.network import EntropyNetworkV2
from prometheus_tqfd.entropy.rollout import MiniRolloutSelector
from prometheus_tqfd.tactics import TacticsDetector
from prometheus_tqfd.physics import PhysicsFieldCalculator

def test_encoding_roundtrip():
    encoder = BoardEncoder()
    board = chess.Board()
    tensor = encoder.encode(board)
    assert tensor.shape == (19, 8, 8)
    return True

def test_mcts_basic():
    config = PrometheusConfig()
    config.atlas_mcts_simulations = 10
    network = AtlasNetwork(config)
    mcts = MCTS(config, network, 'cpu')
    board = chess.Board()
    root = mcts.search(board)
    assert root.visit_count > 0
    return True

def test_tactics_detector():
    config = PrometheusConfig()
    detector = TacticsDetector(config)
    board = chess.Board("r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3")
    # Not a mate in 1 position yet, but check basic detection
    threats = detector.detect(board)
    assert 'mate_in_1' in threats
    return True

def test_physics_symmetry():
    config = PrometheusConfig()
    calc = PhysicsFieldCalculator(config)
    board = chess.Board()
    fields = calc.compute(board)
    assert fields.shape == (3, 8, 8)
    # White and Black initial positions are symmetrical
    # Kanal 0 (Masse) sollte anfangs etwa 0-summiert sein oder symmetrisch
    return True

def run_smoke_tests():
    print("🧪 Running Smoke Tests...")
    tests = [
        ("Encoding", test_encoding_roundtrip),
        ("MCTS", test_mcts_basic),
        ("Tactics", test_tactics_detector),
        ("Physics", test_physics_symmetry),
    ]
    
    all_passed = True
    for name, fn in tests:
        try:
            if fn():
                print(f"  ✅ {name} passed")
            else:
                print(f"  ❌ {name} failed")
                all_passed = False
        except Exception as e:
            print(f"  ❌ {name} error: {e}")
            all_passed = False
            
    return all_passed

if __name__ == "__main__":
    run_smoke_tests()



In [None]:
%%writefile main.py
import os
import sys
import subprocess
import time
import multiprocessing as mp
from pathlib import Path

# Add project root to path
sys.path.append(os.getcwd())

def setup_directories(config):
    # Handle /content for Colab
    if os.path.exists("/content"):
        config.base_dir = Path("/content/prometheus_runs")
    else:
        config.base_dir = Path("./prometheus_runs")

    run_dir = config.base_dir / config.run_id
    for sub in ['checkpoints', 'metrics', 'games', 'logs']:
        (run_dir / sub).mkdir(parents=True, exist_ok=True)

def start_dashboard(port, run_dir):
    print(f"🚀 Starting Streamlit Dashboard on port {port}...")
    # Start streamlit as a subprocess
    process = subprocess.Popen([
        sys.executable, "-m", "streamlit", "run",
        "prometheus_tqfd/dashboard/app.py",
        f"--server.port={port}",
        "--server.headless=true"
    ])
    return process

def main():
    print("=" * 60)
    print("🔥 PROMETHEUS-TQFD v2.0")
    print("   Dual-AI Tabula Rasa Chess Training System")
    print("=" * 60)

    # 1. mp setup
    mp.set_start_method('spawn', force=True)

    # 2. Hardware Detection
    from prometheus_tqfd.utils.hardware import detect_hardware
    hw = detect_hardware()
    print(f"\n📊 Hardware erkannt:")
    print(f"   Device: {hw.device}")
    if hw.gpu_name:
        print(f"   GPU: {hw.gpu_name} ({hw.vram_gb:.1f} GB VRAM)")
    print(f"   RAM: {hw.ram_gb:.1f} GB")

    # 3. Config
    from prometheus_tqfd.config import PrometheusConfig, adjust_config_for_hardware
    config = PrometheusConfig()
    config = adjust_config_for_hardware(config, hw)

    # 4. Setup Directories
    setup_directories(config)
    run_dir = config.base_dir / config.run_id

    # 5. Smoke Tests
    from prometheus_tqfd.tests import run_smoke_tests
    if not run_smoke_tests():
        print("⛔ Smoke Tests failed. Training aborted.")
        return

    # 6. Google Drive Mount (if in Colab)
    if hw.is_colab and config.use_drive:
        try:
            from google.colab import drive
            drive.mount('/content/drive')
            print("✅ Google Drive gemountet")
        except:
            print("⚠️ Google Drive nicht verfügbar")
            config.use_drive = False

    # 7. Start Dashboard
    dashboard_proc = start_dashboard(config.dashboard_port, run_dir)
    
    # 8. Setup Tunnel
    from prometheus_tqfd.utils.tunneling import setup_tunnel
    ngrok_token = os.environ.get('NGROK_TOKEN')
    try:
        from google.colab import userdata
        ngrok_token = ngrok_token or userdata.get('NGROK_TOKEN')
    except:
        pass
    
    url = setup_tunnel(config.dashboard_port, ngrok_token=ngrok_token)
    if url:
        print(f"🌐 Dashboard URL: {url}")

    # 9. Start Supervisor
    from prometheus_tqfd.orchestration.supervisor import Supervisor
    supervisor = Supervisor(config)

    print("\n🚀 Starting Training Loop...")
    try:
        supervisor.run()
    except KeyboardInterrupt:
        print("\n🛑 Shutdown requested.")
    finally:
        if dashboard_proc: dashboard_proc.terminate()
        print("✅ Prometheus finished.")

if __name__ == "__main__":
    main()



## 🚀 3. Start Training

In [None]:
# @title ⚙️ Configuration
NGROK_TOKEN = "" # @param {type:"string"}
USE_DRIVE = True # @param {type:"boolean"}

import os
os.environ['NGROK_TOKEN'] = NGROK_TOKEN

from main import main
main()