In [63]:
import torch
from torch import nn
from copy import deepcopy
from tqdm import tqdm
import torchvision
import torch.fx
from segformer import MixVisionTransformer
from fusion import *

In [64]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcd4c44aad0>

In [65]:
def find_conv_bn_pairs(traced_model):
    conv_bn_pairs = []
    prev_node = None
    module_dict = dict(traced_model.named_modules())  # Get all modules with proper dot-separated names

    for node in traced_model.graph.nodes:
        if node.op == 'call_module':
            module = module_dict[node.target]
            if isinstance(module, nn.Conv2d):
                prev_node = node
            elif isinstance(module, nn.BatchNorm2d) and prev_node:
                # Use the full dot-separated module names
                conv_name = node.target  # Already in dot notation
                bn_name = prev_node.target  # Already in dot notation
                conv_bn_pairs.append((bn_name, conv_name))  # Keep order (conv, bn)
                prev_node = None
    return conv_bn_pairs

In [77]:
model = torchvision.models.resnext101_64x4d(pretrained=False)
traced = torch.fx.symbolic_trace(model) 
conv_bn_pairs = find_conv_bn_pairs(traced)
print(f"Conv bn pairs: {len(conv_bn_pairs)}, Con extracted: {len(extract_layers_hierarchy(model))}")

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Conv bn pairs: 104, Con extracted: 104


In [72]:
fused_model = deepcopy(model)
fuseable_layer_attributes = extract_layers_hierarchy(model)
# fuseable_layer_attributes = conv_bn_pairs

params_reduced = 0
feasible_cnt = 0
for fuseable_layer_attribute in tqdm(fuseable_layer_attributes, desc="Fusing Layers"):
    try:
        conv_layer = get_layer_by_path(fused_model, fuseable_layer_attribute[0])
        bn_layer = get_layer_by_path(fused_model, fuseable_layer_attribute[1])
        if isinstance(bn_layer, nn.Identity):
            continue
        # Fuse conv and bn layers
        fused_layer = fuse_conv_and_bn(conv_layer, bn_layer)
        num_conv_params = sum(p.numel() for p in conv_layer.parameters())
        num_bn_params = sum(p.numel() for p in bn_layer.parameters())
        num_fused_params = sum(p.numel() for p in fused_layer.parameters())
        params_reduced += num_conv_params + num_bn_params - num_fused_params
        rsetattr(fused_model, fuseable_layer_attribute[0], fused_layer)
        rsetattr(fused_model, fuseable_layer_attribute[1], nn.Identity())
        feasible_cnt += 1
    except:
        pass

print(f"Fusion completed: BatchNorm fusion finished. {params_reduced} parameters were reduced after fusion. Feasible fused: {feasible_cnt}")

Fusing Layers: 100%|██████████| 99/99 [00:00<00:00, 552.98it/s]

Fusion completed: BatchNorm fusion finished. 12608 parameters were reduced after fusion. Feasible fused: 99





In [73]:
fused_model1 = deepcopy(model)
# fuseable_layer_attributes = extract_layers_hierarchy(model)
fuseable_layer_attributes = conv_bn_pairs

params_reduced = 0
feasible_cnt = 0
for fuseable_layer_attribute in tqdm(fuseable_layer_attributes, desc="Fusing Layers"):
    try:
        conv_layer = get_layer_by_path(fused_model1, fuseable_layer_attribute[0])
        bn_layer = get_layer_by_path(fused_model1, fuseable_layer_attribute[1])
        if isinstance(bn_layer, nn.Identity):
            continue
        # Fuse conv and bn layers
        fused_layer = fuse_conv_and_bn(conv_layer, bn_layer)
        num_conv_params = sum(p.numel() for p in conv_layer.parameters())
        num_bn_params = sum(p.numel() for p in bn_layer.parameters())
        num_fused_params = sum(p.numel() for p in fused_layer.parameters())
        params_reduced += num_conv_params + num_bn_params - num_fused_params
        rsetattr(fused_model1, fuseable_layer_attribute[0], fused_layer)
        rsetattr(fused_model1, fuseable_layer_attribute[1], nn.Identity())
        feasible_cnt += 1
    except:
        pass

print(f"Fusion completed: BatchNorm fusion finished. {params_reduced} parameters were reduced after fusion. Feasible fused: {feasible_cnt}")

Fusing Layers: 100%|██████████| 200/200 [00:01<00:00, 168.64it/s]

Fusion completed: BatchNorm fusion finished. 13888 parameters were reduced after fusion. Feasible fused: 102





In [None]:
feasible_cnt

In [39]:
dummy = torch.randn((1, 3, 224, 224))
model.eval()
fused_model.eval()
fused_model1.eval()
out1 = fused_model(dummy)
out2 = model(dummy)
out3 = fused_model(dummy)

In [40]:
torch.allclose(out1, out3, atol=1e-4)

True

In [49]:
start = time.time()
for _ in range(20):
    model(dummy)
end = time.time()
print(f"Duration: {end - start}")

Duration: 1.4178943634033203


In [56]:
start = time.time()
for _ in range(20):
    fused_model(dummy)
end = time.time()
print(f"Duration: {end - start}")

Duration: 1.4924664497375488


In [57]:
start = time.time()
for _ in range(20):
    fused_model1(dummy)
end = time.time()
print(f"Duration: {end - start}")

Duration: 1.4566829204559326
