In [None]:
import timm
import torch
from torch import nn

import gymnasium as gym
import time
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
import ale_py
import random
import torch
from torchvision import transforms
import numpy as np
from PIL import Image , ImageFilter
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
gym.register_envs(ale_py)

ENV_NAME = "PongNoFrameskip-v4"
TOTAL_TIMESTEPS = 1_000_000
MODEL_PATH = "ppo_pong"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


BATCH_SIZE = 8
ALLOWED_ACTIONS = [0, 2, 3]
action_to_index = {a: i for i, a in enumerate(ALLOWED_ACTIONS)}



env = make_atari_env(ENV_NAME, n_envs=1)
env.reset()

model = PPO("MlpPolicy", env, verbose=1)

obs_stack = []


In [None]:
def get_game_status(frame):

    arr = np.array(frame)

    # === Define ROI coordinates ===
    x_start, y_start = 0, 14
    x_end, y_end = 83, 76

    #

    ai_col = 9
    oppo_col = 74


    ai_col -= x_start
    oppo_col -= x_start

    th_up = 150
    th_down = 130

    # Extract ROI from original image
    roi = arr[y_start:y_end + 1, x_start:x_end + 1]  # shape: (63, 84)

    max_val = np.max(roi)
    coords = np.argwhere((roi == max_val) & (roi >= 180))
    if (len(coords) == 0):
        ball_pos_global = None
    else:
        ball_pos_roi = coords[0]  # Take the first match
        ball_pos_global = (ball_pos_roi[1] + x_start, ball_pos_roi[0] + y_start)


    indecies_ai = np.where((roi[:, ai_col] >= th_down) & (roi[:, ai_col] <= th_up))[0]
    indecies_oppo = np.where((roi[:, oppo_col] >= th_down) & (roi[:, oppo_col] <= th_up))[0]

    ai_loc = int(np.median(indecies_ai) + y_start) if len(indecies_ai) > 0 else None
    oppo_loc = int(np.median(indecies_oppo) + y_start) if len(indecies_oppo) > 0 else None
    if (ball_pos_global is None):
        return np.array([-1,-1, ai_loc, oppo_loc],dtype=object)
    return np.array([int(ball_pos_global[0]) ,int(ball_pos_global[1]) , ai_loc, oppo_loc], dtype=object)

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import random
# import cv2
# from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# frames = []

# # === Run and collect rendered frames ===
# for s in range(180):
#     action = [random.choice(ALLOWED_ACTIONS)]
#     obs, reward, done, info = env.step(action)
#     obs = obs[0, :, :, 0]

#     status = get_game_status(obs)
#     ball_x, ball_y, ai_y, oppo_y = status

#     # Create figure and attach canvas
#     fig, ax = plt.subplots(figsize=(3, 3), dpi=336)
#     canvas = FigureCanvas(fig)

#     ax.imshow(obs, cmap='gray')
#     ax.axis('off')

#     if (ball_x != -1):
#         ax.scatter(ball_x, ball_y, c='red', s=10, label='Ball')
#     if ai_y is not None:
#         ax.scatter(9, ai_y, c='blue', s=10, label='AI Paddle')
#     if oppo_y is not None:
#         ax.scatter(74, oppo_y, c='green', s=10, label='Opponent Paddle')

#     canvas.draw()
#     img = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
#     img = img.reshape(canvas.get_width_height()[::-1] + (4,))
#     img = img[:, :, :3]  # Drop alpha channel

#     frames.append(img)
#     plt.close(fig)

#     if done:
#         env.reset()

# # === Save frames to mp4 using OpenCV ===
# height, width, _ = frames[0].shape
# out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 20, (width, height))

# for frame in frames:
#     out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

# out.release()
# print("Saved video to output.mp4")


In [None]:
class PongWrapperHistory(gym.Wrapper):
        
        def __init__(self, env):
            super().__init__(env , history = 4)
            self.observation_space = gym.spaces.Box(
                low=np.array([-np.inf, -np.inf, -np.inf, -np.inf]*history, dtype=np.float32),
                high=np.array([np.inf, np.inf, np.inf, np.inf]*history, dtype=np.float32),
                dtype=np.float32
            )
        
        def step(self, action):
            obs, reward, done, info = self.env.step(action)

            obs = obs[0, :, :, 0]

            obs = get_game_status(obs)

            return obs, reward, done, info
                

In [None]:
env = make_atari_env(ENV_NAME, n_envs=1, seed=0)
env = PongWrapperHistory(env.envs[0])

model.learn(total_timesteps=TOTAL_TIMESTEPS)

In [8]:
import numpy as np
from collections import deque
from stable_baselines3.common.env_util import make_atari_env

class PongWrapperHistory(gym.ObservationWrapper):
    def __init__(self, env, history=4):
        super().__init__(env)
        self.history = history
        self.obs_shape = (84, 84)
        self.frames = deque(maxlen=self.history)

        flat_dim = np.prod(self.obs_shape) * history
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(flat_dim,), dtype=np.float32
        )

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        obs = obs[0, :, :, 0]
        processed = get_game_status(obs)
        self.frames.clear()
        for _ in range(self.history):
            self.frames.append(processed)
        return self._get_obs()

    def observation(self, obs):
        obs = obs[0, :, :, 0]
        processed = get_game_status(obs)
        self.frames.append(processed)
        return self._get_obs()

    def _get_obs(self):
        return np.concatenate([f.flatten() for f in self.frames]).astype(np.float32)


In [11]:

env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)
env = PongWrapperHistory(env)


AssertionError: Expected env to be a `gymnasium.Env` but got <class 'stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv'>