In [1]:
from abc import ABC, abstractmethod
from mlagents_envs.base_env import ActionTuple
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.environment import UnityEnvironment
import numpy as np
from tqdm import tqdm

class UnityEnvironmentAbstract(ABC):
    @abstractmethod
    def step(action):
        pass

    @abstractmethod
    def reset():
        pass

    @abstractmethod
    def close():
        pass

In [2]:
class MiniFootballEnv(UnityEnvironmentAbstract):
    def __init__(self, path='./mini_football_windows/Mini Football Environment.exe'):
        self.channel = EngineConfigurationChannel()
        self.env = UnityEnvironment(path, seed=42, side_channels=[self.channel])
        self.state = None

        self.env.step()
        self.behavior_name = list(self.env.behavior_specs)[0]
        
        self.reset()

    def step(self, action):

        terminated = False
        truncated = False
        info = None
        step_reward = 0

        transformed_action = ActionTuple(np.array(action[:2]).reshape(1,2).astype(np.float32), np.array(action[-1]).reshape(1,1).astype(np.float32))

        self.env.set_actions(behavior_name=self.behavior_name, action=transformed_action)
        self.env.step()

        decision_steps, terminal_steps = self.env.get_steps(self.behavior_name)
        step_reward += decision_steps.reward[0]

        for agent_id in terminal_steps:
            step_reward += terminal_steps.reward[0]
            terminated = True

        if not terminated:
            observation = decision_steps.obs[0][0]
            
        else:
            observation = []
            #self.reset()
        
        self.state = observation

        return observation, step_reward, terminated, truncated, info

    def reset(self):
        self.env.reset()
        self.env.step()
        decision_steps, terminal_steps = self.env.get_steps(self.behavior_name)
        self.state = decision_steps.obs[0][0]

    def close(self):
        pass

    def set_channel_params(self, width, height, quality_level, time_scale, target_frame_rate, capture_frame_rate):
        self.channel.set_configuration_parameters(
            width= width,
            height= height,
            quality_level= quality_level,
            time_scale= time_scale,
            target_frame_rate= target_frame_rate,
            capture_frame_rate= capture_frame_rate,
            )


In [3]:
import onnxruntime as rt

# Load Brain
sess = rt.InferenceSession("trained_brains/FootballPlayer.onnx")

# Params
input_name0 = sess.get_inputs()[0].name
input_name1= sess.get_inputs()[1].name
label_name = sess.get_outputs()[0].name

def get_brain_action(observation):

    if len(observation) == 0:
        return [0,0,0]
    
    observation = np.concatenate((observation[:15], observation[3:6]-observation[6:9], observation[15:], observation[15+3:15+6]-observation[15+6:15+9]), axis=0)
    
    a = np.ones((2,2)).astype(np.float32)
    new_obs = observation.astype(np.float32)

    pred = sess.run(['continuous_actions', 'discrete_actions'], {input_name0: np.array(new_obs).reshape(1,36), input_name1: a})
    return np.concatenate((pred[0][0], pred[1][0]), axis=0)

In [4]:
# Make Environment
env = MiniFootballEnv()
env.set_channel_params(
    width=848, 
    height=480, 
    quality_level=1, 
    time_scale=100, 
    target_frame_rate=-1, 
    capture_frame_rate=60
    )

In [5]:
env.reset()

In [6]:
all_rewards = []

for _ in tqdm(range(100)):

    terminated = False
    steps = 0
    episode_reward = 0
    observation = env.state

    while not terminated:
        
        action = get_brain_action(observation)
        
        # Skip 5 Frames
        for _ in range(5):
            
            observation, reward, terminated, _, _ = env.step(action)
            episode_reward += reward
            steps += 1

            if terminated:
                break

    all_rewards.append(episode_reward)

np.mean(all_rewards)

100%|██████████| 100/100 [02:40<00:00,  1.61s/it]


0.965104984366335