In [1]:
!export PYTHONWARNINGS="default"

In [1]:
a=9709249010.93793
b=2610906993.0028586
print('indiv',"{:e}".format(a),end = "\n")
print('pop',"{:e}".format(b),end = "\n")

indiv 9.709249e+09
pop 2.610907e+09


In [2]:
%pip install swig  # Must be installed before box2d.
%pip install ribs[visualize] gymnasium[box2d]==0.29.1 "moviepy>=1.0.0" dask distributed

import importlib
import decorator
importlib.reload(decorator)

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


<module 'decorator' from '/opt/mamba/lib/python3.11/site-packages/decorator.py'>

In [3]:

from env_hiv import HIVPatient
import numpy as np
import torch
import torch.nn as nn
from ribs.archives import GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler
import time
from dask.distributed import Client
from tqdm import trange,tqdm
import sys
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
class Policy(nn.Module) :
    def __init__(self,state_n,action_n,hsize = 32) :
        super().__init__()
        self.fc1= nn.Linear(state_n, hsize)
        self.fc2= nn.Linear(hsize, hsize)
        self.fc3= nn.Linear(hsize, action_n)
    def forward(self,x) : 
        x= torch.Tensor(x)
        if len(x.shape)==1 :
            x= x.unsqueeze(0)
        x = nn.ReLU()(self.fc1(x))
        x = nn.ReLU()(self.fc2(x))
        return nn.ReLU()(self.fc3(x))
    def get_action(self,x) :
        logits = nn.Softmax(dim=1)(self.forward(x))
        return torch.argmax(logits, dim=1).item()
    def get_params(self):
        p = np.empty((0,))
        for n in self.parameters():
            p = np.append(p, n.flatten().cpu().detach().numpy())
        return p
    def set_params(self, x):
        start = 0
        for p in self.parameters():
            e = start + np.prod(p.shape)
            p.data = torch.FloatTensor(x[start:e]).reshape(p.shape)
            start = e

In [4]:
def simulate(model, periode_echantillonage,seed=None):
    domain_randomization = False#np.random.binomial(1,0.5)
    env = HIVPatient(clipping=True, domain_randomization=domain_randomization)
    env = TimeLimit(env,200)
    policy = Policy(state_n=env.observation_space.shape[0] , action_n=env.action_space.n)
    policy.set_params(model)
    total_reward = 0
    states = [[] for t in range(6)]
    s,_ = env.reset(seed=seed)
    s = np.log(s+1e-9)
    done = False 
    it =0
    while not done :
        
        action = policy.get_action(s)  # Linear policy.
        s, reward, terminated, truncated, _ = env.step(action)
        s = np.log(s+1e-9)
        done = terminated or truncated
        total_reward+=reward
        if it%periode_echantillonage ==0 :
            for i,a in enumerate(s) : 
                states[i].append(a)
        it+=1
    traj = []
    for i in range(len(states)): 
        traj.append(np.mean(states[i]))
    return total_reward, traj

In [5]:
env = HIVPatient(clipping=True, domain_randomization=False)
periode_echantillonage =1
#env= gym.make("Acrobot")

env = TimeLimit(env,200)
ranges = list(zip(env.lower, env.upper))
for i in range(len(ranges)) :
    a,b = ranges[i]
    ranges[i] = (np.log(a+1e-9), np.log(b+1e-9))
policy = Policy(env.observation_space.shape[0] , env.action_space.n)
solution_dim=len(policy.get_params())
dims = [30 for i in range(len(ranges))]
env_seed=  None
archive = GridArchive(
    solution_dim=solution_dim,
    dims=dims,
    ranges=ranges,
    qd_score_offset=0,
    learning_rate = 0.01,
    threshold_min =0
)
result_archive = GridArchive(solution_dim=solution_dim,
                             dims=dims,
                             ranges=ranges)
emitters = [
    EvolutionStrategyEmitter(
        archive,
        x0=policy.get_params(),
        sigma0=0.5,
        ranker="imp",
        selection_rule="mu",
        restart_rule="basic",
        batch_size = 32,
    ) for _ in range(5)
]
scheduler = Scheduler(archive, emitters, result_archive=result_archive)

start_time = time.time()
total_itrs = 450
workers = 128  # Adjust the number of workers based on your available CPUs.

client = Client(
    n_workers=workers,  # Create this many worker processes using Dask LocalCluster.
    threads_per_worker=1,  # Each worker process is single-threaded.
)

for itr in trange(1, total_itrs + 1, file=sys.stdout, desc='Iterations'):
    # Request models from the scheduler.
    sols = scheduler.ask()

    # Evaluate the models and record the objectives and measuress.
    futures = client.map(lambda model: simulate(model,periode_echantillonage, env_seed), sols)
    results = client.gather(futures)
    
    objs , meas = [],[]
    for obj, traj in results :
        objs.append(obj)
        meas.append(traj)
    # Send the results back to the scheduler.
    scheduler.tell(objs, meas)

    # Logging.
    if itr % 1 == 0:
        tqdm.write(f"> {itr} itrs completed after {time.time() - start_time:.2f}s")
        tqdm.write(f"  - Size: {archive.stats.num_elites}")    # Number of elites in the archive. len(archive) also provides this info.
        tqdm.write(f"  - Coverage: {archive.stats.coverage}")  # Proportion of archive cells which have an elite.
        tqdm.write(f"  - QD Score: {archive.stats.qd_score}")  # QD score, i.e. sum of objective values of all elites in the archive.
                                                               # Accounts for qd_score_offset as described in the GridArchive section.
        objmax = "{:e}".format(archive.stats.obj_max)
        objmean  = "{:e}".format(archive.stats.obj_mean)
        tqdm.write(f"  - Max Obj: {objmax}")    # Maximum objective value in the archive.
        tqdm.write(f"  - Mean Obj: {objmean}")  # Mean objective value of elites in the archive.

  logger.warn(
  logger.warn(


> 1 itrs completed after 102.69s                   
  - Size: 14                                       
  - Coverage: 1.9204389574759946e-08               
  - QD Score: 159954596.92118236                   
  - Max Obj: 3.115946e+07                          
  - Mean Obj: 1.142533e+07                         
> 2 itrs completed after 185.35s                              
  - Size: 19                                                  
  - Coverage: 2.6063100137174212e-08                          
  - QD Score: 170467739.90372968                              
  - Max Obj: 3.115946e+07                                     
  - Mean Obj: 8.971986e+06                                    
> 3 itrs completed after 262.29s                              
  - Size: 22                                                  
  - Coverage: 3.017832647462277e-08                           
  - QD Score: 174954208.08611214                              
  - Max Obj: 3.115946e+07                                 

In [7]:
import pickle
print("{:e}".format(archive.best_elite['objective']))
#SI PROBLEME : ESSAYER AVEC RESULT ARCHIVE DE CMA MAE
serialized= {"params" : archive.best_elite['solution']}
with open('saved.pkl', 'wb') as f:  # open a text file
    pickle.dump(serialized, f) # serialize the list

9.709249e+09
