In [1]:
#model definition

import random
import time
from collections import deque
import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
DEVICE = tc.device("cuda" if tc.cuda.is_available() else "cpu")

import gym
import gym_snake
env = gym.make('snake-v0')

class Option(object):
    HUNGRY_RATE=20
    ROW=7
    COL=11
    
class FFC(nn.Module):
    def __init__(self, chn_in, chn_out, ker_sz):
        super().__init__()
        self.c=nn.Conv2d(chn_in,chn_out,ker_sz,padding=ker_sz//2,padding_mode="circular",bias=False)
        #self.d=nn.Dropout2d(0.5)
        #self.b=nn.BatchNorm2d(chn_out)
        self.a=nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.a(self.c(x))
    
class CTR(nn.Module):
    def __init__(self, N_in, N_out, drop_out=False):
        super().__init__()
        self.l=nn.Linear(N_in,N_out)
        self.drop_out=drop_out
        if drop_out:
            self.d=nn.Dropout(0.5)
        self.a=nn.LeakyReLU(0.1)

    def forward(self, x):
        x=self.l(x)
        if self.drop_out:
            x=self.d(x)
        return self.a(x)

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()
        self.chn_in=12
        self.chn_mid=12
        self.chn_out=12
        self.ch_adjuster=nn.Conv2d(3,self.chn_in,1,padding=0,bias=False)
        self.ffcs=nn.ModuleList([
            FFC(self.chn_in,self.chn_mid,3),
            FFC(self.chn_mid,self.chn_mid,3),
            FFC(self.chn_mid,self.chn_mid,5),
            FFC(self.chn_mid,self.chn_out,3)])
        
        self.dense=nn.Sequential(
            nn.Linear(self.chn_out*Option.ROW*Option.COL,256),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.1),
            
            CTR(256,128,True),
            CTR(128,128),
            CTR(128,32),
            
            nn.Linear(32,3),
        )
#         for ffc in self.ffcs:
#             init.xavier_uniform_(ffc.c.weight.data)
#             init.xavier_uniform_(ffc.c.bias.data)
#         init.xavier_uniform_(dense.weight.data)
#         init.xavier_uniform_(dense.bias.data)
                
    def forward(self,x):
        x=x.to(DEVICE)
        #x=x.reshape((-1,self.chn_in,Option.ROW,Option.COL))
        x=x.reshape((-1,3,Option.ROW,Option.COL))
        xa=self.ch_adjuster(x)
        x=xa
        for ffc in self.ffcs:
            x=ffc(x)+xa #residual training
        x=x.reshape(-1,self.chn_out*Option.ROW*Option.COL)
        return self.dense(x)

In [9]:
#load
net_trained=NN().cuda()
net_trained.load_state_dict(tc.load('./netw.pt'))

<All keys matched successfully>

In [16]:
import time

net_trained.eval()

with tc.no_grad():
    obs=env.reset()
    while True:
        res=net_trained(tc.tensor(obs).to(DEVICE)).detach().squeeze().tolist()
        obs, rwd, done, _ = env.step(res.index(max(res)))
        env.render()
        time.sleep(0.1)
        if done:
            break
    time.sleep(0.5)
    env.close()

In [4]:
obs=env.reset()
env.render()

True