In [28]:
import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt

import gym
import gym_snake
env = gym.make('snake-v0')

from model import NN
from core import CUDA_AVAILABLE, DEVICE

In [29]:
#Train Parameters

BATCH_SIZE = 64
DISC_RATIO=0.9
EPISODE_CNT=50000
EPS_EXPONENT=2.6
EPS_MAX=1.0
EPS_MIN=0.1
REPLAY_MEM_SIZE = 500000
STAT_DISPLAY_FREQ = 50
SAVE_TEMP_FREQ = 400
SAVE_VERSIONS_FREQ = EPISODE_CNT//10
TARGET_UPD_FREQ = 1000

In [None]:
print("CUDA: ",CUDA_AVAILABLE)
net=NN().cuda() if CUDA_AVAILABLE else NN()
net_target=copy.deepcopy(net)
net_target.load_state_dict(net.state_dict())
net.train()
#net_target.eval()

lossf=tc.nn.MSELoss()
opter=opt.Adam(net.parameters())
expr_mem = deque(maxlen=REPLAY_MEM_SIZE)

eps=EPS_MAX
train_i=0
losses=[]
rwdsums=[]
for epi in range(EPISODE_CNT):
    s1=env.reset()
    done=False
        
    rwdsum=0
    while not done:
        qval=net(tc.tensor(s1).to(DEVICE))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,3)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        s2,rwd,done,info=env.step(actidx)
        rwdsum+=rwd
        expr_mem.append((s1,actidx,rwd,s2,int(done)))
        s1=s2
        
        if len(expr_mem)>BATCH_SIZE:
            train_i+=1
            bat=random.sample(expr_mem,BATCH_SIZE)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net_target(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+DISC_RATIO*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            losses.append(loss.cpu().detach().numpy())
            if train_i%TARGET_UPD_FREQ==0:
                net_target.load_state_dict(net.state_dict())
                #net_target.eval()
            
    rwdsums.append(rwdsum)
    eps=min(EPS_MAX,max(EPS_MIN, ((EPISODE_CNT-epi)/EPISODE_CNT)**EPS_EXPONENT ))
    
        
    if epi and epi%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        from IPython.display import clear_output
        clear_output(wait=True)
    if epi and epi%STAT_DISPLAY_FREQ==0:
        print("{}/{}({:.2f}%): LossAvg={:.4f} RwdAvg={:.4f}".format(
            epi,
            EPISODE_CNT,
            epi/EPISODE_CNT*100,
            sum(losses)/len(losses),
            sum(rwdsums)/len(rwdsums)))
        rwdsums=[]
        losses=[]

#save
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")

CUDA:  True
