In [1]:
import minigrid
from minigrid.wrappers import ImgObsWrapper
from stable_baselines3.common.torch_layers import MinigridFeaturesExtractor
from stable_baselines3.common.buffers import *
from stable_baselines3 import PPO
from torch import nn
import gymnasium as gym

pygame 2.5.2 (SDL 2.28.2, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
policy_kwargs = dict(
    features_extractor_class=MinigridFeaturesExtractor,
    features_extractor_kwargs=dict(features_dim=128),
)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.LLmodules.tfcl import Task_free_continual_learning

tfcl_instance = Task_free_continual_learning()

class EvaluateHardBufferCallback(BaseCallback):
    def __init__(self,  use_hard_buffer:bool, hard_buffer, xh, yh, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Create an instance of Task_free_continual_learning
        self.use_hard_buffer = use_hard_buffer
        self.hard_buffer = hard_buffer
        self.xh = xh
        self.yh = yh
        

    def _on_rollout_end(self) -> bool:
        # Call the evaluate_hard_buffer method through the tfcl_instance
        total_loss = self.tfcl_instance.evaluate_hard_buffer(self.use_hard_buffer, self.hard_buffer, self.xh, self.yh)
        return total_loss

class MASRegularizationCallback(BaseCallback):
    def __init__(self, total_loss, continual_learning, star_varibles, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.total_loss = total_loss
        self.continual_learning = continual_learning
        self.star_varibles = star_varibles

    def _on_rollout_end(self) -> bool:
        total_loss = self.tfcl_instance.MAS_regularization(self.total_loss, self.continual_learning, self.star_varibles)
        return total_loss

class SaveTrainingAccuracyCallback(BaseCallback):
    def __init__(self, use_hard_buffer, hard_buffer, x, y, xh, yh, recent_loss, hard_loss=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_hard_buffer = use_hard_buffer
        self.hard_buffer = hard_buffer
        self.x = x
        self.y = y
        self.xh = xh
        self.yh = yh
        self.recent_loss = recent_loss
        self.hard_loss = hard_loss

    def _on_rollout_end(self) -> bool:
        msg, losses, accuracy = self.tfcl_instance.save_training_accuracy(self.use_hard_buffer, self.hard_buffer, self.x, self.y, self.xh, self.yh, self.recent_loss, self.hard_loss)
        return msg, losses, accuracy
    
class update_loss_window_and_detect_peak_callback(BaseCallback):
    def __init__(self, first_train_loss, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.first_train_loss = first_train_loss

    def _on_rollout_end(self) -> bool:
        new_peak_detected, loss_window_mean, loss_window_variance = self.tfcl_instance.update_loss_window_and_detect_peak()
        return new_peak_detected, loss_window_mean, loss_window_variance
    
class UpdateImportanceWeightsCallback(BaseCallback):
    def __init__(self, continual_learning, new_peak_detected, loss_window_mean, loss_window_variance, hard_buffer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.continual_learning = continual_learning
        self.new_peak_detected = new_peak_detected
        self.loss_window_mean = loss_window_mean
        self.loss_window_variance = loss_window_variance
        self.hard_buffer = hard_buffer
        
    def _on_rollout_end(self) -> bool:
        dict_val = self.tfcl_instance.update_importance_weights_and_variables()
        return True, dict_val

class UpdateHardBufferCallback(BaseCallback):
    def __init__(self, use_hard_buffer, recent_loss, hard_loss, xt, yt, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_hard_buffer = use_hard_buffer
        self.recent_loss = recent_loss
        self.hard_loss = hard_loss
        self.xt = xt
        self.yt = yt

    def _on_rollout_end(self) -> bool:
        self.tfcl_instance.update_hard_buffer()
        return True

class EvaluateTestAccuracyCallback(BaseCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def _on_rollout_end(self) -> bool:
        msg, self.test_loss = self.tfcl_instance.evaluate_test_accuracy()
        return msg, self.test_loss


In [None]:
#total_losss
class SharedData:
    def __init__(self):
        self.data = {}


In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList

# Initialize shared data
shared_data = SharedData()

# Initialize callbacks with shared data
callback1 = EvaluateHardBufferCallback(shared_data, ...)
callback2 = MASRegularizationCallback(shared_data, ...)
# Add more callbacks as needed

# Combine callbacks into a CallbackList
callback_list = CallbackList([callback1, callback2, ...])


In [3]:
env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)

In [4]:
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(2e5)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.02e+03 |
|    ep_rew_mean     | 0        |
| time/              |          |
|    fps             | 408      |
|    iterations      | 1        |
|    time_elapsed    | 5        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.02e+03    |
|    ep_rew_mean          | 0           |
| time/                   |             |
|    fps                  | 227         |
|    iterations           | 2           |
|    time_elapsed         | 17          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.012974362 |
|    clip_fraction        | 0.0774      |
|    clip_range     

<stable_baselines3.ppo.ppo.PPO at 0x7f023394d310>