Skip to content

Problem with loading checkpoint of a model with embeddings #2359

@narain1

Description

@narain1

🐛 Bug

Unable to load from checkpoint for model with embeddings

Code sample

model arch

class Model(pl.LightningModule):
      def __init__(self, emb_szs):
            super().__init__()
            m = get_base()
            self.enc =  nn.Sequential(*list(m.children())[:-1], nn.Flatten())    
            nc = list(m.children())[-1].in_features
            self.head = nn.Sequential(nn.Linear(2*nc+25,512),Mish(),
                                nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,2))
            self.embs = nn.ModuleList([nn.Embedding(c, s) for c,s in emb_szs])
    
      def forward(self, xb, x_cat, x_cont):
             x1 = [e(x_cat[:,i]-1) for i,e in enumerate(self.embs)]
             x1 = torch.cat(x1, 1)
             x_img = self.enc(xb)
             x = torch.cat([x1, x_cont.unsqueeze(1)], 1)
             x = torch.cat([x, x_img], 1)
             return self.head(x)
  checkpoint_callback = ModelCheckpoint(
             filepath=os.path.join(os.getcwd(), 'model_dir'),
             #     save_top_k=True,
             verbose=True,
             monitor='val_loss',
             mode='min',
             prefix=''
             )

   trainer = Trainer(max_epochs=15, 
              early_stop_callback = early_stopping,
              gpus=1,
              gradient_clip_val=1.0,
              weights_save_path=os.getcwd(),
              checkpoint_callback = checkpoint_callback,
              num_sanity_val_steps=0
             )

the training loop has no problem but when I call trainer.test() a runtime error arrises

RuntimeError: Error(s) in loading state_dict for Model:
Unexpected key(s) in state_dict: "embs.0.weight", "embs.1.weight", "embs.2.weight", "embs.3.weight".

Expected behavior

As in the documentation It should have used the best checkpoint for test but loading checkpoint fails

Environment

  • CUDA:
    • GPU:
      • Tesla P100-PCIE-16GB
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.1
    • pyTorch_debug: False
    • pyTorch_version: 1.5.1
    • pytorch-lightning: 0.8.1
    • tensorboard: 2.2.2
    • tqdm: 4.45.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.6
    • version: Proposal for help #1 SMP Sat Jun 13 11:04:33 PDT 2020

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions