In [1]:
import gymnasium as gym
import torch
import pygame
import numpy as np
import time
import os

# Settings
env_name = 'LunarLander-v3'
model_path = 'main_net.pth'
reload_interval = 10  # seconds between model reloads

# Quick fix1
while True:
    if os.path.exists('main_net.pth'):
        main_net = torch.load(model_path)
        break

def load_model(path):
    try:
        model = torch.load(path)
        model.eval()
        print(f"Loaded model from {path}")
        return model
    except Exception as e:
        print(f"Could not load model: {e}")
        return None

def render_with_overlay():
    env = gym.make(env_name, render_mode="rgb_array")
    obs, _ = env.reset()
    done = False
    total_reward = 0

    # Get initial frame for size
    frame = env.render()
    height, width, _ = frame.shape

    # Pygame setup
    pygame.init()
    screen = pygame.display.set_mode((width, height))
    pygame.display.set_caption("LunarLander with Score Overlay (Live Model Reload)")
    font = pygame.font.SysFont("Arial", 24)
    clock = pygame.time.Clock()

    # Model loading
    last_reload = 0
    main_net = load_model(model_path)

    while True:
        # Reload model if needed
        if time.time() - last_reload > reload_interval:
            new_model = load_model(model_path)
            if new_model is not None:
                main_net = new_model
            last_reload = time.time()

        # Handle pygame events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                env.close()
                return

        # Select action
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32)
        with torch.no_grad():
            q_values = main_net(obs_tensor)
            action = torch.argmax(q_values).item()

        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward

        # Render frame and overlay score
        frame = env.render()
        surf = pygame.surfarray.make_surface(np.transpose(frame, (1, 0, 2)))
        screen.blit(surf, (0, 0))
        text = font.render(f"Total reward: {total_reward:.1f}", True, (255, 255, 255))
        screen.blit(text, (10, 10))
        pygame.display.flip()
        clock.tick(60)

        if done:
            time.sleep(1)
            obs, _ = env.reset()
            total_reward = 0

render_with_overlay()

Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded model from main_net.pth
Loaded m

KeyboardInterrupt: 