# Train BeeWorld with TD3 model

### Preparation for colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install gymnasium
!pip install stable_baselines3
!git clone https://github.com/alTeska/rl-bee-multimodal-sensing.git
!mv rl-bee-multimodal-sensing/bee.py ./

## Load and setup

In [None]:
from train import init_gym, init_model, setup_logging
from utils import create_directory, set_device

DEVICE = set_device()

In [None]:
config_path = "/content/drive/model-blablabla/config.yaml"

In [None]:
gym_name = "BeeWorld"
base_path = "drive/MyDrive/neuromatch/"
model_algo = "TD3"
timesteps = 10000
iters_max = 10
learning_rate = 0.01
policy_kwargs = {"net_arch": [100, 100], "activation_fn": nn.ReLU}

## Setup model and environment

In [None]:
models_path = base_path + "models/"
logs_path = base_path + "logs/"
replay_buffer_path = base_path + "replay_buffer/"
best_model_save_path = models_path + "{}".format(model_algo)

create_directory(models_path)
create_directory(logs_path)
create_directory(replay_buffer_path)
create_directory(best_model_save_path)

env = init_gym(gym_name, logs_path)
callback, logger = setup_logging(env, logs_path, best_model_save_path)
model = init_model(env, policy_kwargs, learning_rate, logger=logger)

vec_env = model.get_env()
obs = vec_env.reset()

## Training loop (+ save model at each iteration)

In [None]:
iters = 0

while iters < iters_max:
    iters += 1

    model_name = model_algo + "_" + str(timesteps * iters)
    model_path = models_path + model_algo + "/" + model_name
    replay_buffer_path = replay_buffer_path + model_algo + "/" + model_name

    cur_model_zip_path = model_path + ".zip"

    # if we already have saved the model learning at this stage, load that model
    # TODO: it is a bit akward, cause we just retrained the model and then check if exists and pick the old model?
    if os.path.exists(cur_model_zip_path):
        print("Loading this model:", cur_model_zip_path)
        model = TD3.load(cur_model_zip_path)
        model.set_env(
            DummyVecEnv([lambda: gym.make("BeeWorld", render_mode="rgb_array")])
        )
        model.load_replay_buffer(replay_buffer_path)

    # train the model if no model saved at this stage yet
    else:
        model.learn(
            total_timesteps=timesteps,
            reset_num_timesteps=False,
            callback=callback,
        )
        model.save(model_path)
        model.save_replay_buffer(replay_buffer_path)

env.close()

In [None]:
%load_ext tensorboard
%tensorboard --logdir 'drive/MyDrive/2023Neuromatch/logs/TD3LR_0p01_100_100_20230719_114053/' --port=80