Skip to content

Commit fd1cde2

Browse files
committed
Move model to device
1 parent 6fb73a3 commit fd1cde2

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

scratchgpt/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def main() -> None:
280280
latest_model_path = get_latest_model_weights_path(args.experiment)
281281

282282
model = TransformerLanguageModel(NUM_HEADS, tokenizer.vocab_size, N_EMBED, BLOCK_SIZE, NUM_BLOCKS)
283-
load_model(best_model_path, model, DEVICE)
283+
model = load_model(best_model_path, model, DEVICE)
284284

285285
print_model_complexity(model)
286286
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

scratchgpt/model_io.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@ def get_tokenizer_path(exp_folder: str) -> str:
2424
return os.path.join(exp_folder, "tokenizer.pkl")
2525

2626

27-
def load_model(model_path: str, model: nn.Module, device: torch.device) -> None:
27+
def load_model(model_path: str, model: nn.Module, device: torch.device) -> nn.Module:
28+
model.to(device)
2829
if os.path.exists(model_path):
2930
try:
3031
print(f"Loading weights from: {model_path}")
3132
model_dict = torch.load(model_path, map_location=device)
3233
model.load_state_dict(model_dict)
33-
model.to(device)
3434
except Exception:
3535
raise ModelLoadFailed(model_path)
3636
else:
37-
print("No model path exists, proceeding with a new model")
37+
print(f"No saved model, starting from scratch...gpt, lol!")
38+
return model
3839

3940

4041
def get_tokenizer(exp_path: str) -> Tokenizer:

0 commit comments

Comments
 (0)