diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index b883e4c..3be6c60 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import os +from pathlib import Path from typing import Literal, override import torch @@ -11,13 +12,13 @@ class TextProvider(ABC): @abstractmethod - def get_text(self) -> str: + def get_text(self) -> Path: """This method fetches the text from the underlying storage""" class FileTextProvider(TextProvider): - def __init__(self, file_path: str) -> None: - if not os.path.exists(file_path): + def __init__(self, file_path: Path) -> None: + if not file_path.exists(): raise ValueError(f"File path {file_path} does not exist") self._data = "" @@ -30,21 +31,19 @@ def get_text(self) -> str: class FolderTextProvider(TextProvider): - def __init__(self, dir_path: str) -> None: - if not os.path.exists(dir_path): + def __init__(self, dir_path: Path) -> None: + if not dir_path.exists(): raise ValueError(f"Directory path {dir_path} does not exist") - if not os.path.isdir(dir_path): + if not dir_path.is_dir(): raise ValueError(f"Directory path {dir_path} is not a directory") self._data = "" - for root, _, files in os.walk(dir_path): - print(f"Loading data from {root}") - for file_name in files: - if not file_name.startswith("."): - file_path = os.path.join(root, file_name) - with open(file_path, "r", encoding="utf-8") as f: - self._data += f.read() + "\n" # Concatenate with a + for file_path in dir_path.rglob("*"): # Recursively find all files + if file_path.is_file() and not file_path.name.startswith("."): + print(f"Loading data from {file_path.parent}") + with open(file_path, "r", encoding="utf-8") as f: + self._data += f.read() + "\n" # Concatenate with a # newline between files, could be the place to add # special tokens diff --git a/scratchgpt/infer.py b/scratchgpt/infer.py index bc408d0..0642193 100644 --- a/scratchgpt/infer.py +++ b/scratchgpt/infer.py @@ -1,9 +1,10 @@ import argparse -import pathlib import sys +from pathlib import Path -from pydantic_yaml import parse_yaml_file_as import torch +from pydantic_yaml import parse_yaml_file_as +from rich.pretty import pprint as rpprint from scratchgpt.config import ScratchGPTConfig @@ -28,7 +29,7 @@ def parse_args() -> argparse.Namespace: "--experiment", help="The path to the folder where to save experiment checkpoints", required=True, - type=pathlib.Path, + type=Path, ) parser.add_argument( "-m", @@ -42,9 +43,11 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - config_file = args.experiment / "scratch_gpt.yaml" + config_file: Path = args.experiment / "scratch_gpt.yaml" config = parse_yaml_file_as(ScratchGPTConfig, config_file) - print(f"Using config file {config_file}: {config.model_dump_json(indent=2)}") + print(f"Using config file {config_file}") + rpprint(config.model_dump(), indent_guides=True, expand_all=True) + tokenizer = get_tokenizer(args.experiment) device = torch.device(args.device) diff --git a/scratchgpt/main.py b/scratchgpt/main.py index 9274a66..61c2002 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -1,10 +1,12 @@ import argparse import os +from pathlib import Path import sys from typing import Literal -from pydantic_yaml import to_yaml_file import torch +from pydantic_yaml import to_yaml_file, parse_yaml_file_as +from rich.pretty import pprint as rpprint from torch.optim.adamw import AdamW from torch.optim.optimizer import Optimizer from torch.types import Tensor @@ -23,9 +25,7 @@ save_tokenizer, ) -config = ScratchGPTConfig() - -torch.manual_seed(config.training.random_seed) +DatasetType = tuple[Tensor, Tensor] def parse_args() -> argparse.Namespace: @@ -38,14 +38,14 @@ def parse_args() -> argparse.Namespace: "--train_source", help="The file you want to train on", required=True, - type=str, + type=Path, ) parser.add_argument( "-e", "--experiment", help="The path to the folder where to save experiment checkpoints", required=True, - type=str, + type=Path, ) parser.add_argument( "-d", @@ -57,7 +57,18 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -DatasetType = tuple[Tensor, Tensor] +def load_or_create_config(experiment_path: Path) -> ScratchGPTConfig: + """ + Load config from experiment folder if it exists, otherwise create default. + """ + config_path: Path = experiment_path / "scratch_gpt.yaml" + + if config_path.exists(): + print(f"Loading existing config from {config_path}") + return parse_yaml_file_as(ScratchGPTConfig, config_path) + else: + print("No existing config found, creating default config") + return ScratchGPTConfig() def run_epoch( @@ -111,8 +122,8 @@ def run_epoch( return average_loss.value() -def get_text_provider(path: str) -> TextProvider: - if os.path.isdir(path): +def get_text_provider(path: Path) -> TextProvider: + if path.is_dir(): return FolderTextProvider(path) return FileTextProvider(path) @@ -120,6 +131,11 @@ def get_text_provider(path: str) -> TextProvider: def main() -> None: args = parse_args() + config = load_or_create_config(args.experiment) + + torch.manual_seed(config.training.random_seed) + print(f"Set random seed to: {config.training.random_seed}") + device = torch.device(args.device) print(f"Using the device: {device}") @@ -127,7 +143,7 @@ def main() -> None: tokenizer = get_tokenizer(args.experiment) config.architecture.vocab_size = tokenizer.vocab_size - print(config) + rpprint(config.model_dump(), indent_guides=True, expand_all=True) train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9) val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1) @@ -211,4 +227,4 @@ def main() -> None: if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/scratchgpt/model_io.py b/scratchgpt/model_io.py index a34c301..fbe69b7 100644 --- a/scratchgpt/model_io.py +++ b/scratchgpt/model_io.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pickle import torch @@ -13,19 +14,19 @@ class ModelLoadFailed(Exception): pass -def get_best_model_weights_path(exp_folder: str) -> str: - return os.path.join(exp_folder, "best_model_weights.pth") +def get_best_model_weights_path(exp_folder: Path) -> Path: + return exp_folder / "best_model_weights.pth" -def get_latest_model_weights_path(exp_folder: str) -> str: - return os.path.join(exp_folder, "latest_model_weights.pth") +def get_latest_model_weights_path(exp_folder: Path) -> Path: + return exp_folder / "latest_model_weights.pth" -def get_tokenizer_path(exp_folder: str) -> str: - return os.path.join(exp_folder, "tokenizer.pkl") +def get_tokenizer_path(exp_folder: Path) -> Path: + return exp_folder / "tokenizer.pkl" -def load_model(model_path: str, model: TransformerLanguageModel, device: torch.device) -> TransformerLanguageModel: +def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.device) -> TransformerLanguageModel: model.to(device) if os.path.exists(model_path): try: @@ -39,7 +40,7 @@ def load_model(model_path: str, model: TransformerLanguageModel, device: torch.d return model -def get_tokenizer(exp_path: str) -> Tokenizer: +def get_tokenizer(exp_path: Path) -> Tokenizer: tokenizer_path = get_tokenizer_path(exp_path) if os.path.exists(tokenizer_path): with open(tokenizer_path, "rb") as f: @@ -49,7 +50,7 @@ def get_tokenizer(exp_path: str) -> Tokenizer: return tokenizer -def save_tokenizer(exp_path: str, tokenizer: Tokenizer) -> None: +def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None: tokenizer_path = get_tokenizer_path(exp_path) with open(tokenizer_path, "wb") as f: pickle.dump(tokenizer, f)