In [1]:
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import timm
from models.pim_module import FPN, WeaklySelector, GCNCombiner,SharedPluginMoodel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 定义参数
return_nodes = {
   'layer1.2.act3': 'layer1',
    'layer2.3.act3': 'layer2',
    'layer3.5.act3': 'layer3',
    'layer4.2.act3': 'layer4',
}
img_size = 224
num_classes = 10
num_selects = {'layer1': 32, 'layer2': 32, 'layer3': 32, 'layer4': 32}
fpn_size = 512
comb_proj_size = 512

backbone1 = timm.create_model('resnet50', pretrained=False, num_classes=num_classes)
backbone2 = timm.create_model('resnet50', pretrained=False, num_classes=num_classes)

In [3]:
# get_graph_node_names(backbone1)[0]

In [3]:
backbone11 = create_feature_extractor(backbone1, return_nodes=return_nodes)
rand_in = torch.randn(1, 3, img_size, img_size)
outs = backbone11(rand_in)


In [4]:
# 创建共享的插件模块
shared_fpn = FPN(
    outs,
    fpn_size, proj_type="Conv", upsample_type="Bilinear"
)
shared_selector = WeaklySelector(
    outs,
    num_classes, num_selects, fpn_size
)
shared_combiner = GCNCombiner(
    total_num_selects=sum(num_selects.values()),
    num_classes=num_classes,
    fpn_size=fpn_size
)

In [5]:
# 创建两个 PluginMoodel 实例，共享插件模块
net1 = SharedPluginMoodel(
    backbone=backbone1,
    return_nodes=return_nodes,
    img_size=img_size,
    use_fpn=True,
    fpn_size=fpn_size,
    proj_type="Conv",
    upsample_type="Bilinear",
    use_selection=True,
    num_classes=num_classes,
    num_selects=num_selects,
    use_combiner=True,
    comb_proj_size=comb_proj_size,
    fpn=shared_fpn,          # 传入共享的 FPN
    selector=shared_selector, # 传入共享的 Selector
    combiner=shared_combiner  # 传入共享的 Combiner
)

net2 = SharedPluginMoodel(
    backbone=backbone2,
    return_nodes=return_nodes,
    img_size=img_size,
    use_fpn=True,
    fpn_size=fpn_size,
    proj_type="Conv",
    upsample_type="Bilinear",
    use_selection=True,
    num_classes=num_classes,
    num_selects=num_selects,
    use_combiner=True,
    comb_proj_size=comb_proj_size,
    fpn=shared_fpn,          # 传入共享的 FPN
    selector=shared_selector, # 传入共享的 Selector
    combiner=shared_combiner  # 传入共享的 Combiner
)

In [6]:
rand_inp = torch.randn(1, 3, 224, 224)
outs1 = net1(rand_inp)

In [7]:
outs1.keys()

dict_keys(['layer1', 'layer2', 'layer3', 'layer4', 'select_layer1', 'drop_layer1', 'select_layer2', 'drop_layer2', 'select_layer3', 'drop_layer3', 'select_layer4', 'drop_layer4', 'comb_outs'])

In [16]:
outs1['comb_outs'].shape

torch.Size([1, 10])

In [18]:
# # 创建三个张量
# tensor1 = torch.randn(1, 20, 10)   # 形状 (1, 20, 10)
# tensor2 = torch.randn(1, 100, 10)  # 形状 (1, 100, 10)
# tensor3 = torch.randn(1, 120, 10)  # 形状 (1, 120, 10)

# 在第二个维度（dim=1）上合并
result = torch.cat([outs1[name].unsqueeze(1) if outs1[name].dim() == 2 else outs1[name] for name in outs1], dim=1)

In [19]:
result[:,:5,:].mean(1).shape

torch.Size([1, 10])

In [8]:
outs2 = net2(rand_inp)

In [None]:
outs2.keys()

dict_keys(['layer1', 'layer2', 'layer3', 'layer4', 'select_layer1', 'drop_layer1', 'select_layer2', 'drop_layer2', 'select_layer3', 'drop_layer3', 'select_layer4', 'drop_layer4', 'comb_outs'])

In [11]:
outs2['select_layer1']

tensor([[[ 0.3097, -0.0814, -0.4883,  ..., -0.8681,  1.8043, -0.3105],
         [-0.5980, -0.2696,  0.6486,  ..., -0.1383,  0.8112, -1.0788],
         [ 0.0657, -0.7937, -0.5046,  ..., -0.0362,  1.5001, -0.8177],
         ...,
         [-0.0561, -0.3114, -0.4861,  ...,  0.1306,  0.3294, -0.2948],
         [-0.1141, -0.2894, -0.2055,  ..., -0.0591,  0.4489, -0.0492],
         [-0.2455,  0.0032, -0.1823,  ..., -0.6737,  0.2210, -0.7378]]],
       grad_fn=<StackBackward0>)