In [1]:
import numpy as np
def setstate(env,state):
    env.T1 = state[0]
    env.T1star = state[1]
    env.T2 = state[2]
    env.T2star = state[3]
    env.V = state[4]
    env.E = state[5]
    return env
from gymnasium.wrappers import TimeLimit
from env_hiv import * 
class MazeState(object):
    def __init__(self, pos):
        self.pos = np.array(pos)
        self.actions = [0,1,2,3]
        self.env= setstate(HIVPatient(),pos)
    def perform(self, action):
        y,r,_,_,_ = self.env.step(action)
        self.env = setstate(self.env,self.pos)
        return MazeState(y)
        
    def reward(self, parent, action):
        y,r,_,_,_ = self.env.step(action)
        self.env = setstate(self.env,self.pos)
        return r
            
    def is_terminal(self):
        return False
            
    def __eq__(self, other):
        return all(self.pos == other.pos)
        
    def __hash__(self):
        return int(sum([self.pos[i]*10**i for i in range(len(self.pos))]))

In [48]:
from joblib import Parallel, delayed
import os
from __future__ import print_function
from mcts import  tree_policies,default_policies,backups,graph
import random
from mcts import utils
from mcts.mcts import MCTS,_get_next_node
class ParallelMCTS(MCTS):
    def __init__(self, tree_policy, default_policy, backup, n_jobs = 20):
        super().__init__(tree_policy, default_policy, backup)
        self.tree_policy = tree_policy
        self.default_policy = default_policy
        self.backup = backup
        self.n_jobs = n_jobs
    def iteration(self,root):
        node = _get_next_node(root, self.tree_policy)
        node.reward = self.default_policy(node)
        self.backup(node)
    def __call__(self, root, n=18*2):
        if root.parent is not None:
            raise ValueError("Root's parent must be None.")

        Parallel(n_jobs = self.n_jobs,backend = "threading")(delayed(self.iteration)(root) for _ in range(n))
        return utils.rand_max(root.children.values(), key=lambda x: x.q).action
    
    
mcts = ParallelMCTS(tree_policy=tree_policies.UCB1(c=1.41), 
            default_policy=default_policies.immediate_reward,
            backup=backups.monte_carlo)
start,_ = HIVPatient().reset()
start = MazeState(start)
root = graph.StateNode(parent=None, state=start)
best_action = mcts(root,n=20)

In [54]:
env = HIVPatient()
total_reward= 0
x,_ = env.reset()
from tqdm import tqdm
pbar = tqdm(range(200))
pi = []
parent = None
for t in pbar :
    max_possible = (200-t)*7e10/200+total_reward
    start = MazeState(x)
    root = graph.StateNode(parent=parent, state=start)
    best_action = mcts(root,n=200)
    pi.append(best_action)
    x,r,_,_,_ = env.step(best_action)
    total_reward+=r
    pbar.set_postfix(total_reward = total_reward, immediate_reward = r,max_possible = max_possible)
total_reward

 48%|████▊     | 96/200 [33:40<36:29, 21.05s/it, immediate_reward=2.1e+8, max_possible=4.05e+10, total_reward=3.91e+9]   


KeyboardInterrupt: 