In [109]:
import torch
import torch.nn as nn
import torch.fx
from collections import OrderedDict

### Test my IDEA

In [102]:
class MODEL(nn.Module):
	def	__init__(self, model_dict: dict[str, any], forward_order: list[str] = None):
		super().__init__()
		for key, value in model_dict.items():
			self.add_module(key, value)
		
		self.forward_order = forward_order or list(model_dict.keys())

	def	forward(self, x):
		for module_name in self.forward_order:
			module = getattr(self, module_name)
			x = module(x)
		return x

In [103]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

forward_order = [
	"conv1", "relu1", "pool1",
	"conv2", "relu2", "pool2", 
	"flatten", "fc1", "relu3", "fc2"
]

model = MODEL(model_dict, forward_order)

In [104]:
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

tensor([[-0.1700, -0.0129, -0.1568,  0.1397]], grad_fn=<AddmmBackward0>)

### Model Creator that can choose output

In [134]:
class MODEL(nn.Module):
    def __init__(self, model_dict: dict[str, any], computation_graph: list[tuple[str, list[str]]] = None):
        super().__init__()
        for key, value in model_dict.items():
            self.add_module(key, value)

        self.computation_graph = computation_graph
        if computation_graph:
            self.forward_order = self._build_forward_order(computation_graph)
        else:
            self.forward_order = list(model_dict.keys())
        
    def _build_forward_order(self, graph):
        """Build execution order from computation graph using topological sort"""
        all_nodes = set(self._modules.keys())
        adj = {node: [] for node in all_nodes}
        indegree = {node: 0 for node in all_nodes}

        for node, next_nodes in graph:
            if node in all_nodes:
                adj[node] = next_nodes
                for next_node in next_nodes:
                    if next_node in all_nodes:
                        indegree[next_node] = indegree.get(next_node, 0) + 1
        
        queue = deque([node for node in all_nodes if indegree[node] == 0])
        order = []

        while queue:
            node = queue.popleft()
            order.append(node)
            for neighbor in adj[node]:
                if neighbor in all_nodes:
                    indegree[neighbor] -= 1
                    if indegree[neighbor] == 0:
                        queue.append(neighbor)
        
        return order

    def forward(self, x):
        node_outputs = {}
        
        for module_name in self.forward_order:
            module = getattr(self, module_name)
            
            # Find all input sources for this node
            input_sources = []
            for src_node, dest_nodes in self.computation_graph:
                if module_name in dest_nodes:
                    input_sources.append(src_node)
            
            # Handle input based on number of sources
            if len(input_sources) == 0:
                current_input = x
            elif len(input_sources) == 1:
                current_input = node_outputs[input_sources[0]]
            else:
                inputs = [node_outputs[src] for src in input_sources]
                current_input = sum(inputs)
            
            output = module(current_input)
            node_outputs[module_name] = output
        
        return output

In [135]:
model_dict = {
	"conv1": nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
	"relu1": nn.ReLU(),
	"pool1": nn.MaxPool2d(2, 2),
	"conv2": nn.Conv2d(16, 32, 3, padding=1),
	"relu2": nn.ReLU(),
	"pool2": nn.MaxPool2d(2, 2),
	"flatten": nn.Flatten(),
	"fc1": nn.Linear(32 * 8 * 8, 128),
	"relu3": nn.ReLU(),
	"fc2": nn.Linear(128, 4),
}

computation_graph = [
	("conv1", ["relu1"]),
	("relu1", ["pool1"]),
	("pool1", ["conv2"]),
	("conv2", ["relu2"]),
	("relu2", ["pool2"]),
	("pool2", ["flatten"]),
	("flatten", ["fc1"]),
	("fc1", ["relu3"]),
	("relu3", ["fc2"]),
	("fc2", []),
]

model = MODEL(model_dict, computation_graph)
dummy_input = torch.randn(1, 3, 32, 32)
output = model.forward(dummy_input)
output

tensor([[-0.0152,  0.0952, -0.0076, -0.1873]], grad_fn=<AddmmBackward0>)

##### Test Resnet18

In [136]:
mycnn_model_dict = {
    # Block 0
    "conv0": nn.Conv2d(3, 64, kernel_size=3, padding=1),
    "bn0": nn.BatchNorm2d(64),
    "relu0": nn.ReLU(),
    "maxpool0": nn.MaxPool2d(kernel_size=2, stride=2),

    # Block 1
    "conv1": nn.Conv2d(64, 128, kernel_size=3, padding=1),
    "bn1": nn.BatchNorm2d(128),
    "relu1": nn.ReLU(),
    "maxpool1": nn.MaxPool2d(kernel_size=2, stride=2),

    # Block 2
    "conv2": nn.Conv2d(128, 256, kernel_size=3, padding=1),
    "bn2": nn.BatchNorm2d(256),
    "relu2": nn.ReLU(),

    "conv3": nn.Conv2d(256, 256, kernel_size=3, padding=1),
    "bn3": nn.BatchNorm2d(256),
    "relu3": nn.ReLU(),

    # Block 3
    "conv4": nn.Conv2d(256, 512, kernel_size=3, padding=1),
    "bn4": nn.BatchNorm2d(512),
    "relu4": nn.ReLU(),

    "conv5": nn.Conv2d(512, 512, kernel_size=3, padding=1),
    "bn5": nn.BatchNorm2d(512),
    "relu5": nn.ReLU(),

    "conv6": nn.Conv2d(512, 512, kernel_size=3, padding=1),
    "bn6": nn.BatchNorm2d(512),
    "relu6": nn.ReLU(),
    "maxpool6": nn.MaxPool2d(kernel_size=2, stride=2),

    # Block 4
    "conv7": nn.Conv2d(512, 512, kernel_size=3, padding=1),
    "bn7": nn.BatchNorm2d(512),
    "relu7": nn.ReLU(),

    "conv8": nn.Conv2d(512, 512, kernel_size=3, padding=1),
    "bn8": nn.BatchNorm2d(512),
    "relu8": nn.ReLU(),

    "conv9": nn.Conv2d(512, 512, kernel_size=3, padding=1),
    "bn9": nn.BatchNorm2d(512),
    "relu9": nn.ReLU(),
    "maxpool9": nn.MaxPool2d(kernel_size=2, stride=2),

    # Block 5
    "conv10": nn.Conv2d(512, 256, kernel_size=3, padding=1),
    "bn10": nn.BatchNorm2d(256),
    "relu10": nn.ReLU(),

    "conv11": nn.Conv2d(256, 256, kernel_size=3, padding=1),
    "bn11": nn.BatchNorm2d(256),
    "relu11": nn.ReLU(),

    "conv12": nn.Conv2d(256, 64, kernel_size=3, padding=1),
    "bn12": nn.BatchNorm2d(64),
    "relu12": nn.ReLU(),

    "conv13": nn.Conv2d(64, 64, kernel_size=3, padding=1),
    "bn13": nn.BatchNorm2d(64),
    "relu13": nn.ReLU(),
    "maxpool13": nn.MaxPool2d(kernel_size=2, stride=2),

    # Final classifier head
    "avgpool": nn.AdaptiveAvgPool2d((1, 1)),
    "flatten": nn.Flatten(),
    "fc": nn.Linear(64, 1000),
}

mycnn_computation_graph = [
    ("conv0", ["bn0"]),
    ("bn0", ["relu0"]),
    ("relu0", ["maxpool0"]),
    ("maxpool0", ["conv1"]),

    ("conv1", ["bn1"]),
    ("bn1", ["relu1"]),
    ("relu1", ["maxpool1"]),
    ("maxpool1", ["conv2"]),

    ("conv2", ["bn2"]),
    ("bn2", ["relu2"]),
    ("relu2", ["conv3"]),

    ("conv3", ["bn3"]),
    ("bn3", ["relu3"]),
    ("relu3", ["conv4"]),

    ("conv4", ["bn4"]),
    ("bn4", ["relu4"]),
    ("relu4", ["conv5"]),

    ("conv5", ["bn5"]),
    ("bn5", ["relu5"]),
    ("relu5", ["conv6"]),

    ("conv6", ["bn6"]),
    ("bn6", ["relu6"]),
    ("relu6", ["maxpool6"]),
    ("maxpool6", ["conv7"]),

    ("conv7", ["bn7"]),
    ("bn7", ["relu7"]),
    ("relu7", ["conv8"]),

    ("conv8", ["bn8"]),
    ("bn8", ["relu8"]),
    ("relu8", ["conv9"]),

    ("conv9", ["bn9"]),
    ("bn9", ["relu9"]),
    ("relu9", ["maxpool9"]),
    ("maxpool9", ["conv10"]),

    ("conv10", ["bn10"]),
    ("bn10", ["relu10"]),
    ("relu10", ["conv11"]),

    ("conv11", ["bn11"]),
    ("bn11", ["relu11"]),
    ("relu11", ["conv12"]),

    ("conv12", ["bn12"]),
    ("bn12", ["relu12"]),
    ("relu12", ["conv13"]),

    ("conv13", ["bn13"]),
    ("bn13", ["relu13"]),
    ("relu13", ["maxpool13"]),

    ("maxpool13", ["avgpool"]),
    ("avgpool", ["flatten"]),
    ("flatten", ["fc"]),
    ("fc", []),
]

mycnn_model = MODEL(mycnn_model_dict, mycnn_computation_graph)
dummy_input = torch.randn(1, 3, 224, 224)
output = mycnn_model(dummy_input)

print(output.shape)
print(output)

torch.Size([1, 1000])
tensor([[-4.3417e-01, -9.2019e-01,  5.0168e-01,  1.0639e+00,  5.6053e-01,
          2.3193e-01,  1.6699e-01, -7.7604e-01,  8.8394e-01,  6.0368e-01,
         -2.7266e-01,  2.3500e-01, -1.0095e-01, -7.7234e-01,  1.0674e+00,
         -3.3842e-01, -7.8781e-01,  3.1908e-01, -3.2107e-01, -3.1551e-01,
          8.1427e-01,  6.4469e-01,  4.6067e-01,  5.2217e-01, -8.2943e-01,
          4.6688e-01,  7.0582e-01, -3.1540e-01, -9.8112e-02,  1.6774e-02,
          5.4505e-01, -1.0040e-01,  3.4305e-01,  7.7029e-01, -5.7248e-01,
          1.1978e-02, -3.1681e-01, -1.9202e-01, -6.1170e-02,  1.6212e+00,
         -1.0591e+00, -5.0628e-01, -1.2541e+00,  2.1777e-01,  1.2237e+00,
         -7.9195e-01,  1.6583e-01,  1.9182e-01, -5.7734e-01, -3.2444e-01,
         -6.4151e-01,  5.9827e-01,  1.6499e-01, -1.3548e-01,  2.5947e-01,
         -5.4122e-02,  3.6143e-01, -1.2056e-01, -3.4168e-01,  9.4372e-02,
          7.3542e-01, -4.5201e-01, -4.3898e-02,  5.1047e-01,  3.5959e-01,
         -1.9831

### Model Extractor