In [1]:
!export PYTHONWARNINGS="default"

In [2]:
import importlib
import decorator
importlib.reload(decorator)
from env_hiv import HIVPatient
import numpy as np
import torch
import torch.nn as nn
from ribs.archives import GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler
import time
from dask.distributed import Client
from tqdm import trange,tqdm
import sys
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
class Policy(nn.Module) :
    def __init__(self,state_n,action_n,hsize = 32) :
        super().__init__()
        self.fc1= nn.Linear(state_n, hsize)
        self.fc2= nn.Linear(hsize, hsize)
        self.fc3= nn.Linear(hsize, action_n)
    def forward(self,x) : 
        x= torch.Tensor(x)
        if len(x.shape)==1 :
            x= x.unsqueeze(0)
        x = nn.ReLU()(self.fc1(x))
        x = nn.ReLU()(self.fc2(x))
        return nn.ReLU()(self.fc3(x))
    def get_action(self,x) :
        logits = nn.Softmax(dim=1)(self.forward(x))
        return torch.argmax(logits, dim=1).item()
    def get_params(self):
        p = np.empty((0,))
        for n in self.parameters():
            p = np.append(p, n.flatten().cpu().detach().numpy())
        return p
    def set_params(self, x):
        start = 0
        for p in self.parameters():
            e = start + np.prod(p.shape)
            p.data = torch.FloatTensor(x[start:e]).reshape(p.shape)
            start = e

In [3]:
def simulate(model, periode_echantillonage,seed=None, ):
    env = HIVPatient(clipping=True, domain_randomization=True)
    env = TimeLimit(env,200)
    policy = Policy(state_n=env.observation_space.shape[0] , action_n=env.action_space.n)
    policy.set_params(model)
    total_reward = 0
    states = []
    s,_ = env.reset(seed=seed)
    s = np.log(s+1e-9)
    done = False 
    it =0
    while not done :
        
        action = policy.get_action(s)  # Linear policy.
        s, reward, terminated, truncated, _ = env.step(action)
        s = np.log(s+1e-9)
        done = terminated or truncated
        total_reward+=reward
        if it%periode_echantillonage ==0 :
            states.append(s)
        it+=1
    return total_reward, states

In [8]:
env = HIVPatient(clipping=True, domain_randomization=True)
periode_echantillonage =10
#env= gym.make("Acrobot")

env = TimeLimit(env,200)
ranges = list(zip(env.observation_space.low, env.observation_space.high))
policy = Policy(env.observation_space.shape[0] , env.action_space.n)
solution_dim=len(policy.get_params())
dims = [10 for i in range(len(ranges))]
env_seed=  None
archive = GridArchive(
    solution_dim=solution_dim,
    dims=dims,
    ranges=ranges,
    qd_score_offset=-600
)
emitters = [
    EvolutionStrategyEmitter(
        archive,
        x0=policy.get_params(),
        sigma0=0.1,
        ranker="2imp",
        batch_size = 30,
    ) for _ in range(3)
]
scheduler = Scheduler(archive, emitters)

start_time = time.time()
total_itrs = 300
workers = 19  # Adjust the number of workers based on your available CPUs.

client = Client(
    n_workers=workers,  # Create this many worker processes using Dask LocalCluster.
    threads_per_worker=1,  # Each worker process is single-threaded.
)

for itr in trange(1, total_itrs + 1, file=sys.stdout, desc='Iterations'):
    # Request models from the scheduler.
    sols = scheduler.ask()

    # Evaluate the models and record the objectives and measuress.
    futures = client.map(lambda model: simulate(model,periode_echantillonage, env_seed), sols)
    results = client.gather(futures)
    
    objs , meas = [],[]
    for obj, traj in results :
        objs.append(obj)
        meas.append(traj[-1])
    # Send the results back to the scheduler.
    scheduler.tell(objs, meas)

    # Logging.
    if itr % 1 == 0:
        tqdm.write(f"> {itr} itrs completed after {time.time() - start_time:.2f}s")
        tqdm.write(f"  - Size: {archive.stats.num_elites}")    # Number of elites in the archive. len(archive) also provides this info.
        tqdm.write(f"  - Coverage: {archive.stats.coverage}")  # Proportion of archive cells which have an elite.
        tqdm.write(f"  - QD Score: {archive.stats.qd_score}")  # QD score, i.e. sum of objective values of all elites in the archive.
                                                               # Accounts for qd_score_offset as described in the GridArchive section.
        objmax = "{:e}".format(archive.stats.obj_max)
        objmean  = "{:e}".format(archive.stats.obj_mean)
        tqdm.write(f"  - Max Obj: {objmax}")    # Maximum objective value in the archive.
        tqdm.write(f"  - Mean Obj: {objmean}")  # Mean objective value of elites in the archive.

  y *= step
  y += start
Perhaps you already have a cluster running?
Hosting the HTTP server on port 45953 instead


> 1 itrs completed after 31.67s                    
  - Size: 1                                        
  - Coverage: 1e-06                                
  - QD Score: 649331749.0398                       
  - Max Obj: 6.493311e+08                          
  - Mean Obj: 6.493311e+08                         
Iterations:   0%|          | 1/300 [00:30<2:34:03, 30.91s/it]

  grid_indices = ((self._dims *
  grid_indices = ((self._dims *


> 2 itrs completed after 61.12s                              
  - Size: 1                                                  
  - Coverage: 1e-06                                          
  - QD Score: 649331749.0398                                 
  - Max Obj: 6.493311e+08                                    
  - Mean Obj: 6.493311e+08                                   
> 3 itrs completed after 91.20s                              
  - Size: 1                                                  
  - Coverage: 1e-06                                          
  - QD Score: 5212627645.34801                               
  - Max Obj: 5.212627e+09                                    
  - Mean Obj: 5.212627e+09                                   
> 4 itrs completed after 120.76s                             
  - Size: 1                                                  
  - Coverage: 1e-06                                          
  - QD Score: 5212627645.34801                               
  - Max 

In [3]:
from train import *
agent = ProjectAgent()
agent.save()
agent.load()