# Reinforcement Learning Modell 

# Godot Connector

In [None]:
import socket
import json
from typing import Any, Dict, List, Tuple
import random
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim

HOST = "127.0.0.1"
PORT = 5000

# ===========================
#  RL / model config
# ===========================

# State vector: tank_x, tank_y, goal_x, goal_y, dx, dy  (all normalized)
STATE_SIZE = 6

# Discrete actions: (turn, throttle) with values in {-1, 0, 1}
# Godot will feed these into setTurn() and setDirection().
ACTIONS: List[Dict[str, float]] = [
    {"turn": -1.0, "throttle":  1.0},  # left + forward
    {"turn":  0.0, "throttle":  1.0},  # straight + forward
    {"turn":  1.0, "throttle":  1.0},  # right + forward
    {"turn":  0.0, "throttle":  0.0},  # stop
    {"turn":  0.0, "throttle": -1.0},  # straight + backward
]
N_ACTIONS = len(ACTIONS)

GAMMA = 0.99
EPSILON_START = 0.2
EPSILON_END = 0.01
EPSILON_DECAY = 1e-4
BATCH_SIZE = 32
REPLAY_SIZE = 10_000
LEARNING_RATE = 1e-3


# ===========================
#  Q-network
# ===========================

class QNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)  # [batch, output_dim] (Q-values for each action)


device = torch.device("cpu")
print("Using device:", device)

q_net = QNetwork(STATE_SIZE, N_ACTIONS).to(device)
optimizer = optim.Adam(q_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

ReplayEntry = Tuple[List[float], int, float, List[float], bool]
replay_buffer: deque[ReplayEntry] = deque(maxlen=REPLAY_SIZE)

epsilon = EPSILON_START
prev_state: List[float] | None = None
prev_action_idx: int | None = None


# ===========================
#  Networking helpers
# ===========================

def send_json(sock: socket.socket, payload: Dict[str, Any]) -> None:
    data = json.dumps(payload)
    sock.sendall(data.encode("utf-8"))


# ===========================
#  State / RL helpers
# ===========================

def build_state_vector(message: Dict[str, Any]) -> List[float]:
    arena = message.get("arena", {})
    tank = message.get("tank", {})
    goal = message.get("goal", {})

    arena_w = float(arena.get("width", 1.0))
    arena_h = float(arena.get("height", 1.0))

    tank_x = float(tank.get("x", 0.0))
    tank_y = float(tank.get("y", 0.0))
    goal_x = float(goal.get("x", 0.0))
    goal_y = float(goal.get("y", 0.0))

    tank_x_n = tank_x / arena_w
    tank_y_n = tank_y / arena_h
    goal_x_n = goal_x / arena_w
    goal_y_n = goal_y / arena_h

    dx = (goal_x - tank_x) / arena_w
    dy = (goal_y - tank_y) / arena_h

    return [tank_x_n, tank_y_n, goal_x_n, goal_y_n, dx, dy]


def select_action(state: List[float]) -> int:
    global epsilon

    epsilon = max(EPSILON_END, epsilon - EPSILON_DECAY)

    # Epsilon-greedy
    if random.random() < epsilon:
        return random.randrange(N_ACTIONS)

    state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
        q_vals = q_net(state_t)
    return int(torch.argmax(q_vals, dim=1).item())


def store_transition(s: List[float], a: int, r: float, s_next: List[float], done: bool) -> None:
    replay_buffer.append((s, a, r, s_next, done))


def train_step() -> float | None:
    if len(replay_buffer) < BATCH_SIZE:
        return None

    batch = random.sample(replay_buffer, BATCH_SIZE)

    states = torch.tensor([b[0] for b in batch], dtype=torch.float32, device=device)
    actions = torch.tensor([b[1] for b in batch], dtype=torch.int64, device=device)
    rewards = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device)
    next_states = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device)
    dones = torch.tensor([b[4] for b in batch], dtype=torch.float32, device=device)

    q_values = q_net(states)                # [B, N_ACTIONS]
    q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)  # [B]

    with torch.no_grad():
        next_q = q_net(next_states)
        max_next_q = next_q.max(dim=1)[0]
        targets = rewards + GAMMA * max_next_q * (1.0 - dones)

    loss = loss_fn(q_values, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


# ===========================
#  Message handler
# ===========================

def on_message(sock: socket.socket, message: Dict[str, Any]) -> None:
    global prev_state, prev_action_idx

    state = build_state_vector(message)
    reward = float(message.get("reward", 0.0))
    done = bool(message.get("done", False))

    # RL: use (prev_state, prev_action) -> current state as next_state
    if prev_state is not None and prev_action_idx is not None:
        store_transition(prev_state, prev_action_idx, reward, state, done)
        loss = train_step()
    else:
        loss = None

    # Choose next action for current state
    action_idx = select_action(state)
    prev_state = state
    prev_action_idx = action_idx

    action = ACTIONS[action_idx]

    # For debug: Q-value of chosen action
    state_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
        q_vals = q_net(state_t)
        chosen_q = float(q_vals[0, action_idx].item())

    response = {
        "value": chosen_q,
        "action": {
            "turn": action["turn"],        # -1, 0, 1
            "throttle": action["throttle"] # -1, 0, 1
        },
        "debug": {
            "state": state,
            "reward": reward,
            "done": done,
            "epsilon": epsilon,
            "loss": loss,
            "action_idx": action_idx,
        },
    }

    print("\nReceived from Godot:")
    print(json.dumps(message, indent=4))
    print("\nSending back to Godot:")
    print(json.dumps(response, indent=4))

    send_json(sock, response)

    if done:
        # Episode ended â€“ next step starts fresh
        prev_state = None
        prev_action_idx = None


# ===========================
#  Main receive loop
# ===========================

def receive_loop() -> None:
    print(f"Connecting to {HOST}:{PORT} ...")
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.connect((HOST, PORT))
        print("Connected! Waiting for JSON messages...\n")

        buffer = ""
        decoder = json.JSONDecoder()

        while True:
            chunk = sock.recv(4096)
            if not chunk:
                print("Connection closed by server.")
                break

            buffer += chunk.decode("utf-8")

            while buffer:
                buffer = buffer.lstrip()
                if not buffer:
                    break

                try:
                    obj, idx = decoder.raw_decode(buffer)
                except json.JSONDecodeError:
                    break

                buffer = buffer[idx:]

                try:
                    on_message(sock, obj)
                except Exception as e:
                    print(f"Error in on_message: {e}")


if __name__ == "__main__":
    receive_loop()