In [2]:
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
import torch
import numpy as np
import random

In [5]:
channel = EngineConfigurationChannel()
channel.set_configuration_parameters(width=640, height=480, quality_level=0, time_scale=1)
unity_env = UnityEnvironment("LanderProbe.exe",side_channels=[channel])
env = UnityToGymWrapper(unity_env)



In [4]:
env.close()

In [3]:
env.action_size, env.action_space

(9, MultiDiscrete([3 3 3 3 3 3 2 2 2]))

In [3]:
env.reset()
for _ in range(100):
    action = env.action_space.sample()
    observation, reward, terminated, truncated= env.step(action)
    
    if  terminated or not truncated:
        observation = env.reset()

In [6]:
class DQN(torch.nn.Module):
    def __init__(self,nb_action, nb_observation):
        super().__init__()
        self.stack = torch.nn.Sequential(
            torch.nn.Linear(nb_observation,32,dtype=torch.float32),
            torch.nn.ReLU(),
            torch.nn.Linear(32,16,dtype= torch.float32),
            torch.nn.ReLU(),
            torch.nn.Linear(16,nb_action,dtype=torch.float32)
        )
    def forward(self,x):
        return self.stack(x)

In [7]:
epsilon = 0.01
def policy(q_values: torch.Tensor):
    rand = random.random()
    if(rand < epsilon):
        return random.randint(0,q_values.size(0)-1)
    return q_values.argmax(0).item()

In [13]:
throttleDQN = [
    DQN(2,21),
    DQN(2,21),
    DQN(2,21),
]
angleDQN  = {
    DQN(3,21),
    DQN(3,21),
    DQN(3,21),
    DQN(3,21),
    DQN(3,21),
    DQN(3,21),
}

In [9]:
def combine_actions(throttle_actions, angle_actions):
    # Convert to tensors if not already
    throttle_tensor = torch.tensor(throttle_actions, dtype=torch.float32)
    angle_tensor = torch.tensor(angle_actions, dtype=torch.float32)

    # Concatenate into one tensor
    combined = torch.cat((throttle_tensor, angle_tensor))

    return combined

In [12]:
history=[]

In [14]:
y = 0.9
lossFn = torch.nn.MSELoss()

throttleOptimizers = [torch.optim.Adam(dqn.parameters()) for dqn in throttleDQN]
angleOptimizers = [torch.optim.Adam(dqn.parameters()) for dqn in angleDQN]

def train():
    observation = env.reset()
    is_done = False
    total_reward = 0

    while not is_done:
        observation_tensor = torch.tensor(observation, dtype=torch.float32)

        # Example: index 0 controls throttle, 1 controls angle
        throttle_actions = []
        angle_actions = []

        for dqn in throttleDQN:
            q = dqn(observation_tensor)
            a = policy(q)
            throttle_actions.append(a)

        for dqn in angleDQN:
            q = dqn(observation_tensor)
            a = policy(q)
            angle_actions.append(a)

        # Combine actions however needed (example shown for discrete actions)
        action = combine_actions(throttle_actions, angle_actions)

        next_observation, reward, terminated, truncated = env.step(action)
        total_reward += reward
        next_obs_tensor = torch.tensor(next_observation, dtype=torch.float32)

        for dqn, opt, a in zip(throttleDQN, throttleOptimizers, throttle_actions):
            q_values = dqn(observation_tensor)
            with torch.no_grad():
                next_q = dqn(next_obs_tensor).max(0)[0]
                target = reward + y * next_q * (1 - terminated)
            loss = lossFn(q_values[a], target)
            opt.zero_grad()
            loss.backward()
            opt.step()

        for dqn, opt, a in zip(angleDQN, angleOptimizers, angle_actions):
            q_values = dqn(observation_tensor)
            with torch.no_grad():
                next_q = dqn(next_obs_tensor).max(0)[0]
                target = reward + y * next_q * (1 - terminated)
            loss = lossFn(q_values[a], target)
            opt.zero_grad()
            loss.backward()
            opt.step()

        observation = next_observation
        is_done = terminated or not truncated
    history.append(total_reward)
    print(total_reward)

for i in range(10000):
    train()


-96.8999999538064
-97.39999996125698
-96.8999999538064
-97.89999996870756
-97.39999996125698
-97.99999997019768
-97.29999995976686
-98.0999999716878
-97.39999996125698
-97.59999996423721
-97.69999996572733
-97.99999997019768
-97.4999999627471
-97.99999997019768
-97.69999996572733
-97.99999997019768
-96.8999999538064
-97.79999996721745
-97.4999999627471
-96.99999995529652
-97.39999996125698
-97.29999995976686
-97.29999995976686
-94.79999992251396
-95.49999993294477
-97.09999995678663
-97.69999996572733
-95.29999992996454
-97.79999996721745
-96.19999994337559
-95.99999994039536
-97.09999995678663
-97.19999995827675
-97.29999995976686
-98.19999997317791
-97.09999995678663
-96.99999995529652
-97.19999995827675
-96.99999995529652
-97.09999995678663
-98.39999997615814
-95.99999994039536
-97.79999996721745
-96.69999995082617
-97.39999996125698
-97.19999995827675
-97.19999995827675
-95.89999993890524
-95.09999992698431
-97.4999999627471
-97.89999996870756
-97.69999996572733
-97.69999996572733


KeyboardInterrupt: 

In [17]:
print(history)

[-96.8999999538064, -97.39999996125698, -96.8999999538064, -97.89999996870756, -97.39999996125698, -97.99999997019768, -97.29999995976686, -98.0999999716878, -97.39999996125698, -97.59999996423721, -97.69999996572733, -97.99999997019768, -97.4999999627471, -97.99999997019768, -97.69999996572733, -97.99999997019768, -96.8999999538064, -97.79999996721745, -97.4999999627471, -96.99999995529652, -97.39999996125698, -97.29999995976686, -97.29999995976686, -94.79999992251396, -95.49999993294477, -97.09999995678663, -97.69999996572733, -95.29999992996454, -97.79999996721745, -96.19999994337559, -95.99999994039536, -97.09999995678663, -97.19999995827675, -97.29999995976686, -98.19999997317791, -97.09999995678663, -96.99999995529652, -97.19999995827675, -96.99999995529652, -97.09999995678663, -98.39999997615814, -95.99999994039536, -97.79999996721745, -96.69999995082617, -97.39999996125698, -97.19999995827675, -97.19999995827675, -95.89999993890524, -95.09999992698431, -97.4999999627471, -97.89