In [1]:
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import timm
from models.pim_module import FPN, WeaklySelector, GCNCombiner,SharedPluginMoodel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 定义参数
return_nodes = {
   'layer1.2.act3': 'layer1',
    'layer2.3.act3': 'layer2',
    'layer3.5.act3': 'layer3',
    'layer4.2.act3': 'layer4',
}
img_size = 224
num_classes = 10
num_selects = {'layer1': 2048, 'layer2': 512, 'layer3': 128, 'layer4': 32}
fpn_size = 256
comb_proj_size = 512

backbone1 = timm.create_model('resnet50', pretrained=False, num_classes=num_classes)
backbone2 = timm.create_model('resnet50', pretrained=False, num_classes=num_classes)

In [3]:
# get_graph_node_names(backbone1)[0]

In [4]:
backbone1 = create_feature_extractor(backbone1, return_nodes=return_nodes)
rand_in = torch.randn(1, 3, img_size, img_size)
outs = backbone1(rand_in)


In [5]:
# 创建共享的插件模块
shared_fpn = FPN(
    outs,
    fpn_size, proj_type="Conv", upsample_type="Bilinear"
)
shared_selector = WeaklySelector(
    outs,
    num_classes, num_selects, fpn_size
)
shared_combiner = GCNCombiner(
    total_num_selects=sum(num_selects.values()),
    num_classes=num_classes,
    fpn_size=fpn_size
)

In [6]:
# 创建两个 PluginMoodel 实例，共享插件模块
net1 = SharedPluginMoodel(
    backbone=backbone1,
    return_nodes=return_nodes,
    img_size=img_size,
    use_fpn=True,
    fpn_size=fpn_size,
    proj_type="Conv",
    upsample_type="Bilinear",
    use_selection=True,
    num_classes=num_classes,
    num_selects=num_selects,
    use_combiner=True,
    comb_proj_size=comb_proj_size,
    fpn=shared_fpn,          # 传入共享的 FPN
    selector=shared_selector, # 传入共享的 Selector
    combiner=shared_combiner  # 传入共享的 Combiner
)

net2 = SharedPluginMoodel(
    backbone=backbone2,
    return_nodes=return_nodes,
    img_size=img_size,
    use_fpn=True,
    fpn_size=fpn_size,
    proj_type="Conv",
    upsample_type="Bilinear",
    use_selection=True,
    num_classes=num_classes,
    num_selects=num_selects,
    use_combiner=True,
    comb_proj_size=comb_proj_size,
    fpn=shared_fpn,          # 传入共享的 FPN
    selector=shared_selector, # 传入共享的 Selector
    combiner=shared_combiner  # 传入共享的 Combiner
)

In [7]:
rand_inp = torch.randn(1, 3, 224, 224)
outs1 = net1(rand_inp)

In [8]:
outs1.keys()

dict_keys(['layer1', 'layer2', 'layer3', 'layer4', 'select_layer1', 'drop_layer1', 'select_layer2', 'drop_layer2', 'select_layer3', 'drop_layer3', 'select_layer4', 'drop_layer4', 'comb_outs'])

In [9]:
outs2 = net2(rand_inp)

In [10]:
outs2.keys()

dict_keys(['layer1', 'layer2', 'layer3', 'layer4', 'select_layer1', 'drop_layer1', 'select_layer2', 'drop_layer2', 'select_layer3', 'drop_layer3', 'select_layer4', 'drop_layer4', 'comb_outs'])