In [None]:
pwd

In [None]:
from PIL import Image

img_path = './image/1.png'
def _get_img_from_path(img_path, transform=None):
    with open(img_path, 'rb') as f:
        img = Image.open(f).convert('RGB')
    if transform is not None:
        img = transform(img)
    return img

In [None]:
img = _get_img_from_path(img_path)
img.show()
print(img)

In [1]:
# 测试resnet
from models.image_encoders.resnet import ResNet18Layer4Upper, ResNet18Layer4Lower
pretrained = True
lower_encoder = ResNet18Layer4Lower(pretrained)
# print(type(lower_encoder))
# print(lower_encoder.layer_shapes())
lower_feature_shape = lower_encoder.layer_shapes()['layer4'] # 512
feature_size = 512
norm_scale = 4
model = ResNet18Layer4Upper(lower_feature_shape, feature_size, pretrained=pretrained,norm_scale=norm_scale)
# print(model)
print(model.lower_feature_shape, model.feature_size) # 512 512

# 结论：ResNet18Layer4Lower主要负责特征的初步提取，而ResNet18Layer4Upper则负责将这些特征转换为更高级的、用于后续任务的表示。
#      这两个类设计为连续使用，其中ResNet18Layer4Lower的输出直接作为ResNet18Layer4Upper的输入。这种设计使得整个模型的两部分可以灵活地根据需要进行调整或替换。

ValueError: Unknown value 'True' for ResNet18_Weights.

In [None]:
import pickle
def create_read_func(vocab_path):
    # 用于从给定的路径 vocab_path 读取并反序列化数据
    def read_func():
        with open(vocab_path, 'rb') as f:
            # 使用了 Python 的 pickle 模块来加载存储的对象
            data = pickle.load(f)
        return data

    return read_func
read_func = create_read_func('D:/Datasets/fashionIQ/fashion_iq_vocab.pkl')
data = read_func()
print(type(data))

In [None]:
# 测试RoBerta
from models.text_encoders.roberta import RobertaEncoder,BertFc
model1, model2 = RobertaEncoder(512), BertFc(512)
print(model1, model2)


In [2]:
import torch
import torch.nn as nn
class GlobalCrossAttentionMap(nn.Module):
    def __init__(self, feature_size, text_feature_size, num_heads, normalizer=None, *args, **kwargs):
        super().__init__()

        # 注意力头的数量
        self.n_heads = num_heads
        # 每个注意力头处理的特征维度，计算为 feature_size 除以 num_heads
        self.c_per_head = feature_size // num_heads
        assert feature_size == self.n_heads * self.c_per_head

        # 用于将文本特征映射到与图像特征相同的维度
        self.W_t = nn.Linear(text_feature_size, feature_size)
        self.normalize = normalizer if normalizer else nn.Softmax(dim=1)

    def forward(self, x, t):
        # 将图像的四个维度剥离保存，batch_size,chanel,height,weight
        b, c, h, w = x.size()
        # x 的维度被重塑以适应注意力头的处理，首先按照注意力头数量和每头的特征维度重塑
        x_reshape = x.view(b * self.n_heads, self.c_per_head, h, w)
        # 然后将高度和宽度合并
        x_reshape = x_reshape.view(b * self.n_heads, self.c_per_head, h * w)

        # 文本特征 t 通过 self.W_t 映射
        t_mapped = self.W_t(t)
        # 并调整维度以匹配处理后的图像特征 x_reshape
        t_mapped = t_mapped.view(b * self.n_heads, self.c_per_head, 1)

        # 使用批量矩阵乘法 (torch.bmm) 计算注意力映射，然后通过根号缩放因子进行缩放
        att_map = torch.bmm(x_reshape.transpose(1, 2), t_mapped).squeeze(-1) / (self.c_per_head ** 0.5)
        # 应用归一化函数（默认为 Softmax）来归一化注意力映射
        att_map = self.normalize(att_map)  # (b * n_heads, h * w)
        # 注意力映射的维度被调整回原始的批次大小和注意力头数量
        att_map = att_map.view(b * self.n_heads, 1, h * w)
        att_map = att_map.view(b, self.n_heads, h * w)

        return att_map

In [5]:
class SelfAttentionMap(nn.Module):
    def __init__(self, feature_size, num_heads, *args, **kwargs):
        super().__init__()

        self.n_heads = num_heads
        self.c_per_head = feature_size // num_heads
        assert feature_size == self.n_heads * self.c_per_head

        self.W_k = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False)
        self.W_q = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x, *args, **kwargs):
        b, c, h, w = x.size()

        keys, queries = self.W_k(x), self.W_q(x)
        keys = keys.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w)
        queries = queries.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w)

        att_map = torch.bmm(queries.transpose(1, 2), keys) / (self.c_per_head ** 0.5)
        att_map = self.softmax(att_map)  # (b * num_heads, h * w, h * w), torch.sum(att_map[batch_idx][?]) == 1
        att_map = att_map.view(b, self.n_heads, h * w, h * w)

        return att_map

In [6]:
def reshape_text_features_to_concat(text_features, image_features_shapes):
    return text_features.view((*text_features.size(), 1, 1)).repeat(1, 1, *image_features_shapes[2:])
from models.attention_modules.external_attention import ExternalAttention
class AttentionModule(nn.Module):
    def __init__(self, feature_size, text_feature_size, num_heads, *args, **kwargs):
        super().__init__()

        self.n_heads = num_heads
        self.c_per_head = feature_size // num_heads
        assert feature_size == self.n_heads * self.c_per_head
        
        # 两个组件的初始化
        self.self_att_generator = SelfAttentionMap(feature_size, num_heads, *args, **kwargs)
        self.external_att = ExternalAttention(d_model= feature_size, S= 64)
        self.global_att_generator = GlobalCrossAttentionMap(feature_size, text_feature_size, num_heads, *args, **kwargs)

        # 合并图像和调整形状后的文本特征
        self.merge = nn.Conv2d(feature_size + text_feature_size, feature_size, kernel_size=1, bias=False)
        # 用于生成值（values）
        self.W_v = nn.Conv2d(feature_size, feature_size, kernel_size=1, bias=False)
        # 用于最终输出的调整
        self.W_r = nn.Conv2d(feature_size, feature_size, kernel_size=1)

    def forward(self, x, t, return_map=False, *args, **kwargs):
        b, c, h, w = x.size()
        
        # 将文本特征调整形状以匹配图像特征的维度
        t_reshaped = reshape_text_features_to_concat(t, x.size())
        # 将图像特征和调整形状后的文本特征合并后，通过卷积层merge处理
        vl_features = self.merge(torch.cat([x, t_reshaped], dim=1))  # (b, c, h, w)

        # 通过卷积层W_v处理合并后的特征，用于生成注意力机制中的值（values）
        values = self.W_v(vl_features)
        values = values.view(b * self.n_heads, self.c_per_head, h, w).view(b * self.n_heads, self.c_per_head, h * w)
        
        # 通过外部注意力模块处理图像特征
        external_att_out = self.external_att(x.view(b * h * w, c))  # (b * h * w, c)
        external_att_out = external_att_out.view(b, h, w, c).permute(0, 3, 1, 2)  # (b, c, h, w)

        # 通过自注意力生成器计算自注意力映射图
        self_att_map = self.self_att_generator(x)  # (b, num_heads, h * w, h * w)
        # 通过全局交叉注意力生成器计算全局交叉注意力映射图
        global_cross_att_map = self.global_att_generator(x, t)
        global_cross_att_map = global_cross_att_map.view(b, self.n_heads, 1, h * w)  # (b, num_heads, 1, h * w)
        #  将自注意力图和全局交叉注意力图相加，得到最终的注意力图
        att_map = self_att_map + global_cross_att_map  # (b, num_heads, h * w, h * w)
        att_map_reshaped = att_map.view(b * self.n_heads, h * w, h * w)  # (b * num_heads, h * w, h * w)

        # 使用注意力图重新加权values，通过矩阵乘法和重塑操作处理，最后通过卷积层W_r调整输出
        att_out = torch.bmm(values, att_map_reshaped.transpose(1, 2))  # (b * num_heads, c_per_head, h * w)
        att_out = att_out.view(b, self.n_heads * self.c_per_head, h * w)
        att_out = att_out.view(b, self.n_heads * self.c_per_head, h, w)
        att_out = self.W_r(att_out + external_att_out)  # 将外部注意力的输出与现有的注意力输出相加

        return att_out, att_map if return_map else att_out

In [7]:
att = AttentionModule(feature_size=512, text_feature_size= 512, num_heads=2, normalizer=nn.LayerNorm)
print(att)

AttentionModule(
  (self_att_generator): SelfAttentionMap(
    (W_k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (W_q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (softmax): Softmax(dim=2)
  )
  (external_att): ExternalAttention(
    (mk): Linear(in_features=512, out_features=64, bias=False)
    (mv): Linear(in_features=64, out_features=512, bias=False)
    (softmax): Softmax(dim=1)
  )
  (global_att_generator): GlobalCrossAttentionMap(
    (W_t): Linear(in_features=512, out_features=512, bias=True)
  )
  (merge): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (W_v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (W_r): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)


In [1]:
import torch
from torch import nn

'''为特征加入高斯噪声'''
class NormalAugmenter(nn.Module):

    def __init__(self, feature_size, alpha_scale=1, beta_scale=1, *args, **kwargs):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.instance_norm = nn.InstanceNorm1d(feature_size) # 用于对特征进行归一化处理
        self.alpha_scale = alpha_scale # 1
        self.beta_scale = beta_scale # 1

    def forward(self, features, *args, **kwargs):
        std, mean = torch.std_mean(features, dim=1)
        normal_alpha = torch.distributions.Normal(loc=1, scale=std)
        normal_beta = torch.distributions.Normal(loc=mean, scale=std)
        alpha = self.alpha_scale * normal_alpha.sample([features.shape[1]]).transpose(-1, -2)
        beta = self.beta_scale * normal_beta.sample([features.shape[1]]).transpose(-1, -2)

        features = self.instance_norm(features)
        x = alpha * features + beta
        return x

    @classmethod
    def code(cls) -> str:
        return 'normal_gaussian'

In [2]:
model = NormalAugmenter(512)
print(model)

NormalAugmenter(
  (instance_norm): InstanceNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)


In [3]:
import torch
import torch.nn as nn

from models.attention_modules.self_attention import AttentionModule
class DisentangledTransformer(nn.Module):
    def __init__(self, feature_size, text_feature_size, num_heads, global_styler=None, *args, **kwargs):
        super().__init__()
        self.n_heads = num_heads
        self.c_per_head = feature_size // num_heads
        assert feature_size == self.n_heads * self.c_per_head

        self.att_module = AttentionModule(feature_size, text_feature_size, num_heads, *args, **kwargs)
        self.att_module2 = AttentionModule(feature_size, text_feature_size, num_heads, *args, **kwargs)
        self.global_styler = global_styler

        self.weights = nn.Parameter(torch.tensor([1., 1.]))
        self.instance_norm = nn.InstanceNorm2d(feature_size)

    def forward(self, x, t, *args, **kwargs):
        normed_x = self.instance_norm(x)
        att_out, att_map = self.att_module(normed_x, t, return_map=True)
        out = normed_x + self.weights[0] * att_out

        att_out2, att_map2 = self.att_module2(out, t, return_map=True)
        out = out + self.weights[1] * att_out2

        out = self.global_styler(out, t, x=x)

        return out, att_map

In [4]:
model = DisentangledTransformer(512,512,2)
print(model)

DisentangledTransformer(
  (att_module): AttentionModule(
    (self_att_generator): SelfAttentionMap(
      (W_k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (W_q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (softmax): Softmax(dim=2)
    )
    (global_att_generator): GlobalCrossAttentionMap(
      (W_t): Linear(in_features=512, out_features=512, bias=True)
      (normalize): Softmax(dim=1)
    )
    (merge): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (W_v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (W_r): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
  )
  (att_module2): AttentionModule(
    (self_att_generator): SelfAttentionMap(
      (W_k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (W_q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (softmax): Softmax(dim=2)
    )
    (global_att_generator): GlobalCrossAttentionMap(
    

In [2]:
from tqdm import tqdm
import time

# 使用tqdm显示一个简单的循环进度条
for i in tqdm(range(10)):
    time.sleep(0.1)  # 模拟耗时操作


100%|██████████| 10/10 [00:01<00:00,  9.11it/s]


In [1]:
from tqdm import tqdm
import time

# 使用更多参数自定义进度条
for i in tqdm(range(10), desc="Processing", unit="step"):
    time.sleep(0.1)  # 模拟耗时操作


Processing: 100%|██████████| 10/10 [00:01<00:00,  9.09step/s]


In [1]:
import torch
# 1. 查看预训练权重的关键字
pretrained_weights_path = './ckpt/resnet18.pth'
state_dict = torch.load(pretrained_weights_path)
# 获取预训练权重的关键字
pretrained_keys = state_dict.keys()
print("预训练权重关键字：")
for key in state_dict.keys():
    print(key)

预训练权重关键字：
conv1.weight
bn1.running_mean
bn1.running_var
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.weight
layer1.1.bn2.bias
layer2.0.conv1.weight
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.running_mean
layer2.0.downsample.1.running_var
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.running_mean
layer2.1.bn1.running_var
layer2.1.bn1.weight
layer2.1.bn

In [2]:
# 2. 查看自身网络模型的关键字
from torchvision.models import resnet18
net = resnet18()
model_keys = net.state_dict().keys()

print("模型权重的关键字：")
for key in model_keys:
    print(key)

模型权重的关键字：
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.0.downsample.0.weight
layer2.

In [3]:
# 3. 找出模型中缺失的权重（针对预训练关键字比模型关键字少）
missing_keys = model_keys - pretrained_keys
print("模型中缺失的权重关键字：")
for key in missing_keys:
    print(key)

模型中缺失的权重关键字：
layer2.0.downsample.1.num_batches_tracked
layer3.0.bn2.num_batches_tracked
layer4.1.bn1.num_batches_tracked
layer1.0.bn1.num_batches_tracked
layer4.1.bn2.num_batches_tracked
layer1.1.bn2.num_batches_tracked
layer3.1.bn1.num_batches_tracked
layer3.1.bn2.num_batches_tracked
layer2.0.bn2.num_batches_tracked
layer2.0.bn1.num_batches_tracked
layer2.1.bn1.num_batches_tracked
layer3.0.bn1.num_batches_tracked
bn1.num_batches_tracked
layer2.1.bn2.num_batches_tracked
layer3.0.downsample.1.num_batches_tracked
layer1.1.bn1.num_batches_tracked
layer4.0.downsample.1.num_batches_tracked
layer1.0.bn2.num_batches_tracked
layer4.0.bn2.num_batches_tracked
layer4.0.bn1.num_batches_tracked


In [4]:
# 4. 找出模型中多余的权重（针对预训练关键字比模型关键字多）
unexpected_keys = pretrained_keys - model_keys
print("模型中多余的权重关键字：")
for key in unexpected_keys:
    print(key)

模型中多余的权重关键字：


In [5]:
# 关键字不匹配
net.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
# 5. 解决关键字不匹配问题
model_weights_path = './ckpt/resnet18.pth'
# 加载预训练权重
ckpt = torch.load(model_weights_path)
net = resnet18()
# 得到模型参数
model_dict = net.state_dict()
# 判断预训练模型中网络的模块是否修改后的网络中也存在，并且shape相同，如果相同则取出
pretrained_dict = {k:v for k, v in ckpt.items() if k in model_dict and (v.shape == model_dict[k].shape)}
# 更新之后的model_dict
model_dict.update(pretrained_dict)
# 加载真正需要的state_dict
net.load_state_dict(model_dict, strict=True)

<All keys matched successfully>

In [8]:
import torch
pretrained_weights_path = './ckpt/resnet18.pth'
state_dict = torch.load(pretrained_weights_path)
from torchvision.models import resnet18,resnet50
net = resnet50()
model_keys = net.state_dict().keys()
net.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var". 
	size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
	size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for layer2.0.downsample.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.0.downsample.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
	size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
	size mismatch for layer3.0.downsample.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.0.downsample.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
	size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 1, 1]).
	size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048, 1024, 1, 1]).
	size mismatch for layer4.0.downsample.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.0.downsample.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 2048, 1, 1]).
	size mismatch for fc.weight: copying a param with shape torch.Size([1000, 512]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).

In [9]:
# 5. 解决关键字不匹配问题
model_weights_path = './ckpt/resnet18.pth'
# 加载预训练权重
ckpt = torch.load(model_weights_path)
net = resnet50()
# 得到模型参数
model_dict = net.state_dict()
# 判断预训练模型中网络的模块是否修改后的网络中也存在，并且shape相同，如果相同则取出
pretrained_dict = {k:v for k, v in ckpt.items() if k in model_dict and (v.shape == model_dict[k].shape)}
# 更新之后的model_dict
model_dict.update(pretrained_dict)
# 加载真正需要的state_dict
net.load_state_dict(model_dict, strict=True)

<All keys matched successfully>

In [5]:
# 测试ResNet-50
from models.image_encoders.resnet import ResNet50Layer4Lower, ResNet50Layer4Upper
net1 = ResNet50Layer4Lower(pretrained=True)
net2 = ResNet50Layer4Upper(2048,512,pretrained=True,norm_scale=4)

In [6]:
print(net1)

ResNet50Layer4Lower(
  (_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(

In [7]:
print(net2)

ResNet50Layer4Upper(
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=512, bias=True)
)
