diff --git a/scratchgpt/main.py b/scratchgpt/main.py index a338fbc..5873fa9 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -280,7 +280,7 @@ def main() -> None: latest_model_path = get_latest_model_weights_path(args.experiment) model = TransformerLanguageModel(NUM_HEADS, tokenizer.vocab_size, N_EMBED, BLOCK_SIZE, NUM_BLOCKS) - load_model(best_model_path, model, DEVICE) + model = load_model(best_model_path, model, DEVICE) print_model_complexity(model) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) diff --git a/scratchgpt/model_io.py b/scratchgpt/model_io.py index 98401f9..ab72b37 100644 --- a/scratchgpt/model_io.py +++ b/scratchgpt/model_io.py @@ -24,17 +24,18 @@ def get_tokenizer_path(exp_folder: str) -> str: return os.path.join(exp_folder, "tokenizer.pkl") -def load_model(model_path: str, model: nn.Module, device: torch.device) -> None: +def load_model(model_path: str, model: nn.Module, device: torch.device) -> nn.Module: + model.to(device) if os.path.exists(model_path): try: print(f"Loading weights from: {model_path}") model_dict = torch.load(model_path, map_location=device) model.load_state_dict(model_dict) - model.to(device) except Exception: raise ModelLoadFailed(model_path) else: - print("No model path exists, proceeding with a new model") + print("No saved model, starting from scratch...gpt, lol!") + return model def get_tokenizer(exp_path: str) -> Tokenizer: