In [1]:
from src.Game import Game
from src.SACAgent import Actor
import torch
import numpy as np
import torch
import math
import random

In [2]:
actor = Actor(state_dim=9, action_dim=2, lstm_hidden_dim=128).to("cuda")
_ = actor.load_state_dict(
    torch.load("training/backup/ckpt.pth", map_location="cpu", weights_only=False)[
        "actor"
    ]
)

In [3]:
h, c = torch.zeros(1, 1, 128).to("cuda"), torch.zeros(1, 1, 128).to("cuda")

def get_action(state):
    global h, c
    with torch.no_grad():
        _, _, a_det, (h, c) = actor(torch.from_numpy(state).float().to("cuda").unsqueeze(0), (h, c))
    return a_det.cpu().numpy()[0]

In [4]:
def build_state(env: Game, desired_vx: float, desired_vy: float) -> np.ndarray:
    vx, vy = env.env.drone_velocity / 5
    va = env.env.ang_vel / 10
    a_cos = math.cos(env.drone_angle)
    a_sin = math.sin(env.drone_angle)
    propL = env.env.drone.L_speed
    propR = env.env.drone.R_speed
    return np.array(
        [
            vx,
            vy,
            va,
            a_cos,
            a_sin,
            propL,
            propR,
            desired_vx / 5,
            desired_vy / 5,
        ],
        dtype=np.float32,
    )

In [7]:
game = Game(gui=True, human_player=True, dt=1 / 144, wind=True, rain=True)
game.set_drone_angle(random.uniform(0, 2 * math.pi))
game.set_drone_velocity(random.uniform(-5, 5), random.uniform(-5, 5))
game.set_drone_propeller_speeds(random.uniform(-1, 1), random.uniform(-1, 1))
game.set_drone_angular_velocity(random.uniform(-math.pi, math.pi))
while game.is_running:
    keys = game.handle_events(control_type="arrow")
    if keys is None:
        continue
    desired_vx, desired_vy = 0.0, 0.0

    # Handle vertical movement
    if keys["UP"] and keys["DOWN"]:
        desired_vy = 0.0  # Cancel out when both are pressed
    elif keys["UP"]:
        desired_vy = 5.0
    elif keys["DOWN"]:
        desired_vy = -5.0

    # Handle horizontal movement
    if keys["LEFT"] and keys["RIGHT"]:
        desired_vx = 0.0  # Cancel out when both are pressed
    elif keys["LEFT"]:
        desired_vx = -5.0
    elif keys["RIGHT"]:
        desired_vx = 5.0

    state = build_state(game, desired_vx, desired_vy)
    action = get_action(state)
    game.step(action[0], action[1])
    game.render()