In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
import wandb
import gymnasium as gym
from gymnasium import spaces

from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.monitor import Monitor

import sys
sys.path.append("/home/martina/codi2/4year/tfg")  # add parent folder of general.py

from general import prepare, GlioblastomaPositionalEncoding, testing

In [None]:
run = wandb.init(
    project="TFG-glioblastoma-ppo",
    config={"envs": 8, "algo": "PPO"},
    sync_tensorboard=True,
)

In [None]:
class DatasetWrapper(gym.Wrapper):
    def __init__(self, image_paths, mask_paths, **env_kwargs):
        # image_paths and mask_paths are lists of length N
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.n = len(image_paths)
        self.env_kwargs = env_kwargs  # arguments to pass to inner env

        # TEMP env to inherit observation/action space
        tmp_env = GlioblastomaPositionalEncoding(image_paths[0], mask_paths[0], **env_kwargs)
        super().__init__(tmp_env)

    def reset(self, **kwargs):
        # pick a random image-mask pair
        idx = np.random.randint(0, self.n)
        # rebuild the inner env with that image
        self.env = GlioblastomaPositionalEncoding(
            self.image_paths[idx],
            self.mask_paths[idx],
            **self.env_kwargs
        )
        return self.env.reset(**kwargs)

    def step(self, action):
        return self.env.step(action)


In [None]:
def make_env(image_path, mask_path):
    def _init():
        env = GlioblastomaPositionalEncoding(image_path, mask_path)
        return Monitor(env)
    return _init


train_pairs = prepare(dataset=200)

image_paths = [p[0] for p in train_pairs]
mask_paths  = [p[1] for p in train_pairs]

def make_env():
    def _init():
        env = DatasetWrapper(
            image_paths=image_paths,
            mask_paths=mask_paths,
            grid_size=6,
            tumor_threshold=0.01,
            rewards=[100.0, -10.0, 0.5, -0.1],
            action_space=spaces.Discrete(5),
            max_steps=50
        )
        return Monitor(env)
    return _init


env_fns = [make_env() for _ in range(8)]
env = DummyVecEnv(env_fns)

model = PPO(
    "CnnPolicy",
    env,
    verbose=2,
    n_steps=256,
    batch_size=1024,
    tensorboard_log=f"runs/{run.id}",
    policy_kwargs={"normalize_images": False}

)

In [None]:
import numpy as np
model.learn(
    total_timesteps=1_000_000,
    callback=WandbCallback(
        gradient_save_freq=100,
        model_save_freq=10000,
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)

# Testing

In [None]:
test_pairs = prepare(mode = 'test') # Load 50 pairs

In [None]:
model_path = "/home/martina/codi2/4year/tfg/ppo/models/601mft6k/model.zip" # <--- UPDATE THIS PATH
loaded_model = PPO.load(model_path)
config = {
    'grid_size': 6,
    'rewards': [100.0, -10.0, 0.5, -0.1],
    'action_space': spaces.Discrete(5)
}
test_results = testing(
    agent=loaded_model, 
    test_pairs=test_pairs, 
    agent_type="ppo", 
    num_episodes=100, 
    env_config=config,
    save_gifs=True,
    gif_folder="TEST_RESULTS_GIFS"
)