Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The "do-eval" option will rewrite trained model #1

Closed
zjjj opened this issue Jul 27, 2021 · 1 comment
Closed

The "do-eval" option will rewrite trained model #1

zjjj opened this issue Jul 27, 2021 · 1 comment

Comments

@zjjj
Copy link

zjjj commented Jul 27, 2021

The second command in run.sh: main.py --do-eval --quantify --model-type roberta --prefix 0524 --filename final --task ast will rewrite the trained model checkpoint. Also it does not seem that the code has the ability to load a saved model.

@ghost
Copy link

ghost commented Jul 30, 2021

I see, yea the saving of checkpoints occurs with evaluation. You can comment this out, add a --do-save flag, or change the name of the checkpoint file. I would recommend the third option, which was in the original code before I cleaned things up. For example, something like:

def save_pretrained(args, filepath=None):
    if filepath is None:
        filepath = os.path.join(args.output_dir, f'{args.task}_model_{args.seed}.pt')
    torch.save(self.state_dict(), filepath)
    print(f"Model weights saved in {filepath}")

In terms of loading a checkpoint, you can use any typical loading function. An example:

def load_checkpoint(model, load_dir, checkpoint_name):  
    ckpt_path = os.path.join(load_dir, checkpoint_name)
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(checkpoint)
    model.eval()
    model.to(device)
    return model

Hope that helps!

@ghost ghost closed this as completed Jan 4, 2022
This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant