In [7]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

def obs2input(obs):
    snake = obs["snake"]
    dy,dx = obs["dot"]
    
    arr=[[[0.0 for _ in range(env.height)] for __ in range(env.width)] for ___ in range(2)]
    
    for i in range(len(snake)):
        y,x=snake[i]
        arr[0][y][x]=1.0+i/len(snake)
    arr[1][dy][dx]=1.0
    
    return arr

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.in_ch=2
        self.out_ch=4
        
        self.c1=nn.Sequential(
#             nn.Conv2d(self.in_ch,10,3,padding=1,padding_mode="circular"),
#             nn.Dropout2d(0.3),
#             nn.LeakyReLU(0.1),
#             nn.Conv2d(10,15,3,padding=1,padding_mode="circular"),
#             nn.Dropout2d(0.3),
#             nn.LeakyReLU(0.1),
#             nn.Conv2d(15,8,3,padding=1,padding_mode="circular"),
#             nn.Dropout2d(0.3),
#             nn.LeakyReLU(0.1),
#             nn.Conv2d(8,self.out_ch,3,padding=1,padding_mode="circular"),
#             nn.Dropout2d(0.3),
#             nn.LeakyReLU(0.1),
        )
        
        self.d1=nn.Sequential(
            #nn.Linear(self.out_ch*env.height*env.width,256),
            nn.Linear(self.in_ch*env.height*env.width,256),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.1),
            
            nn.Linear(256,128),
            nn.Dropout(0.35),
            nn.LeakyReLU(0.1),
            
            nn.Linear(128,4),
        )
        for m in [self.c1,self.d1]:
            if type(m)==nn.Linear or type(m)==nn.Conv2d:
                init.xavier_uniform_(m.weight.data)
                init.xavier_uniform_(m.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE).reshape((-1,self.in_ch,env.height,env.width))
        #x=self.c1(x)
        x=x.reshape(-1,self.in_ch*env.height*env.width)
        return self.d1(x)

In [8]:
#train on GPU

net=NN().cuda()
net.train()
print(next(net.parameters()).is_cuda)
lossf=tc.nn.MSELoss()
opter=opt.Adam(net.parameters(),lr=5e-5)#, weight_decay=5e-6)

epoch_N=20000
dr=0.75
eps=1.0
eps_base=0.9998
eps_min=0.05

expr = deque(maxlen=10000)
batch_N = 128

j=0
last_loss=0
rsums=[]
for i in range(epoch_N):
    obs=env.reset()
    done=False
    if last_loss and i%100==0:
        print("Progress: {0:.2f} ({1}/{2})".format(i/epoch_N*100,i,epoch_N))
        print("Last Loss: ",last_loss)
        print("Avg Rwd: ", sum(rsums)/len(rsums))
        print("Max Rwd: ", max(rsums))
        rsums=[]
        losses=[]
        
        if i/100%10==0:
            tc.save(net.state_dict(),'./netw.pt')
            from IPython.display import clear_output
            clear_output(wait=True)
        
    rsum=0
    while not done:
        s1=obs2input(obs)
        qval=net(tc.tensor(s1))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,4)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        obs,rwd,done,info=env.step(actidx)
        rsum+=rwd
        expr.append((s1,actidx,rwd,obs2input(obs),int(done)))
        
        j+=1
        if len(expr)>batch_N and j%(batch_N//10)==0:
            bat=random.sample(expr,batch_N)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+dr*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            last_loss=loss.cpu().detach().numpy()
    rsums.append(rsum)

    eps=max(eps_min,eps_base**i)

#save
tc.save(net.state_dict(),'./netw.pt')

Progress: 95.50 (19100/20000)
Last Loss:  0.27602392
Avg Rwd:  -4.34
Max Rwd:  -1
Progress: 96.00 (19200/20000)
Last Loss:  0.5702325
Avg Rwd:  -4.25
Max Rwd:  0
Progress: 96.50 (19300/20000)
Last Loss:  0.50435394
Avg Rwd:  -4.22
Max Rwd:  2
Progress: 97.00 (19400/20000)
Last Loss:  0.80050725
Avg Rwd:  -4.41
Max Rwd:  0
Progress: 97.50 (19500/20000)
Last Loss:  0.4726485
Avg Rwd:  -4.39
Max Rwd:  0
Progress: 98.00 (19600/20000)
Last Loss:  0.42089307
Avg Rwd:  -4.25
Max Rwd:  1
Progress: 98.50 (19700/20000)
Last Loss:  0.09751186
Avg Rwd:  -4.2
Max Rwd:  0
Progress: 99.00 (19800/20000)
Last Loss:  0.9190427
Avg Rwd:  -4.06
Max Rwd:  3
Progress: 99.50 (19900/20000)
Last Loss:  0.021211188
Avg Rwd:  -4.34
Max Rwd:  0


In [146]:
obs=env.reset()
while True:
    res=net(tc.tensor(obs2input(obs)).to(DEVICE)).detach().squeeze().tolist()
    obs, rwd, done, _ = env.step(res.index(max(res)))
    env.render()
    if done:
        break
env.close()