In [None]:
import os
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecMonitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, BaseCallback

#Klasa ułatwiająca czytanie logów podczas treningu
class EpisodeLoggerCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(EpisodeLoggerCallback, self).__init__(verbose)

        if self.locals.get("infos") is not None:
            for info in self.locals["infos"]:
                if info is not None and 'episode' in info:
                    # 'r' to nagroda, 'l' to długość epizodu
                    episode_reward = info['episode']['r']
                    episode_length = info['episode']['l']
                    print(f"Zakończono epizod. Nagroda: {episode_reward:.2f}, Długość: {episode_length}")
        return True


def main():
    env_id = 'BoxingNoFrameskip-v4'
    num_envs = 8  # Liczba równoległych środowisk
    seed = 42
    total_timesteps_finetune = 2_000_000

    #Gdzie zapisany jest model jezeli chcemy dotrenowac
    pretrained_model_path = 'logs/best_model/best_model.zip'
    #Tworzenie katalogow na modele, stany optimizera,tensorboard monitory itd.
    finetune_log_dir_base = './logs_finetuned_better'
    os.makedirs(finetune_log_dir_base, exist_ok=True)

    monitor_csv_path = os.path.join(finetune_log_dir_base, 'monitor_csv_logs')
    os.makedirs(monitor_csv_path, exist_ok=True)

    finetune_checkpoint_dir = os.path.join(finetune_log_dir_base, 'checkpoints')
    os.makedirs(finetune_checkpoint_dir, exist_ok=True)

    finetune_best_model_dir = os.path.join(finetune_log_dir_base, 'best_model')
    os.makedirs(finetune_best_model_dir, exist_ok=True)

    finetune_final_model_path = os.path.join(finetune_log_dir_base, 'ppo_boxing_finetuned_final_model.zip')

    tensorboard_log_path = os.path.join(finetune_log_dir_base, 'tensorboard_logs')
    os.makedirs(tensorboard_log_path, exist_ok=True)

    #Każde z równoległych srodowisk ma swoj wlasny monitor
    env = make_atari_env(env_id, n_envs=num_envs, seed=seed)
    env = VecFrameStack(env, n_stack=4)
    env = VecMonitor(env, monitor_csv_path)


    # Środowisko ewaluacyjne
    eval_env = make_atari_env(env_id, n_envs=1, seed=seed + 100) # Użyj innego seedu dla eval_env
    eval_env = VecFrameStack(eval_env, n_stack=4)



    #Wczytywanie modelu
    if not os.path.exists(pretrained_model_path):
        print(f"Nie ma modelu")
        return

    print(f"Model istnieje.")

    custom_hyperparameters = {
        "learning_rate": 5e-5,
        "clip_range": 0.1,
        # "n_steps": 2048,
        # "batch_size": 64,
        # "ent_coef": 0.01,
    }

    model = PPO.load(
        pretrained_model_path,
        env=env
        custom_objects={"learning_rate": custom_hyperparameters.get("learning_rate"),
                        "clip_range": custom_hyperparameters.get("clip_range")
                       },
        tensorboard_log=tensorboard_log_path,
        device='auto',
        verbose=1
    )



    print(f"Model wczytany. Liczba kroków przed dostrajaniem: {model.num_timesteps}")
    print(f"Rozpoczynanie dostrajania. Nowy learning rate: {model.learning_rate}, Clip range: {model.clip_range}")


    checkpoint_save_freq = max(1, 100_000 // num_envs)
    checkpoint_callback_finetune = CheckpointCallback(
        save_freq=checkpoint_save_freq,
        save_path=finetune_checkpoint_dir,
        name_prefix='ppo_boxing_finetuned_ckpt'
    )

    # Ewaluacja i zapis najlepszego modelu
    eval_freq = max(1, 25_000 // num_envs)
    eval_callback_finetune = EvalCallback(
        eval_env,
        best_model_save_path=finetune_best_model_dir,
        log_path=os.path.join(finetune_log_dir_base, 'eval_logs'),
        eval_freq=eval_freq,
        n_eval_episodes=5,
        deterministic=True,
        render=False
    )

    episode_logger_callback = EpisodeLoggerCallback(verbose=1)
    print(f"Rozpoczynanie dostrajania na {total_timesteps_finetune} dodatkowych kroków.")

    model.learn(
        total_timesteps=total_timesteps_finetune,
        callback=[checkpoint_callback_finetune, eval_callback_finetune, episode_logger_callback],
        reset_num_timesteps=False
    )

    # Zapis końcowy finalnego modelu po dostrajaniu
    model.save(finetune_final_model_path)
    print("Zakończono proces dostrajania.")
