In [5]:
import random
import numpy as np
from tqdm import trange
import matplotlib.pyplot as plt
from hiv_patient import HIVPatient
from sklearn.ensemble import ExtraTreesRegressor 

In [6]:
class ReplayBuffer:
    def __init__(self, patient, nb_patients=5, nb_steps=10):
        self.nb_patients = nb_patients
        self.nb_steps  = nb_steps
        self.capacity = nb_patients * nb_steps # capacity of the buffer
        self.data = np.zeros((self.capacity, 14))
        self.index = 0 # index of the next cell to be filled
        self.patient = patient

    def fill(self, reg=None, epsilon=.15):
        k=0
        for i in trange(self.nb_patients):
            s = self.patient.reset(mode="healthy")
            s[5] *= .75
            self.patient.E *= .75
            for step in range(self.nb_steps):
                if reg == None:
                    a = np.random.randint(4)
                else:
                    threshold = np.random.rand()
                    if threshold < epsilon:
                        a = np.random.randint(4)
                    else:
                        a = np.argmax([reg.predict(np.append(s,a).reshape(1,-1)) for a in range(4)])
                s = self.patient.state()
                s_, r,_ ,_  = self.patient.step(a)
                self.data[k,:6] = s
                self.data[k,6:7] = a
                self.data[k,7:8] = r
                self.data[k,8:] = s_
                k+=1

    def concatenate(self, replay_buffer):
        self.data = np.concatenate((self.data, replay_buffer.data))

    def sample(self, batch_size):
        return random.sample(self.data, batch_size)

    def __len__(self):
        return len(self.data)

In [None]:
class FittedQExtraTree:
    def __init__(self, initial_buffer, nb_iterations=3, gamma=.99, epsilon=.15):
        self.Qtree = ExtraTreesRegressor()
        self.buffers = [initial_buffer]
        self.nb_iterations = nb_iterations
        self.nb_patients = initial_buffer.nb_patients
        self.nb_steps = initial_buffer.nb_steps
        self.patient = initial_buffer.patient
        self.gamma=gamma
    
    def fit(self):

        print("Training on the initial replay buffer:")
        X, y = self.buffers[0].data[:,:7], self.buffers[0].data[:,7] 
        self.Qtree.fit(X,y)

        for i in range(1,self.nb_iterations):
            print("Creating replay buffer n°{:}".format(i+1))
            self.buffers.append(ReplayBuffer(self.patient, self.nb_patients, self.nb_steps))
            self.buffers[i].fill(reg = self.Qtree)
            self.buffers[i].concatenate(self.buffers[i-1])
            print("Training on replay buffer n°{:}".format(i+1))
            X, y = self.buffers[i].data[:,:7], self.buffers[i].data[:,7] + self.gamma * np.max([self.Qtree.predict(np.append((self.buffers[i].data[:, :6], a),axis=1)) for a in range(4)])
            self.Qtree.fit(X,y)
        
        print("Training done!")
    
    def predict(self, X):
        return self.Qtree.predict(X)


In [17]:
class ErnstModel:
    def __init__(self, patient=HIVPatient(), nb_patients=5, nb_steps=10, gamma=.99, nb_iterations=3, epsilon=.15):
        self.nb_patients = nb_patients
        self.nb_steps = nb_steps
        self.gamma = gamma
        self.nb_iterations = nb_iterations 
        self.Qtree = None
        self.patient = patient
        self.epsilon = epsilon
    
    def fit(self):

        print("Creating replay buffer n°1")
        RB0 = ReplayBuffer(self.patient, nb_patients=self.nb_patients, nb_steps=self.nb_steps)
        RB0.fill()

        self.Qtree = FittedQExtraTree(RB0, nb_iterations=self.nb_iterations, gamma=self.gamma, epsilon=self.epsilon)
        self.Qtree.fit()
    
    def predict(self, X):
        return self.Qtree.predict(X)

In [18]:
em = ErnstModel()
em.fit()

Creating replay buffer n°1


100%|██████████| 5/5 [00:03<00:00,  1.32it/s]


Training on the initial replay buffer:
Creating replay buffer n°2


100%|██████████| 5/5 [00:04<00:00,  1.03it/s]


Training on replay buffer n°2
Creating replay buffer n°3


100%|██████████| 5/5 [00:03<00:00,  1.39it/s]

Training on replay buffer n°3
Training done!





In [20]:
em.predict(np.array([10, 10, 10, 10, 10, 10, 0]).reshape(1,-1))

array([14564472.4566539])