In [None]:
import gymnasium as gym
from gymnasium.spaces import Dict, Box, MultiDiscrete
import numpy as np
import heapq
import torch
import torch.nn as nn
from flask import Flask, request
import threading
import time
from stable_baselines3 import PPO
from stable_baselines3.common.buffers import DictRolloutBuffer
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import MultiInputActorCriticPolicy
import cv2

# 커스텀 DummyVecEnv
class CustomDummyVecEnv(DummyVecEnv):
    def reset(self, seed=None, options=None):
        self.buf_obs = {key: [] for key in self.observation_space.spaces.keys()}
        for env_idx, env in enumerate(self.envs):
            obs, info = env.reset(seed=seed, options=options)
            for key in self.buf_obs:
                self.buf_obs[key].append(obs[key])
        return {key: np.array(val, dtype=self.observation_space[key].dtype) for key, val in self.buf_obs.items()}, info

# Flask 앱
app = Flask(__name__)
data_heap = []
heap_lock = threading.Lock()
current_action = None
action_lock = threading.Lock()
current_obs = None
obs_lock = threading.Lock()
rollout_buffer = None
model = None
env = None
step_counter = 0
total_steps = 100000
n_steps = 2048
is_first_request = True

# 시뮬레이터 요청 처리
@app.route('/simulator_data', methods=['POST'])
def receive_data():
    global step_counter, current_action, current_obs, is_first_request
    data = request.json
    timestamp = time.time()
    with heap_lock:
        heapq.heappush(data_heap, (-timestamp, data))
        print(f"Data received at: {timestamp}, Heap size: {len(data_heap)}")
    
    if model is None or env is None:
        return {"status": "Model not initialized"}, 200
    
    with heap_lock:
        if not data_heap:
            return {"status": "No data"}, 200
        _, latest_data = heapq.heappop(data_heap)
    
    if is_first_request:
        # 첫 번째 요청: 행동 예측
        obs = {
            "image": np.array(latest_data["image"], dtype=np.uint8),
            "sensors": np.array(latest_data["sensors"], dtype=np.float32)
        }
        with obs_lock:
            global current_obs
            current_obs = obs
        
        with torch.no_grad():
            action, value, log_prob = model.policy(obs, deterministic=False)
        
        with action_lock:
            action_idx, weight_idx = action[0]
            weight = model.env.envs[0].weight_bins[weight_idx]
            current_action = {"action": int(action_idx), "weight": float(weight)}
        
        is_first_request = False
        return {"status": "Action predicted"}, 200
    else:
        # 두 번째 요청: env.step()
        with obs_lock:
            obs = current_obs
        
        with torch.no_grad():
            action, value, log_prob = model.policy(obs, deterministic=False)
        
        # env.step() 호출
        new_obs, reward, terminated, truncated, info = env.step(action[0])
        
        # RolloutBuffer에 저장
        rollout_buffer.add(
            obs=obs,
            actions=action,
            rewards=np.array([reward]),
            dones=np.array([terminated]),
            values=value,
            log_probs=log_prob,
            episode_starts=np.array([False])
        )
        step_counter += 1
        
        # 정책 업데이트
        if step_counter % n_steps == 0:
            with torch.no_grad():
                next_value = model.policy.predict_values(new_obs)
            rollout_buffer.compute_returns_and_advantage(last_values=next_value, dones=np.array([terminated]))
            model.train()
            rollout_buffer.reset()
        
        # 학습 종료
        if step_counter >= total_steps:
            model.save("ppo_custom_model")
            print("Learning completed")
            return {"status": "Learning completed"}, 200
        
        # 에피소드 리셋
        if terminated or truncated:
            obs, _ = env.reset(options={
                "image": new_obs["image"],
                "sensor_data": new_obs["sensors"].tolist()
            })
            with obs_lock:
                current_obs = obs
            is_first_request = True
        else:
            is_first_request = True
        
        return {"status": "Step processed"}, 200

# 커스텀 환경
class CustomEnv(gym.Env):
    def __init__(self, simulator_url="http://localhost:5000"):
        super().__init__()
        self.observation_space = Dict({
            "image": Box(low=0, high=255, shape=(128, 128, 2), dtype=np.uint8),
            "sensors": Box(low=-np.inf, high=np.inf, shape=(9,), dtype=np.float32)
        })
        self.action_space = MultiDiscrete([4, 10])
        self.simulator_url = simulator_url
        self.max_steps = 1000
        self.step_count = 0
        self.weight_bins = np.linspace(0.0, 0.9, 10)
        self.render_mode = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.step_count = 0
        if options:
            # options로 초기 상태 설정
            image = np.array(options["image"], dtype=np.uint8)
            sensors = np.array(options["sensor_data"], dtype=np.float32)
            return {"image": image, "sensors": sensors}, {}
        with heap_lock:
            heapq.heapify(data_heap)
            if data_heap:
                _, data = heapq.heappop(data_heap)
            else:
                for _ in range(4):  # 최대 1초
                    time.sleep(0.25)
                    with heap_lock:
                        if data_heap:
                            _, data = heapq.heappop(data_heap)
                            break
                else:
                    raise TimeoutError("No initial data from simulator")
        image = np.array(data["image"], dtype=np.uint8)
        sensors = np.array(data["sensors"], dtype=np.float32)
        return {"image": image, "sensors": sensors}, {}

    def step(self, action):
        global current_action
        start_time = time.time()
        # 두 번째 데이터 대기
        for _ in range(2):  # 최대 0.5초
            with heap_lock:
                if data_heap:
                    _, data = heapq.heappop(data_heap)
                    print(f"Step data at: {data['timestamp']}")
                    break
            time.sleep(0.25)
        else:
            raise TimeoutError("No result data from simulator")
        
        image = np.array(data["image"], dtype=np.uint8)
        sensors = np.array(data["sensors"], dtype=np.float32)
        reward = float(data["reward"])
        self.step_count += 1
        terminated = self.step_count >= self.max_steps or data.get("terminated", False)
        truncated = False
        info = {"step_time": time.time() - start_time}
        return {"image": image, "sensors": sensors}, reward, terminated, truncated, info

    def render(self):
        if self.render_mode == "human":
            image = self.observation_space["image"].sample()
            cv2.imshow("Environment", image[:, :, 0])
            cv2.waitKey(1)
        return image

    def close(self):
        cv2.destroyAllWindows()

# 커스텀 피처 추출기
class CustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        self.cnn = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            sample_image = torch.zeros(1, 2, 128, 128)
            n_flatten = self.cnn(sample_image).shape[1]
        self.mlp = nn.Sequential(
            nn.Linear(9, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.linear = nn.Sequential(
            nn.Linear(n_flatten + 64, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations):
        image = observations["image"].permute(0, 3, 1, 2).float() / 255.0
        image_features = self.cnn(image)
        sensor_features = self.mlp(observations["sensors"])
        combined = torch.cat([image_features, sensor_features], dim=1)
        return self.linear(combined)

# PPO 초기화
def initialize_ppo():
    global model, env, rollout_buffer
    env = CustomEnv()
    env = CustomDummyVecEnv([lambda: env])
    rollout_buffer = DictRolloutBuffer(
        buffer_size=n_steps,
        observation_space=env.observation_space,
        action_space=env.action_space,
        device="cpu",
        gae_lambda=0.95,
        gamma=0.99,
        n_envs=1,
    )
    model = PPO(
        policy=MultiInputActorCriticPolicy,
        env=env,
        policy_kwargs={"features_extractor_class": CustomFeaturesExtractor},
        learning_rate=3e-4,
        n_steps=n_steps,
        batch_size=64,
        n_epochs=10,
        verbose=1,
    )

# Flask 서버 실행
if __name__ == "__main__":
    initialize_ppo()
    app.run(host="0.0.0.0", port=5000, threaded=True)

In [7]:
a = 0.184069
b = a // 0.05

round(b * 0.05,2)

0.15