In [5]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

class Option(object):
    HUNGRY_RATE=20
    ROW=7
    COL=11

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.in_ch=3
        self.out_ch=4
        
        self.c1=nn.Sequential(
            nn.Conv2d(self.in_ch,6,3,padding=1,padding_mode="circular"),
            nn.Dropout2d(0.5),
            nn.LeakyReLU(0.1),
            
            nn.Conv2d(6,self.out_ch,3,padding=1,padding_mode="circular"),
            nn.BatchNorm2d(self.out_ch)
        )
        
        self.d1=nn.Sequential(
            nn.Linear(self.out_ch*Option.ROW*Option.COL,512),
            nn.LeakyReLU(0.1),
            
            nn.Linear(512,256),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.1),
            
            nn.Linear(256,3),
        )
        for m in [self.c1,self.d1]:
            if type(m)==nn.Linear or type(m)==nn.Conv2d:
                init.xavier_uniform_(m.weight.data)
                init.xavier_uniform_(m.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE)
        x=x.reshape((-1,self.in_ch,Option.ROW,Option.COL))
        x=self.c1(x)
        x=x.reshape(-1,self.out_ch*Option.ROW*Option.COL)
        return self.d1(x)

In [6]:
#train on GPU

net=NN().cuda()
net.train()
print(next(net.parameters()).is_cuda)
lossf=tc.nn.MSELoss()
#opter=opt.SGD(net.parameters(),lr=1e-4)#, weight_decay=4e-7)
opter=opt.Adam(net.parameters(),lr=1e-3,weight_decay=3e-7)
sched = opt.lr_scheduler.StepLR(opter,step_size=100,gamma=0.9945)

epoch_N=100000
dr=0.85
eps=1.0
eps_base=0.99993
eps_min=0.1

expr = deque(maxlen=100000)
batch_N = 150

j=0
last_loss=0
rsums=[]
for i in range(epoch_N):
    obs=env.reset()
    done=False
    if last_loss and i%100==0:
        print("Progress: {0:.2f} ({1}/{2})".format(i/epoch_N*100,i,epoch_N))
        print("Last Loss: ",last_loss)
        print("Avg Rwd: ", sum(rsums)/len(rsums))
        print("Max Rwd: ", max(rsums))
        rsums=[]
        losses=[]
        
        if i/100%10==0:
            tc.save(net.state_dict(),'./netw.pt')
            from IPython.display import clear_output
            clear_output(wait=True)
        
    rsum=0
    while not done:
        s1=obs
        qval=net(tc.tensor(s1))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,3)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        obs,rwd,done,info=env.step(actidx)
        rsum+=rwd
        while expr and expr[0][2]>0 and random.randint(0,2): #prefer to memorize positive reward more
            expr.append(expr.popleft())
        expr.append((s1,actidx,rwd,obs,int(done)))
        
        j+=1
        if len(expr)>batch_N and j%(batch_N//10)==0:
            bat=random.sample(expr,batch_N)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+dr*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            sched.step()
            last_loss=loss.cpu().detach().numpy()
    rsums.append(rsum)

    eps=max(eps_min,eps_base**i)

#save
tc.save(net.state_dict(),'./netw.pt')

Progress: 11.10 (11100/100000)
Last Loss:  0.059996046
Avg Rwd:  1.28
Max Rwd:  13
Progress: 11.20 (11200/100000)
Last Loss:  0.07775326
Avg Rwd:  0.9
Max Rwd:  14
Progress: 11.30 (11300/100000)
Last Loss:  0.060875576
Avg Rwd:  1.62
Max Rwd:  11
Progress: 11.40 (11400/100000)
Last Loss:  0.07968573
Avg Rwd:  1.8
Max Rwd:  18
Progress: 11.50 (11500/100000)
Last Loss:  0.076544285
Avg Rwd:  1.86
Max Rwd:  13


KeyboardInterrupt: 