In [22]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import namedtuple

# Define the DQN network
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define replay buffer
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Define the DQN agent
class DQNAgent:
    def __init__(self, state_size, action_size, lr=0.001, gamma=0.95, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.state_size = state_size
        self.action_size = action_size
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.policy_net = DQN(state_size, action_size)
        self.target_net = DQN(state_size, action_size)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(2000)

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        with torch.no_grad():
            return self.policy_net(state).argmax().item()

    def remember(self, state, action, next_state, reward, done):
        self.memory.push(state, action, next_state, reward, done)

    def replay(self, batch_size):
        print("inside replay: ")
        print("self.memory: ", len(self.memory))
        if len(self.memory) < batch_size:
            return
        transitions = self.memory.sample(batch_size)
        batch = Transition(*zip(*transitions))
        # Convert batch.state into a tensor with the appropriate shape and type
        state_batch = torch.tensor(np.stack(batch.state), dtype=torch.float32)
        # state_batch = torch.tensor(batch.state, dtype=torch.float32)
        action_batch = torch.tensor(batch.action, dtype=torch.int64).unsqueeze(-1)
        reward_batch = torch.tensor(batch.reward, dtype=torch.float32)
        # Convert each individual state in batch.next_state to a tensor and stack them
        next_state_batch = torch.stack([torch.tensor(state, dtype=torch.float32) for state in batch.next_state])

        # next_state_batch = torch.tensor(batch.next_state, dtype=torch.float32)
        done_mask = torch.tensor(batch.done, dtype=torch.bool)

        state_action_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1))

        # state_action_values = self.policy_net(state_batch).gather(1, action_batch)
        
        # Compute next state values only for non-terminal states
        next_state_values = torch.zeros(batch_size, dtype=torch.float32)
        non_terminal_mask = ~done_mask
        if torch.any(non_terminal_mask):
            non_terminal_next_states = next_state_batch[non_terminal_mask]
            next_state_values[non_terminal_mask] = self.target_net(non_terminal_next_states).max(1)[0].detach()
        
        expected_state_action_values = reward_batch + self.gamma * next_state_values.unsqueeze(1)
        
        loss = nn.functional.smooth_l1_loss(state_action_values, expected_state_action_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

# Define the environment
class SimpleEnv(gym.Env):
    def __init__(self):
        self.grid_size = 5
        self.observation_space = gym.spaces.Discrete(self.grid_size)
        self.action_space = gym.spaces.Discrete(4)  # 4 possible actions: up, down, left, right
        self.agent_pos = np.array([0, 0])
        self.goal_pos = np.array([self.grid_size-1, self.grid_size-1])
        self.obstacle_pos = np.array([[2, 2], [3, 3]])

    def reset(self):
        self.agent_pos = np.array([0, 0])
        return self.agent_pos

    def step(self, action):
        if action == 0:  # up
            self.agent_pos[0] = max(0, self.agent_pos[0] - 1)
        elif action == 1:  # down
            self.agent_pos[0] = min(self.grid_size - 1, self.agent_pos[0] + 1)
        elif action == 2:  # left
            self.agent_pos[1] = max(0, self.agent_pos[1] - 1)
        elif action == 3:  # right
            self.agent_pos[1] = min(self.grid_size - 1, self.agent_pos[1] + 1)

        done = bool(np.array_equal(self.agent_pos, self.goal_pos))
        reward = -1 if not done else 0
        return self.agent_pos, reward, done, {}

    def render(self, mode='human'):
        grid = np.zeros((self.grid_size, self.grid_size))
        grid[self.agent_pos[0], self.agent_pos[1]] = 0.5  # Agent
        grid[self.goal_pos[0], self.goal_pos[1]] = 1.0  # Goal
        for obstacle in self.obstacle_pos:
            grid[obstacle[0], obstacle[1]] = -1.0  # Obstacle

        print(grid)

# Instantiate the environment and agent
env = SimpleEnv()
state_size = 2
action_size = 4
agent = DQNAgent(state_size, action_size)

batch_size = 32
EPISODES = 100

# Main training loop
for e in range(EPISODES):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    print(".... EPISODES ....", e)
    print(".... States .....")
    print(state)
    for time in range(100):
        print("Time: ", time)
        env.render()
        action = agent.act(state)
        print("Actions: ", env.step(action))
        next_state, reward, done, _ = env.step(action)
        next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
        # print("next_state: ", next_state)
        reward = reward if not done else -10
        print("reward: ", reward)
        agent.remember(state, action, next_state, reward, done)
        state = next_state
        print("Replay: ", agent.replay(batch_size))
        # agent.replay(batch_size)
        if done:
            print("episode: {}/{}, score: {}, epsilon: {:.2f}"
                  .format(e, EPISODES, time, agent.epsilon))
            break
    if e % 10 == 0:
        agent.update_target_network()


.... EPISODES .... 0
.... States .....
tensor([[0., 0.]])
Time:  0
[[ 0.5  0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.  -1.   0.   0. ]
 [ 0.   0.   0.  -1.   0. ]
 [ 0.   0.   0.   0.   1. ]]
Actions:  (array([0, 1]), -1, False, {})
reward:  -1
inside replay: 
self.memory:  1
Replay:  None
Time:  1
[[ 0.   0.   0.5  0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.  -1.   0.   0. ]
 [ 0.   0.   0.  -1.   0. ]
 [ 0.   0.   0.   0.   1. ]]
Actions:  (array([0, 1]), -1, False, {})
reward:  -1
inside replay: 
self.memory:  2
Replay:  None
Time:  2
[[ 0.5  0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.  -1.   0.   0. ]
 [ 0.   0.   0.  -1.   0. ]
 [ 0.   0.   0.   0.   1. ]]
Actions:  (array([0, 1]), -1, False, {})
reward:  -1
inside replay: 
self.memory:  3
Replay:  None
Time:  3
[[ 0.   0.   0.5  0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.  -1.   0.   0. ]
 [ 0.   0.   0.  -1.   0. ]
 [ 0.   0.   0.   0.   1. ]]
Actions:  (array([0, 2]), -1, False, {})
rewa

  next_state_batch = torch.stack([torch.tensor(state, dtype=torch.float32) for state in batch.next_state])


RuntimeError: index 2 is out of bounds for dimension 1 with size 1