In [1]:
!python3 -m pip install -e .

import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt

import gym
import gym_snake
env = gym.make('snake-v0')

from model import NN
from core import CUDA_AVAILABLE, DEVICE

Obtaining file:///home/lobo/gym-snake
Installing collected packages: gym-snake
  Attempting uninstall: gym-snake
    Found existing installation: gym-snake 0.0.1
    Uninstalling gym-snake-0.0.1:
      Successfully uninstalled gym-snake-0.0.1
  Running setup.py develop for gym-snake
Successfully installed gym-snake


In [2]:
#Train Parameters

BATCH_SIZE = 64
DISC_RATIO=0.86
EPISODE_CNT=50000
EPS_EXPONENT=2.6
EPS_MAX=1.0
EPS_MIN=0.1
LEARN_FREQ=16
REPLAY_MEM_SIZE = 1000000
STAT_DISPLAY_FREQ = 100
SAVE_TEMP_FREQ = 400
SAVE_VERSIONS_FREQ = EPISODE_CNT//10
TARGET_UPD_FREQ = 1000
VALIDATION_EPISODE_CNT = 100
VALIDATION_FREQ = EPISODE_CNT//100

In [4]:
print("CUDA: ",CUDA_AVAILABLE)
net=NN().cuda() if CUDA_AVAILABLE else NN()
net_target=copy.deepcopy(net)
net_target.load_state_dict(net.state_dict())
net.train()
net_target.eval()

lossf=tc.nn.MSELoss()
opter=opt.Adam(net.parameters(),lr=1e-4)
expr_mem = deque(maxlen=REPLAY_MEM_SIZE)

def validation_reward(x):
    a=[]
    for epi in range(VALIDATION_EPISODE_CNT):
        s=env.reset()
        done=False
        rwdsum=0
        while not done:
            qval=x(tc.tensor(s).to(DEVICE))
            actidx=np.argmax(qval.cpu().detach().numpy())
            s,rwd,done,info=env.step(actidx)
            rwdsum+=rwd
        a.append(rwdsum)
    return sum(a)/len(a)

eps=EPS_MAX
losses=[]
rwdsums=[]
validation_rwds = [-9999]
for epi in range(EPISODE_CNT):
    s1=env.reset()
    done=False
        
    rwdsum=0
    step=0
    while not done:
        qval=net(tc.tensor(s1).to(DEVICE))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,3)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        s2,rwd,done,info=env.step(actidx)
        rwdsum+=rwd
        expr_mem.append((s1,actidx,rwd,s2,int(done)))
        s1=s2
        
        if step%LEARN_FREQ==0 and len(expr_mem)>BATCH_SIZE:
            bat=random.sample(expr_mem,BATCH_SIZE)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net_target(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+DISC_RATIO*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            losses.append(loss.cpu().detach().numpy())
        if step%TARGET_UPD_FREQ==0:
            net_target.load_state_dict(net.state_dict())
            net_target.eval()
        step+=1
            
    rwdsums.append(rwdsum)
    eps=min(EPS_MAX,max(EPS_MIN, ((EPISODE_CNT-epi)/EPISODE_CNT)**EPS_EXPONENT ))
        
    if epi and epi%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        from IPython.display import clear_output
        clear_output(wait=True)
    if epi and epi%VALIDATION_FREQ==0:
        cur_validation_rwd = validation_reward(net)
        if cur_validation_rwd < max(validation_rwds)*0.5:
            print("Exited at validation, episode{}, max:{}, cur:{}".format(epi,max(validation_rwds),cur_validation_rwd))
            #break
        validation_rwds.append(cur_validation_rwd)
    if epi and epi%STAT_DISPLAY_FREQ==0:
        print("{}/{}({:.2f}%): LossAvg={:.4f} RwdAvg={:.4f}".format(
            epi,
            EPISODE_CNT,
            epi/EPISODE_CNT*100,
            sum(losses)/len(losses),
            sum(rwdsums)/len(rwdsums)))
        rwdsums=[]
        losses=[]


#save
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")

KeyboardInterrupt: 