In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils import data
# Generate environment,'#' is forbidden area,'x' is target area
env = np.array([
    ['*','*','*','*','*'],
    ['*','#','#','*','*'],
    ['*','*','#','*','*'],
    ['*','#','x','#','*'],
    ['*','#','*','*','*']
])

class DQN_NN(nn.Module):
    def __init__(self):
        super(DQN_NN, self).__init__()
        self.fc1 = nn.Linear(3, 100) # input layer to hidden layer
        self.fc2 = nn.Linear(100, 1) # hidden layer to output layer

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class Deep_Q_learning():
    def __init__(self, env, _lambda,k,alpha = 0.01, step = 1000,batch_size = 100,target_update_freq = 100):
        self.env = env
        self._lambda = _lambda # discount rate
        self.k = k # maximum number of iterations
        self.alpha = alpha # learning rate
        self.step = step # number of steps in each 
        self.batch_size = batch_size # batch size
        self.target_update_freq = target_update_freq # target network update frequency
       
        self.m,self.n = self.env.shape # size of the environment
        self.action_num = 5 # number of actions
        self.v = np.zeros((self.m,self.n)) # State value
        self.q = np.zeros((self.m,self.n,self.action_num)) # Action-value
        self.policy = np.zeros((self.m,self.n),dtype=int) # Target policy

        self.main_net = DQN_NN() # main network 
        self.target_net = DQN_NN() # target network
        self.target_net.load_state_dict(self.main_net.state_dict()) # initialize target network

        self.optimizer = optim.SGD(self.main_net.parameters(), lr=self.alpha) # optimizer
        self.loss_func = nn.MSELoss() # loss function

    def next_state(self,x,y,a):
        """return the next state index"""
        xx, yy = [-1, 0, 1, 0, 0], [0, 1, 0, -1, 0] # action space（up, right, down, left, stay）
        reward = 0
        isboundary = False
        x_next = x + xx[a]
        y_next = y + yy[a]
        # check the boundary
        if x_next < 0 or x_next >= self.m or y_next < 0 or y_next >= self.n: 
            x_next, y_next = x, y
            isboundary = True
        # target area
        if self.env[x_next,y_next] == 'x' and not isboundary:
            reward = 1
        # boundary area
        elif isboundary:
            reward = -1
        # forbidden area
        elif self.env[x_next,y_next] == '#': 
            reward = -1
        return x_next, y_next, reward

    def generate_episode(self):
        """generate an episode"""
        episode = []
        # generate initial state
        while True:
            x = np.random.randint(0,self.m)
            y = np.random.randint(0,self.n)
            if self.env[x,y] != '#' and self.env[x,y] != 'x':
                break
        # generate episode by uniform distribution(Πb)
        for _ in range(self.step):
            a = np.random.randint(0,self.action_num)
            x_next, y_next, reward = self.next_state(x,y,a)
            episode.append((x,y,a,reward,x_next,y_next))
            x, y = x_next, y_next
        return episode

    def draw_mini_batch(self,episode):
        """draw a mini-batch"""
        state_action = [] # (s,a)
        reward = [] # r
        next_state = [] # s'
        for i in range(len(episode)):
            state_action.append((episode[i][0],episode[i][1],episode[i][2]))
            reward.append(episode[i][3])
            next_state.append((episode[i][4],episode[i][5]))
        # transform to tensor   
        state_action = torch.tensor(state_action)
        reward = torch.tensor(reward).reshape(-1,1)
        next_state = torch.tensor(next_state)
        data_arrays = (state_action,reward,next_state)
        # tarnsform to dataset
        dataset = data.TensorDataset(*data_arrays)
        # draw a mini-batch
        dataloader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        return dataloader


    def policy_update(self):
        """update the policy"""
        replay_buffer = self.generate_episode()
        for _ in range(self.k):
            # draw a mini-batch 
            dataloader = self.draw_mini_batch(replay_buffer)
            # update the main network
            for state_action,reward,next_state in dataloader:
                # forward
            
    def show_policy(self):
        """show the optimal policy"""
        s = "↑→↓←O" # action display
        for x in range(self.m):
            for y in range(self.n):
                print(s[self.policy[x,y]], end=" ")
            print(" ")
    
if __name__ == "__main__":
    dqn = Deep_Q_learning(env, 
                         _lambda = 0.9, 
                         k = 1,
                         alpha = 0.01,
                         step = 1000,
                         batch_size = 10,
                         target_update_freq=100)
    dqn.policy_update()
   

TypeError: 'DataLoader' object is not subscriptable