In [1]:
import gymnasium as gym
from gymnasium.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete

import numpy as np
import random
import os

from stable_baselines3 import PPO
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_checker import check_env

In [None]:
from cost_calculator import CostCalculator

In [None]:
dtype = np.int64

In [None]:
l = 1
n = 5
m = 4
k = 3

In [2]:
M = 10 ** 6

In [None]:
T = np.array([
    [  1,   0,   0,   0],
    [0.5, 0.5,   0,   0],
    [  0, 0.5, 0.5,   0],
    [  0,   0, 0.5, 0.5],
    [  0,   0,   0,   1]
])

In [None]:
D = np.array([
    [0, 0, 1, 0],
    [0, 0, 1, 1],
    [0, 0, 0, 0],
    [0, 0, 0, 0]
])

In [None]:
C = np.array([
    [1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1]
])

In [None]:
E = np.array([
    [1, 0, 0, 0, 0]
])

In [3]:
# MultiBinary

In [None]:
class CostCalculatorFactory():
    def __init__(self, T, D, C, E):
        self.T = T
        self.D = D
        self.C = C
        self.E = E

    def get_cost_calculator(self, A):
        cost_calculator = CostCalculator(self.T, self.D, self.C, self.E, A)
        return cost_calculator

In [None]:
class MultiBinaryOperator:
    def __init__(self, m, k, cost_calculator_factory, dtype):
        self.m = m
        self.k = k
        self.cost_calculator_factory = cost_calculator_factory
        self.dtype = dtype
        
        count = m * k
        self.action_space = MultiBinary(count)
        self.observation_space = Box(low=0, high=1, shape=(count, ), dtype=dtype)

    def calculate_reward(self, array):
        A = array.reshape(self.m, self.k)
        cost_calculator = self.cost_calculator_factory.get_cost_calculator(A)
        cost = cost_calculator.calculate()
        return cost

In [None]:
class MultiDiscreteOperator:
    def __init__(self, m, k, cost_calculator_factory, dtype):
        self.m = m
        self.k = k
        self.cost_calculator_factory = cost_calculator_factory
        self.dtype = dtype

        self.action_space = MultiDiscrete([k for i in range(m)])
        self.observation_space = Box(low=0, high=m, shape=(k, ), dtype=self.dtype)

    def calculate_reward(self, array):
        A = self.__convert_A(array)
        cost_calculator = self.cost_calculator_factory.get_cost_calculator(A)
        cost = cost_calculator.calculate()
        return cost
    
    def __convert_A(self, array):
        A = np.zeros((self.m, self.k))
        for i in range(m):
            active_gen = array[i]
            A[i][active_gen] = 1
        return A

In [9]:
class ShowerEnv(gym.Env):
    def __init__(self):
        super(ShowerEnv, self).__init__()
        self.dtype = np.int64
        self.count = 4
        self.action_space = MultiDiscrete([self.count for i in range(self.count)]) # 
        # для MultiBinary(self.count) high = 1
        self.observation_space = Box(low=0, high=self.count, shape=(self.count, ), dtype=self.dtype)
        self.default_observation = np.zeros(self.count, dtype=self.dtype)
        self.observation = self.default_observation
        self.info = {}

    def init_observation(self):
        self.observation = np.zeros(self.count, dtype=self.dtype)

    def step(self, action):
        # Приспособленность
        self.observation = np.array(action).astype(self.dtype)
        
        reward = 1 / (np.sum(self.observation) + 0.1)

        truncated = False
        terminated = False

        return self.observation, float(reward), terminated, truncated, self.info

    def reset(self, seed=None, options=None):
        self.observation = self.default_observation
        return (self.observation, self.info)

In [10]:
env = ShowerEnv()

In [11]:
check_env(env)

In [12]:
# policies: MlpPolicy
# algorithms: A2C, PPO

In [13]:
model = A2C("MlpPolicy", env)

In [14]:
model.learn(total_timesteps=1000)

<stable_baselines3.a2c.a2c.A2C at 0x18f30d8dbe0>

In [15]:
model.env.envs[0].env.observation

array([3, 2, 3, 0])