In [73]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from random import choice

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [74]:
Transition=namedtuple("Transition",("state","action","reward","nx_state"))
class Replay:
  def __init__(self,size):
    self.mem=deque([],maxlen=size)
  def push(self,*arg):
    self.mem.append(Transition(*arg))
  def get(self,batch_size):
    return random.sample(self.mem,batch_size)
  def get_index(self,index):
    return self.mem[index]
  def len(self):
    return len(self.mem)

In [75]:
class DQN(nn.Module):
  def __init__(self,n_observations,n_actions):
    super(DQN,self).__init__()
    self.layer1=nn.Linear(n_observations,400)
    self.layer2=nn.Linear(400,400)
    self.layer3=nn.Linear(400,n_actions)
  def forward(self,x):
    x=F.elu(self.layer1(x))
    x=F.elu(self.layer2(x))
    return self.layer3(x)

In [76]:
BATCH_SIZE=64
eps_st=1
eps_en=0.1
eps_decay=600
gamma=0.99
TAU=0.005
LR=0.001
n_actions=4
n_para=3
n_behind=5
expo=2
state=torch.tensor([])
n_obs=len(state)
policy=DQN(n_para*n_behind,n_actions).to(device)
target=DQN(n_para*n_behind,n_actions).to(device)
torch.save(policy.state_dict(),'policy.pt')
opti=optim.AdamW(policy.parameters(),lr=LR,amsgrad=True)
memory=Replay(10000)
step=0
turn=0
actions=([0,1],[1,0],[0,-1],[-1,0])


def select(state,cur_loc,turn,step,check):
  sample=random.random()
  eps_threshold=eps_en+(eps_st-eps_en)*math.exp(-1*step/eps_decay)
  lst=[]
  mask=[1,1,1,1]
  if(cur_loc[0]==19):
    mask[1]=0
  if(cur_loc[1]==19):
    mask[0]=0
  if(cur_loc[0]==0):
    mask[3]=0
  if(cur_loc[1]==0):
    mask[2]=0
  for i in range(4):
    if(mask[i]==1):
      lst.append(i)
  if(sample<=eps_threshold or turn<n_behind):
    small_lst=[]
    for i in range(len(lst)):
      pos=cur_loc+np.array(actions[lst[i]])
      if(check[tuple(pos)]==0):
        small_lst.append(lst[i])
    if(len(small_lst)!=0):
      return torch.tensor([[random.choice(small_lst)]], device=device,dtype=torch.long)
    else:
      return torch.tensor([[random.choice(lst)]], device=device,dtype=torch.long)
  else:
    val=policy(state)[0,lst].max(-1)[1]
    num=torch.tensor(lst[val.item()],device=device).view(1,1)
    return num


In [77]:
def optimize():
  if memory.len()<BATCH_SIZE:
    return
  transitions=memory.get(BATCH_SIZE)
  batch=Transition(*zip(*transitions))
  non_final_mask=[]
  for i in range(BATCH_SIZE):
    if batch.nx_state[i]!=None:
      non_final_mask.append(i)
  non_final_next_states=torch.tensor([])
  for i in range(BATCH_SIZE):
    if batch.nx_state[i]!=None:
      non_final_next_states=torch.cat((non_final_next_states,batch.nx_state[i]))
  nx_state_values = torch.zeros(BATCH_SIZE, device=device)
  state_batch=torch.cat(batch.state)
  action_batch=torch.cat(batch.action)
  reward_batch=torch.cat(batch.reward)
  state_action_values=policy(state_batch).gather(1,action_batch)
  with torch.no_grad():
        nx_state_values[non_final_mask] = target(non_final_next_states).max(1)[0]
  expect=nx_state_values*gamma+reward_batch
  crit=nn.SmoothL1Loss()
  loss=crit(state_action_values,expect.unsqueeze(1))
  opti.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_value_(policy.parameters(), 100)
  opti.step()

In [78]:
if torch.cuda.is_available():
  num_episodes = 1000
else:
  num_episodes = 1000
def training(map):
  step=0
  ans=np.max(map)
  policy=DQN(n_para*n_behind,n_actions)
  policy.load_state_dict(torch.load('policy.pt'))
  policy.eval()
  target=DQN(n_para*n_behind,n_actions)
  target.load_state_dict(torch.load('policy.pt'))
  target.eval()
  dem=0
  for i in range(num_episodes):
    cur_loc=[19,0]
    state=torch.tensor([[]])
    state=torch.tensor(state,dtype=torch.float32,device=device).unsqueeze(0)
    check=np.zeros((20,20))
    check[19,0]=1
    turn=-1
    for zz in range(400):
      turn=turn+1
      step=step+1
      action=select(state,cur_loc,turn,step,check)
      nx_loc=np.array(actions[action])+cur_loc
      passed=check[tuple(nx_loc)]
      reward=map[tuple(nx_loc)]-map[tuple(cur_loc)]-passed*2-zz*2
      reward = torch.tensor([reward], device=device)
      pf=[map[tuple(nx_loc)]-map[tuple(cur_loc)],action,passed]
      nx_pos=torch.tensor([pf],dtype=torch.float32,device=device)
      if(map[tuple(nx_loc)]!=ans and zz==399):
        nx_state_un=None
      if(map[tuple(nx_loc)]==ans):
        nx_state_un=None
      if(turn>0):
        nx_state=torch.cat((state,nx_pos),-1)
      else:
        nx_state=torch.clone(nx_pos)
      if(nx_state.size()[1]==n_para*(n_behind+1)):
        nx_state=torch.clone(nx_state[:,n_para:])
      if(state.size()[1]==n_para*n_behind):
        memory.push(state,action,reward,nx_state)
      check[tuple(cur_loc)]=check[tuple(cur_loc)]+1
      cur_loc=nx_loc
      state=torch.clone(nx_state)
      optimize()
      target_net_state_dict = target.state_dict()
      policy_net_state_dict = policy.state_dict()
      for key in policy_net_state_dict:
          target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
      target.load_state_dict(target_net_state_dict)
      if(map[tuple(cur_loc)]==ans):
        print(zz)
        print(check)
        break
  torch.save(policy_net_state_dict,'policy.pt')


In [79]:
def testing(map):
  ans=np.max(map)
  state=torch.tensor([[]])
  state=torch.tensor(state,dtype=torch.float32,device=device).unsqueeze(0)
  check=np.zeros((20,20))
  cur_loc=np.array([19,0])
  ans=np.max(map)
  turn=-1
  for i in range(400):
    turn=turn+1
    action=select(state,cur_loc,turn,1e9)
    nx_loc=np.array(actions[action])+cur_loc
    if(map[tuple(nx_loc)]==ans):
      return turn
    passed=check[tuple(nx_loc)]
    pf=[map[tuple(nx_loc)]-map[tuple(cur_loc)],action,passed]
    nx_pos=torch.tensor([pf],dtype=torch.float32,device=device)
    if(turn>0):
      nx_state=torch.cat((state,nx_pos),-1)
    else:
      nx_state=torch.clone(nx_pos)
    if(nx_state.size()[1]==n_para*(n_behind+1)):
      nx_state=torch.clone(nx_state[:,n_para:])
    check[tuple(cur_loc)]=check[tuple(cur_loc)]+1
    cur_loc=nx_loc
    state=torch.clone(nx_state)


In [None]:
for i in range(1,20):
  loaded_arr = np.loadtxt("map_arr.txt")
  load_sub_map_arr = loaded_arr.reshape(loaded_arr.shape[0], 20, 20)
  test = random.randint(0,load_sub_map_arr.shape[0])
  map=load_sub_map_arr[test,:,:]
  training(map)

In [None]:
for i in range(1,100):
  loaded_arr = np.loadtxt("map_arr.txt")
  load_sub_map_arr = loaded_arr.reshape(loaded_arr.shape[0], 20, 20)
  test = random.randint(0,load_sub_map_arr.shape[0])
  map=load_sub_map_arr[test,:,:]
  print(testing(map))
