In [1]:
import pickle
import torch
from models.v4.model_v4 import Policy, Value, count_parameters
from format_obs import format_obs
import torch.nn.functional as F
from post_process import post_process

with open('test_obs_caca.pkl', 'rb') as f:
    obs = pickle.load(f)

In [2]:
pi = Policy()
v = Value()

In [3]:
x = format_obs(obs, 0)

In [4]:
y = pi(x)

torch.Size([1, 8, 32]) torch.Size([1, 2, 32]) torch.Size([1, 41, 32])


RuntimeError: shape '[1, 8, 8]' is invalid for input of size 1312

In [5]:
score = v(x)

In [8]:
score

{'plane_value': tensor([[-0.0646],
         [ 0.0922]], grad_fn=<AddmmBackward0>),
 'cargo_value': tensor([[-0.3025],
         [-0.2957],
         [-0.3206],
         [-0.2833],
         [-0.3021],
         [-0.3143],
         [-0.2907],
         [-0.2983],
         [-0.2981],
         [-0.3187],
         [-0.3205],
         [-0.3150],
         [-0.3000],
         [-0.2812],
         [-0.3135],
         [-0.3145],
         [-0.2982],
         [-0.2780],
         [-0.3017],
         [-0.3138],
         [-0.2956],
         [-0.2904],
         [-0.2893],
         [-0.3134],
         [-0.3121],
         [-0.3177],
         [-0.3190],
         [-0.3146],
         [-0.3148],
         [-0.3123],
         [-0.2827],
         [-0.2999],
         [-0.2804],
         [-0.3045],
         [-0.2942],
         [-0.2986],
         [-0.2819],
         [-0.3152],
         [-0.2841],
         [-0.3131],
         [-0.2969]], grad_fn=<AddmmBackward0>),
 'value': tensor(-12.3635, grad_fn=<AddBackward0>)}

In [7]:
actions = post_process(x, y)
actions

{'a_0': {'priority': 1,
  'cargo_to_load': [],
  'cargo_to_unload': [],
  'destination': 8},
 'a_1': {'priority': 0,
  'cargo_to_load': [],
  'cargo_to_unload': [],
  'destination': 0}}

In [None]:
y['plane_assignments_mtx']

tensor([[0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.]])

In [None]:
x = format_obs(obs, 0)
y = pi(x)
yc = v(x)

In [None]:
class Solution:
    def __init__(self):
        pass

    def post_process(self, obs, x, y):
        """
        Post-process the action logits into valid actions.

        Args:
            obs (dict): observation from the environment
            x: {
                nodes: torch_geometric.data.Data,
                agents: {
                    agents: [{
                        id: int,
                        location: int,
                        cargo_onboard: List[int],
                        tensor: torch.Tensor, shape=(num_agent_features)
                    }],
                    mask: torch.Tensor, shape=(num_agents)
                },
                cargo: {
                    cargo: [{
                        id: int,
                        location: int,
                        destination: int,
                        tensor: torch.Tensor, shape=(num_cargo_features)
                    }],
                    mask: torch.Tensor, shape=(num_tasks)
                }
            }
            y: {
                plane_assignments_mtx: tensor (p x n) routing assignments for planes
                cargo_assignments_mtx: tensor (c x p) assignment of cargo to planes
                n1: tensor (n x f) for skip connection to critic
                p1: tensor (p x f) for skip connection to critic
                p2: tensor (p x f) for skip connection to critic
                c2: tensor (c x f) for skip connection to critic
            }

        returns: actions (dict) {agent: action, ...}
        """

        # Create a map of agent and cargo ids
        agent_map = []
        for agent in x['agents']['agents']:
            agent_map.append(agent['id'])
        agent_map.append(None)

        cargo_map = []
        for cargo in x['cargo']['cargo']:
            cargo_map.append(cargo['id'])

        # Loop through assignments and create preference_dict
        preference_dict = {}
        unload = []

        # Preference dict - planes to nodes
        plane_to_node = torch.argmax(y['plane_assignments_mtx'], dim=1)
        for plane_idx, node_id in enumerate(plane_to_node):
            plane_id = agent_map[plane_idx]
            node_id = node_id.item() + 1 # +1 to convert from 0-indexed to 1-indexed
            # print(f'Plane {plane_id} -> Node {node_id}')
            preference_dict[plane_id] = {'node': node_id, 'cargo': [], 'unload': []}

        # Preference dict - cargo to planes
        cargo_to_plane = torch.argmax(y['cargo_assignments_mtx'], dim=1)
        for cargo_idx, plane_idx in enumerate(cargo_to_plane):
            cargo_id = cargo_map[cargo_idx]
            plane_id = agent_map[plane_idx]
            # print(f'Cargo {cargo_id} -> Plane {plane_id}')
            if plane_id is not None:
                preference_dict[plane_id]['cargo'].append(cargo_id)
            else:
                unload.append(cargo_id)

        # Add unload to preference dict
        for plane_id, preference in preference_dict.items():
            cargo_onboard = obs[plane_id]['cargo_onboard']
            for c in unload:
                if c in cargo_onboard:
                    preference['unload'].append(c)
        
        """
        preference_dict = {
            plane_id: {
                'node': node_id,
                'cargo': [cargo_id, ...],
                'unload': [cargo_id, ...]
            }
        }
        """

        # Create cargo dict
        cargo_dict = {}
        for cargo in obs['a_0']['globalstate']['active_cargo']:
            cargo_dict[cargo.id] = cargo

        # Create actions
        actions = {}
        for plane_id, preference in preference_dict.items():
            # NOTE: Agents can only do one action per timestep

            more = True
            action = None

            # If plane is processing or flying, skip
            if obs[plane_id]['state'] in [1, 2] and more:
                more = False

            # If we have cargo to unload, unload cargo
            if len(preference['unload']) > 0 and len(obs[plane_id]['cargo_onboard']) > 0 and more:
                most_important_cargo = min(preference['unload'], key=lambda x: cargo_dict[x].hard_deadline)
                action = {'priority': 1, 'cargo_to_load': [], 'cargo_to_unload': [most_important_cargo], 'destination': 0} #self._action_helper.unload_action(most_important_cargo)
                more = False

            # If we have capacity, load cargo in deadline order
            capacity = obs[plane_id]['max_weight'] - obs[plane_id]['current_weight']
            if capacity > 0 and len(preference['cargo']) > 0 and more:
                most_important_cargo = min(preference['cargo'], key=lambda x: cargo_dict[x].hard_deadline)
                if cargo_dict[most_important_cargo].weight <= capacity:
                    action = {'priority': 1, 'cargo_to_load': [most_important_cargo], 'cargo_to_unload': [], 'destination': 0} #self._action_helper.load_action(most_important_cargo)
                    more = False

            # If node preference is not current node, takeoff to node
            if preference['node'] != obs[plane_id]['current_airport'] and more:
                action = {'priority': 1, 'cargo_to_load': [], 'cargo_to_unload': [], 'destination': preference['node']} #self._action_helper.takeoff_action(preference['node'])
                more = False

            if action is None:
                action = {'priority': 0, 'cargo_to_load': [], 'cargo_to_unload': [], 'destination': 0} #self._action_helper.noop_action()
            
            # Assign action to plane
            actions[plane_id] = action

            # # Debugging print statements
            # if len(actions[plane_id]['cargo_to_unload']) > 0:
            #     print(f"Plane {plane_id} unloading cargo {actions[plane_id]['cargo_to_unload']}")
            # if len(actions[plane_id]['cargo_to_load']) > 0:
            #     print(f"Plane {plane_id} loading cargo {actions[plane_id]['cargo_to_load']}")

        return actions

In [None]:
if 1==1:
    if 1==1:
        # Create a map of agent and cargo ids
        agent_map = []
        for agent in x['agents']['agents']:
            agent_map.append(agent['id'])
        agent_map.append(None)

        cargo_map = []
        for cargo in x['cargo']['cargo']:
            cargo_map.append(cargo['id'])

        # Loop through assignments and create preference_dict
        preference_dict = {}
        unload = []

        # Preference dict - planes to nodes
        plane_to_node = torch.argmax(y['plane_assignments_mtx'], dim=1)
        for plane_idx, node_id in enumerate(plane_to_node):
            plane_id = agent_map[plane_idx]
            node_id = node_id.item() + 1 # +1 to convert from 0-indexed to 1-indexed
            # print(f'Plane {plane_id} -> Node {node_id}')
            preference_dict[plane_id] = {'node': node_id, 'cargo': [], 'unload': []}


        # Preference dict - cargo to planes
        cargo_to_plane = torch.argmax(y['cargo_assignments_mtx'], dim=1)
        for cargo_idx, plane_idx in enumerate(cargo_to_plane):
            cargo_id = cargo_map[cargo_idx]
            plane_id = agent_map[plane_idx]
            # print(f'Cargo {cargo_id} -> Plane {plane_id}')
            if plane_id is not None:
                preference_dict[plane_id]['cargo'].append(cargo_id)
            else:
                unload.append(cargo_id)

        # Add unload to preference dict
        for plane_id, preference in preference_dict.items():
            cargo_onboard = obs[plane_id]['cargo_onboard']
            for c in unload:
                if c in cargo_onboard:
                    preference['unload'].append(c)

        # Create cargo dict
        cargo_dict = {}
        for cargo in obs['a_0']['globalstate']['active_cargo']:
            cargo_dict[cargo.id] = cargo

        # Create actions
        for plane_id, preference in preference_dict.items():
            # NOTE: Agents can only take 1 action per timestep

            # If we have cargo to unload, unload cargo
            if len(preference['unload']) > 0 and len(obs[plane_id]['cargo_onboard']) > 0:
                most_important_cargo = min(preference['unload'], key=lambda x: cargo_dict[x].hard_deadline)
                # TODO: Add logic to unload cargo
                break

            # If we have capacity, load cargo in deadline order
            capacity = obs[plane_id]['max_weight'] - obs[plane_id]['current_weight']
            if capacity > 0 and len(preference['cargo']) > 0:
                if cargo_dict[min(preference['cargo'], key=lambda x: cargo_dict[x].hard_deadline)] <= capacity:
                    # TODO: Add logic to load cargo
                    continue
            
            # If node preference is not current node, takeoff to node
            current_airport = obs[plane_id]['current_airport']
            if preference['node'] != current_airport:
                # Takeoff to node
                # TODO: Add logic to takeoff to node
                continue
                
        

In [None]:
a

{'plane_assignments_mtx': tensor([[0.0000, 0.0012, 0.0012, 0.0012, 0.0000, 0.0000, 0.0000, 0.2432],
         [0.0000, 0.0000, 0.0016, 0.3245, 0.0000, 0.0016, 0.0000, 0.0000]],
        grad_fn=<MulBackward0>),
 'cargo_assignments_mtx': tensor([[0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.9995],
         [0.0000, 0.0000, 0.999

In [None]:
# get parameter count
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(actor))
print(count_parameters(critic))

1217896
928266
