<a href="https://colab.research.google.com/github/Arkajeet7/warehouse-robotics-using-reinforcement-learning-/blob/main/training/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
import matplotlib.pyplot as plt
from collections import deque
import random
import torch
from torch import nn
import torch.nn.functional as F
import importlib.util
import pandas as pd

In [None]:
class WarehouseDQL():
    # Hyperparameters
    learning_rate_a = 0.0001
    discount_factor_g = 0.99
    network_sync_rate = 1000
    replay_memory_size = 100000
    mini_batch_size = 64

    # Neural Network
    loss_fn = nn.SmoothL1Loss()
    optimizer = None

    ACTIONS = ['L', 'D', 'R', 'U']

    def train(self, episodes, render=False):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Create Warehouse environment
        env = WarehouseEnv(render_mode='human' if render else None)
        rows, cols = env.rows, env.colm
        num_actions = env.action_space.n

        epsilon = 1  # Start with 100% random actions
        memory = ReplayMemory(self.replay_memory_size)

        # Create policy and target networks
        policy_dqn = DQN(rows=rows, cols=cols, h1_nodes=256, out_actions=num_actions).to(device)
        target_dqn = DQN(rows=rows, cols=cols, h1_nodes=256, out_actions=num_actions).to(device)
        target_dqn.load_state_dict(policy_dqn.state_dict())

        # Optimizer
        self.optimizer = torch.optim.Adam(policy_dqn.parameters(), lr=self.learning_rate_a)

        # Training metrics
        rewards_per_episode = []
        epsilon_history = []
        step_count_log=[]
        train_steps = 0

        for i in range(episodes):
            state, _ = env.reset()
            episode_positions = [env.agent_pos]
            episode_reward = 0
            step_count = 0
            terminated = False
            truncated = False

            while not (terminated or truncated):
                # --------------------------
                # 1. conversion of state to tensor
                # --------------------------
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)

                # -----------------------------------------------
                # 2. exploration based on epsilon greedy function
                # -----------------------------------------------
                if random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    with torch.no_grad():
                       policy_dqn.eval()
                       action = policy_dqn(state_tensor).argmax().item()
                    policy_dqn.train()
                # ------------------------------------
                # 3. step taken based on warehouse env
                # ------------------------------------
                new_state, reward, terminated, truncated, _ = env.step(action)
                episode_reward += reward
                episode_positions.append(env.agent_pos)

                # ---------------------------
                # 4. memory save
                # ---------------------------
                reward=float(np.clip(reward, -10, 10))
                memory.append((state, action, new_state, reward, terminated or truncated))

                # ----------------------------
                # 5. next state update
                # ----------------------------
                state = new_state
                step_count += 1
                step_count_log.append(step_count)

                # -------------------------
                # 6. training on mini batch
                # -------------------------
                if len(memory) > self.mini_batch_size:
                    mini_batch = memory.sample(self.mini_batch_size)
                    self.optimize(mini_batch, policy_dqn, target_dqn, rows, cols,device)
                    train_steps += 1

                   # ---------------------------------
                   # 7. syncing the target and policy
                   # ---------------------------------
                    if train_steps % self.network_sync_rate == 0:
                      target_dqn.load_state_dict(policy_dqn.state_dict())

            # -------------------------------------
            # 8. appending the position and reward
            # -------------------------------------
            rewards_per_episode.append(episode_reward)
            all_episode_position.append(episode_positions)

            # --------------------
            # 9. epsilon decay
            # --------------------
            epsilon = max(0.05, epsilon * 0.9995)  # Don't go below 1% random actions
            epsilon_history.append(epsilon)

            # ---------------------------------
            # 10. syncing the target and policy
            # ---------------------------------


            # -----------------------------
            # 11. result printing
            # -----------------------------
            print(f"\nEpisode {i+1}/{episodes}")
            print(f"  final Reward: {reward}")
            print(f"episodic reward: {episode_reward}")
            print(f"  Epsilon: {epsilon:.3f}")
            print(f"termination status={terminated}")
            print(f"truncation status={truncated}")
            print(f"steps taken={step_count}")

        # Close environment
        env.close()

        # -----------------
        # 11. save model
        # -----------------
        torch.save(policy_dqn.state_dict(), "warehouse_dql.pt")

        # ------------------------------------
        # 12. plotting total reward vs episode
        # ------------------------------------
        plt.figure(figsize=(12, 4))

        plt.subplot(121)
        plt.plot(rewards_per_episode)
        plt.title('Rewards per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')

        plt.subplot(122)
        plt.plot(epsilon_history)
        plt.title('Epsilon Decay')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')

        plt.tight_layout()
        plt.savefig('warehouse_training.png')

    def optimize(self, mini_batch, policy_dqn, target_dqn, rows, cols,device):
        policy_dqn.train()

        states = []
        actions = []
        next_states = []
        rewards = []
        dones = []

        # Separate batch into individual arrays
        for state, action, next_state, reward, done in mini_batch:
            states.append(state)
            actions.append(action)
            next_states.append(next_state)
            rewards.append(reward)
            dones.append(done)

        # Convert to tensors
        states = torch.FloatTensor(np.array(states)).to(device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(device)
        next_states = torch.FloatTensor(np.array(next_states)).to(device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(device)
        dones = torch.BoolTensor(dones).unsqueeze(1).to(device)

        # --------------------------
        # 1. compute current q value
        # --------------------------

        current_q =  policy_dqn(states).gather(1, actions)

        # ------------------------------
        # 2. compute the next q-values
        # ------------------------------
        with torch.no_grad():

          next_actions = policy_dqn(next_states).argmax(dim=1, keepdim=True)

          next_q = target_dqn(next_states).gather(1, next_actions)

          next_q[dones] = 0.0

        # ----------------------------------
        # 3. Compute TD target
        # ----------------------------------
        gamma = self.discount_factor_g
        target_q = rewards + gamma * next_q.detach()

        # ----------------------------------
        # 4. loss and backprop
        # ----------------------------------

        loss = self.loss_fn(current_q, target_q)

        if not loss.requires_grad:
         return None

        self.optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(policy_dqn.parameters(), max_norm=10.0)

        self.optimizer.step()

        return loss.item()

    def test(self, episodes):
        # Create environment
        env = WarehouseEnv(render_mode='human')

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        rows, cols = env.rows, env.colm
        num_actions = env.action_space.n

        # Load trained policy
        policy_dqn = DQN(rows=rows, cols=cols, h1_nodes=256, out_actions=num_actions).to(device)
        ckpt = torch.load("warehouse_dql.pt", map_location=device)
        policy_dqn.load_state_dict(ckpt)
        policy_dqn.eval()

        for i in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            terminated = False
            truncated = False
            test_agent_pos=[env.agent_pos]
            while not (terminated or truncated):
                # Select best action
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                    action = int(policy_dqn(state_tensor).argmax().item())


                # Execute action
                state, reward, terminated, truncated, _ = env.step(action)
                total_reward += reward
                test_agent_pos.append(env.agent_pos)

                #printing the path taken


            print(f"Episode {i+1}: Total Reward = {total_reward}")
            self.plot_path(env, test_agent_pos)

    def plot_path(self, env, path):
      layout = env.warehouse_layout.copy()

      # Create a figure
      plt.figure(figsize=(8, 6))
      plt.imshow(layout, cmap='Greys', origin='upper')

      # Unpack path coordinates
      rows = [pos[0] for pos in path]
      cols = [pos[1] for pos in path]

      # Plot path
      plt.plot(cols, rows, color='blue', linewidth=2, marker='o',alpha=0.7, markersize=4, label='Agent Path')

      # Mark start and goal
      start_r, start_c = env.start_pos
      goal_r, goal_c = env.goal_pos
      intermediate_r,intermediate_c=env.intermediate_pos
      plt.scatter(start_c, start_r, color='green', s=300, label='Start', marker='s',zorder=5)
      plt.scatter(goal_c, goal_r, color='red', s=300, label='Goal', marker='*',zorder=5)
      plt.scatter(intermediate_c,intermediate_r,color='yellow',label='intermediate',marker='*',zorder=5)

    # Customize grid
      plt.title("Agent Path in Warehouse")
      plt.legend(loc='lower right')
      plt.gca().invert_yaxis()  # to align (row, col) with visual top-down grid
      plt.xticks(range(env.colm))
      plt.yticks(range(env.rows))
      plt.grid(True, linestyle='--', linewidth=0.5)
      plt.savefig('path for 1 agent with intermediate, variable goal and intermediate ')
      plt.show()

      env.close()
