In [1]:
import os
import gym
import numpy as np
import cv2
from gym.spaces import Box
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import set_random_seed

# 1. CARTELLE E PERCORSI
CHECKPOINT_DIR = './train/ttlivelli'
LOG_DIR = './logs/'
LOAD_MODEL_PATH = r"C:\Users\matte\Desktop\marioia\mario_stabile\train\PP02\mario_model_VITTORIA.zip"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# --- CALLBACK PERSONALIZZATO PER SALVARE ALLA BANDIERA ---
class SaveOnFlagCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(SaveOnFlagCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _on_step(self) -> bool:
        # Recuperiamo i dati 'info' dall'ambiente
        # self.locals['infos'] contiene i dati di ogni ambiente vettorializzato
        for info in self.locals['infos']:
            if info.get('flag_get'):
                print(f"\n--- !!! BANDIERA RAGGIUNTA !!! ---")
                print(f"Salvataggio modello speciale in corso...")
                save_file = os.path.join(self.save_path, "mario_model_VITTORIA")
                self.model.save(save_file)
                print(f"Modello salvato come: {save_file}.zip")
        return True

# --- CLASSE AMBIENTE ---
class MarioModernized(gym.Env):
    def __init__(self):
        super().__init__()
        inner_env = gym_super_mario_bros.make('SuperMarioBros-v0')
        self.mario = JoypadSpace(inner_env, SIMPLE_MOVEMENT)
        self.observation_space = Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
        self.action_space = self.mario.action_space
        
        self.max_x = 0
        self.stagnant_steps = 0
        self.skip_frames = 4 
        self.view_stack = [np.zeros((84, 84), dtype=np.uint8)] * 4

    def _process_frame(self, frame):
        if frame is not None:
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
            resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
            
            self.view_stack.pop(0)
            self.view_stack.append(resized)
            film_strip = np.hstack(self.view_stack)
            big_strip = cv2.resize(film_strip, (84*4*2, 84*2), interpolation=cv2.INTER_NEAREST)
            
            cv2.imshow("Memoria IA (4 Frame)", big_strip)
            cv2.waitKey(1)
            
            return resized[:, :, np.newaxis]
        return np.zeros((84, 84, 1), dtype=np.uint8)

    def reset(self, seed=None, options=None):
        self.max_x = 0
        self.stagnant_steps = 0
        self.view_stack = [np.zeros((84, 84), dtype=np.uint8)] * 4
        obs = self.mario.reset()
        self.mario.render()
        return self._process_frame(obs)

    def step(self, action):
        total_reward = 0
        for _ in range(self.skip_frames):
            obs, reward, done, info = self.mario.step(action)
            total_reward += reward
            if done: break

        self.mario.render()
        
        current_x = info['x_pos']
        if current_x > self.max_x:
            diff = current_x - self.max_x
            total_reward += diff 
            self.max_x = current_x
            self.stagnant_steps = 0
        else:
            self.stagnant_steps += 1

        if done and info['flag_get'] is False:
            total_reward -= 50 

        if info['flag_get']:
            total_reward += 500
            
        if self.stagnant_steps > 150: 
            total_reward -= 20
            done = True 

        obs = self._process_frame(obs)
        return obs, float(total_reward / 10.0), done, info

if __name__ == "__main__":
    SEED = 42
    set_random_seed(SEED)
    
    env = DummyVecEnv([lambda: MarioModernized()])
    env.seed(SEED)
    env = VecFrameStack(env, n_stack=4, channels_order='last')

    # 1. Callback standard per salvare ogni 100k passi
    checkpoint_callback = CheckpointCallback(
        save_freq=100000, 
        save_path=CHECKPOINT_DIR,
        name_prefix='mario_model'
    )

    # 2. Callback speciale per salvare quando Mario vince
    flag_callback = SaveOnFlagCallback(
        check_freq=1, 
        save_path=CHECKPOINT_DIR
    )

    # CARICAMENTO
    if os.path.exists(LOAD_MODEL_PATH):
        print(f"Caricamento modello da: {LOAD_MODEL_PATH}")
        model = PPO.load(LOAD_MODEL_PATH, env=env)
    else:
        print("ERRORE: Modello non trovato.")
        exit()

    print("Ripresa allenamento. Salvataggio automatico attivo e monitoraggio bandiera ON.")
    
    try:
        # Passiamo ENTRAMBI i callback in una lista
        model.learn(
            total_timesteps=1000000, 
            callback=[checkpoint_callback, flag_callback],
            reset_num_timesteps=False
        )
    except KeyboardInterrupt:
        print("Salvataggio modello finale...")
        model.save("mario_model_final_v2")
    finally:
        cv2.destroyAllWindows()

Caricamento modello da: C:\Users\matte\Desktop\marioia\mario_stabile\train\PP02\mario_model_VITTORIA.zip
Wrapping the env in a VecTransposeImage.
Ripresa allenamento. Salvataggio automatico attivo e monitoraggio bandiera ON.
Logging to ./logs/PPO_8


  return (self.ram[0x86] - self.ram[0x071c]) % 256



--- !!! BANDIERA RAGGIUNTA !!! ---
Salvataggio modello speciale in corso...
Modello salvato come: ./train/ttlivelli\mario_model_VITTORIA.zip
--------------------------------
| time/              |         |
|    fps             | 56      |
|    iterations      | 1       |
|    time_elapsed    | 36      |
|    total_timesteps | 2443772 |
--------------------------------

--- !!! BANDIERA RAGGIUNTA !!! ---
Salvataggio modello speciale in corso...
Modello salvato come: ./train/ttlivelli\mario_model_VITTORIA.zip
----------------------------------------
| time/                   |            |
|    fps                  | 56         |
|    iterations           | 2          |
|    time_elapsed         | 73         |
|    total_timesteps      | 2445820    |
| train/                  |            |
|    approx_kl            | 0.08334087 |
|    clip_fraction        | 0.252      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.535     |
|    explained_variance   | 0.728 