In [None]:
def create_hnn_model_arch(model_version:str, img_size:tuple, class_groups:dict, classes:dict, hierarchy:dict):
    base_model = nn.Sequential(
        # note that padding is 0 to re-produce TF's implementation of m01, where the default value for 'padding' argument is 'valid' which means no padding
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=0), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Flatten(),
        # deviated from TF implementation of m01 by making this 1024 instead of 512
        nn.LazyLinear(1024),
        nn.ReLU(),
        # unlike TF implementation of m01, we'll not include this layer here, rather we'll add it later
        # nn.Softmax(dim=1) 
    )

    # these are the indepdent layers of parent and children
    model_layers = [
        # the output of this layer is feed forward from parent to child
        nn.Linear(1024, 512),
        nn.Linear(512, 256),
        nn.Linear(256, 128),
        nn.Linear(128, 64),
        # we do not include the layer below, as it is automatically created at the end of each parent/child. check this for details:
        # https://github.com/rajivsarvepalli/SimpleHierarchy/blob/8e4c29f334928f43509b2d328b11b4a83f2d2af6/src/simple_hierarchy/hierarchal_model.py#L173
        # nn.Linear(64, num_classes), 
    ]

    # 1024 is the output size of our base model
    # 1024 is the input size of our additional indepdent layers (called model_layers)
    # 64 is the output size of our additional indepdent layers (called model_layers) (excluding the final automatically-created hidden layer)
    # 512 is the output size of fourth to last additional indepdent layer to feed (note: we say 'third to last' if we neglect talking about the final automatically-created hidden layer)
    # forward from parent to child (with concatenation)
    size = (1024,1024,64,512)
    # all 4 layers are distinct for each grouping of classes of model_layers (actually 5 layers if you add the automatically-created hidden layer)
    k = 4
    # we want to feed from the fourth to last layer (from parent to child (with concatenation)) (note: we say 'third to last' if we neglect talking about the final automatically-created hidden layer)
    feed_from = 3
    output_order = [*class_groups.values()]
    idx_to_class_name = {k : v for k, v in enumerate([*class_groups.keys(), *classes.keys()])}
    model = DemoModel(hierarchy, base_model, size, model_layers, k, feed_from, output_order)
    
    a = torch.rand(3,*img_size).unsqueeze(0)
    # running an arbitrary forward pass to initialze weights/params (since LazyLinear was used)
    model(a) 

    writer = SummaryWriter(f"./models/hnn_model_v{model_version}_tb_graphs")
    writer.add_graph(model, a)
    writer.close()

    return model, idx_to_class_name
# Example input
img_size = (306,306)
demo_model, idx_to_class_name = create_hnn_model_arch("03", img_size, cgs, cs, hierarchy)
a = torch.rand(3,*img_size).unsqueeze(0)
demo_model(a)
