In [1]:
import gymnasium as gym
import numpy as np


class Swimmer:
    def __init__(
        self,
        seed=None,
        num_rollouts: int = 3,
        render: bool = False,
    ):
        render_mode = "human" if render else None
        self.env = gym.make(
            "Swimmer-v4",
            render_mode=render_mode,
        )
        self.policy_shape = (
            self.env.action_space.shape[0],
            self.env.observation_space.shape[0],
        )
        self.dim = np.prod(
            self.policy_shape,
            dtype=int,
        )
        self.seed = seed
        self.counter = 0
        self.num_rollouts = num_rollouts

        self.lb = -1 * np.ones(self.dim)
        self.ub = 1 * np.ones(self.dim)


    def __call__(self, x):
        self.counter += 1
        assert len(x) == self.dim
        assert x.ndim == 1
        assert np.all(x <= self.ub) and np.all(x >= self.lb)

        M = x.reshape(self.policy_shape)

        returns = []
        observations = []
        actions = []

        for i in range(self.num_rollouts):
            obs, info = self.env.reset(seed=self.seed)
            terminated = False
            truncated = False
            total_reward = 0.0
            steps = 0

            while not (terminated or truncated):
                action = np.dot(M, obs)
                observations.append(obs)
                actions.append(action)
                (
                    obs,
                    reward,
                    terminated,
                    truncated,
                    info,
                ) = self.env.step(action)
                total_reward += reward
                steps += 1
                if self.env.render_mode is not None:
                    self.env.render()
            returns.append(total_reward)

        return np.mean(returns) * -1

In [19]:
swim = Swimmer(render=True)

In [13]:
swim.policy_shape

(2, 8)

In [10]:
a = np.array([3,4,5,2,2,4,5,6,7,8,1,2,3,4,5,6])/10

In [21]:
swim.lb

array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1.])