<a href="https://colab.research.google.com/github/PandisDP/Deep-Reinforcement-Learning/blob/main/DRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
import torch
import numpy as np
import torch as th
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from itertools import count
import torch.nn.functional as F
from collections import namedtuple
from IPython import display
import torch
import random
import math

Class Field

In [None]:

class Field:
    def __init__(self,device,size,item_pickup,item_dropoff,
                start_position,zones_blocks=[],path_predicts='Episodes'):
        '''
        Constructor of the class Field
        Parameters:
        device: device to run the game
        size: size of the field
        item_pickup: position of the item to pickup
        item_dropoff: position of the item to dropoff
        start_position: position of the agent
        zones_blocks: list of tuples with the positions of the blocks
        path_predicts: path to save the images of the episodes
        '''
        self.device=device
        self.size = size
        self.item_pickup = item_pickup
        self.item_dropoff = item_dropoff
        self.position = start_position
        self.position_start= start_position
        self.block_zones=zones_blocks
        self.item_in_car= False
        self.number_of_actions=6
        self.allposicions = []
        self.path_predicts= path_predicts
        self.done=False
        self.save_state()
        self.initial_state = {
            'device': self.device,
            'position': self.position,
            'item_pickup': self.item_pickup,
            'item_dropoff': self.item_dropoff,
            'item_in_car': self.item_in_car
        }

    def reset(self):
        '''
        Reset the game
        '''
        self.device = self.initial_state['device']
        self.position = self.initial_state['position']
        self.item_pickup = self.initial_state['item_pickup']
        self.item_dropoff = self.initial_state['item_dropoff']
        self.item_in_car = self.initial_state['item_in_car']
        self.done=False
        self.allposicions = []
        self.save_state()    

    def get_number_of_actions(self):
        '''
        Get the number of actions of the game
        Returns: number of actions
        '''
        return self.number_of_actions
    
    def get_number_of_states(self):
        '''
        Get the number of states of the game
        Returns: number of states
        '''
        return (self.size**4)*2 

    def get_state(self):
        '''
        Get the state of the game
        Returns: state
        '''
        state= self.position[0]*self.size*self.size*self.size*2
        state+= self.position[1]*self.size*self.size*2
        state+= self.item_pickup[0]*self.size*2
        state+= self.item_pickup[1]*2   
        if self.item_in_car:
            state+=1
        return torch.tensor([state],device=self.device)   
    
    def save_state(self):
        '''
        Save the state of the game
        '''
        self.allposicions.append(self.position)

    def graphics(self,puntos,name_fig):
        '''
        Create a plot of the game
        Parameters:
        puntos: list of tuples with the positions of the points
        name_fig: name of the figure
        '''
        # Crear una cuadrícula de 10x10
        cuadricula = np.zeros((10, 10))
        # Marcar los puntos en la cuadrícula
        for punto in puntos:
            cuadricula[punto] = 1
        # Crear la figura y el eje para el plot
        fig, ax = plt.subplots()
        # Usar 'imshow' para mostrar la cuadrícula como una imagen
        # 'cmap' define el mapa de colores, 'Greys' es bueno para gráficos en blanco y negro
        ax.imshow(cuadricula, cmap='Greys', origin='lower')
        # Ajustar los ticks para que coincidan con las posiciones de la cuadrícula
        ax.set_xticks(np.arange(-.5, 10, 1))
        ax.set_yticks(np.arange(-.5, 10, 1))
        # Dibujar las líneas de la cuadrícula
        ax.grid(color='black', linestyle='-', linewidth=2)
        # Ajustar el límite para evitar cortes
        ax.set_xlim(-0.5, 9.5)
        ax.set_ylim(-0.5, 9.5)
        for punto in self.block_zones:
            ax.scatter(punto[1], punto[0], color='red', marker='X', s=100) 
        for punto in puntos:
            ax.text(punto[1], punto[0], '✔', color='white', ha='center', va='center', fontsize=10)

        lst_start=[self.position_start, self.item_pickup,self.item_dropoff]
        for punto in lst_start:
            ax.scatter(punto[1], punto[0], color='blue',marker='*', s=100)  
        name_fig_path = self.path_predicts + '/' +name_fig
        plt.savefig(name_fig_path)
        plt.close()

    def empty_predict_data(self):
        '''
        Empty the folder of the predictions
        '''
        path=self.path_predicts
        for nombre in os.listdir(path):
            ruta_completa = os.path.join(path, nombre)
            try:
                if os.path.isfile(ruta_completa) or os.path.islink(ruta_completa):
                    os.remove(ruta_completa)
                elif os.path.isdir(ruta_completa):
                    shutil.rmtree(ruta_completa)
            except Exception as e:
                print(f'Error {ruta_completa}. reason: {e}')

    def block_zones_evaluation(self,position):
        '''
        Evaluate if the position is in a block zone
        Parameters:
        position: position to evaluate
        Returns: True if the position is in a block zone, False otherwise
        '''
        if position in self.block_zones:
            return True
        return False

    def make_action(self,action):
        '''
        Make an action in the game
        Parameters:
        action: action to make
        Returns: reward, done
        '''
        val_return=0
        (x,y) = self.position
        if action ==0: #down
            if y==self.size-1:
                val_return= -10 #reward,done
                return torch.tensor([val_return],device=self.device),self.done
            else:
                self.position = (x,y+1)
                self.save_state()
                if self.block_zones_evaluation(self.position):
                    val_return= -100
                    return torch.tensor([val_return],device=self.device),self.done 
                val_return = -1
                return torch.tensor([val_return],device=self.device),self.done
        elif action ==1: #up
            if y==0:
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done  
            else:
                self.position = (x,y-1)
                self.save_state()
                if self.block_zones_evaluation(self.position):
                    val_return =-100
                    return torch.tensor([val_return],device=self.device),self.done  
                val_return = -1
                return torch.tensor([val_return],device=self.device),self.done
        elif action ==2: #left
            if x==0:
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done  
            else:
                self.position = (x-1,y)
                self.save_state()
                if self.block_zones_evaluation(self.position):
                    val_return = -100
                    return torch.tensor([val_return],device=self.device),self.done  
                val_return= -1
                return torch.tensor([val_return],device=self.device),self.done  
        elif action ==3: #right
            if x==self.size-1:
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done  
            else:
                self.position = (x+1,y)
                self.save_state()
                if self.block_zones_evaluation(self.position):
                    val_return =-100
                    return torch.tensor([val_return],device=self.device),self.done  
                val_return = -1
                return torch.tensor([val_return],device=self.device),self.done 
        elif action ==4: #pickup
            if self.item_in_car:
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done   
            elif self.item_pickup != (x,y):
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done  
            else:
                self.item_in_car = True
                val_return = 20
                return torch.tensor([val_return],device=self.device),self.done
        elif action ==5: #dropoff
            if not self.item_in_car:
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done  
            elif self.item_dropoff != (x,y):
                val_return = -10
                return torch.tensor([val_return],device=self.device),self.done   
            else:
                self.item_in_car = False
                self.done=True
                val_return = 20
                return torch.tensor([val_return],device=self.device),self.done  




Class QValues

In [None]:

class QValues():
    '''
    This class is used to manage the Q-values of the agent
    '''
    def __init__(self,device):
        '''
        Params:
        device: The device used to run the agent'''
        self.device= device
    def get_current(self,policy_net,states,actions):
        '''
        This method is used to get the Q-values of the current state
        Params:
        policy_net: The neural network used to calculate the Q-values
        states: The states of the agent
        actions: The actions of the agent
        Returns: The Q-values of the current state'''
        return policy_net(states).gather(dim=1,index=actions.unsqueeze(-1))
    
    def get_current_i(self,policy_net,state,action):
        '''
        This method is used to get the Q-values of the current state
        Params:
        policy_net: The neural network used to calculate the Q-values
        state: The state of the agent
        action: The action of the agent
        Returns: The Q-values of the current state'''
        return policy_net(state).gather(1, th.tensor([[action]], device=self.device))
    
    def get_next_i(self, target_net, next_state):
        '''
        This method is used to get the Q-values of the next state
        Params:
        target_net: The target neural network used to calculate the Q-values
        next_state: The next state of the agent
        Returns: The Q-values of the next state'''
        return target_net(next_state).max(1)[0].unsqueeze(1)
    
    def get_next(self,target_net,next_states,is_done):
        '''
        This method is used to get the Q-values of the next state
        Params:
        target_net: The target neural network used to calculate the Q-values
        next_states: The next states of the agent
        is_done: A boolean array that indicates if the episode is done
        Returns: The Q-values of the next state'''
        next_q_values= torch.zeros(len(next_states)).to(self.device)
        non_final_mask= ~is_done
        non_final_next_states= next_states[non_final_mask]
        if len(non_final_next_states)>0:
            next_q_values[non_final_mask]= target_net(non_final_next_states).max(dim=1)[0]
        return next_q_values
        
class DQN(nn.Module):
    '''
    This class is used to manage the neural network of the agent
    This arquiteture is used in the Deep Q-Learning algorithm but is possible
    change for diferentes problems or games.
    '''
    def __init__(self,feature_size, num_actions):
        '''
        Params:
        feature_size: The size of the features
        num_actions: The number of actions that the agent can take'''
        super().__init__()
        self.fc1 = nn.Linear(in_features=feature_size,out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=256)
        self.fc3 = nn.Linear(in_features=256, out_features=128)
        self.out= nn.Linear(in_features=128 ,out_features=num_actions)

    def forward(self,t): 
        '''
        This method calculates the Q-values of the agent.
        
        Params:
        t (Tensor): The input features of the agent.
        Returns:
        Tensor: The Q-values of the agent.
        '''
    
        if t.dim()==1:
            t= t.unsqueeze(1)
        t=t.float()  
        t= F.relu(self.fc1(t))
        t= F.relu(self.fc2(t))
        t= F.relu(self.fc3(t))
        t= self.out(t)
        return t    


Class Agent

In [None]:

#This class is used to manage the agent
class Agent():
    '''
    This class is used to manage the agent
    '''
    def __init__(self,strategy,num_actions,device):
        '''
        Params:
        strategy: The strategy used to select the actions
        num_actions: The number of actions that the agent can take
        device: The device used to run the agent'''
        self.step=0
        self.strategy=strategy
        self.num_actions= num_actions
        self.device=device

    def select_action(self,state,policy_net):
        '''
        This method is used to select an action for the agent
        throught exploration or exploitation
        Params:
        state: The state of the agent
        policy_net: The neural network used to calculate the Q-values
        Returns: The action selected by the agent'''
        rate= self.strategy.get_exploration_rate(self.step)
        self.step+=1
        if random.random()<rate:
            action= random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) #action
        else:
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).to(self.device)

class EpsilonGreedyStrategy():
    '''
    This class is used to manage the exploration rate of the agent
    '''
    def __init__(self,start,end,decay):
        '''
        Params:
        start: The initial exploration rate
        end: The final exploration rate
        decay: The decay of the exploration rate'''
        self.start= start
        self.end= end
        self.decay= decay

    def get_exploration_rate(self,step):
        '''
        This method is used to get the exploration rate of the agent
        Params:
        step: The step of the training process
        Returns: The exploration rate of the agent'''
        return self.end + (self.start - self.end)*math.exp(-step*self.decay)    

class ReplayMemory():
    '''
    This class is used to store the experiences of the agent
    '''
    def __init__(self,capacity):
        '''
        Params:
        capacity: The maximum number of experiences that the memory can store
        memory: List with the experiences
        count: The number of experiences stored in the memory'''
        self.capacity= capacity
        self.memory= []
        self.count=0
    def push(self,exp):
        '''
        This method is used to store an experience in the memory
        Params:
        exp: The experience to store in the memory
        '''
        if len(self.memory)< self.capacity:
            self.memory.append(exp)
        else:
            self.memory[self.count%self.capacity]=exp
        self.count+=1      
    def sample(self,batch_size):
        '''
        This method is used to get a sample of experiences from the memory
        Params:
        batch_size: The number of experiences to get from the memory
        Returns: A sample of experiences from the memory
        '''
        return random.sample(self.memory,batch_size),0,0

    def can_provide_sample(self,batch_size):
        '''
        This method is used to check if the memory has enough experiences to provide a sample
        Params:
        batch_size: The number of experiences to get from the memory
        Returns: True if the memory has enough experiences to provide a sample, False otherwise
        '''
        return len(self.memory)> batch_size 

class PrioritizedReplayMemory():
    '''
    This class is used to store the experiences of the agent with priorities
    '''
    def __init__(self, capacity, alpha=0.6):
        '''
        Params:
        capacity: The maximum number of experiences that the memory can store
        alpha: The exponent used to calculate the priority of the experiences
        beta: The exponent used to calculate the importance sampling weights
        beta_increment_per_sampling: The increment of beta for each sampling
        epsilon: A small value to avoid division by zero
        '''
        self.tree = SumTree(capacity)
        self.capacity = capacity
        self.alpha = alpha
        self.beta = 0.4
        self.beta_increment_per_sampling = 0.001
        self.epsilon = 1e-6

    def _get_priority(self, error):
        '''
        This method is used to calculate the priority of an experience
        Params:
        error: The error of the experience
        Returns: The priority of the experience'''
        return (error + self.epsilon) ** self.alpha

    def push(self, error, sample):
        '''
        This method is used to store an experience in the memory
        Params:
        error: The error of the experience
        sample: The experience to store in the memory
        '''
        p = self._get_priority(error)
        self.tree.add(p, sample)

    def sample(self, batch_size):
        '''
        This method is used to get a sample of experiences from the memory
        Params:
        batch_size: The number of experiences to get from the memory
        Returns: A sample of experiences from the memory'''
        batch = []
        idxs = []
        segment = self.tree.total() / batch_size
        priorities = []
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)
            s = min(max(s, 0), self.tree.total())
            (idx, p, data) = self.tree.get(s)
            batch.append(data)
            idxs.append(idx)
            priorities.append(p)

        sampling_probabilities = priorities / self.tree.total()
        sampling_probabilities+=self.epsilon
        is_weights = np.power(self.tree.total() * sampling_probabilities, -self.beta)
        is_weights /= is_weights.max()
        return batch, idxs, is_weights

    def update(self, idx, error):
        '''
        This method is used to update the priority of an experience
        Params:
        idx: The index of the experience
        error: The error of the experience
        '''
        p = self._get_priority(error)
        self.tree.update(idx, p)

    def can_provide_sample(self, batch_size):
        '''
        This method is used to check if the memory has enough experiences to provide a sample
        Params:
        batch_size: The number of experiences to get from the memory
        Returns: True if the memory has enough experiences to provide a sample, False otherwise'''
        return self.tree.write >= batch_size  


class SumTree:
    '''
    This class is used to store the priorities of the experiences'''
    def __init__(self, capacity):
        '''
        Params:
        capacity: The maximum number of experiences that the memory can store
        tree: The binary tree used to store the priorities
        data: The experiences stored in the memory
        write: The number of experiences stored in the memory
        '''
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.write = 0
        self.visited_nodes = set()

    def _propagate(self, idx, change):
        '''
        This method is used to update the priorities of the experiences
        Params:
        idx: The index of the experience
        change: The change in the priority of the experience'''
        parent = (idx - 1) // 2
        self.tree[parent] += change
        self.visited_nodes.add(parent)
        if parent > 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        '''
        This method is used to get the index of an experience
        Params:
        idx: The index of the experience
        s: The priority of the experience
        Returns: The index of the experience'''
        left = 2 * idx + 1
        right = left + 1
        if left >= len(self.tree):
            return idx
        if s <= self.tree[left] and left in self.visited_nodes:
            return self._retrieve(left, s)
        else:
            if right in self.visited_nodes:
                return self._retrieve(right, s - self.tree[left])
            else:
                return idx

    def total(self):
        '''
        This method is used to get the total priority of the experiences
        Returns: The total priority of the experiences'''
        return self.tree[0]

    def add(self, p, data):
        '''
        This method is used to store an experience in the memory
        Params:
        p: The priority of the experience
        data: The experience to store in the memory
        '''
        idx = self.write + self.capacity - 1
        self.data[self.write] = data
        self.update(idx, p)
        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        '''
        This method is used to update the priority of an experience and 
        propagate the change to the parent nodes
        Params:
        idx: The index of the experience
        p: The priority of the experience
        '''
        change = p - self.tree[idx]
        self.tree[idx] = p
        self.visited_nodes.add(idx)
        self._propagate(idx, change)

    def get(self, s):
        '''
        This method is used to get an experience from the memory
        Params:
        s: The priority of the experience
        Returns: The index, priority, and experience of the experience
        '''
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        return idx, self.tree[idx], self.data[dataIdx]                

Class QDQ

In [None]:
Experience= namedtuple('Experience',('state','action','next_state','reward','is_done'))

class QDQ:
    '''
    This class is used to manage the training process of the agent
    '''
    def __init__(self,device,enviroment,agent,memory_strategy,features):
        '''
        Params:
        device: The device used to run the agent
        enviroment: The enviroment of the game
        agent: The agent of the game
        memory_strategy: The memory strategy used to store the experiences
        features: The number of dimensiones of the game'''
        self.device= device
        self.env= enviroment
        self.memory= memory_strategy
        self.features= features
        self.agent= agent
        self.policy_net= DQN(self.features,self.env.get_number_of_actions()).to(self.device)
        self.target_net= DQN(self.features,self.env.get_number_of_actions()).to(self.device)
        self.qvalue= QValues(device)

    def training_priorized_memory(self,batch_size,gamma,target_update,
                                learning_rate,num_episodes,checkpoint_file='checkpoint.pth'):
        '''
        This method is used to train the agent using the Prioritized Memory
        Params:
        batch_size: The number of experiences to get from the memory
        gamma: The discount factor
        target_update: The number of episodes to update the target network
        learning_rate: The learning rate of the training process
        num_episodes: The number of episodes to train the agent
        checkpoint_file: The name of the file to save the checkpoint
        '''
        print('Training Process with Prioritized Memory')
        self.target_net.load_state_dict(self.policy_net.state_dict()) 
        self.target_net.eval()
        optimizer= th.optim.Adam(self.policy_net.parameters(),lr=learning_rate)
        episode_durations=[]
        episode_losses=[]
        total_timesteps = 0
        # Load checkpoint if exists and get the start episode
        start_episode=0
        try:
            start_episode= self.load_checkpoint(optimizer,checkpoint_file)
        except FileNotFoundError:
            print('No checkpoint found starting from scratch')    
        for episode in range(start_episode,num_episodes):
            self.env.reset()
            for timestep in count():
                state= self.env.get_state()
                action= self.agent.select_action(state,self.policy_net)
                reward,done= self.env.make_action(action)
                next_state= self.env.get_state()
                done = th.tensor([done], device=self.device, dtype=th.bool)
                with th.no_grad():
                    current_q_value = self.qvalue.get_current_i(self.policy_net, state, action)
                    next_q_value = self.qvalue.get_next_i(self.target_net, next_state)
                    target_q_value = reward + (gamma * next_q_value * (1 - done.float()))
                    error = abs(current_q_value - target_q_value).item()    
                self.memory.push(error, Experience(state, action, next_state, reward, done))
                if self.memory.can_provide_sample(batch_size):
                    experiences,idxs,is_weights= self.memory.sample(batch_size)
                    states,actions,rewards,next_states,is_done= self.__extract_tensors(experiences)
                    current_q_values= self.qvalue.get_current(self.policy_net,states,actions)
                    with th.no_grad():
                        next_q_values= self.qvalue.get_next(self.target_net,next_states,is_done)
                    target_q_values= (next_q_values*gamma)+rewards
                    is_weights= th.tensor(is_weights,dtype=th.float).unsqueeze(1).to(self.device)
                    loss = (is_weights * F.mse_loss(current_q_values, target_q_values.unsqueeze(1)
                                                    , reduction='none')).mean()
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_timesteps += 1
                    errors = th.abs(current_q_values - target_q_values.unsqueeze(1)).detach()
                    for idx, error in zip(idxs, errors):
                        self.memory.update(idx, error.item())    
                if done:
                    episode_durations.append(timestep)
                    if 'loss' in locals():
                        avg_loss = loss.item()
                    else:
                        avg_loss = 0   
                    print("Episode: ",episode," Average_Losses: ",avg_loss,
                        " Duration: ",timestep)  
                    self.save_checkpoint(episode,optimizer) 
                    break
                if total_timesteps % target_update == 0:
                    self.target_net.load_state_dict(self.policy_net.state_dict())     

    def training_replay_memory(self,batch_size,gamma,target_update,learning_rate,
                            num_episodes,checkpoint_file='checkpoint.pth'):
        '''
        This method is used to train the agent using the Replay Memory
        Params:
        batch_size: The number of experiences to get from the memory
        gamma: The discount factor
        target_update: The number of episodes to update the target network
        learning_rate: The learning rate of the training process
        num_episodes: The number of episodes to train the agent
        checkpoint_file: The name of the file to save the checkpoint
        '''
        print('Training Process with Replay Memory')
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        optimizer= th.optim.Adam(self.policy_net.parameters(),lr=learning_rate)
        episode_durations=[]
        episode_losses=[]
        total_timesteps = 0
        # Load checkpoint if exists and get the start episode
        start_episode=0
        try:
            start_episode= self.load_checkpoint(optimizer,checkpoint_file)
        except FileNotFoundError:
            print('No checkpoint found starting from scratch')  
        for episode in range(start_episode,num_episodes):
            self.env.reset()
            for timestep in count():
                state= self.env.get_state()
                action= self.agent.select_action(state,self.policy_net)
                reward,done= self.env.make_action(action)
                next_state= self.env.get_state()
                done = th.tensor([done], device=self.device, dtype=th.bool)
                self.memory.push(Experience(state,action,next_state,reward,done))
                if self.memory.can_provide_sample(batch_size):
                    experiences,*_= self.memory.sample(batch_size)
                    states,actions,rewards,next_states,is_done= self.__extract_tensors(experiences)
                    current_q_values= self.qvalue.get_current(self.policy_net,states,actions)
                    with th.no_grad():
                        next_q_values= self.qvalue.get_next(self.target_net,next_states,is_done)
                    target_q_values= (next_q_values*gamma)+rewards
                    loss= F.mse_loss(current_q_values,target_q_values.unsqueeze(1))
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_timesteps += 1
                if done:
                    episode_durations.append(timestep)
                    if 'loss' in locals():
                        avg_loss = loss.item()
                    else:
                        avg_loss = 0   
                    print("Episode: ",episode," Average_Losses: ",avg_loss,
                        " Duration: ",timestep) 
                    self.save_checkpoint(episode,optimizer)
                    break
                if total_timesteps % target_update == 0:
                    self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_checkpoint(self, episode, optimizer, filename='checkpoint.pth'):
        checkpoint = {
            'episode': episode,
            'model_state_dict': self.policy_net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'target_net_state_dict': self.target_net.state_dict()
        }
        th.save(checkpoint, filename)
        print(f'Checkpoint saved at episode {episode}')

    def load_checkpoint(self, optimizer, filename='checkpoint.pth'):
        if th.cuda.is_available():
            checkpoint = th.load(filename)
        else:
            checkpoint = th.load(filename, map_location=th.device('cpu'))
        
        self.policy_net.load_state_dict(checkpoint['model_state_dict'])
        self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_episode = checkpoint['episode']
        print(f'Checkpoint loaded from episode {start_episode}')
        return start_episode    

    def __extract_tensors(self,experiences):
        batch = Experience(*zip(*experiences))
        states = th.cat(batch.state).to(self.device)
        actions = th.cat(batch.action).to(self.device)
        rewards = th.cat(batch.reward).to(self.device)
        next_states = th.cat(batch.next_state).to(self.device)
        final_states = th.tensor(batch.is_done, dtype=th.bool).to(self.device)
        return states, actions, rewards, next_states, final_states

    def __get_moving_avg(self,values,period):
        values = th.tensor(values,dtype=th.float)
        if len(values)>=period:
            moving_avg= values.unfold(dimension=0,size=period,step=1).mean(dim=1).flatten(start_dim=0)
            moving_avg= th.cat((th.zeros(period-1),moving_avg))
            return moving_avg
        else:
            moving_avg= th.zeros(len(values))
            return moving_avg
        
    def __plot(self,values,moving_avg_period):
        plt.figure(2)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Duration')
        plt.plot(values)
        moving_avg= self.get_moving_avg(values,moving_avg_period)
        plt.plot(moving_avg)
        plt.pause(0.001)
        #print("Episode", len(values),"\n",moving_avg_period,"episode moving avg:", moving_avg[-1])
        display.clear_output(wait=True)            



In [None]:

def training_process(params_env,prms_tra,type_memory):
    '''
    Params:
    params_env: Dictionary with the parameters of the game
    prms_tra: Dictionary with the parameters of the training
    type_memory: 0 for ReplayMemory and 1 for PrioritizedReplayMemory
    '''
    #device= th.device("cuda" if th.cuda.is_available() else "cpu")
    device = th.device("mps" if th.backends.mps.is_available() else "cpu")
    print('The training process used this Device: ',device)
    env= Field(device,params_env['size'],params_env['start_position'],params_env['item_pickup'],
                params_env['item_dropoff'],params_env['zones_block'],params_env['Path'])
    eps= EpsilonGreedyStrategy(prms_tra['eps_start'],prms_tra['eps_end'],prms_tra['eps_decay'])
    agent= Agent(eps,env.get_number_of_actions(),device)
    if type_memory==0:
        memory= ReplayMemory(prms_tra['memory_size'])
        q= QDQ(device,env,agent,memory,prms_tra['features'])
        q.training_replay_memory(prms_tra['batch_size'],prms_tra['gamma'],
                            prms_tra['target_update'],prms_tra['lr'],prms_tra['num_episodes'])
    elif type_memory==1:
        memory= PrioritizedReplayMemory(prms_tra['memory_size'])
        q= QDQ(device,env,agent,memory,prms_tra['features'])
        q.training_priorized_memory(prms_tra['batch_size'],prms_tra['gamma'],
                            prms_tra['target_update'],prms_tra['lr'],prms_tra['num_episodes'])
    else:
        print('Error: type_memory must be 0 or 1')
        return 'Error' 
    return 0
if __name__ == '__main__':
    params_game = {
                "size": 10,
                "start_position": (9, 0),  # (9,0)
                "item_pickup": (1, 1),  # (1,1)
                "item_dropoff": (8, 8),  # (8,8)
                "zones_block": [(4, 0), (4, 1), (4, 2), (4, 3), (2, 6), (2, 7), (2, 8), (2, 9), 
                                (4, 8), (5, 8), (6, 8), (7, 6), (8, 6), (9, 6)],
                "Path": 'Episodes'
    }
    params_training = {
                "batch_size": 1000,
                "features": 1,
                "gamma": 0.99,
                "eps_start": 1,
                "eps_end": 0.01,
                "eps_decay": 0.001,
                "target_update": 5000,
                "memory_size": 100000,
                "lr": 0.001,
                "num_episodes": 10000
    }
    training_process(params_game,params_training,1)