In [1]:
import torch
import torchvision.models as models

import time
from datetime import timedelta

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from functions.base import CheckpointRunner
from functions.evaluator import Evaluator
from models.classifier import Classifier
from models.losses.classifier import CrossEntropyLoss
from models.losses.p2m import P2MLoss
from models.p2m import P2MModel
from utils.average_meter import AverageMeter
from utils.mesh import Ellipsoid
from utils.tensor import recursive_detach
from utils.vis.renderer import MeshRenderer
from options import update_options, options, reset_options

  from .collection import imread_collection_wrapper


In [2]:
ellipsoid = Ellipsoid(mesh_pos=[0.0, 0.0, -0.8])
model = P2MModel(options.model, ellipsoid,
                                      options.dataset.camera_f, options.dataset.camera_c,
                                      options.dataset.mesh_pos)

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import config
import os

import numpy as np

from models.losses.p2m import P2MLoss
from utils.mesh import Ellipsoid
from utils.average_meter import AverageMeter
from functions.evaluator import Evaluator

def summarize_model(model):
    layers = [(name if len(name) > 0 else 'TOTAL', str(module.__class__.__name__), sum(np.prod(p.shape) for p in module.parameters())) for name, module in model.named_modules()]
    layers.append(layers[0])
    del layers[0]

    columns = [
        [" ", list(map(str, range(len(layers))))],
        ["Name", [layer[0] for layer in layers]],
        ["Type", [layer[1] for layer in layers]],
        ["Params", [layer[2] for layer in layers]],
    ]

    n_rows = len(columns[0][1])
    n_cols = 1 + len(columns)

    # Get formatting width of each column
    col_widths = []
    for c in columns:
        col_width = max(len(str(a)) for a in c[1]) if n_rows else 0
        col_width = max(col_width, len(c[0]))  # minimum length is header length
        col_widths.append(col_width)

    # Formatting
    s = "{:<{}}"
    total_width = sum(col_widths) + 3 * n_cols
    header = [s.format(c[0], l) for c, l in zip(columns, col_widths)]

    summary = " | ".join(header) + "\n" + "-" * total_width
    for i in range(n_rows):
        line = []
        for c, l in zip(columns, col_widths):
            line.append(s.format(str(c[1][i]), l))
        summary += "\n" + " | ".join(line)

    return summary

In [4]:
print(summarize_model(model))

    | Name                             | Type              | Params    
-----------------------------------------------------------------------------
0   | nn_encoder                       | P2MResNet         | 25557032  
1   | nn_encoder.conv1                 | Conv2d            | 9408      
2   | nn_encoder.bn1                   | BatchNorm2d       | 128       
3   | nn_encoder.relu                  | ReLU              | 0         
4   | nn_encoder.maxpool               | MaxPool2d         | 0         
5   | nn_encoder.layer1                | Sequential        | 215808    
6   | nn_encoder.layer1.0              | Bottleneck        | 75008     
7   | nn_encoder.layer1.0.conv1        | Conv2d            | 4096      
8   | nn_encoder.layer1.0.bn1          | BatchNorm2d       | 128       
9   | nn_encoder.layer1.0.conv2        | Conv2d            | 36864     
10  | nn_encoder.layer1.0.bn2          | BatchNorm2d       | 128       
11  | nn_encoder.layer1.0.conv3        | Conv2d           

In [5]:
checkpoint = torch.load('checkpoints/resnet.pth.tar')	# 加载模型
params=model.state_dict()
print(checkpoint['model'].keys())												# 查看模型元素

odict_keys(['init_pts', 'nn_encoder.conv1.weight', 'nn_encoder.bn1.weight', 'nn_encoder.bn1.bias', 'nn_encoder.bn1.running_mean', 'nn_encoder.bn1.running_var', 'nn_encoder.bn1.num_batches_tracked', 'nn_encoder.layer1.0.conv1.weight', 'nn_encoder.layer1.0.bn1.weight', 'nn_encoder.layer1.0.bn1.bias', 'nn_encoder.layer1.0.bn1.running_mean', 'nn_encoder.layer1.0.bn1.running_var', 'nn_encoder.layer1.0.bn1.num_batches_tracked', 'nn_encoder.layer1.0.conv2.weight', 'nn_encoder.layer1.0.bn2.weight', 'nn_encoder.layer1.0.bn2.bias', 'nn_encoder.layer1.0.bn2.running_mean', 'nn_encoder.layer1.0.bn2.running_var', 'nn_encoder.layer1.0.bn2.num_batches_tracked', 'nn_encoder.layer1.0.conv3.weight', 'nn_encoder.layer1.0.bn3.weight', 'nn_encoder.layer1.0.bn3.bias', 'nn_encoder.layer1.0.bn3.running_mean', 'nn_encoder.layer1.0.bn3.running_var', 'nn_encoder.layer1.0.bn3.num_batches_tracked', 'nn_encoder.layer1.0.downsample.0.weight', 'nn_encoder.layer1.0.downsample.1.weight', 'nn_encoder.layer1.0.downsample.

In [6]:
checkpoint = torch.load('checkpoints/resnet.pth.tar')	# 加载模型
params=model.state_dict()

In [7]:
for k,v in params.items():
    print(k) #打印网络中的变量名

init_pts
nn_encoder.conv1.weight
nn_encoder.bn1.weight
nn_encoder.bn1.bias
nn_encoder.bn1.running_mean
nn_encoder.bn1.running_var
nn_encoder.bn1.num_batches_tracked
nn_encoder.layer1.0.conv1.weight
nn_encoder.layer1.0.bn1.weight
nn_encoder.layer1.0.bn1.bias
nn_encoder.layer1.0.bn1.running_mean
nn_encoder.layer1.0.bn1.running_var
nn_encoder.layer1.0.bn1.num_batches_tracked
nn_encoder.layer1.0.conv2.weight
nn_encoder.layer1.0.bn2.weight
nn_encoder.layer1.0.bn2.bias
nn_encoder.layer1.0.bn2.running_mean
nn_encoder.layer1.0.bn2.running_var
nn_encoder.layer1.0.bn2.num_batches_tracked
nn_encoder.layer1.0.conv3.weight
nn_encoder.layer1.0.bn3.weight
nn_encoder.layer1.0.bn3.bias
nn_encoder.layer1.0.bn3.running_mean
nn_encoder.layer1.0.bn3.running_var
nn_encoder.layer1.0.bn3.num_batches_tracked
nn_encoder.layer1.0.downsample.0.weight
nn_encoder.layer1.0.downsample.1.weight
nn_encoder.layer1.0.downsample.1.bias
nn_encoder.layer1.0.downsample.1.running_mean
nn_encoder.layer1.0.downsample.1.running_

In [8]:
def string_rename(old_string, new_string, start, end):
    new_string = old_string[:start] + new_string + old_string[end:]
    return new_string

In [17]:
def get_bigmodule_list(children):
    Bigmodule = []
    for name, module in children:
        Bigmodule.append(name)
    return Bigmodule
    

In [18]:
def modify_state_dict(pretrained_dict, model_dict,submodule, old_prefix, new_prefix):
    state_dict = {}
    for k, v in pretrained_dict.items():
        if  k.startswith("nn_encoder"):
            print("Missing key(s) in state_dict :{}".format(k))
        else:
            if k not in old_prefix:
            # state_dict.setdefault(k, v)
                state_dict[k] = v
            else:
                for o, n in zip(old_prefix, new_prefix):
                    prefix = k[:len(o)]
                    if prefix == o:
                        kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))
                        print("rename layer modules:{}-->{}".format(k, kk))
                        state_dict[kk] = v
    return state_dict

In [19]:
state_dict = modify_state_dict(checkpoint["model"], params,get_bigmodule_list(model.named_children()), old_prefix=['gcns.2.conv2.weight','gcns.2.conv2.loop_weight','gcns.2.conv2.bias'], new_prefix=['gcns.3.conv2.weight','gcns.3.conv2.loop_weight','gcns.3.conv2.bias'])
model.load_state_dict(state_dict,strict=False)

Missing key(s) in state_dict :nn_encoder.conv1.weight
Missing key(s) in state_dict :nn_encoder.bn1.weight
Missing key(s) in state_dict :nn_encoder.bn1.bias
Missing key(s) in state_dict :nn_encoder.bn1.running_mean
Missing key(s) in state_dict :nn_encoder.bn1.running_var
Missing key(s) in state_dict :nn_encoder.bn1.num_batches_tracked
Missing key(s) in state_dict :nn_encoder.layer1.0.conv1.weight
Missing key(s) in state_dict :nn_encoder.layer1.0.bn1.weight
Missing key(s) in state_dict :nn_encoder.layer1.0.bn1.bias
Missing key(s) in state_dict :nn_encoder.layer1.0.bn1.running_mean
Missing key(s) in state_dict :nn_encoder.layer1.0.bn1.running_var
Missing key(s) in state_dict :nn_encoder.layer1.0.bn1.num_batches_tracked
Missing key(s) in state_dict :nn_encoder.layer1.0.conv2.weight
Missing key(s) in state_dict :nn_encoder.layer1.0.bn2.weight
Missing key(s) in state_dict :nn_encoder.layer1.0.bn2.bias
Missing key(s) in state_dict :nn_encoder.layer1.0.bn2.running_mean
Missing key(s) in state_

IncompatibleKeys(missing_keys=['nn_encoder.conv1.weight', 'nn_encoder.bn1.weight', 'nn_encoder.bn1.bias', 'nn_encoder.bn1.running_mean', 'nn_encoder.bn1.running_var', 'nn_encoder.layer1.0.conv1.weight', 'nn_encoder.layer1.0.bn1.weight', 'nn_encoder.layer1.0.bn1.bias', 'nn_encoder.layer1.0.bn1.running_mean', 'nn_encoder.layer1.0.bn1.running_var', 'nn_encoder.layer1.0.conv2.weight', 'nn_encoder.layer1.0.bn2.weight', 'nn_encoder.layer1.0.bn2.bias', 'nn_encoder.layer1.0.bn2.running_mean', 'nn_encoder.layer1.0.bn2.running_var', 'nn_encoder.layer1.0.conv3.weight', 'nn_encoder.layer1.0.bn3.weight', 'nn_encoder.layer1.0.bn3.bias', 'nn_encoder.layer1.0.bn3.running_mean', 'nn_encoder.layer1.0.bn3.running_var', 'nn_encoder.layer1.0.downsample.0.weight', 'nn_encoder.layer1.0.downsample.1.weight', 'nn_encoder.layer1.0.downsample.1.bias', 'nn_encoder.layer1.0.downsample.1.running_mean', 'nn_encoder.layer1.0.downsample.1.running_var', 'nn_encoder.layer1.1.conv1.weight', 'nn_encoder.layer1.1.bn1.weigh

In [10]:
print(model.named_children())

<generator object Module.named_children at 0x7faa0127ff50>


In [13]:
Bigmodule = []
for name, module in model.named_children():
    Bigmodule.append(name)
print(Bigmodule)

['nn_encoder', 'gcns', 'unpooling', 'projection', 'gconv']


In [42]:
from collections import OrderedDict
model_v1 = OrderedDict(model.named_children())
    # remove avgpool,fc
model_v1.pop("nn_encoder")
model_v1 = torch.nn.Sequential(model_v1)

In [49]:
for name, module in model_v1.named_children():
    print(name)

gcns
unpooling
projection
gconv


In [53]:
print(params.keys())

odict_keys(['gcns.0.blocks.0.conv1.adj_mat', 'gcns.0.blocks.0.conv1.weight', 'gcns.0.blocks.0.conv1.loop_weight', 'gcns.0.blocks.0.conv1.bias', 'gcns.0.blocks.0.conv2.adj_mat', 'gcns.0.blocks.0.conv2.weight', 'gcns.0.blocks.0.conv2.loop_weight', 'gcns.0.blocks.0.conv2.bias', 'gcns.0.blocks.1.conv1.adj_mat', 'gcns.0.blocks.1.conv1.weight', 'gcns.0.blocks.1.conv1.loop_weight', 'gcns.0.blocks.1.conv1.bias', 'gcns.0.blocks.1.conv2.adj_mat', 'gcns.0.blocks.1.conv2.weight', 'gcns.0.blocks.1.conv2.loop_weight', 'gcns.0.blocks.1.conv2.bias', 'gcns.0.blocks.2.conv1.adj_mat', 'gcns.0.blocks.2.conv1.weight', 'gcns.0.blocks.2.conv1.loop_weight', 'gcns.0.blocks.2.conv1.bias', 'gcns.0.blocks.2.conv2.adj_mat', 'gcns.0.blocks.2.conv2.weight', 'gcns.0.blocks.2.conv2.loop_weight', 'gcns.0.blocks.2.conv2.bias', 'gcns.0.blocks.3.conv1.adj_mat', 'gcns.0.blocks.3.conv1.weight', 'gcns.0.blocks.3.conv1.loop_weight', 'gcns.0.blocks.3.conv1.bias', 'gcns.0.blocks.3.conv2.adj_mat', 'gcns.0.blocks.3.conv2.weight',

In [37]:
torch.save(model_v1.state_dict(), 'model_v1.pth')

RuntimeError: sparse tensors do not have storage