In [22]:
import math
from typing import Optional

import gin
import numpy as np
import torch
import torch.nn as nn
import torch.optim
from einops import rearrange
from torch.nn import functional as F

In [24]:
class compositional_mlp(nn.Module):
    def __init__(
        self,
        sizes,
        activation,
        num_modules, 
        module_assignment_positions, 
        module_inputs, 
        interface_depths,
        graph_structure,
        output_activation=nn.Identity
    ):
        super().__init__()
        self._num_modules = num_modules
        self.module_assignment_positions = module_assignment_positions
        self._module_inputs = module_inputs         # keys in a dict
        self._interface_depths = interface_depths
        self._graph_structure = graph_structure     # [[0], [1,2], 3] or [[0], [1], [2], [3]]   

        self._module_list = nn.ModuleList() # e.g., object, robot, task...
        
        for graph_depth in range(len(graph_structure)): # root -> children -> ... leaves 
            for j in graph_structure[graph_depth]:          # loop over all module types at this depth
                self._module_list.append(nn.ModuleDict())   # pre, post
                self._module_list[j]['pre_interface'] = nn.ModuleList()
                self._module_list[j]['post_interface'] = nn.ModuleList()
                
                for k in range(num_modules[j]):                 # loop over all modules of this type
                    layers_pre = []
                    layers_post = []
                    for i in range(len(sizes[j]) - 1):              # loop over all depths in this module
                        act = activation if graph_depth < len(graph_structure) - 1 or i < len(sizes[j])-2 else output_activation

                        if i == interface_depths[j]:
                            input_size = sum(sizes[j_prev][-1] for j_prev in graph_structure[graph_depth - 1])
                            input_size += sizes[j][i]
                        else:
                            input_size = sizes[j][i]

                        new_layer = [nn.Linear(input_size, sizes[j][i+1]), act()]
                        if i < interface_depths[j]:
                            layers_pre += new_layer
                        else:
                            layers_post += new_layer
                    if layers_pre:
                        self._module_list[j]['pre_interface'].append(nn.Sequential(*layers_pre))
                    else:   # it's either a root or a module with no preprocessing
                        self._module_list[j]['pre_interface'].append(nn.Identity())
                    self._module_list[j]['post_interface'].append(nn.Sequential(*layers_post))

    def forward(self, input_val):
        x = None
        for graph_depth in range(len(self._graph_structure)):     # root -> children -> ... -> leaves
            x_post = []
            for j in self._graph_structure[graph_depth]:          # nodes (modules) at this depth
                if len(input_val.shape) == 1:
                    x_pre = input_val[self._module_inputs[j]]
                    onehot = input_val[self.module_assignment_positions[j]]
                else:
                    x_pre = input_val[:, self._module_inputs[j]]
                    onehot = input_val[0, self.module_assignment_positions[j]]
                    assert (input_val[:, self.module_assignment_positions[j]] == onehot).all()
                module_index = onehot.argmax()

                print(self._module_list[j]['pre_interface'][module_index])
                print(x_pre)
                x_pre = self._module_list[j]['pre_interface'][module_index](x_pre)
                if x is not None: x_pre = torch.cat((x, x_pre), dim=-1)
                x_post.append(self._module_list[j]['post_interface'][module_index](x_pre))
            x = torch.cat(x_post, dim=-1)
        return x

In [30]:
# Define the input positions and sizes based on your input structure
module_inputs = {
    0: [0],  # 'object-state' (14)
    1: [14],  # 'obstacle-state' (14)
    2: [28],  # 'goal-state' (17)
    3: [45],  # 'robot0_proprio-state' (32)
}

# Define module sizes based on the input sizes and positions
sizes = [
    [14, 32, 64],  # 'object-state' (input size: 14)
    [14, 32, 64],  # 'obstacle-state' (input size: 14)
    [17, 32, 64],  # 'goal-state' (input size: 17)
    [32, 64, 128], # 'robot0_proprio-state' (input size: 32)
]

# Activation function and other parameters
activation = nn.ReLU
num_modules = [4, 4, 4, 4]  # Four module per state
module_assignment_positions = [0, 1, 2, 3]  
interface_depths = [1, 1, 2, 3]  
graph_structure = [[0], [1], [2], [3]]  

# Initialize the model
model = compositional_mlp(
    sizes=sizes,
    activation=activation,
    num_modules=num_modules,
    module_assignment_positions=module_assignment_positions,
    module_inputs=module_inputs,
    interface_depths=interface_depths,
    graph_structure=graph_structure,
)

# Create a sample input tensor based on the input structure
#input_val = torch.randn(77)  # 14 + 14 + 17 + 32 = 77 (total size)

# Sizes for each state type
object_state_size = 14
obstacle_state_size = 14
goal_state_size = 17
robot_state_size = 32

# Create random input vectors for each state (nested array format)
object_state = torch.randn(object_state_size)  # Size [14]
obstacle_state = torch.randn(obstacle_state_size)  # Size [14]
goal_state = torch.randn(goal_state_size)  # Size [17]
robot_state = torch.randn(robot_state_size)  # Size [32]

input_val = [
    object_state,  # size: [14]
    obstacle_state,  # size: [14]
    goal_state,  # size: [17]
    robot_state  # size: [32]
]

print(input_val)

# Forward pass through the network
output = model(input_val)

# Print output to investigate
print("Output of the network:", output)

[tensor([ 1.2011,  0.1963,  0.7617,  0.2197,  0.7114, -1.1768, -0.9514, -1.6777,
        -0.2301,  0.1120, -0.0690,  0.1586,  0.6890,  0.0356]), tensor([ 1.4509,  0.8889,  0.1023, -0.4460, -0.2469, -0.0178,  0.5129,  2.0603,
         0.4722, -1.2127, -0.8660,  0.6401,  0.0864,  0.1021]), tensor([ 2.9942, -0.2817, -0.4680,  0.2616,  0.1397,  1.2501,  0.1648, -0.1226,
        -1.4296, -0.3062, -1.3422,  0.1108,  0.1556,  1.1018,  0.0235, -0.1095,
         2.7542]), tensor([-1.8966, -0.6748,  0.0033,  1.2975,  0.1815,  0.9043,  0.7674,  1.6319,
         1.7122, -0.9976, -0.3286,  0.9625, -1.4661, -0.4840,  2.1565,  0.8144,
         0.4366,  0.4033,  2.0544, -1.4394, -0.1928, -0.4410, -0.5871,  1.0863,
         0.4419, -1.9019, -0.5759,  1.0635, -0.1198, -1.8619,  1.2086,  0.1176])]


AttributeError: 'list' object has no attribute 'shape'