-
Notifications
You must be signed in to change notification settings - Fork 0
Move model to device #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
scratchgpt/main.py
Outdated
| latest_model_path = get_latest_model_weights_path(args.experiment) | ||
|
|
||
| model = TransformerLanguageModel(NUM_HEADS, tokenizer.vocab_size, N_EMBED, BLOCK_SIZE, NUM_BLOCKS) | ||
| model = model.to(DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should fix it in load_model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless you want to change the signature, we can't.
We need to assign model = model.to(DEVICE)
So.. we either just add that line, or we change
def load_model(model_path: str, model: nn.Module, device: torch.device) -> Noneto
def load_model(model_path: str, model: nn.Module, device: torch.device) -> nn.Module
...
return model
But it's messy because we are doing:
else:
print("No model path exists, proceeding with a new model")In my experience, I always have a line like
model = model.to(DEVICE)
after creation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha - then I think we can remove model.to from load_model and add this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also - we can do return model.to(DEVICE) from load_model. Since we always provide a valid model pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.. if you disagree, I don't care
cf9bdf6 to
fd1cde2
Compare
fd1cde2 to
cad5896
Compare
ayeganov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ship it!
What does this PR do?
It moves the model to the same device as all the tensors