In [1]:
import os
import logging
import sys
import argparse
import re
import queue
import threading
from math import ceil
from datetime import datetime
from tqdm import tqdm
import cv2
from PIL import Image
import PIL
from torch.backends import cudnn
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
import importlib

import torch.nn.functional as F
import numpy as np
import transforms.transforms as extended_transforms

from collections import OrderedDict
from datasets import MSD, Trans10k, GDD
from optimizer import restore_snapshot

import transforms.joint_transforms as joint_transforms

from utils.my_data_parallel import MyDataParallel
from utils.misc import fast_hist, save_log, \
    evaluate_eval_for_inference, cal_mae, cal_ber, evaluate_eval_for_inference_with_mae_ber

from network.EBLNet import EBLNet_resnet50_os8
from config import assert_cfg_vid


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set up configurations with a custom function instead of the validated assert_and_infer_cfg
assert_cfg_vid()

In [3]:
# Initialise a model
model = EBLNet_resnet50_os8(num_classes=3,
                            criterion=None,
                            num_cascade=1,
                            num_points=96,
                            threshold=0.85)



In [8]:
type(model)

network.EBLNet.EBLNet

In [4]:
checkpoint_temp = torch.load("checkpoints/Trans10k_resnet50_os8.pth")
print(checkpoint_temp.keys())
state_dict_temp = checkpoint_temp['state_dict']

dict_keys(['state_dict', 'optimizer', 'epoch', 'mean_iu', 'command'])


In [5]:
state_dict_temp.keys()

odict_keys(['module.layer0.0.0.weight', 'module.layer0.0.1.weight', 'module.layer0.0.1.bias', 'module.layer0.0.1.running_mean', 'module.layer0.0.1.running_var', 'module.layer0.0.1.num_batches_tracked', 'module.layer0.0.3.weight', 'module.layer0.0.4.weight', 'module.layer0.0.4.bias', 'module.layer0.0.4.running_mean', 'module.layer0.0.4.running_var', 'module.layer0.0.4.num_batches_tracked', 'module.layer0.0.6.weight', 'module.layer0.1.weight', 'module.layer0.1.bias', 'module.layer0.1.running_mean', 'module.layer0.1.running_var', 'module.layer0.1.num_batches_tracked', 'module.layer1.0.conv1.weight', 'module.layer1.0.bn1.weight', 'module.layer1.0.bn1.bias', 'module.layer1.0.bn1.running_mean', 'module.layer1.0.bn1.running_var', 'module.layer1.0.bn1.num_batches_tracked', 'module.layer1.0.conv2.weight', 'module.layer1.0.bn2.weight', 'module.layer1.0.bn2.bias', 'module.layer1.0.bn2.running_mean', 'module.layer1.0.bn2.running_var', 'module.layer1.0.bn2.num_batches_tracked', 'module.layer1.0.con

In [6]:
state_dict_temp.keys()
new_state_dict = OrderedDict()
for k, v in state_dict_temp.items():
    name = k[7:]
    new_state_dict[name] = v

In [7]:
model.load_state_dict(new_state_dict)

RuntimeError: Error(s) in loading state_dict for EBLNet:
	Missing key(s) in state_dict: "final_seg_out_pre.0.3.weight", "final_seg_out_pre.0.4.weight", "final_seg_out_pre.0.4.bias", "final_seg_out_pre.0.4.running_mean", "final_seg_out_pre.0.4.running_var". 
	Unexpected key(s) in state_dict: "body_fines.1.weight", "body_fines.2.weight", "body_fuse.1.weight", "body_fuse.2.weight", "edge_extractors.1.pre_extractor.0.weight", "edge_extractors.1.pre_extractor.1.weight", "edge_extractors.1.pre_extractor.1.bias", "edge_extractors.1.pre_extractor.1.running_mean", "edge_extractors.1.pre_extractor.1.running_var", "edge_extractors.1.pre_extractor.1.num_batches_tracked", "edge_extractors.1.extractor.0.weight", "edge_extractors.1.extractor.1.weight", "edge_extractors.1.extractor.1.bias", "edge_extractors.1.extractor.1.running_mean", "edge_extractors.1.extractor.1.running_var", "edge_extractors.1.extractor.1.num_batches_tracked", "edge_extractors.2.pre_extractor.0.weight", "edge_extractors.2.pre_extractor.1.weight", "edge_extractors.2.pre_extractor.1.bias", "edge_extractors.2.pre_extractor.1.running_mean", "edge_extractors.2.pre_extractor.1.running_var", "edge_extractors.2.pre_extractor.1.num_batches_tracked", "edge_extractors.2.extractor.0.weight", "edge_extractors.2.extractor.1.weight", "edge_extractors.2.extractor.1.bias", "edge_extractors.2.extractor.1.running_mean", "edge_extractors.2.extractor.1.running_var", "edge_extractors.2.extractor.1.num_batches_tracked", "refines.1.gcn.conv_adj.weight", "refines.1.gcn.bn_adj.weight", "refines.1.gcn.bn_adj.bias", "refines.1.gcn.bn_adj.running_mean", "refines.1.gcn.bn_adj.running_var", "refines.1.gcn.bn_adj.num_batches_tracked", "refines.1.gcn.conv_wg.weight", "refines.1.gcn.bn_wg.weight", "refines.1.gcn.bn_wg.bias", "refines.1.gcn.bn_wg.running_mean", "refines.1.gcn.bn_wg.running_var", "refines.1.gcn.bn_wg.num_batches_tracked", "refines.2.gcn.conv_adj.weight", "refines.2.gcn.bn_adj.weight", "refines.2.gcn.bn_adj.bias", "refines.2.gcn.bn_adj.running_mean", "refines.2.gcn.bn_adj.running_var", "refines.2.gcn.bn_adj.num_batches_tracked", "refines.2.gcn.conv_wg.weight", "refines.2.gcn.bn_wg.weight", "refines.2.gcn.bn_wg.bias", "refines.2.gcn.bn_wg.running_mean", "refines.2.gcn.bn_wg.running_var", "refines.2.gcn.bn_wg.num_batches_tracked", "edge_out_pre.1.0.weight", "edge_out_pre.1.1.weight", "edge_out_pre.1.1.bias", "edge_out_pre.1.1.running_mean", "edge_out_pre.1.1.running_var", "edge_out_pre.1.1.num_batches_tracked", "edge_out_pre.2.0.weight", "edge_out_pre.2.1.weight", "edge_out_pre.2.1.bias", "edge_out_pre.2.1.running_mean", "edge_out_pre.2.1.running_var", "edge_out_pre.2.1.num_batches_tracked", "edge_out.1.weight", "edge_out.2.weight", "body_out_pre.1.0.weight", "body_out_pre.1.1.weight", "body_out_pre.1.1.bias", "body_out_pre.1.1.running_mean", "body_out_pre.1.1.running_var", "body_out_pre.1.1.num_batches_tracked", "body_out_pre.2.0.weight", "body_out_pre.2.1.weight", "body_out_pre.2.1.bias", "body_out_pre.2.1.running_mean", "body_out_pre.2.1.running_var", "body_out_pre.2.1.num_batches_tracked", "body_out.1.weight", "body_out.2.weight", "final_seg_out_pre.1.0.weight", "final_seg_out_pre.1.1.weight", "final_seg_out_pre.1.1.bias", "final_seg_out_pre.1.1.running_mean", "final_seg_out_pre.1.1.running_var", "final_seg_out_pre.1.1.num_batches_tracked", "final_seg_out_pre.2.0.weight", "final_seg_out_pre.2.1.weight", "final_seg_out_pre.2.1.bias", "final_seg_out_pre.2.1.running_mean", "final_seg_out_pre.2.1.running_var", "final_seg_out_pre.2.1.num_batches_tracked", "final_seg_out_pre.2.3.weight", "final_seg_out_pre.2.4.weight", "final_seg_out_pre.2.4.bias", "final_seg_out_pre.2.4.running_mean", "final_seg_out_pre.2.4.running_var", "final_seg_out_pre.2.4.num_batches_tracked", "final_seg_out.1.weight", "final_seg_out.2.weight". 
	size mismatch for final_seg_out_pre.0.0.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).

In [None]:
def get_model(network, num_classes, criterion, args):
    """
    Fetch Network Function Pointer
    """
    module = network[:network.rfind('.')]
    model = network[network.rfind('.') + 1:]
    mod = importlib.import_module(module)
    net_func = getattr(mod, model)
    if model == 'EBLNet_resnet50_os8' or model == 'EBLNet_resnet50_os16' or \
            model == 'EBLNet_resnet101_os8' or model == 'EBLNet_resnext101_os8':
        net = net_func(num_classes=num_classes, criterion=criterion,
                       num_cascade=args.num_cascade, num_points=args.num_points, threshold=args.thres_gcn)
    else:
        net = net_func(num_classes=num_classes, criterion=criterion)
    return net

def get_net(args, criterion):
    """
    Get Network Architecture based on arguments provided
    """
    net = get_model(network=args.arch, num_classes=args.dataset_cls.num_classes,
                    criterion=criterion, args=args)

    net = net.cuda()
    return net

