In [6]:
import pygame
import threading
import gymnasium as gym
import torch.nn.functional as F
import torch
from torch.distributions import Categorical

In [5]:
class Display:
    def __init__(self, env):
        self.env = env
        self.screen = None
        self.clock = pygame.time.Clock()
        self.is_running = False
        self.display_thread = None

    def initialize_display(self):
        self.env.reset()
        # Get the environment rendering size
        render_size = self.env.render().shape[1::-1]

        pygame.init()
        self.screen = pygame.display.set_mode(render_size)
        pygame.display.set_caption("Environment Display")
        self.is_running = True

    def display_loop(self):
        self.initialize_display()

        while self.is_running:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.is_running = False

            # Get the environment rendering
            rendered_frame = self.env.render()

            # Convert the frame to a Pygame surface
            frame_surface = pygame.surfarray.make_surface(rendered_frame.swapaxes(0, 1))

            # Display the frame on the screen
            self.screen.blit(frame_surface, (0, 0))
            pygame.display.flip()

            # Limit the frame rate
            self.clock.tick(30)

        pygame.quit()

    def start_display_thread(self):
        self.display_thread = threading.Thread(target=self.display_loop)
        self.display_thread.start()

    def join_display_thread(self):
        if self.display_thread is not None and self.display_thread.is_alive():
            self.display_thread.join()

In [6]:
# Example usage with the LunarLander environment
lunar_lander_env = gym.make('LunarLander-v2',render_mode='rgb_array')
display = Display(lunar_lander_env)
# Start the display in a new thread
display.start_display_thread()
# Continue with other tasks, e.g., interacting with the environment
# ...
# Wait for the display thread to finish
display.join_display_thread()

In [10]:
logits = torch.randn(5,7)
actions = torch.randn(5,7)
new_policy = F.softmax(logits, dim=1)
new_m = Categorical(new_policy)
#new_log_policy = new_m.log_prob(actions)
print(logits)
print(new_policy)
print(new_m)
#print(new_log_policy)

tensor([[-0.2185,  0.7709, -0.3901, -0.2150, -0.1516,  0.0532, -0.7860],
        [ 1.4378,  2.2149, -1.6656, -0.2960, -1.5264,  1.2136, -0.5608],
        [-1.5065,  0.6141, -0.6813,  0.8452, -0.9935,  0.3178,  1.2464],
        [-0.4298,  0.4591, -0.0485, -0.7450,  0.2609,  1.7190,  1.5103],
        [ 0.7857, -0.9246, -1.6470,  0.7054,  2.2700, -0.5954, -0.2216]])
tensor([[0.1179, 0.3170, 0.0993, 0.1183, 0.1260, 0.1547, 0.0668],
        [0.2282, 0.4963, 0.0102, 0.0403, 0.0118, 0.1823, 0.0309],
        [0.0219, 0.1825, 0.0500, 0.2299, 0.0366, 0.1357, 0.3434],
        [0.0432, 0.1050, 0.0632, 0.0315, 0.0862, 0.3703, 0.3006],
        [0.1385, 0.0250, 0.0122, 0.1278, 0.6111, 0.0348, 0.0506]])
Categorical(probs: torch.Size([5, 7]))
