In [65]:
import math
import random
from collections import namedtuple, deque
from random import choice

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np


In [66]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,32,3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32,64, 3)
        self.fc1 = nn.Linear(64 * 3 * 3, 400)
        self.fc2 = nn.Linear(400, 400)
        self.fc3 = nn.Linear(400, 4)

    def forward(self, x):
        x = self.pool(F.elu(self.conv1(x)))
        x = self.pool(F.elu(self.conv2(x)))
        x = torch.flatten(x,1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

In [67]:
Transition=namedtuple("Transition",("state","action","reward","nx_state"))
class Replay:
  def __init__(self,size):
    self.mem=deque([],maxlen=size)
  def push(self,*arg):
    self.mem.append(Transition(*arg))
  def get(self,batch_size):
    return random.sample(self.mem,batch_size)
  def len(self):
    return len(self.mem)

In [68]:
BATCH_SIZE=64
eps_st=1
eps_en=0.1
eps_decay=600
gamma=0.99
TAU=0.005
LR=0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy=Net().to(device)
torch.save(policy.state_dict(),'policy.pt')
opti = optim.AdamW(policy.parameters(), lr=LR, amsgrad=True)
memory=Replay(10000)
step=0
actions=([0,0,1],[0,1,0],[0,0,-1],[0,-1,0])

def select(state,cur_loc,step):
  sample=random.random()
  eps_threshold=eps_en+(eps_st-eps_en)*math.exp(-1*step/eps_decay)
  lst=[]
  mask=[1,1,1,1]
  if(cur_loc[1]==19):
    mask[1]=0
  if(cur_loc[2]==19):
    mask[0]=0
  if(cur_loc[1]==0):
    mask[3]=0
  if(cur_loc[2]==0):
    mask[2]=0
  for i in range(4):
    if(mask[i]==1):
      lst.append(i)
  if(sample<=eps_threshold):
    small_lst=[]
    for i in range(len(lst)):
      pos=cur_loc+np.array(actions[lst[i]])
      if(state[tuple(pos)]==0):
        small_lst.append(lst[i])
    if(len(small_lst)!=0):
      return torch.tensor([[random.choice(small_lst)]], device=device,dtype=torch.long)
    else:
      return torch.tensor([[random.choice(lst)]], device=device,dtype=torch.long)
  else:
    val=policy(torch.unsqueeze(state,0))[0,lst].max(-1)[1]
    num=torch.tensor(lst[val.item()],device=device).view(1,1)
    return num

In [69]:
def optimize():
  if memory.len()<BATCH_SIZE:
    return
  transitions=memory.get(BATCH_SIZE)
  batch=Transition(*zip(*transitions))
  non_final_mask=[]
  for i in range(BATCH_SIZE):
    if batch.nx_state[i]!=None:
      non_final_mask.append(i)
  non_final_next_states=torch.tensor([])
  for i in range(BATCH_SIZE):
    if batch.nx_state[i]!=None:
      non_final_next_states=torch.cat((non_final_next_states,batch.nx_state[i]))
  nx_state_values = torch.zeros(BATCH_SIZE, device=device)
  state_batch=torch.cat(batch.state)
  action_batch=torch.cat(batch.action)
  reward_batch=torch.cat(batch.reward)
  state_action_values=policy(state_batch).gather(1,action_batch)
  with torch.no_grad():
        nx_state_values[non_final_mask] = target(non_final_next_states).max(1)[0]
  expect=nx_state_values*gamma+reward_batch
  crit=nn.SmoothL1Loss()
  loss=crit(state_action_values,expect.unsqueeze(1))
  opti.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_value_(policy.parameters(), 100)
  opti.step()

In [70]:
if torch.cuda.is_available():
  num_episodes = 1000
else:
  num_episodes = 1000
def training(map):
  step=0
  ans=torch.max(map)
  policy=Net()
  policy.load_state_dict(torch.load('policy.pt'))
  policy.eval()
  target=Net()
  target.load_state_dict(torch.load('policy.pt'))
  target.eval()
  dem=0
  for i in range(num_episodes):
    state=np.zeros((3,20,20))
    state[0,19,0]=map[0,19,0]
    state[1,19,0]=1
    state[2,19,0]=1
    state=torch.tensor(state,dtype=torch.float32,device=device)
    ##state=torch.tensor(state,dtype=torch.float32,device=device).unsqueeze(0)
    cur_loc_0=np.array([0,19,0])
    cur_loc_1=np.array([1,19,0])
    cur_loc_2=np.array([2,19,0])
    for j in range(400):
      step=step+1
      action=select(state,cur_loc_0,step)
      nx_loc_0=np.array(actions[action])+cur_loc_0
      nx_loc_1=np.array(actions[action])+cur_loc_1
      nx_loc_2=np.array(actions[action])+cur_loc_2
      passed=state[tuple(nx_loc_2)]
      reward=map[tuple(nx_loc_0)]-map[tuple(cur_loc_0)]-passed*1e-6-j*1e-6
      reward = torch.tensor([reward], device=device)
      nx_state=torch.clone(state)
      nx_state[tuple(nx_loc_0)]=map[tuple(nx_loc_0)]
      nx_state[tuple(cur_loc_1)]=0
      nx_state[tuple(nx_loc_1)]=1
      nx_state[tuple(nx_loc_2)]=nx_state[tuple(nx_loc_2)]+1
      state_un=torch.unsqueeze(state,0)
      nx_state_un=torch.unsqueeze(nx_state,0)
      if(map[tuple(nx_loc_0)]!=ans and j==399):
        nx_state_un=None
        ##print(j)
        ##print(state[2])
      if(map[tuple(nx_loc_0)]==ans):
        nx_state_un=None
      memory.push(state_un,action,reward,nx_state_un)
      cur_loc_0=nx_loc_0
      cur_loc_1=nx_loc_1
      cur_loc_2=nx_loc_2
      state=torch.clone(nx_state)
      optimize()
      target_net_state_dict = target.state_dict()
      policy_net_state_dict = policy.state_dict()
      for key in policy_net_state_dict:
          target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
      target.load_state_dict(target_net_state_dict)
      if(map[tuple(cur_loc_0)]==ans):
        print(j)
        print(state[2])
        break
  torch.save(policy_net_state_dict,'policy.pt')


In [71]:
def testing(map):
  step=0
  ans=torch.max(map)
  state=np.zeros((3,20,20))
  state[0,19,0]=map[0,19,0]
  state[1,19,0]=1
  state[2,19,0]=1
  state=torch.tensor(state,dtype=torch.float32,device=device)
  ##state=torch.tensor(state,dtype=torch.float32,device=device).unsqueeze(0)
  cur_loc_0=np.array([0,19,0])
  cur_loc_1=np.array([1,19,0])
  cur_loc_2=np.array([2,19,0])
  for i in range(400):
    step=step+1
    action=select(state,cur_loc_0,step)
    nx_loc_0=np.array(actions[action])+cur_loc_0
    nx_loc_1=np.array(actions[action])+cur_loc_1
    nx_loc_2=np.array(actions[action])+cur_loc_2
    nx_state=state
    nx_state[tuple(nx_loc_0)]=map[tuple(nx_loc_0)]
    nx_state[tuple(cur_loc_1)]=0
    nx_state[tuple(cur_loc_1)]=1
    nx_state[tuple(nx_loc_2)]=nx_state[tuple(nx_loc_2)]+1
    cur_loc_0=nx_loc_0
    cur_loc_1=nx_loc_1
    cur_loc_2=nx_loc_2
    state=torch.clone(nx_state)
    if(map[tuple(nx_loc_0)]==ans):
      return i


In [None]:
for i in range(1,20):
  loaded_arr = np.loadtxt("map_arr.txt")
  load_sub_map_arr = loaded_arr.reshape(loaded_arr.shape[0], 20, 20)
  test = random.randint(0,load_sub_map_arr.shape[0])
  map=load_sub_map_arr[test,:,:]
  map=torch.tensor(map,device=device)
  map=torch.unsqueeze(map,0)
  training(map)

In [None]:
for i in range(1,100):
  loaded_arr = np.loadtxt("map_arr.txt")
  load_sub_map_arr = loaded_arr.reshape(loaded_arr.shape[0], 20, 20)
  test = random.randint(0,load_sub_map_arr.shape[0])
  map=load_sub_map_arr[test,:,:]
  training(map)
  map=torch.tensor(map,device=device)
  map=torch.unsqueeze(map,0)
  print(testing(map))