Skip to content

Running main_test. py alone resulted in an error #17

@shushushulian

Description

@shushushulian

Configure the following code according to the same configuration as the test_stngle. sh file, but there will be an error message indicating inconsistent loading model sizes.
`
from gln.common.cmd_args import cmd_args
import argparse
import os
import random
import numpy as np
import torch
from gln.test.model_inference import RetroGLN

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

random.seed(cmd_args.seed)
np.random.seed(cmd_args.seed)
torch.manual_seed(cmd_args.seed)

cmd_opt = argparse.ArgumentParser(description='Argparser for test only')
cmd_opt.add_argument('-model_for_test', default=None, help='model for test')
local_args = argparse.Namespace(
model_for_test="dropbox/schneider50k.ckpt",
)

cmd_args.dropbox = "GLN-master/dropbox"
cmd_args.data_name = 'schneider50k'
model = RetroGLN(cmd_args.dropbox, local_args.model_for_test)
**Error content**
~/Code/Retro/mutilstep/GLN-master/gln/test/model_inference.py in init(self, dropbox, model_dump)
37 model_file = os.path.join(model_dump, 'model.dump')
38 self.gln = GraphPath(self.args)
---> 39 self.gln.load_state_dict(torch.load(model_file))
40 self.gln.cuda()
41 self.gln.eval()

~/.conda/envs/GLN/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1666 if len(error_msgs) > 0:
1667 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1668 self.class.name, "\n\t".join(error_msgs)))
1669 return _IncompatibleKeys(missing_keys, unexpected_keys)
1670

RuntimeError: Error(s) in loading state_dict for GraphPath:
size mismatch for tpl_fwd_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for tpl_fwd_predicate.tpl_enc.prod_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for tpl_fwd_predicate.tpl_enc.react_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for prod_center_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for prod_center_predicate.prod_center_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for reaction_predicate.prod_enc.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
size mismatch for reaction_predicate.react_enc.react_gnn.w_n2l.weight: copying a param with shape torch.Size([128, 39]) from checkpoint, the shape in current model is torch.Size([128, 66]).
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions