In [1]:
import torch
import torch.nn as nn

### Test my IDEA

In [54]:
class MODEL(nn.Module):
	def	__init__(self, model_dict: dict[str, any], forward_order: list[str] = None):
		super().__init__()
		for key, value in model_dict.items():
			self.add_module(key, value)
		
		self.forward_order = forward_order or list(model_dict.keys())

	def	forward(self, x):
		for module_name in self.forward_order:
			module = getattr(self, module_name)
			x = module(x)
		return x

In [55]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

forward_order = [
	"conv1", "relu1", "pool1",
	"conv2", "relu2", "pool2", 
	"flatten", "fc1", "relu3", "fc2"
]

model = MODEL(model_dict, forward_order)

In [56]:
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

tensor([[-0.0136, -0.0220, -0.0751, -0.0288]], grad_fn=<AddmmBackward0>)

### Test more IDEA

In [60]:
class MODEL(nn.Module):
	def __init__(self, model_dict: dict[str, any], computation_graph: list[tuple[str, list[str]]] = None):
		super().__init__()
		for key, value in model_dict.items():
			self.add_module(key, value)

		self.computation_graph = computation_graph
		if computation_graph:
			self.forward_order = self._build_forward_order(computation_graph)
		else:
			self.forward_order = list(model_dict.keys())
		
	def _build_forward_order(self, graph):
		"""Build execution order from computation graph using topological sort"""
		adj = {node: [] for node, _ in graph}
		indegree = {node: 0 for node, _ in graph}

		for node, next_nodes in graph:
			adj[node] = next_nodes
			for next_node in next_nodes:
				indegree[next_node] = indegree.get(next_node, 0) + 1
		
		# Find starting nodes (indegree = 0)
		queue = [node for node in indegree if indegree[node] == 0]
		order = []

		while queue:
			node = queue.pop(0)
			order.append(node)
			for neighbor in adj[node]:
				indegree[neighbor] -= 1
				if indegree[neighbor] == 0:
					queue.append(neighbor)
		
		return order

	def forward(self, x):
		for module_name in self.forward_order:
			module = getattr(self, module_name)
			x = module(x)
		return x

In [61]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

computation_graph = [
	("conv1", ["relu1"]),
	("relu1", ["pool1"]),
	("pool1", ["conv2"]),
	("conv2", ["relu2"]),
	("relu2", ["pool2"]),
	("pool2", ["flatten"]),
	("flatten", ["fc1"]),
	("fc1", ["relu3"]),
	("relu3", ["fc2"]),
	("fc2", []),
]

model = MODEL(model_dict, computation_graph)
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

tensor([[ 0.0279,  0.1236, -0.0747,  0.0309]], grad_fn=<AddmmBackward0>)