In [1]:
import sys
sys.path.insert(0,"../")
import os.path as osp
import numpy
import torch 
import torch.nn as nn
import torchvision
import copy
from model.vlocnet import VLocNet
from torchvision import transforms, models
from torch.autograd import Variable

In [2]:

def load_state_dict(model, state_dict):
    """
    Loads a state dict when the model has some prefix before the parameter names
    :param model:
    :param state_dict:
    :return: loaded model
    """
    model_names = [n for n, _ in model.named_parameters()]
    state_names = [n for n in state_dict.keys()]

    # find prefix for the model and state dicts from the first param name
    if model_names[0].find(state_names[0]) >= 0:
        model_prefix = model_names[0].replace(state_names[0], '')
        state_prefix = None
    elif state_names[0].find(model_names[0]) >= 0:
        state_prefix = state_names[0].replace(model_names[0], '')
        model_prefix = None
    else:
        print('Could not find the correct prefixes between {:s} and {:s}'.
              format(model_names[0], state_names[0]))
        raise KeyError

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if state_prefix is None:
            k = model_prefix + k
        else:
            k = k.replace(state_prefix, '')
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

In [3]:
def load_resnet50(model, state_dict):
    """
    Loads a state dict when the model has some prefix before the parameter names
    :param model:
    :param state_dict:
    :return: loaded model
    """
    model_names = [n for n, _ in model.named_parameters()]
#     model_names = [n for n, _ in model.state_dict().keys()]
    state_names = [n for n in state_dict.keys()]
    
    from collections import OrderedDict
#     new_state_dict = OrderedDict()
    new_state_dict= model.state_dict()
    # find prefix for the model and state dicts from the first param name
    counter=0
    marker=copy.deepcopy(state_dict)
    for mn in model_names:    
        for sn in state_names:
            if('resModule' in mn and mn.split('resModule.')[1] == sn):
                
                if(new_state_dict[mn].size()==state_dict[sn].size()):
                    if('layer1.0.bn3.bias' in mn):
                        print('layer1.0.bn3.bias','!!!')
                    counter += 1
                    new_state_dict[mn] = state_dict[sn]
                    marker.pop(sn, default=0)
                else:
                    print(mn)
            elif('head' in mn and  '.'.join(mn.split('.')[1:]) == sn):
                if(new_state_dict[mn].size()==state_dict[sn].size()):
                    counter += 1
                    new_state_dict[mn] = state_dict[sn]
                    marker.pop(sn, default=0)
                else:
                    print(mn)
    
    print('dropped parameters weights: ', marker.keys())
    print('# assigned parameters: ',counter)
    print(len(new_state_dict))

    model.load_state_dict(new_state_dict)
    return model

In [4]:
mn='odom_en1.0.resModule.layer1.0.downsample.1.weight'#.split('resModule.')
'.'.join(mn.split('resModule.')[1])

'l.a.y.e.r.1...0...d.o.w.n.s.a.m.p.l.e...1...w.e.i.g.h.t'

In [5]:
# model.state_dict().keys()

In [6]:
# resnet50 = models.resnet50()

In [7]:
# resnet50.state_dict().keys()

# Load test

In [8]:
model_path='/home/sensetime/DATA/PretrainedModels/ResNet_pytorch/check/resnet50-19c8e357.pth'

In [9]:
checkpoint = torch.load(
                    model_path)

In [10]:
model = VLocNet(share_levels_n=3, recur_pose='')
# model = VLocNet(share_levels_n=3, recur_pose=False)
# params = [p for p in model.named_parameters()]
params = [p for p in model.state_dict()]
print(params)

['odom_en1_head.conv1.weight', 'odom_en1_head.bn1.weight', 'odom_en1_head.bn1.bias', 'odom_en1_head.bn1.running_mean', 'odom_en1_head.bn1.running_var', 'odom_en1_head.bn1.num_batches_tracked', 'odom_en2_head.conv1.weight', 'odom_en2_head.bn1.weight', 'odom_en2_head.bn1.bias', 'odom_en2_head.bn1.running_mean', 'odom_en2_head.bn1.running_var', 'odom_en2_head.bn1.num_batches_tracked', 'global_en_head.conv1.weight', 'global_en_head.bn1.weight', 'global_en_head.bn1.bias', 'global_en_head.bn1.running_mean', 'global_en_head.bn1.running_var', 'global_en_head.bn1.num_batches_tracked', 'odom_en1.0.resModule.layer1.0.conv1.weight', 'odom_en1.0.resModule.layer1.0.bn1.weight', 'odom_en1.0.resModule.layer1.0.bn1.bias', 'odom_en1.0.resModule.layer1.0.bn1.running_mean', 'odom_en1.0.resModule.layer1.0.bn1.running_var', 'odom_en1.0.resModule.layer1.0.bn1.num_batches_tracked', 'odom_en1.0.resModule.layer1.0.conv2.weight', 'odom_en1.0.resModule.layer1.0.bn2.weight', 'odom_en1.0.resModule.layer1.0.bn2.bias

In [11]:
model=load_resnet50(model, checkpoint)
torch.save(model.state_dict(), osp.join(osp.split(model_path)[0], 'vlocnet.pth' ))
params=[p for p in model.named_parameters()]
print(params[0])

layer1.0.bn3.bias !!!
layer1.0.bn3.bias !!!
odom_final_res.resModule.layer4.0.conv1.weight
odom_final_res.resModule.layer4.0.downsample.0.weight
dropped parameters weights:  odict_keys(['bn1.running_mean', 'bn1.running_var', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn3.running_mean', 'layer1.1.bn3.running_var', 'layer1.2.bn1.running_mean', 'layer1.2.bn1.running_var', 'layer1.2.bn2.running_mean', 'layer1.2.bn2.running_var', 'layer1.2.bn3.running_mean', 'layer1.2.bn3.running_var', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn2.running_mean', 'layer2.0.bn2.running_var', 'layer2.0.bn3.running_mean', 'layer2.0.bn3.running_var', 'layer2.0.downsa

In [12]:
print(model.state_dict()['odom_en1_head.conv1.weight'])

tensor([[[[ 0.0133,  0.0147, -0.0154,  ..., -0.0409, -0.0430, -0.0708],
          [ 0.0041,  0.0058,  0.0149,  ...,  0.0022, -0.0209, -0.0385],
          [ 0.0223,  0.0236,  0.0161,  ...,  0.1028,  0.0626,  0.0520],
          ...,
          [-0.0009,  0.0278, -0.0101,  ..., -0.1272, -0.0766,  0.0078],
          [ 0.0036,  0.0480,  0.0621,  ...,  0.0243, -0.0337, -0.0157],
          [-0.0800, -0.0322, -0.0178,  ...,  0.0354,  0.0224,  0.0017]],

         [[-0.0185,  0.0114,  0.0239,  ...,  0.0537,  0.0440, -0.0095],
          [-0.0077,  0.0189,  0.0680,  ...,  0.1596,  0.1461,  0.1200],
          [-0.0460, -0.0761, -0.0896,  ...,  0.1211,  0.1670,  0.1762],
          ...,
          [ 0.0288,  0.0137, -0.0838,  ..., -0.3808, -0.3041, -0.1397],
          [ 0.0829,  0.1386,  0.1524,  ..., -0.0051, -0.1244, -0.1297],
          [-0.0073,  0.0770,  0.1400,  ...,  0.1843,  0.1114,  0.0234]],

         [[-0.0183, -0.0056,  0.0087,  ...,  0.0258,  0.0264, -0.0040],
          [-0.0101,  0.0042,  

In [13]:
check = torch.load(osp.join(osp.split(model_path)[0], 'vlocnet_init_with_preres50.pth' ))
# check.keys()
print(check['Odometry_en1.conv1.weight'])

FileNotFoundError: [Errno 2] No such file or directory: '/home/sensetime/DATA/PretrainedModels/ResNet_pytorch/check/vlocnet_init_with_preres50.pth'

In [14]:
states = [s for s in model.state_dict()]

In [115]:
print(len(states))

1086


In [43]:
x= torch.ones(1, requires_grad=True)
x_ = x
y=3*x

# DEBUG

In [65]:
input_size = 224
dummy_input = [
    Variable(torch.rand(1, 2, 3, input_size, input_size)),
    Variable(torch.rand(1, 2, 7))
]
# self.summary_writer.add_graph(self.model, dummy_input)

In [66]:
model.train()
output=model(dummy_input)

In [67]:
output[1].backward(torch.ones_like(output[1]))

In [68]:
print(model.state_dict()['global_en_head.conv1.weight'].grad)

None


In [91]:
for n, v in model.named_parameters():
#     if v.requires_grad:
    if n=='odom_en2_head.conv1.weight':
        print (v)

Parameter containing:
tensor([[[[ 6.0005e-02,  3.7458e-02,  5.9268e-02,  ...,  6.7889e-02,
            4.6262e-02,  1.2856e-02],
          [-1.8131e-02, -3.2489e-02,  1.4438e-02,  ..., -4.5772e-02,
           -4.1270e-02,  3.8354e-02],
          [ 5.4659e-02,  7.8941e-02,  6.3661e-02,  ...,  6.9199e-02,
           -7.4049e-02,  8.2037e-02],
          ...,
          [ 3.3698e-02,  1.5175e-02, -7.4326e-02,  ...,  5.0901e-02,
            6.2583e-02, -3.4065e-02],
          [ 3.9462e-02,  7.0256e-02, -5.8524e-02,  ..., -7.5620e-02,
            6.4321e-02,  3.3046e-02],
          [-3.5835e-02,  2.1033e-02, -5.8699e-02,  ...,  7.7129e-02,
            3.1169e-02,  7.3544e-02]],

         [[-2.8048e-03,  1.5761e-02, -7.1931e-02,  ..., -6.3730e-02,
            5.2744e-02,  1.9380e-02],
          [-7.2924e-02, -1.6188e-03, -3.7955e-02,  ..., -2.5478e-02,
           -2.9791e-03,  6.8586e-03],
          [-2.5446e-02, -6.7102e-02,  7.9395e-02,  ..., -5.5829e-02,
           -7.9303e-02, -6.8617e-02]

In [87]:
model.named_parameters()

<generator object Module.named_parameters at 0x7f0d82e340a0>

In [90]:
model.state_dict()['odom_en2_head.conv1.weight']

tensor([[[[ 6.0005e-02,  3.7458e-02,  5.9268e-02,  ...,  6.7889e-02,
            4.6262e-02,  1.2856e-02],
          [-1.8131e-02, -3.2489e-02,  1.4438e-02,  ..., -4.5772e-02,
           -4.1270e-02,  3.8354e-02],
          [ 5.4659e-02,  7.8941e-02,  6.3661e-02,  ...,  6.9199e-02,
           -7.4049e-02,  8.2037e-02],
          ...,
          [ 3.3698e-02,  1.5175e-02, -7.4326e-02,  ...,  5.0901e-02,
            6.2583e-02, -3.4065e-02],
          [ 3.9462e-02,  7.0256e-02, -5.8524e-02,  ..., -7.5620e-02,
            6.4321e-02,  3.3046e-02],
          [-3.5835e-02,  2.1033e-02, -5.8699e-02,  ...,  7.7129e-02,
            3.1169e-02,  7.3544e-02]],

         [[-2.8048e-03,  1.5761e-02, -7.1931e-02,  ..., -6.3730e-02,
            5.2744e-02,  1.9380e-02],
          [-7.2924e-02, -1.6188e-03, -3.7955e-02,  ..., -2.5478e-02,
           -2.9791e-03,  6.8586e-03],
          [-2.5446e-02, -6.7102e-02,  7.9395e-02,  ..., -5.5829e-02,
           -7.9303e-02, -6.8617e-02],
          ...,
     

In [13]:
print(type(model.state_dict()))

<class 'collections.OrderedDict'>
