In [None]:
!python3 -m pip install -e .

from IPython.core.debugger import set_trace

import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt

import gym
import gym_snake
env = gym.make('snake-v0')

from model import NN
from core import CUDA_AVAILABLE, DEVICE

from per_memory import PERMemory

In [None]:
#Train Parameters

BATCH_SIZE = 64
DISC_RATIO=0.99
EPISODE_CNT=20000
EPS_EXPONENT=2.7
EPS_MAX=1.0
EPS_MIN=0.01
LEARN_FREQ=8
LEARNING_RATE=1e-4
REPLAY_MEM_SIZE = 1000000
STAT_DISPLAY_FREQ = 100
SAVE_TEMP_FREQ = 500
SAVE_VERSIONS_FREQ = EPISODE_CNT//10
TARGET_UPD_FREQ = LEARN_FREQ*5000
VALIDATION_EPISODE_CNT = 100
VALIDATION_FREQ = EPISODE_CNT//100

In [None]:
print("CUDA: ",CUDA_AVAILABLE)
net=NN().cuda() if CUDA_AVAILABLE else NN()
net_target=copy.deepcopy(net)
net_target.load_state_dict(net.state_dict())
net.train()
net_target.eval()

opter=opt.Adam(net.parameters(),lr=LEARNING_RATE)
perm = PERMemory(REPLAY_MEM_SIZE,alpha=0.7,beta=0.5)

def train():
    idxs,isws,bat=zip(*perm.sample(BATCH_SIZE,epi/EPISODE_CNT))
    
    s1bat=tc.tensor([s1[0] for (s1,a,r,s2,d) in bat]).to(DEVICE)
    starv1bat=tc.tensor([s1[1] for (s1,a,r,s2,d) in bat]).to(DEVICE)
    abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
    rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
    s2bat=tc.tensor([s2[0] for (s1,a,r,s2,d) in bat]).to(DEVICE)
    starv2bat=tc.tensor([s2[1] for (s1,a,r,s2,d) in bat]).to(DEVICE)
    dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

    q1=net(s1bat,starv1bat)
    with tc.no_grad():
        q2t=net_target(s2bat,starv2bat)
        q2=net(s2bat,starv2bat)
    x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
    actidxs_from_net=tc.argmax(q2,dim=1)
    actidx_for_take=tc.tensor([i*q2t.shape[1]+actidxs_from_net[i] for i in range(BATCH_SIZE)]).to(DEVICE)
    y=rbat+DISC_RATIO*((1-dbat)*q2t.take(actidx_for_take))

    loss=(tc.tensor(isws).to(DEVICE)*(x-y)**2).mean()
    opter.zero_grad()
    loss.backward()
    opter.step()
    losses.append(loss.cpu().detach().numpy())

eps=EPS_MAX
losses=[]
rwdsums=[]
for epi in range(1,EPISODE_CNT+1):
    s1=env.reset()
    done=False
        
    rwdsum=0
    step=1
    while not done:
        q1=net(tc.tensor(s1[0]).to(DEVICE),tc.tensor(s1[1]).to(DEVICE))
        actidx=(np.random.randint(0,3) if random.random()<eps else 
                np.argmax(q1.cpu().detach().numpy()))
        val1=q1[0][actidx]
        s2,rwd,done,info=env.step(actidx)
        s2=s2
        with tc.no_grad():
            q2t=net_target(tc.tensor(s2[0]).to(DEVICE),tc.tensor(s2[1]))
            q2=net(tc.tensor(s2[0]).to(DEVICE),tc.tensor(s2[1]))
        val2=rwd+DISC_RATIO*((1-done)*q2t[0][tc.argmax(q2,dim=1)])
        td=val2-val1
        perm.push(td,(s1,actidx,rwd,s2,int(done)))
        
        s1=s2
        rwdsum+=rwd
        
        if step%LEARN_FREQ==0 and perm.cnt>=BATCH_SIZE:
            train()
        if step%TARGET_UPD_FREQ==0:
            net_target.load_state_dict(net.state_dict())
            net_target.eval()
        step+=1
            
    rwdsums.append(rwdsum)
    eps=min(EPS_MAX,max(EPS_MIN, ((EPISODE_CNT-epi)/EPISODE_CNT)**EPS_EXPONENT ))
        
    if epi%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        from IPython.display import clear_output
        clear_output(wait=True)
    if epi%STAT_DISPLAY_FREQ==0:
        print("{}/{}({:.2f}%): LossAvg={:.4f} RwdAvg={:.4f}".format(
            epi,
            EPISODE_CNT,
            epi/EPISODE_CNT*100,
            float(sum(losses))/len(losses),
            float(sum(rwdsums))/len(rwdsums)))
        rwdsums=[]
        losses=[]


#save
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")

In [None]:
from core import render_inline
#load
net=NN().cuda()
net.load_state_dict(tc.load('./netw.pt'))
net.eval();
with tc.no_grad():
    obs=env.reset()
    while True:
        res=net(*map(lambda x:tc.tensor(x).to(DEVICE),obs)).detach().squeeze().tolist()
        obs, rwd, done, _ = env.step(res.index(max(res)))
        #env.render()
        render_inline(env)
        if done:
            break
env.close()