In [1]:
from sb3_contrib import TRPO
from Client_diff_Emb import FRLClient
from Agent import SB3Agent
import copy
from stable_baselines3.common.evaluation import evaluate_policy
import torch
import numpy as np


from H_Envs.pendulum import PendulumEnv
from H_Envs.pendulum_emb import PendulumEnvEmb
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
from MBEnvs.mb_pendulum_emb import MB_PendulumEnv
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

In [2]:
def train(env_paras, device, save_dir, timesteps_real_per_round = 500, timesteps_fc_per_round = 20000, epoch_per_round = 100, rounds_num = 30, batch_size_env_model = 128):
    
    model_tmp_path = save_dir + "/models/tmp"
    
    CLIENTS_NUM = len(env_paras)
    embs = [np.array([x]) for x in env_paras]
    env_models = []
    MB_env = TimeLimit(MB_PendulumEnv(env_models, embs,device), max_episode_steps = 200)
    
    # Global_RL = PPO("MlpPolicy", MB_env, verbose=1)
    Global_RL = TRPO("MlpPolicy", MB_env, verbose=1)
    
    # env_theta = [0.1, 0.3, 0.5, 0.7, 0.9]
    real_envs = []
    eva_envs = []
    Clients = []
    for i in range(CLIENTS_NUM):
        real_envs.append( TimeLimit(PendulumEnv(g=env_paras[i]), max_episode_steps=200) )
        eva_envs.append( TimeLimit(PendulumEnvEmb(g=env_paras[i]), max_episode_steps=200) )
        policy_net = Global_RL
        agent = SB3Agent(policy_net)
        client = FRLClient(real_envs[i], agent, lr = 3e-4, hidden_size = 256, device = device, emb=embs[i])
        Clients.append(client)
        env_model = copy.deepcopy(client.model)
        env_models.append(env_model)
        
    
        
    Global_RL.env.models = env_models
    
    Global_RL.save(model_tmp_path)
    
    rewards_log = []
    
    env_models = []
    for round_idx in range(rounds_num):
        print('------------------------------')
        print("round: " + str(round_idx))
        env_models = []
        for client_idx in range(len(Clients)):
            print('------------------------------')
            print("client: " + str(client_idx))
            # update policy
            Clients[client_idx].agent.policy_net = Global_RL
            # train prediction models
            Clients[client_idx].learn(timesteps_real_per_round, epoch_per_round, batch_size_env_model)
            #
            env_model = Clients[client_idx].get_prediction_model()
            env_models.append(env_model)
        
    #     Server.update_env_models(env_models)
    
        MB_env = TimeLimit(MB_PendulumEnv(env_models,embs,device), max_episode_steps = 200)
        
        Global_RL = TRPO.load(model_tmp_path, env = MB_env)
    #     Global_RL.env.models = env_models
        #
        Global_RL.learn(total_timesteps=timesteps_fc_per_round)
        
        Global_RL.save(model_tmp_path)
    #     Server.learn(timesteps_real_per_round = 10000)
        # test performance
        
        round_reward = []
        
        for client_idx in range(CLIENTS_NUM):
            mean_reward, std_reward = evaluate_policy(Global_RL, eva_envs[client_idx], n_eval_episodes=20)
            round_reward.append(mean_reward)
        rewards_log.append(round_reward)
        print("mean_reward in real env:" + str(round_reward))
        
    return rewards_log

In [6]:
    # initialize the client and server
timesteps_real_per_round = 1000
timesteps_fc_per_round = timesteps_real_per_round * 30
epoch_per_round = 10
rounds_num = 50
batch_size_env_model = 128
test_dir= "TestEmb"
env_paras = [7.0,7.0,7.0, 10.0,10.0,10.0, 13.0,13.0,13.0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = test_dir

In [7]:
model_tmp_path = save_dir + "/models/tmp"
rounds_num = 1
    
CLIENTS_NUM = len(env_paras)
embs = [np.array([x]) for x in env_paras]
env_models = []
MB_env = TimeLimit(MB_PendulumEnv(env_models, embs,device), max_episode_steps = 200)

# Global_RL = PPO("MlpPolicy", MB_env, verbose=1)
Global_RL = TRPO("MlpPolicy", MB_env, verbose=1)

# env_theta = [0.1, 0.3, 0.5, 0.7, 0.9]
real_envs = []
eva_envs = []
Clients = []
for i in range(CLIENTS_NUM):
    real_envs.append( TimeLimit(PendulumEnv(g=env_paras[i]), max_episode_steps=200) )
    eva_envs.append( TimeLimit(PendulumEnvEmb(g=env_paras[i]), max_episode_steps=200) )
    policy_net = Global_RL
    agent = SB3Agent(policy_net)
    client = FRLClient(real_envs[i], agent, lr = 3e-4, hidden_size = 256, device = device, emb=embs[i])
    Clients.append(client)
    env_model = copy.deepcopy(client.model)
    env_models.append(env_model)
    

    
Global_RL.env.models = env_models

Global_RL.save(model_tmp_path)

rewards_log = []

env_models = []
for round_idx in range(rounds_num):
    print('------------------------------')
    print("round: " + str(round_idx))
    env_models = []
    for client_idx in range(len(Clients)):
        print('------------------------------')
        print("client: " + str(client_idx))
        # update policy
        Clients[client_idx].agent.policy_net = Global_RL
        # train prediction models
        Clients[client_idx].learn(timesteps_real_per_round, epoch_per_round, batch_size_env_model)
        #
        env_model = Clients[client_idx].get_prediction_model()
        env_models.append(env_model)
    
#     Server.update_env_models(env_models)

    MB_env = TimeLimit(MB_PendulumEnv(env_models,embs,device), max_episode_steps = 200)
    
    Global_RL = TRPO.load(model_tmp_path, env = MB_env)
#     Global_RL.env.models = env_models
    #
    Global_RL.learn(total_timesteps=timesteps_fc_per_round)
    
    Global_RL.save(model_tmp_path)
#     Server.learn(timesteps_real_per_round = 10000)
    # test performance
    
    round_reward = []
    
    for client_idx in range(CLIENTS_NUM):
        mean_reward, std_reward = evaluate_policy(Global_RL, eva_envs[client_idx], n_eval_episodes=20)
        round_reward.append(mean_reward)
    rewards_log.append(round_reward)
    print("mean_reward in real env:" + str(round_reward))

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




------------------------------
round: 0
------------------------------
client: 0
Avg loss: 0.059428532811192175!
Avg loss: 0.03293476407258519!
Avg loss: 0.023983481511628876!
Avg loss: 0.018789916601684428!
Avg loss: 0.018249193139757456!
Avg loss: 0.017707207884929327!
Avg loss: 0.017114631159972003!
Avg loss: 0.016514584048236428!
Avg loss: 0.01635221907910515!
Avg loss: 0.01634298132273519!
------------------------------
client: 1
Avg loss: 0.053227601621765645!
Avg loss: 0.024563304519542726!
Avg loss: 0.011105171824795737!
Avg loss: 0.005387888080846703!
Avg loss: 0.0024734056072702516!
Avg loss: 0.002442291295725075!
Avg loss: 0.002377772452573481!
Avg loss: 0.001721375538521291!
Avg loss: 0.001491951504874578!
Avg loss: 0.0012957964165555798!
------------------------------
client: 2
Avg loss: 0.05578530824499467!
Avg loss: 0.028420251820546884!
Avg loss: 0.01836653086715766!
Avg loss: 0.014014349858007336!
Avg loss: 0.013225900448912096!
Avg loss: 0.013822014175117752!
Avg loss

-----------------------------------------
| rollout/                  |           |
|    ep_len_mean            | 200       |
|    ep_rew_mean            | -1.16e+03 |
| time/                     |           |
|    fps                    | 304       |
|    iterations             | 8         |
|    time_elapsed           | 53        |
|    total_timesteps        | 16384     |
| train/                    |           |
|    explained_variance     | 1.5e-05   |
|    is_line_search_success | 1         |
|    kl_divergence_loss     | 0.00732   |
|    learning_rate          | 0.001     |
|    n_updates              | 7         |
|    policy_objective       | 0.00768   |
|    std                    | 0.955     |
|    value_loss             | 6.93e+03  |
-----------------------------------------
-----------------------------------------
| rollout/                  |           |
|    ep_len_mean            | 200       |
|    ep_rew_mean            | -1.15e+03 |
| time/                     |     



mean_reward in real env:[-1219.391234071972, -1181.3146120980382, -1166.449192996323, -1174.0073819140903, -1159.3624025997008, -1157.299137605878, -1058.3642891133204, -1158.4174241259693, -1080.9330974409356]


In [56]:
MB_env.reset()

(array([ 0.87111247, -0.49108353,  0.805978  , 13.        ], dtype=float32),
 {})

In [74]:
MB_env.step(MB_env.action_space.sample())

(array([ 0.30767384,  0.7691297 , -3.5322733 , 13.        ], dtype=float32),
 -3.186168909072876,
 False,
 False,
 {})

In [87]:
eva_envs[3].reset()

(array([-0.12367555, -0.9923227 ,  0.9276148 , 10.        ], dtype=float32),
 {})

In [89]:
Clients[0].emb

array([7.])

In [90]:
Clients[0].model

PredictionModel(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=3, bias=True)
  )
)

In [92]:
Clients[0].dataset_X

Unnamed: 0,0,1,2,actions
0,0.994386,0.105812,-0.139859,[-1.4192818]
1,0.995849,0.091024,-0.297200,[-0.8089152]
2,0.997365,0.072549,-0.370749,[-0.8214741]
3,0.998759,0.049798,-0.455882,[1.2252804]
4,0.999296,0.037513,-0.245946,[1.1295936]
...,...,...,...,...
994,-0.785173,0.619277,-3.779884,[1.5387541]
995,-0.675600,0.737269,-3.223950,[0.54706544]
996,-0.567969,0.823050,-2.754824,[-0.26217845]
997,-0.467035,0.884239,-2.362050,[1.6102653]


In [93]:
env = Clients[0].env

In [94]:
observation, info = env.reset()

In [95]:
observation

array([ 0.2996944 , -0.9540353 ,  0.21686848], dtype=float32)

In [96]:
 obs_tensor = np.concatenate((observation, Clients[0].emb)).astype(np.float32)

In [97]:
obs_tensor

array([ 0.2996944 , -0.9540353 ,  0.21686848,  7.        ], dtype=float32)

In [98]:
action = Clients[0].agent.act(obs_tensor)

In [99]:
action

array([0.37384626], dtype=float32)

In [100]:
observation, reward, done, truncated, info = env.step(action)

In [101]:
observation

array([ 0.28880283, -0.9573886 , -0.2279231 ], dtype=float32)

In [None]:
obs_tensor = np.concatenate((observation, self.emb)).astype(np.float32)