In [3]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

class Option(object):
    HUNGRY_RATE=20
    ROW=7
    COL=11

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.in_ch=3
        self.out_ch=3
        
        self.c1=nn.Sequential(
            nn.Conv2d(self.in_ch,3,3,padding=1,padding_mode="circular"),
            nn.LeakyReLU(0.1),
            
            nn.Conv2d(3,self.out_ch,3,padding=1,padding_mode="circular"),
        )
        
        self.d1=nn.Sequential(
            nn.Linear(self.out_ch*Option.ROW*Option.COL,512),
            nn.LeakyReLU(0.1),
            
            nn.Linear(512,256),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.1),
            
            nn.Linear(256,3),
        )
        for m in [self.c1,self.d1]:
            if type(m)==nn.Linear or type(m)==nn.Conv2d:
                init.xavier_uniform_(m.weight.data)
                init.xavier_uniform_(m.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE)
        x=x.reshape((-1,self.in_ch,Option.ROW,Option.COL))
        x=self.c1(x)
        x=x.reshape(-1,self.out_ch*Option.ROW*Option.COL)
        return self.d1(x)

In [None]:
#train on GPU

net=NN().cuda()
net.train()
print(next(net.parameters()).is_cuda)
lossf=tc.nn.MSELoss()
#opter=opt.SGD(net.parameters(),lr=1e-4)#, weight_decay=4e-7)
opter=opt.Adam(net.parameters(),lr=2e-4)

epoch_N=100000
dr=0.85
eps=1.0
eps_base=0.99993
eps_min=0.1

expr = deque(maxlen=100000)
batch_N = 150

j=0
last_loss=0
rsums=[]
for i in range(epoch_N):
    obs=env.reset()
    done=False
    if last_loss and i%100==0:
        print("Progress: {0:.2f} ({1}/{2})".format(i/epoch_N*100,i,epoch_N))
        print("Last Loss: ",last_loss)
        print("Avg Rwd: ", sum(rsums)/len(rsums))
        print("Max Rwd: ", max(rsums))
        rsums=[]
        losses=[]
        
        if i/100%10==0:
            tc.save(net.state_dict(),'./netw.pt')
            from IPython.display import clear_output
            clear_output(wait=True)
        
    rsum=0
    while not done:
        s1=obs
        qval=net(tc.tensor(s1))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,3)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        obs,rwd,done,info=env.step(actidx)
        rsum+=rwd
        while expr and expr[0][2]>0 and random.randint(0,2): #prefer to memorize positive reward more
            expr.append(expr.popleft())
        expr.append((s1,actidx,rwd,obs,int(done)))
        
        j+=1
        if len(expr)>batch_N and j%(batch_N//10)==0:
            bat=random.sample(expr,batch_N)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+dr*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            last_loss=loss.cpu().detach().numpy()
    rsums.append(rsum)

    eps=max(eps_min,eps_base**i)

#save
tc.save(net.state_dict(),'./netw.pt')

Progress: 15.10 (15100/100000)
Last Loss:  0.08240378
Avg Rwd:  6.29
Max Rwd:  15
Progress: 15.20 (15200/100000)
Last Loss:  0.09292434
Avg Rwd:  6.54
Max Rwd:  17
Progress: 15.30 (15300/100000)
Last Loss:  0.050171047
Avg Rwd:  6.23
Max Rwd:  16
Progress: 15.40 (15400/100000)
Last Loss:  0.028395174
Avg Rwd:  6.5
Max Rwd:  17
Progress: 15.50 (15500/100000)
Last Loss:  0.07063913
Avg Rwd:  5.98
Max Rwd:  15


In [None]:
net.eval()

with tc.no_grad():
    obs=env.reset()
    while True:
        res=net(tc.tensor(obs).to(DEVICE)).detach().squeeze().tolist()
        obs, rwd, done, _ = env.step(res.index(max(res)))
        env.render()
        if done:
            break
    env.close()