In [None]:
import os
import sys
import re

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models as models

from collections import OrderedDict

In [None]:
resnet = models.resnet50(pretrained=True)

In [None]:
def get_layers(model, layer_specs):
    layers = []
    for name, layer in model.named_modules():
        for spec in layer_specs:
            if spec in name:
                if isinstance(layer, nn.Sequential):
                    for sub_name, sub_layer in layer.named_children():
                        full_name = f"{name}.{sub_name}"
                        if spec == full_name:
                            layers.append((full_name, sub_layer))
                else:
                    layers.append((name, layer))
    return nn.Sequential(OrderedDict(layers))


In [None]:
# 此处的模型的最后的最后一层还需要修改为resnet对应的层

class CustomModel(nn.Module):
    def __init__(self, base_model, layer_names, num_classes):
        super(CustomModel, self).__init__()
        self.features = get_layers(base_model, layer_names)
        
        # 获取最后一层的输出特征数
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            out = self.features(dummy_input)
            num_features = out.view(-1).size(0)
        
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
        self.frozen()
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def frozen():
        for param in self.features.parameters():
            param.requires_grad = False        


In [None]:
model = CustomModel(resnet, [
    'conv1', 'bn1', 'relu', 'maxpool',
    'layer1.0.conv3', 'layer1.1.conv3', 'layer1.2.conv3',  # layer1的所有bottleneck的conv3
    'layer2.0.conv3', 'layer2.1.conv3',                    # layer2的前两个bottleneck的conv3
    'layer3.0.conv3'                                       # layer3的第一个bottleneck的conv3
], num_classes=10)


