In [1]:
import torch
import torch.nn as nn
from collections import defaultdict, deque
from typing import Dict, List, Tuple, Set, Any

### Test my IDEA

In [None]:
class MODEL(nn.Module):
	def	__init__(self, model_dict: dict[str, any], forward_order: list[str] = None):
		super().__init__()
		for key, value in model_dict.items():
			self.add_module(key, value)
		
		self.forward_order = forward_order or list(model_dict.keys())

	def	forward(self, x):
		for module_name in self.forward_order:
			module = getattr(self, module_name)
			x = module(x)
		return x

In [None]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

forward_order = [
	"conv1", "relu1", "pool1",
	"conv2", "relu2", "pool2", 
	"flatten", "fc1", "relu3", "fc2"
]

model = MODEL(model_dict, forward_order)

In [None]:
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

### Model Creator that can choose output

In [2]:
class MODEL(nn.Module):
    def __init__(self, model_dict: dict[str, any], computation_graph: list[tuple[str, list[str]]] = None):
        super().__init__()
        for key, value in model_dict.items():
            self.add_module(key, value)

        self.computation_graph = computation_graph
        if computation_graph:
            self.forward_order = self._build_forward_order(computation_graph)
        else:
            self.forward_order = list(model_dict.keys())
        
    def _build_forward_order(self, graph):
        """Build execution order from computation graph using topological sort"""
        all_nodes = set(self._modules.keys())
        adj = {node: [] for node in all_nodes}
        indegree = {node: 0 for node in all_nodes}

        for node, next_nodes in graph:
            if node in all_nodes:
                adj[node] = next_nodes
                for next_node in next_nodes:
                    if next_node in all_nodes:
                        indegree[next_node] = indegree.get(next_node, 0) + 1
        
        queue = deque([node for node in all_nodes if indegree[node] == 0])
        order = []

        while queue:
            node = queue.popleft()
            order.append(node)
            for neighbor in adj[node]:
                if neighbor in all_nodes:
                    indegree[neighbor] -= 1
                    if indegree[neighbor] == 0:
                        queue.append(neighbor)
        
        return order

    def forward(self, x):
        node_outputs = {}
        
        for module_name in self.forward_order:
            module = getattr(self, module_name)
            
            # Find all input sources for this node
            input_sources = []
            for src_node, dest_nodes in self.computation_graph:
                if module_name in dest_nodes:
                    input_sources.append(src_node)
            
            # Handle input based on number of sources
            if len(input_sources) == 0:
                current_input = x
            elif len(input_sources) == 1:
                current_input = node_outputs[input_sources[0]]
            else:
                inputs = [node_outputs[src] for src in input_sources]
                current_input = sum(inputs)
            
            output = module(current_input)
            node_outputs[module_name] = output
        
        return output

In [3]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

computation_graph = [
	("conv1", ["relu1"]),
	("relu1", ["pool1"]),
	("pool1", ["conv2"]),
	("conv2", ["relu2"]),
	("relu2", ["pool2"]),
	("pool2", ["flatten"]),
	("flatten", ["fc1"]),
	("fc1", ["relu3"]),
	("relu3", ["fc2"]),
	("fc2", []),
]

model = MODEL(model_dict, computation_graph)
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

tensor([[-0.1183,  0.2067,  0.0747,  0.0485]], grad_fn=<AddmmBackward0>)

##### Test Resnet18

In [10]:
# Simple 3-layer MLP with skip connections
mlp_model_dict = {
    "fc1": nn.Linear(10, 20),
    "relu1": nn.ReLU(),
    
    "fc2": nn.Linear(20, 20), 
    "relu2": nn.ReLU(),
    
    "fc3": nn.Linear(20, 10),
    "relu3": nn.ReLU(),
    
    "skip_fc": nn.Linear(20, 10),
}

mlp_computation_graph = [
    ("fc1", ["relu1"]),
    ("relu1", ["fc2", "skip_fc"]),
    
    ("fc2", ["relu2"]),
    ("relu2", ["fc3"]),
    ("fc3", ["relu3"]),
    
    ("skip_fc", ["relu3"]),
    ("relu3", []),
]

# Test the MLP
mlp_model = MODEL(mlp_model_dict, mlp_computation_graph)
dummy_input = torch.randn(1, 10)
output = mlp_model(dummy_input)
print(f"MLP output shape: {output.shape}")
mlp_model

MLP output shape: torch.Size([1, 10])


MODEL(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=20, out_features=20, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=20, out_features=10, bias=True)
  (relu3): ReLU()
  (skip_fc): Linear(in_features=20, out_features=10, bias=True)
)

### Model Extractor