Skip to content

Conversation

@dariocazzani
Copy link
Contributor

What does this PR do?

It moves the model to the same device as all the tensors

@dariocazzani dariocazzani requested a review from ayeganov August 22, 2025 21:06
latest_model_path = get_latest_model_weights_path(args.experiment)

model = TransformerLanguageModel(NUM_HEADS, tokenizer.vocab_size, N_EMBED, BLOCK_SIZE, NUM_BLOCKS)
model = model.to(DEVICE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should fix it in load_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you want to change the signature, we can't.
We need to assign model = model.to(DEVICE)

So.. we either just add that line, or we change

def load_model(model_path: str, model: nn.Module, device: torch.device) -> None

to

def load_model(model_path: str, model: nn.Module, device: torch.device) -> nn.Module
    ...
    return model

But it's messy because we are doing:

    else:
        print("No model path exists, proceeding with a new model")

In my experience, I always have a line like

model = model.to(DEVICE) 

after creation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha - then I think we can remove model.to from load_model and add this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also - we can do return model.to(DEVICE) from load_model. Since we always provide a valid model pointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.. if you disagree, I don't care

@dariocazzani dariocazzani force-pushed the hotfix/model_to_device branch from cf9bdf6 to fd1cde2 Compare August 22, 2025 23:55
@dariocazzani dariocazzani requested a review from ayeganov August 22, 2025 23:55
@dariocazzani dariocazzani force-pushed the hotfix/model_to_device branch from fd1cde2 to cad5896 Compare August 22, 2025 23:56
Copy link
Contributor

@ayeganov ayeganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ship it!

@dariocazzani dariocazzani merged commit 3536277 into main Aug 23, 2025
@dariocazzani dariocazzani deleted the hotfix/model_to_device branch August 23, 2025 01:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants