In [1]:
%cd ../../..

/home/user/Projects/denk_baseline


In [2]:
import os
import glob
from pprint import pprint

import torch
import pandas as pd
from omegaconf import OmegaConf

from denk_baseline.datamodules import DataModule
from denk_baseline.lightning_models import SegmentationMulticlassModel
from run import preprocess_config, parse_loggers, get_obj_from_str

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = OmegaConf.load('./suadd_config.yaml')
config = preprocess_config(config)
config['trainer']['params']['gpus'] = None

In [4]:
datamodule = DataModule(config)

In [5]:
exp_name = config['common']['exp_name']
project_name = config['common']['project_name']
save_dir = config['common']['save_dir']

exp_dir = f'./{save_dir}/{project_name}/{exp_name}'
ckpt_paths = glob.glob(exp_dir + '/*.ckpt')

test_models_folder = 'tested_models'
ckpt_paths = list(sorted(ckpt_paths))

for ckpt_path in ckpt_paths:
    model_name = os.path.basename(ckpt_path)
    model = SegmentationMulticlassModel(config)
    model.load_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict'])
    
    state_dict = {}
    m_dict = model.state_dict()
    for name in m_dict:
        state_dict[name.replace('model.', '')] = m_dict[name]
        
    out_folder = f'./{save_dir}/{test_models_folder}/{exp_name}'
    os.makedirs(out_folder, exist_ok=True)
    out_path = f'./{out_folder}/{model_name}'
    torch.save({
        'state_dict': state_dict,
    }, out_path)

_IncompatibleKeys(missing_keys=['patch_embed1.proj.weight'], unexpected_keys=['decode_head.conv_seg.weight', 'decode_head.linear_pred.weight', 'decode_head.conv_seg.bias', 'decode_head.linear_pred.bias', 'decode_head.linear_c4.proj.weight', 'decode_head.linear_c4.proj.bias', 'decode_head.linear_c3.proj.weight', 'decode_head.linear_c3.proj.bias', 'decode_head.linear_c2.proj.weight', 'decode_head.linear_c2.proj.bias', 'decode_head.linear_c1.proj.weight', 'decode_head.linear_c1.proj.bias', 'decode_head.linear_fuse.conv.weight', 'decode_head.linear_fuse.bn.weight', 'decode_head.linear_fuse.bn.bias', 'decode_head.linear_fuse.bn.running_mean', 'decode_head.linear_fuse.bn.running_var', 'decode_head.linear_fuse.bn.num_batches_tracked'])
_IncompatibleKeys(missing_keys=['patch_embed1.proj.weight'], unexpected_keys=['decode_head.conv_seg.weight', 'decode_head.linear_pred.weight', 'decode_head.conv_seg.bias', 'decode_head.linear_pred.bias', 'decode_head.linear_c4.proj.weight', 'decode_head.linear_