In [16]:
from gymnasium.wrappers import TimeLimit
from env_hiv import HIVPatient
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from replay_buffer import ReplayBuffer
import torch.optim.lr_scheduler as lr_scheduler
from DQN import DQN
from copy import deepcopy
import random
import os
from pathlib import Path

In [17]:
env = TimeLimit(
    env=HIVPatient(domain_randomization=False), max_episode_steps=200
)

In [27]:
def greedy_action(network, state):
    device = next(network.parameters()).device
    with torch.no_grad():
        Q = network(torch.Tensor(state).unsqueeze(0).to(device))
        return torch.argmax(Q).item()

class ProjectAgent:
    def __init__(self):
        self.n_actions = 4
        self.state_dim = 6
        self.gamma = 0.85 
        self.device = "cuda" if next(model.parameters()).is_cuda else "cpu"
        self.save_path = "agent.pt"

        self.replay_buffer = ReplayBuffer(capacity=60000,device = self.device)
        self.model = DQN(self.state_dim, self.n_actions).to(self.device)
        self.lr = 1e-3 
        self.batch_size = 1024

        self.epsilon_max = 1
        self.epsilon_min =  0.01
        self.epsilon_stop = 1000
        self.epsilon_delay = 20
        self.epsilon_step = (self.epsilon_max-self.epsilon_min)/self.epsilon_stop

        
        self.update_count = 0
        self.target_update_freq = 100
        self.update_target_tau = 0.005
        self.update_target_strategy = 'ema'
        self.nb_gradient_steps = 1
 
        #self.q_network = QNetwork(self.state_dim, self.n_actions).to(self.device)
        self.target_network = DQN(self.state_dim, self.n_actions).to(self.device)
           
        self.target_network.load_state_dict(self.model.state_dict())
        self.target_network.eval()
        
        self.monitoring_nb_trials = 40
        self.monitor_every =  40

        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(),lr=self.lr)                     
        #self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=350, gamma=0.5)
    
    def MC_eval(self, env, nb_trials):   
        MC_total_reward = []
        MC_discounted_reward = []
        for _ in range(nb_trials):
            x,_ = env.reset()
            done = False
            trunc = False
            total_reward = 0
            discounted_reward = 0
            step = 0
            while not (done or trunc):
                a = greedy_action(self.model, x)
                y,r,done,trunc,_ = env.step(a)
                x = y
                total_reward += r
                discounted_reward += self.gamma**step * r
                step += 1
            MC_total_reward.append(total_reward)
            MC_discounted_reward.append(discounted_reward)
        return np.mean(MC_discounted_reward), np.mean(MC_total_reward)

    def V_initial_state(self, env, nb_trials):   
        with torch.no_grad():
            for _ in range(nb_trials):
                val = []
                x,_ = env.reset()
                val.append(self.model(torch.Tensor(x).unsqueeze(0).to(self.device)).max().item())
        return np.mean(val)
    
    def gradient_step(self):
        if len(self.replay_buffer) > self.batch_size:
            X, A, R, Y, D = self.replay_buffer.sample(self.batch_size)
            QYmax = self.target_network(Y).max(1)[0].detach()
            update = torch.addcmul(R, 1-D, QYmax, value=self.gamma)
            QXA = self.model(X).gather(1, A.to(torch.long).unsqueeze(1))
            loss = self.criterion(QXA, update.unsqueeze(1))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step() 

    def act(self, observation, use_random=False):
        if use_random:
            return env.action_space.sample()
        else:
            return greedy_action(self.model, observation)

    def train(self, env, max_episode):
        episode_return = []
        MC_avg_total_reward = []   # NEW NEW NEW
        MC_avg_discounted_reward = []   # NEW NEW NEW
        V_init_state = []   # NEW NEW NEW
        episode = 0
        episode_cum_reward = 0
        state, _ = env.reset()
        epsilon = self.epsilon_max
        step = 0
        while episode < max_episode:
            # update epsilon
            if step > self.epsilon_delay:
                epsilon = max(self.epsilon_min, epsilon-self.epsilon_step)
            # select epsilon-greedy action
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = greedy_action(self.model, state)
            # step
            next_state, reward, done, trunc, _ = env.step(action)
            self.replay_buffer.append(state, action, reward, next_state, done)
            episode_cum_reward += reward
            # train
            for _ in range(self.nb_gradient_steps): 
                self.gradient_step()
            # update target network if needed
            if self.update_target_strategy == 'replace':
                if step % self.update_target_freq == 0: 
                    self.target_network.load_state_dict(self.model.state_dict())
            if self.update_target_strategy == 'ema':
                target_state_dict = self.target_network.state_dict()
                model_state_dict = self.model.state_dict()
                tau = self.update_target_tau
                for key in model_state_dict:
                    target_state_dict[key] = tau*model_state_dict[key] + (1-tau)*target_state_dict[key]
                self.target_network.load_state_dict(target_state_dict)
            # next transition
            step += 1
            if done or trunc:
                episode += 1
                # Monitoring
                if self.monitoring_nb_trials>0:
                    MC_dr, MC_tr = self.MC_eval(env, self.monitoring_nb_trials)    # NEW NEW NEW
                    V0 = self.V_initial_state(env, self.monitoring_nb_trials)   # NEW NEW NEW
                    MC_avg_total_reward.append(MC_tr)   # NEW NEW NEW
                    MC_avg_discounted_reward.append(MC_dr)   # NEW NEW NEW
                    V_init_state.append(V0)   # NEW NEW NEW
                    episode_return.append(episode_cum_reward)   # NEW NEW NEW
                    print("Episode ", '{:2d}'.format(episode), 
                          ", epsilon ", '{:6.2f}'.format(epsilon), 
                          ", batch size ", '{:4d}'.format(len(self.replay_buffer)), 
                          ", ep return ", '{:4.1f}'.format(episode_cum_reward), 
                          ", MC tot ", '{:6.2f}'.format(MC_tr),
                          ", MC disc ", '{:6.2f}'.format(MC_dr),
                          ", V0 ", '{:6.2f}'.format(V0),
                          sep='')
                else:
                    episode_return.append(episode_cum_reward)
                    print("Episode ", '{:2d}'.format(episode), 
                          ", epsilon ", '{:6.2f}'.format(epsilon), 
                          ", batch size ", '{:4d}'.format(len(self.replay_buffer)), 
                          ", ep return ", '{:4.1f}'.format(episode_cum_reward), 
                          sep='')

                
                state, _ = env.reset()
                episode_cum_reward = 0
            else:
                state = next_state
        return episode_return, MC_avg_discounted_reward, MC_avg_total_reward, V_init_state

    def save(self, path):
        torch.save(self.model.state_dict(), path)
        print(f"Model saved in {path}")

    def load(self):
        self.model.load_state_dict(torch.load(self.save_path, map_location=self.device))
        self.target_network = deepcopy(self.model).to(self.device)
        self.model.eval()


    def collect_sample(self,nb_sample):
        s, _ = env.reset()
        for _ in range(nb_sample):
            a = self.act(s)
            s2, r, done, trunc, _ = env.step(a)
            self.replay_buffer.append(s, a, r, s2, done)
            if done or trunc :
                s, _ = env.reset()
            else:
                s = s2
        print('end of collection')

In [22]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)

In [23]:
seed_everything()
print('Start')
model = DQN(6, 4)
agent = ProjectAgent(env, model)
print('Agent created')
agent.collect_sample(1000)
agent.train(env, 50) 

Start
Agent created
end of collection
Episode  1, epsilon   0.82, batch size 1200, ep return 9718711.5, MC tot 2767769.38, MC disc 121578.42, V0 31670.08
Episode  2, epsilon   0.62, batch size 1400, ep return 11394152.3, MC tot 6787976.82, MC disc 81821.88, V0 52685.08
Episode  3, epsilon   0.43, batch size 1600, ep return 7664525.7, MC tot 4068046.30, MC disc 181834.56, V0 63737.53
Episode  4, epsilon   0.23, batch size 1800, ep return 10174233.5, MC tot 6787976.82, MC disc 81821.88, V0 72945.48
Episode  5, epsilon   0.03, batch size 2000, ep return 15983697.0, MC tot 12904030.71, MC disc 297553.38, V0 87885.02
Episode  6, epsilon   0.01, batch size 2200, ep return 28171138.5, MC tot 6787976.82, MC disc 81821.88, V0 108611.08
Episode  7, epsilon   0.01, batch size 2400, ep return 7167058.9, MC tot 6787976.82, MC disc 81821.88, V0 118914.97
Episode  8, epsilon   0.01, batch size 2600, ep return 7036090.3, MC tot 6787976.82, MC disc 81821.88, V0 142226.66
Episode  9, epsilon   0.01, bat

KeyboardInterrupt: 