This notebook is to understand the NAS pipeline and the differen "modules/sections" of the algorithm.

In [1]:
from _configs import OFA_MODEL_PATH

import ofa.model_zoo as ofa
import torch

from search_space import get_search_space

## 1. Understand the models definition

We will check the sampled model and how to define it from the OFA network

In [2]:
def ofa_mobilenet( weights_path: str | None ):
    """ Loads the MobileNetV3 model class and the corresponding weights """
    
    network = ofa.OFAMobileNetV3(
        dropout_rate=0,
        width_mult=1.0,
        ks_list=[3, 5, 7],
        expand_ratio_list=[3, 4, 6],
        depth_list=[2, 3, 4],
    )
    
    init_weights = torch.load(weights_path, map_location="cpu")["state_dict"]
    network.load_state_dict(init_weights)
    return network

Define the space and the network:

In [3]:
ofa_network = ofa_mobilenet( OFA_MODEL_PATH )
space = get_search_space('mobilenetv3-growth')

Sample an architecture:

In [25]:
sample = space.sample(n_samples=1)[0]
base_arch = { k: sample[k] for k in sample if (k != 'direction') }
scale_dir = [1, 0, 0]

print("Sampled architecture:")
base_arch

Sampled architecture:


{'depths': [3, 3, 3, 3, 3],
 'ksizes': [5, 5, 3, 5, 5, 5, 5, 5, 3, 5, 5, 5, 3, 5, 3],
 'widths': [3, 3, 3, 3, 3, 3, 4, 3, 3, 4, 4, 3, 3, 4, 3],
 'resolution': 152}

Get the expansion:

In [26]:
expanded_arch = space.apply_scaling(base_arch, scale_dir)
print("Expanded architecture:")
print("  - By:", scale_dir)
expanded_arch

Expanded architecture:
  - By: [1, 0, 0]


{'depths': [3, 3, 3, 3, 4],
 'ksizes': [5, 5, 3, 5, 5, 5, 5, 5, 3, 5, 5, 5, 3, 5, 3, 3],
 'widths': [3, 3, 3, 3, 3, 3, 4, 3, 3, 4, 4, 3, 3, 4, 3, 3],
 'resolution': 152}

Get the model weights for the base and expanded architectures:

In [27]:
ofa_network.set_active_subnet(ks=base_arch['ksizes'], e=base_arch['widths'], d=base_arch['depths'])
base_network = ofa_network.get_active_subnet(preserve_weight=True)

ofa_network.set_active_subnet(ks=expanded_arch['ksizes'], e=expanded_arch['widths'], d=expanded_arch['depths'])
expanded_network = ofa_network.get_active_subnet(preserve_weight=True)

### 1. Exploration of the expanded architecture.

In [28]:
print("Base model blocks:", len(base_network.blocks) )
print("Expanded model blocks:", len(expanded_network.blocks) )

Base model blocks: 16
Expanded model blocks: 17


In [29]:
from ofa.utils.layers import (
    SEModule,
    MBConvLayer,
    ResidualBlock,
    IdentityLayer,
    ConvLayer,
    LinearLayer,
)
from ofa.utils.pytorch_modules import MyGlobalAvgPool2d, Hsigmoid, Hswish
from torch import nn


def transfer_linear_weights(source_linear: nn.Linear, target_linear: nn.Linear) -> None:
    """
    Copy weights from source to target linear layer.

    ### Args:
        `source_linear (nn.Linear)`: Source linear layer
        `target_linear (nn.Linear)`: Target linear layer with potentially more features
    """
    out_features, in_features = source_linear.weight.shape

    # Copy existing weights
    target_linear.weight.data[:out_features, :in_features] = source_linear.weight.data
    if source_linear.bias is not None:
        target_linear.bias.data[:out_features] = source_linear.bias.data

def transfer_conv_weights(source_conv: nn.Conv2d, target_conv: nn.Conv2d) -> None:
    """
    Copy weights from source to target conv layer, centering the kernel.

    ### Args:
        `source_conv (nn.Conv2d)`: Source convolution layer
        `target_conv (nn.Conv2d)`: Target convolution layer with potentially more channels
    """
    source_weight = source_conv.weight
    out_c, in_c, k_h, k_w = source_weight.shape
    
    # Get target dimensions
    _, _, target_k_h, target_k_w = target_conv.weight.shape
    
    # Calculate padding for centering when target kernel is larger
    if target_k_h >= k_h and target_k_w >= k_w:
        pad_h = (target_k_h - k_h) // 2
        pad_w = (target_k_w - k_w) // 2
        
        # Copy existing weights to the center
        target_conv.weight.data[:out_c, :in_c, pad_h:pad_h+k_h, pad_w:pad_w+k_w] = source_weight

    # Calculate cropping for centering when target is smaller than source
    else:
        crop_h = (k_h - target_k_h) // 2
        crop_w = (k_w - target_k_w) // 2
        
        # Copy center weights from source to target
        target_conv.weight.data[:out_c, :in_c, :, :] = source_weight[:out_c, :in_c, 
                                                                     crop_h:crop_h+target_k_h,
                                                                     crop_w:crop_w+target_k_w]

def transfer_bn_weights(source_bn: nn.BatchNorm2d, target_bn: nn.BatchNorm2d) -> None:
    """
    Copy weights from source to target batch norm layer.

    ### Args:
        `source_bn (nn.BatchNorm2d)`: Source batch normalization layer
        `target_bn (nn.BatchNorm2d)`: Target batch normalization layer
    """
    num_features = source_bn.num_features

    # Copy existing parameters and buffers
    target_bn.weight.data[:num_features] = source_bn.weight.data
    target_bn.bias.data[:num_features] = source_bn.bias.data
    target_bn.running_mean.data[:num_features] = source_bn.running_mean.data
    target_bn.running_var.data[:num_features] = source_bn.running_var.data

def transfer_se_module_weights(source_se: SEModule, target_se: SEModule) -> None:
    """
    Copy weights for Squeeze-and-Excitation module.

    Args:
        source_se (SEModule): Source SE module
        target_se (SEModule): Target SE module
    """
    # SE modules typically contain fc layers
    if hasattr(source_se, "fc") and source_se.fc:
        for source_layer, target_layer in zip(source_se.fc, target_se.fc):
            if isinstance(source_layer, nn.Linear):
                transfer_linear_weights(source_layer, target_layer)

def transfer_convlayer_weights(source_conv: ConvLayer, target_conv: ConvLayer) -> None:
    """
    Copy weights from source to target conv layer.

    ### Args:
        `source_conv (ConvLayer)`: Source convolution layer
        `target_conv (ConvLayer)`: Target convolution layer
    """
    if hasattr(source_conv, "conv") and source_conv.conv:
        transfer_block_weights(source_conv.conv, target_conv.conv)

    if hasattr(source_conv, "bn") and source_conv.bn:
        transfer_block_weights(source_conv.bn, target_conv.bn)

def transfer_residual_block_weights(source_rb: ResidualBlock, target_rb: ResidualBlock) -> None:
    """
    Copy weights for Residual Block.

    Args:
        source_rb (ResidualBlock): Source Residual Block
        target_rb (ResidualBlock): Target Residual Block
    """
    # Handle main conv path
    if hasattr(source_rb, 'conv') and source_rb.conv:
        transfer_block_weights(source_rb.conv, target_rb.conv)
    
    # Handle shortcut if not identity
    if hasattr(source_rb, 'shortcut') and source_rb.shortcut:
        if not isinstance(source_rb.shortcut, IdentityLayer):  
            transfer_block_weights(source_rb.shortcut, target_rb.shortcut)

def transfer_mb_conv_weights(source_mb: MBConvLayer, target_mb: MBConvLayer) -> None:
    """Copy weights for Mobile Inverted Bottleneck Conv Layer.

    Args:
        source_mb (MBConvLayer): Source MBConv layer
        target_mb (MBConvLayer): Target MBConv layer
    """
    # Handle inverted bottleneck
    if hasattr(source_mb, "inverted_bottleneck") and source_mb.inverted_bottleneck:
        transfer_block_weights(
            source_mb.inverted_bottleneck, target_mb.inverted_bottleneck
        )

    # Handle depth-wise conv
    if hasattr(source_mb, "depth_conv") and source_mb.depth_conv:
        transfer_block_weights(source_mb.depth_conv, target_mb.depth_conv)

    # Handle point-wise conv
    if hasattr(source_mb, "point_linear") and source_mb.point_linear:
        transfer_block_weights(source_mb.point_linear, target_mb.point_linear)

    # Handle SE module if present
    if hasattr(source_mb, "se") and source_mb.se:
        transfer_block_weights(source_mb.se, target_mb.se)

def transfer_sequential_weights(source_seq: nn.Sequential, target_seq: nn.Sequential) -> None:
    """
    Copy weights for Sequential container.

    ### Args:
        `source_seq (nn.Sequential)`: Source Sequential container
        `target_seq (nn.Sequential)`: Target Sequential container
    """
    for source_block, target_block in zip(source_seq.children(), target_seq.children()):
        transfer_block_weights(source_block, target_block)

def transfer_block_weights(source_block: nn.Module, target_block: nn.Module) -> None:
    """
    Transfer weights from source to target block, handling the different block types.

    ### Args:
        `source_block (nn.Module)`: Source block from base network
        `target_block (nn.Module)`: Target block from expanded network
    """

    if isinstance(source_block, nn.Sequential):
        transfer_sequential_weights(source_block, target_block)

    elif isinstance(source_block, nn.Conv2d):
        transfer_conv_weights(source_block, target_block)

    elif isinstance(source_block, nn.BatchNorm2d):
        transfer_bn_weights(source_block, target_block)

    elif isinstance(source_block, nn.Linear):
        transfer_linear_weights(source_block, target_block)

    elif isinstance(source_block, LinearLayer):
        transfer_linear_weights(source_block.linear, target_block.linear)

    elif isinstance(source_block, SEModule):
        transfer_se_module_weights(source_block, target_block)

    elif isinstance(source_block, ConvLayer):
        transfer_convlayer_weights(source_block, target_block)
        
    elif isinstance(source_block, ResidualBlock):
        transfer_residual_block_weights(source_block, target_block)

    elif isinstance(source_block, MBConvLayer):
        transfer_mb_conv_weights(source_block, target_block)

    elif isinstance(
        source_block,
        (IdentityLayer, nn.Identity, nn.ReLU, Hsigmoid, Hswish, MyGlobalAvgPool2d),
    ):
        # Activation layers don't have weights to transfer
        pass

    else:
        print(f"Warning: Unhandled layer type: {type(source_block)}")

In [87]:
def are_blocks_same(bblock, eblock):
    # Compare number of parameters
    base_params = sum(p.numel() for p in bblock.parameters())
    expanded_params = sum(p.numel() for p in eblock.parameters())

    # Compare output shapes by checking conv layers
    base_shapes = [m.weight.shape for m in bblock.modules() if hasattr(m, "weight")]
    expanded_shapes = [m.weight.shape for m in eblock.modules() if hasattr(m, "weight")]

    # Are the same
    return (base_shapes == expanded_shapes) and (base_params == expanded_params)


for i, (bblock, eblock) in enumerate(zip(base_network.blocks, expanded_network.blocks)):
    same_blocks = are_blocks_same(bblock, eblock)

    transfer_block_weights(bblock, eblock)

    if not same_blocks:
        print(f"Block {i}:")
        print(bblock)
        print(eblock)
        print("=" * 50)
        break


Block 1:
ResidualBlock(
  (conv): MBConvLayer(
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (depth_conv): Sequential(
      (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)
ResidualBlock(
  (conv): MBConvLayer(
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_

The transfer of the weights needs to be done in oder:
`ksize` -> `width` -> `depth`

In [113]:
max_blocks = max(len(base_network.blocks) - 1, len(expanded_network.blocks) - 1 )

for i in range(max_blocks):
    if i < len(base_network.blocks[1:]):
        print(base_network.blocks[1:][i])
    print(expanded_network.blocks[1:][i])

ResidualBlock(
  (conv): MBConvLayer(
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (depth_conv): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)
ResidualBlock(
  (conv): MBConvLayer(
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

In [30]:
# Get the base and expanded depths
base_depths = base_arch["depths"]
expanded_depths = expanded_arch["depths"]

# Find where the depth increased
added_block_idx = None
for i, (base_d, exp_d) in enumerate(zip(base_depths, expanded_depths)):
    if exp_d > base_d:
        added_block_idx = i
        break

print(f"Base architecture: {base_arch}")
print(f"Block {added_block_idx} was expanded from depth {base_depths[added_block_idx]} to {expanded_depths[added_block_idx]}")

Base architecture: {'depths': [3, 3, 3, 3, 3], 'ksizes': [5, 5, 3, 5, 5, 5, 5, 5, 3, 5, 5, 5, 3, 5, 3], 'widths': [3, 3, 3, 3, 3, 3, 4, 3, 3, 4, 4, 3, 3, 4, 3], 'resolution': 152}
Block 4 was expanded from depth 3 to 4


In [31]:
last_common_block = 1 + sum(base_arch["depths"][:added_block_idx]) + base_arch["depths"][added_block_idx]
print(f"The last common block is: {last_common_block}")


The last common block is: 16


In [32]:
base_network.blocks.insert(
    last_common_block,
    base_network.blocks[last_common_block - 1],
)

In [34]:
for i, (base_block, expanded_block) in enumerate(zip(base_network.blocks, expanded_network.blocks)):
    # print(base_block)
    # print(expanded_block)
    # print("=" * 50)
    try:
        transfer_block_weights(base_block, expanded_block)

    except RuntimeError:
        print(f"Block {i}:")
        print(base_block)
        print(expanded_block)
        print("=" * 50)
        break