In [2]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

class Option(object):
    HUNGRY_RATE=20
    ROW=7
    COL=11
    
class FFC(nn.Module):
    def __init__(self, chn_in, chn_out, ker_sz):
        super().__init__()
        self.c=nn.Conv2d(chn_in,chn_out,ker_sz,padding=ker_sz//2,padding_mode="circular",bias=False)
        #self.d=nn.Dropout2d(0.5)
        #self.b=nn.BatchNorm2d(chn_out)
        self.a=nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.a(self.c(x))
    
class CTR(nn.Module):
    def __init__(self, N_in, N_out, drop_out=False):
        super().__init__()
        self.l=nn.Linear(N_in,N_out)
        self.drop_out=drop_out
        if drop_out:
            self.d=nn.Dropout(0.5)
        self.a=nn.LeakyReLU(0.1)

    def forward(self, x):
        x=self.l(x)
        if self.drop_out:
            x=self.d(x)
        return self.a(x)

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.chn_in=12
        self.chn_mid=12
        self.chn_out=12
        self.ch_adjuster=nn.Conv2d(3,self.chn_in,1,padding=0,bias=False)
        self.ffcs=nn.ModuleList([
            FFC(self.chn_in,self.chn_mid,3),
            FFC(self.chn_mid,self.chn_mid,3),
            FFC(self.chn_mid,self.chn_mid,5),
            FFC(self.chn_mid,self.chn_out,3)])
        
        self.dense=nn.Sequential(
            nn.Linear(self.chn_out*Option.ROW*Option.COL,256),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.1),
            
            CTR(256,128,True),
            CTR(128,128),
            CTR(128,32),
            
            nn.Linear(32,3),
        )
#         for ffc in self.ffcs:
#             init.xavier_uniform_(ffc.c.weight.data)
#             init.xavier_uniform_(ffc.c.bias.data)
#         init.xavier_uniform_(dense.weight.data)
#         init.xavier_uniform_(dense.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE)
        #x=x.reshape((-1,self.chn_in,Option.ROW,Option.COL))
        x=x.reshape((-1,3,Option.ROW,Option.COL))
        xa=self.ch_adjuster(x)
        x=xa
        for ffc in self.ffcs:
            x=ffc(x)+xa #residual training
        x=x.reshape(-1,self.chn_out*Option.ROW*Option.COL)
        return self.dense(x)

In [3]:
#train on GPU

net=NN().cuda()
net.train()
print(next(net.parameters()).is_cuda)
lossf=tc.nn.MSELoss()
opter=opt.Adam(net.parameters())

epoch_N=100000
dr=0.8
eps=1.0
eps_base=0.99993
eps_min=0.05

expr = deque(maxlen=100000)
batch_N = 64

j=0
last_loss=0
rsums=[]
for i in range(epoch_N):
    obs=env.reset()
    done=False
    if last_loss and i%100==0:
        print("Progress: {0:.2f} ({1}/{2})".format(i/epoch_N*100,i,epoch_N))
        print("Last Loss: ",last_loss)
        print("Avg Rwd: ", sum(rsums)/len(rsums))
        print("Max Rwd: ", max(rsums))
        rsums=[]
        losses=[]
        
        if i/100%10==0:
            tc.save(net.state_dict(),'./netw.pt')
            from IPython.display import clear_output
            clear_output(wait=True)
        
    rsum=0
    while not done:
        s1=obs
        qval=net(tc.tensor(s1))
        actidx=None
        if random.random()<eps:
            actidx=np.random.randint(0,3)
        else:
            actidx=np.argmax(qval.cpu().detach().numpy())
        obs,rwd,done,info=env.step(actidx)
        rsum+=rwd
        expr.append((s1,actidx,rwd,obs,int(done)))
        
        j+=1
        if len(expr)>batch_N and j%(batch_N//10)==0:
            bat=random.sample(expr,batch_N)
            s1bat=tc.tensor([s1 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            abat=tc.tensor([a for (s1,a,r,s2,d) in bat]).to(DEVICE)
            rbat=tc.tensor([r for (s1,a,r,s2,d) in bat]).to(DEVICE)
            s2bat=tc.tensor([s2 for (s1,a,r,s2,d) in bat]).to(DEVICE)
            dbat=tc.tensor([d for (s1,a,r,s2,d) in bat]).to(DEVICE)

            q1=net(s1bat)
            with tc.no_grad():
                q2=net(s2bat)
            x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
            y=rbat+dr*((1-dbat)*tc.max(q2,dim=1)[0])
            loss=lossf(x,y)
            opter.zero_grad()
            loss.backward()
            opter.step()
            last_loss=loss.cpu().detach().numpy()
    rsums.append(rsum)

    eps=max(eps_min,eps_base**i)

#save
tc.save(net.state_dict(),'./netw.pt')

True


KeyboardInterrupt: 