# Reinforcement Learning Modell 

# Godot Connector

In [2]:
import socket
import json
from typing import Any, Dict, List, Tuple
import random
from collections import deque
import os
import select
import math

import torch
import torch.nn as nn
import torch.optim as optim

# ===========================
#      ENVIRONMENT CONFIG
# ===========================

HOST = "127.0.0.1"

# How many parallel Godot environments to train with
N_ENVS = 16 

# Port base for Godot instances
BASE_PORT = 5000

# Generate ports based on environments
ENV_PORTS = [BASE_PORT + i for i in range(N_ENVS)]

print(f"Training with {N_ENVS} environments.")
print(f"Expecting Godot instances on ports: {ENV_PORTS}")

# ===========================
#  RL / model config
# ===========================

# State vector: tank_x, tank_y, goal_x, goal_y, dx, dy  (all normalized)
STATE_SIZE = 8

# Discrete actions: (turn, throttle) with values in {-1, 0, 1}
# Godot will feed these into setTurn() and setDirection().
ACTIONS: List[Dict[str, float]] = [
    {"turn": -1.0, "throttle":  1.0},  # left + forward
    {"turn":  0.0, "throttle":  1.0},  # straight + forward
    {"turn":  1.0, "throttle":  1.0},  # right + forward
    {"turn":  0.0, "throttle":  0.0},  # stop
    {"turn":  0.0, "throttle": -1.0},  # straight + backward
]
N_ACTIONS = len(ACTIONS)

GAMMA = 0.99
EPSILON_START = 0.2
EPSILON_END = 0.01
EPSILON_DECAY = 1e-4
BATCH_SIZE = 32
REPLAY_SIZE = 10_000
LEARNING_RATE = 1e-3

CHECKPOINT_PATH = "tank_dqn_checkpoint.pt"
LOG_PATH = "training_log.csv"

# ===========================
#  Q-network
# ===========================

class QNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)  # [batch, output_dim] (Q-values for each action)


device = torch.device("cpu")
print("Using device:", device)

q_net = QNetwork(STATE_SIZE, N_ACTIONS).to(device)
optimizer = optim.Adam(q_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

ReplayEntry = Tuple[List[float], int, float, List[float], bool]
replay_buffer: deque[ReplayEntry] = deque(maxlen=REPLAY_SIZE)

# Globals for RL bookkeeping (shared across envs)
epsilon = EPSILON_START
total_steps = 0

# Per-env bookkeeping
num_envs = len(ENV_PORTS)
buffers: List[str] = [""] * num_envs
prev_states: List[List[float] | None] = [None] * num_envs
prev_actions: List[int | None] = [None] * num_envs
episode_returns: List[float] = [0.0] * num_envs
episode_idxs: List[int] = [0] * num_envs

decoder = json.JSONDecoder()

# ===========================
#  Checkpoint + logging
# ===========================

def save_checkpoint() -> None:
    state = {
        "q_net": q_net.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epsilon": epsilon,
        "total_steps": total_steps,
        "episode_idxs": episode_idxs,
    }
    torch.save(state, CHECKPOINT_PATH)
    print(f"[CHECKPOINT] Saved to {CHECKPOINT_PATH}")


def load_checkpoint() -> None:
    global epsilon, total_steps, episode_idxs
    if not os.path.exists(CHECKPOINT_PATH):
        print("[CHECKPOINT] No checkpoint found, starting fresh.")
        return
    state = torch.load(CHECKPOINT_PATH, map_location=device)
    q_net.load_state_dict(state["q_net"])
    optimizer.load_state_dict(state["optimizer"])
    epsilon = state.get("epsilon", epsilon)
    total_steps = state.get("total_steps", 0)
    episode_idxs = state.get("episode_idxs", episode_idxs)
    print(f"[CHECKPOINT] Loaded from {CHECKPOINT_PATH}, steps={total_steps}")


def log_episode_to_file(env_idx: int, ep_idx: int, ep_return: float, eps: float) -> None:
    header_needed = not os.path.exists(LOG_PATH)
    with open(LOG_PATH, "a") as f:
        if header_needed:
            f.write("env,episode,return,epsilon\n")
        f.write(f"{env_idx},{ep_idx},{ep_return},{eps}\n")


# ===========================
#  Networking helpers
# ===========================

def send_json(sock: socket.socket, payload: Dict[str, Any]) -> None:
    data = json.dumps(payload)
    sock.sendall(data.encode("utf-8"))


def connect_all_envs() -> List[socket.socket]:
    socks: List[socket.socket] = []
    for p in ENV_PORTS:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((HOST, p))
        s.setblocking(False)
        print(f"Connected to env on port {p}")
        socks.append(s)
    return socks


# ===========================
#  State / RL helpers
# ===========================

def build_state_vector(message: Dict[str, Any]) -> List[float]:
    arena = message.get("arena", {})
    tank = message.get("tank", {})
    goal = message.get("goal", {})

    arena_w = float(arena.get("width", 1.0))
    arena_h = float(arena.get("height", 1.0))

    tank_x = float(tank.get("x", 0.0))
    tank_y = float(tank.get("y", 0.0))
    goal_x = float(goal.get("x", 0.0))
    goal_y = float(goal.get("y", 0.0))

    # NEW: orientation
    theta = float(tank.get("rot", 0.0))  # radians from Godot
    cos_theta = math.cos(theta)
    sin_theta = math.sin(theta)

    # Normalize positions to [0, 1]
    tank_x_n = tank_x / arena_w
    tank_y_n = tank_y / arena_h
    goal_x_n = goal_x / arena_w
    goal_y_n = goal_y / arena_h

    # Relative position tank -> goal, also normalized
    dx = (goal_x - tank_x) / arena_w
    dy = (goal_y - tank_y) / arena_h

    # State: position, relative goal, orientation (cos, sin)
    return [tank_x_n, tank_y_n, goal_x_n, goal_y_n, dx, dy, cos_theta, sin_theta]


def select_action(state: List[float]) -> int:
    global epsilon

    epsilon = max(EPSILON_END, epsilon - EPSILON_DECAY)

    if random.random() < epsilon:
        return random.randrange(N_ACTIONS)

    state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
        q_vals = q_net(state_t)
    return int(torch.argmax(q_vals, dim=1).item())


def store_transition(s: List[float], a: int, r: float, s_next: List[float], done: bool) -> None:
    replay_buffer.append((s, a, r, s_next, done))


def train_step() -> float | None:
    global total_steps

    if len(replay_buffer) < BATCH_SIZE:
        return None

    batch = random.sample(replay_buffer, BATCH_SIZE)

    states = torch.tensor([b[0] for b in batch], dtype=torch.float32, device=device)
    actions = torch.tensor([b[1] for b in batch], dtype=torch.int64, device=device)
    rewards = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device)
    next_states = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device)
    dones = torch.tensor([b[4] for b in batch], dtype=torch.float32, device=device)

    q_values = q_net(states)  # [B, N_ACTIONS]
    q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)  # [B]

    with torch.no_grad():
        next_q = q_net(next_states)
        max_next_q = next_q.max(dim=1)[0]
        targets = rewards + GAMMA * max_next_q * (1.0 - dones)

    loss = loss_fn(q_values, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_steps += 1
    if total_steps % 1000 == 0:
        save_checkpoint()

    return loss.item()


# ===========================
#  Per-env message handling
# ===========================

def on_message_multi_env(env_idx: int, sock: socket.socket, message: Dict[str, Any]) -> None:
    state = build_state_vector(message)
    reward = float(message.get("reward", 0.0))
    done = bool(message.get("done", False))

    episode_returns[env_idx] += reward

    prev_s = prev_states[env_idx]
    prev_a = prev_actions[env_idx]

    if prev_s is not None and prev_a is not None:
        store_transition(prev_s, prev_a, reward, state, done)
        loss = train_step()
    else:
        loss = None

    action_idx = select_action(state)
    prev_states[env_idx] = state
    prev_actions[env_idx] = action_idx

    action = ACTIONS[action_idx]

    state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
        q_vals = q_net(state_t)
        chosen_q = float(q_vals[0, action_idx].item())

    response = {
        "value": chosen_q,
        "action": {
            "turn": action["turn"],
            "throttle": action["throttle"],
        },
        "debug": {
            "env": env_idx,
            "state": state,
            "reward": reward,
            "done": done,
            "epsilon": epsilon,
            "loss": loss,
            "action_idx": action_idx,
            "episode": episode_idxs[env_idx],
            "episode_return": episode_returns[env_idx],
        },
    }

    send_json(sock, response)

    if done:
        print(f"[ENV {env_idx}] episode {episode_idxs[env_idx]} return = {episode_returns[env_idx]:.3f}, eps={epsilon:.3f}")
        log_episode_to_file(env_idx, episode_idxs[env_idx], episode_returns[env_idx], epsilon)

        episode_idxs[env_idx] += 1
        episode_returns[env_idx] = 0.0
        prev_states[env_idx] = None
        prev_actions[env_idx] = None


def _process_buffer_for_env(env_idx: int, sock: socket.socket) -> None:
    buf = buffers[env_idx]

    while buf:
        buf = buf.lstrip()
        if not buf:
            break

        try:
            obj, idx = decoder.raw_decode(buf)
        except json.JSONDecodeError:
            break

        buf = buf[idx:]
        try:
            on_message_multi_env(env_idx, sock, obj)
        except Exception as e:
            print(f"Error in on_message for env {env_idx}: {e}")

    buffers[env_idx] = buf


# ===========================
#  Multi-env receive loop
# ===========================

def receive_loop_multi_env() -> None:
    socks = connect_all_envs()
    print("All envs connected, starting training loop...")

    while True:
        if not socks:
            print("All envs disconnected, stopping.")
            break

        readable, _, _ = select.select(socks, [], [], 0.1)

        for s in readable:
            env_idx = socks.index(s)
            try:
                chunk = s.recv(4096)
            except BlockingIOError:
                continue

            if not chunk:
                print(f"Env {env_idx} disconnected.")
                socks.remove(s)
                continue

            buffers[env_idx] += chunk.decode("utf-8")
            _process_buffer_for_env(env_idx, s)


if __name__ == "__main__":
    load_checkpoint()
    receive_loop_multi_env()

Training with 16 environments.
Expecting Godot instances on ports: [5000, 5001, 5002, 5003, 5004, 5005, 5006, 5007, 5008, 5009, 5010, 5011, 5012, 5013, 5014, 5015]
Using device: cuda


AssertionError: Torch not compiled with CUDA enabled