# Lab 6 - Problemy wieloagentowe

**Wykonali:** Eryk Mikołajek, Jakub Kubicki, Marcin Zub

In [4]:
autorom_path = '/Users/marcin/.pyenv/versions/3.10.13/lib/python3.10/site-packages/AutoROM/roms'

In [10]:
import os
from pathlib import Path

import supersuit as ss
from pettingzoo.atari import pong_v3
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import DummyVecEnv
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnvWrapper
from pettingzoo.utils.conversions import parallel_wrapper_fn
import numpy as np

### Środowisko i proces uczenia

In [None]:
if autorom_path == '':
    print('Podaj ścieżkę do pliku autorom')
    exit(1)

# Ustawienia środowiska
env = pong_v3.env(num_players=2, max_cycles=1000, render_mode='human', auto_rom_install_path=Path(autorom_path))
env = ss.max_observation_v0(env, 2)
env = ss.sticky_actions_v0(env, repeat_action_probability=0.25)
env = ss.frame_skip_v0(env, 4)
env = ss.frame_stack_v1(env, 4)
env = parallel_wrapper_fn(env)

#### Callback do zapisu nagród
class RewardCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(RewardCallback, self).__init__(verbose)
        self.episode_rewards = []

    def _on_step(self) -> bool:
        self.episode_rewards.append(self.locals["rewards"])
        return True

callback = RewardCallback()

model = PPO('CnnPolicy', env, verbose=1)
model.learn(total_timesteps=10000, callback=callback)

### Testowanie modelu

In [None]:
obs = env.reset()
for _ in range(1000):
    action, _ = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

env.close()

### Wykresy i krzywa uczenia

In [None]:
import matplotlib.pyplot as plt

# Uśrednienie nagród na odcinkach
mean_rewards = [np.mean(callback.episode_rewards[i:i + 100]) for i in range(0, len(callback.episode_rewards), 100)]

plt.plot(mean_rewards)
plt.xlabel('Episodes (x100)')
plt.ylabel('Mean Total Reward')
plt.title('Learning Curve')
plt.show()