In [5]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

def obs2input(obs):
    snake = obs["snake"]
    dy,dx = obs["dot"]
    
    arr=[[[0.0 for _ in range(env.height)] for __ in range(env.width)] for ___ in range(2)]
    
    for i in range(len(snake)):
        y,x=snake[i]
        arr[0][y][x]=1.0+i/len(snake)
    arr[1][dy][dx]=1.0
    
    return arr

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.in_ch=2
        self.out_ch=3
        
        self.c1=nn.Sequential(
             nn.Conv2d(self.in_ch,3,3,padding=1,padding_mode="circular"),
             nn.LeakyReLU(0.1),
             nn.Conv2d(3,2,3,padding=1,padding_mode="circular"),
             nn.Dropout2d(0.45),
             nn.LeakyReLU(0.1),
             nn.Conv2d(2,self.out_ch,3,padding=1,padding_mode="circular"),
             nn.Dropout2d(0.45),
             nn.LeakyReLU(0.1),
        )
        
        self.d1=nn.Sequential(
            nn.Linear(self.out_ch*env.height*env.width,200),
            #nn.Linear(self.in_ch*env.height*env.width,256),
            nn.LeakyReLU(0.1),
            
            nn.Linear(200,100),
            nn.Dropout(0.45),
            nn.LeakyReLU(0.1),
            
            nn.Linear(100,4),
        )
        for m in [self.c1,self.d1]:
            if type(m)==nn.Linear or type(m)==nn.Conv2d:
                init.xavier_uniform_(m.weight.data)
                init.xavier_uniform_(m.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE).reshape((-1,self.in_ch,env.height,env.width))
        x=self.c1(x)
        x=x.reshape(-1,self.out_ch*env.height*env.width)
        return self.d1(x)

In [7]:
#train on GPU

net=NN().cuda()
net.train()
print(next(net.parameters()).is_cuda)
lossf=tc.nn.MSELoss()
opter=opt.Adam(net.parameters(),lr=1e-3)#, weight_decay=1e-5)

epoch_N=100000
dr=0.75
eps=1.0
eps_base=0.99996
eps_min=0.1

expr = deque(maxlen=10000)
batch_N = 128

j=0
last_loss=0
rsums=[]
for i in range(epoch_N):
    obs=env.reset()
    done=False
    if last_loss and i%100==0:
        print("Progress: {0:.2f} ({1}/{2})".format(i/epoch_N*100,i,epoch_N))
        print("Last Loss: ",last_loss)
        print("Avg Rwd: ", sum(rsums)/len(rsums))
        print("Max Rwd: ", max(rsums))
        rsums=[]
        losses=[]
        
        if i/100%10==0:
            tc.save(net.state_dict(),'./netw.pt')
            from IPython.display import clear_output
            clear_output(wait=True)
        
    rsum=0
    while not done:
        s1=obs2input(obs)
        qval=net(tc.tensor(s1))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,4)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        obs,rwd,done,info=env.step(actidx)
        rsum+=rwd
        expr.append((s1,actidx,rwd,obs2input(obs),int(done)))
        
        j+=1
        if len(expr)>batch_N and j%(batch_N//10)==0:
            bat=random.sample(expr,batch_N)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+dr*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            last_loss=loss.cpu().detach().numpy()
    rsums.append(rsum)

    eps=max(eps_min,eps_base**i)

#save
tc.save(net.state_dict(),'./netw.pt')

Progress: 99.10 (99100/100000)
Last Loss:  1.7279191
Avg Rwd:  -4.840399999999998
Max Rwd:  -1.4000000000000004
Progress: 99.20 (99200/100000)
Last Loss:  0.8600985
Avg Rwd:  -4.975799999999998
Max Rwd:  -3.3800000000000003
Progress: 99.30 (99300/100000)
Last Loss:  0.9464791
Avg Rwd:  -5.004999999999998
Max Rwd:  -2.5600000000000005
Progress: 99.40 (99400/100000)
Last Loss:  0.8577997
Avg Rwd:  -5.083799999999999
Max Rwd:  -3.1
Progress: 99.50 (99500/100000)
Last Loss:  1.2193669
Avg Rwd:  -4.861399999999997
Max Rwd:  -2.12
Progress: 99.60 (99600/100000)
Last Loss:  1.5765176
Avg Rwd:  -5.014199999999999
Max Rwd:  -3.2800000000000002
Progress: 99.70 (99700/100000)
Last Loss:  0.8000279
Avg Rwd:  -4.958399999999998
Max Rwd:  -2.4800000000000004
Progress: 99.80 (99800/100000)
Last Loss:  0.7863785
Avg Rwd:  -4.951199999999999
Max Rwd:  -2.6
Progress: 99.90 (99900/100000)
Last Loss:  0.80442786
Avg Rwd:  -5.075599999999999
Max Rwd:  -4.08


In [72]:
net.eval()

with tc.no_grad():
    obs=env.reset()
    while True:
        res=net(tc.tensor(obs2input(obs)).to(DEVICE)).detach().squeeze().tolist()
        obs, rwd, done, _ = env.step(res.index(max(res)))
        env.render()
        if done:
            break
    env.close()