In [1]:
import torch
import torch.nn as nn
from collections import defaultdict, deque
from typing import Dict, List, Tuple, Set, Any

### Test my IDEA

In [None]:
class MODEL(nn.Module):
	def	__init__(self, model_dict: dict[str, any], forward_order: list[str] = None):
		super().__init__()
		for key, value in model_dict.items():
			self.add_module(key, value)
		
		self.forward_order = forward_order or list(model_dict.keys())

	def	forward(self, x):
		for module_name in self.forward_order:
			module = getattr(self, module_name)
			x = module(x)
		return x

In [None]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

forward_order = [
	"conv1", "relu1", "pool1",
	"conv2", "relu2", "pool2", 
	"flatten", "fc1", "relu3", "fc2"
]

model = MODEL(model_dict, forward_order)

In [None]:
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

### Model Creator that can choose output

In [13]:
class MODEL(nn.Module):
    def __init__(self, model_dict: dict[str, any], computation_graph: list[tuple[str, list[str]]] = None):
        super().__init__()
        for key, value in model_dict.items():
            self.add_module(key, value)
        self.computation_graph = computation_graph
        if computation_graph:
            self.forward_order = self._build_forward_order(computation_graph)
        else:
            self.forward_order = list(model_dict.keys())
        
    def _build_forward_order(self, graph):
        all_nodes = set(self._modules.keys())
        adj = {node: [] for node in all_nodes}
        indegree = {node: 0 for node in all_nodes}

        for node, next_nodes in graph:
            if node in all_nodes:
                adj[node] = next_nodes
                for next_node in next_nodes:
                    if next_node in all_nodes:
                        indegree[next_node] = indegree.get(next_node, 0) + 1
        
        from collections import deque
        queue = deque([node for node in all_nodes if indegree[node] == 0])
        order = []

        while queue:
            node = queue.popleft()
            order.append(node)
            for neighbor in adj[node]:
                if neighbor in all_nodes:
                    indegree[neighbor] -= 1
                    if indegree[neighbor] == 0:
                        queue.append(neighbor)
        
        return order

    def forward(self, x):
        node_outputs = {}
        final_outputs = []
        
        for module_name in self.forward_order:
            module = getattr(self, module_name)
            
            # Find all input sources for this node
            input_sources = []
            for src_node, dest_nodes in self.computation_graph:
                if module_name in dest_nodes:
                    input_sources.append(src_node)
            
            # Handle input based on number of sources
            if len(input_sources) == 0:
                current_input = x
            elif len(input_sources) == 1:
                current_input = node_outputs[input_sources[0]]
            else:
                # Sum the inputs (for skip connections)
                inputs = [node_outputs[src] for src in input_sources]
                current_input = sum(inputs)
            
            output = module(current_input)
            node_outputs[module_name] = output
            
            # Check if this is a final output (no outgoing connections)
            is_final_output = True
            for src, dests in self.computation_graph:
                if src == module_name and dests:  # Has outgoing connections
                    is_final_output = False
                    break
            
            if is_final_output:
                final_outputs.append(output)
        
        # Return tuple of all final outputs
        return tuple(final_outputs)

In [14]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

computation_graph = [
	("conv1", ["relu1"]),
	("relu1", ["pool1"]),
	("pool1", ["conv2"]),
	("conv2", ["relu2"]),
	("relu2", ["pool2"]),
	("pool2", ["flatten"]),
	("flatten", ["fc1"]),
	("fc1", ["relu3"]),
	("relu3", ["fc2"]),
	("fc2", []),
]

model = MODEL(model_dict, computation_graph)
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

(tensor([[ 0.0070,  0.0326, -0.0805,  0.1795]], grad_fn=<AddmmBackward0>),)

##### Test Multi output

In [36]:
# Complex MLP with multiple branches and skip connections
mlp_model_dict = {
    # Input layers
    "in1": nn.Linear(5, 2),  # Input branch 1
    "in2": nn.Linear(5, 2),  # Input branch 2
    
    # First hidden layer (p1-p4)
    "p1": nn.Linear(2, 4),
    "p2": nn.Linear(2, 4), 
    "p3": nn.Linear(2, 4),
    "p4": nn.Linear(2, 4),
    
    # Second hidden layer (p5-p8) - each gets inputs from multiple p1-p4
    "p5": nn.Linear(4, 4),
    "p6": nn.Linear(4, 4),
    "p7": nn.Linear(4, 4), 
    "p8": nn.Linear(4, 4),
    
    # Output layer
    "o1": nn.Linear(4, 1),  # Each output gets all p5-p8
    "o2": nn.Linear(4, 1),
}

mlp_computation_graph = [
    # Input branches
    ("in1", ["p1", "p2"]),  # in1 -> p1, p2
    ("in2", ["p3", "p4"]),  # in2 -> p3, p4
    
    # First hidden layer connections
    ("p1", ["p5", "p6", "p7", "p8"]),  # p1 -> ALL of p5-p8
    ("p2", ["p5", "p6", "p7", "p8"]),  # p2 -> ALL of p5-p8  
    ("p3", ["p5", "p6", "p7", "p8"]),  # p3 -> ALL of p5-p8
    ("p4", ["p5", "p6", "p7", "p8"]),  # p4 -> ALL of p5-p8
    
    # Second hidden layer to outputs
    ("p5", ["o1", "o2"]),  # p5 -> o1, o2
    ("p6", ["o1", "o2"]),  # p6 -> o1, o2
    ("p7", ["o1", "o2"]),  # p7 -> o1, o2  
    ("p8", ["o1", "o2"]),  # p8 -> o1, o2
    
    ("o1", []),
    ("o2", []),
]

# Test the complex MLP
complex_model = MODEL(mlp_model_dict, mlp_computation_graph)
dummy_input = torch.randn(1, 5)
output_o1, output_o2 = complex_model(dummy_input)

print(f"Input shape: {dummy_input.shape}")
print(f"Output o1 shape: {output_o1.shape}")
print(f"Output o2 shape: {output_o2.shape}")
print(f"Output o1: {output_o1}")
print(f"Output o2: {output_o2}")
complex_model

Input shape: torch.Size([1, 5])
Output o1 shape: torch.Size([1, 1])
Output o2 shape: torch.Size([1, 1])
Output o1: tensor([[-0.0731]], grad_fn=<AddmmBackward0>)
Output o2: tensor([[0.7262]], grad_fn=<AddmmBackward0>)


MODEL(
  (in1): Linear(in_features=5, out_features=2, bias=True)
  (in2): Linear(in_features=5, out_features=2, bias=True)
  (p1): Linear(in_features=2, out_features=4, bias=True)
  (p2): Linear(in_features=2, out_features=4, bias=True)
  (p3): Linear(in_features=2, out_features=4, bias=True)
  (p4): Linear(in_features=2, out_features=4, bias=True)
  (p5): Linear(in_features=4, out_features=4, bias=True)
  (p6): Linear(in_features=4, out_features=4, bias=True)
  (p7): Linear(in_features=4, out_features=4, bias=True)
  (p8): Linear(in_features=4, out_features=4, bias=True)
  (o1): Linear(in_features=4, out_features=1, bias=True)
  (o2): Linear(in_features=4, out_features=1, bias=True)
)

### Visual

In [37]:
from torch.utils.tensorboard import SummaryWriter
# Create a writer instance, logs will be saved in the 'runs' directory
writer = SummaryWriter('runs/complex_mlp_test')

In [38]:
writer.add_graph(complex_model, dummy_input)
writer.close() # Don't forget to close the writer