Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
import os
from pathlib import Path
from typing import Literal, override

import torch
Expand All @@ -11,13 +12,13 @@

class TextProvider(ABC):
@abstractmethod
def get_text(self) -> str:
def get_text(self) -> Path:
"""This method fetches the text from the underlying storage"""


class FileTextProvider(TextProvider):
def __init__(self, file_path: str) -> None:
if not os.path.exists(file_path):
def __init__(self, file_path: Path) -> None:
if not file_path.exists():
raise ValueError(f"File path {file_path} does not exist")

self._data = ""
Expand All @@ -30,21 +31,19 @@ def get_text(self) -> str:


class FolderTextProvider(TextProvider):
def __init__(self, dir_path: str) -> None:
if not os.path.exists(dir_path):
def __init__(self, dir_path: Path) -> None:
if not dir_path.exists():
raise ValueError(f"Directory path {dir_path} does not exist")

if not os.path.isdir(dir_path):
if not dir_path.is_dir():
raise ValueError(f"Directory path {dir_path} is not a directory")

self._data = ""
for root, _, files in os.walk(dir_path):
print(f"Loading data from {root}")
for file_name in files:
if not file_name.startswith("."):
file_path = os.path.join(root, file_name)
with open(file_path, "r", encoding="utf-8") as f:
self._data += f.read() + "\n" # Concatenate with a
for file_path in dir_path.rglob("*"): # Recursively find all files
if file_path.is_file() and not file_path.name.startswith("."):
print(f"Loading data from {file_path.parent}")
with open(file_path, "r", encoding="utf-8") as f:
self._data += f.read() + "\n" # Concatenate with a
# newline between files, could be the place to add
# special tokens

Expand Down
13 changes: 8 additions & 5 deletions scratchgpt/infer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import pathlib
import sys
from pathlib import Path

from pydantic_yaml import parse_yaml_file_as
import torch
from pydantic_yaml import parse_yaml_file_as
from rich.pretty import pprint as rpprint

from scratchgpt.config import ScratchGPTConfig

Expand All @@ -28,7 +29,7 @@ def parse_args() -> argparse.Namespace:
"--experiment",
help="The path to the folder where to save experiment checkpoints",
required=True,
type=pathlib.Path,
type=Path,
)
parser.add_argument(
"-m",
Expand All @@ -42,9 +43,11 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
config_file = args.experiment / "scratch_gpt.yaml"
config_file: Path = args.experiment / "scratch_gpt.yaml"
config = parse_yaml_file_as(ScratchGPTConfig, config_file)
print(f"Using config file {config_file}: {config.model_dump_json(indent=2)}")
print(f"Using config file {config_file}")
rpprint(config.model_dump(), indent_guides=True, expand_all=True)

tokenizer = get_tokenizer(args.experiment)

device = torch.device(args.device)
Expand Down
38 changes: 27 additions & 11 deletions scratchgpt/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import os
from pathlib import Path
import sys
from typing import Literal

from pydantic_yaml import to_yaml_file
import torch
from pydantic_yaml import to_yaml_file, parse_yaml_file_as
from rich.pretty import pprint as rpprint
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.types import Tensor
Expand All @@ -23,9 +25,7 @@
save_tokenizer,
)

config = ScratchGPTConfig()

torch.manual_seed(config.training.random_seed)
DatasetType = tuple[Tensor, Tensor]


def parse_args() -> argparse.Namespace:
Expand All @@ -38,14 +38,14 @@ def parse_args() -> argparse.Namespace:
"--train_source",
help="The file you want to train on",
required=True,
type=str,
type=Path,
)
parser.add_argument(
"-e",
"--experiment",
help="The path to the folder where to save experiment checkpoints",
required=True,
type=str,
type=Path,
)
parser.add_argument(
"-d",
Expand All @@ -57,7 +57,18 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()


DatasetType = tuple[Tensor, Tensor]
def load_or_create_config(experiment_path: Path) -> ScratchGPTConfig:
"""
Load config from experiment folder if it exists, otherwise create default.
"""
config_path: Path = experiment_path / "scratch_gpt.yaml"

if config_path.exists():
print(f"Loading existing config from {config_path}")
return parse_yaml_file_as(ScratchGPTConfig, config_path)
else:
print("No existing config found, creating default config")
return ScratchGPTConfig()


def run_epoch(
Expand Down Expand Up @@ -111,23 +122,28 @@ def run_epoch(
return average_loss.value()


def get_text_provider(path: str) -> TextProvider:
if os.path.isdir(path):
def get_text_provider(path: Path) -> TextProvider:
if path.is_dir():
return FolderTextProvider(path)
return FileTextProvider(path)


def main() -> None:
args = parse_args()

config = load_or_create_config(args.experiment)

torch.manual_seed(config.training.random_seed)
print(f"Set random seed to: {config.training.random_seed}")

device = torch.device(args.device)
print(f"Using the device: {device}")

text_provider = get_text_provider(args.train_source)

tokenizer = get_tokenizer(args.experiment)
config.architecture.vocab_size = tokenizer.vocab_size
print(config)
rpprint(config.model_dump(), indent_guides=True, expand_all=True)

train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9)
val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1)
Expand Down Expand Up @@ -211,4 +227,4 @@ def main() -> None:


if __name__ == "__main__":
main()
main()
19 changes: 10 additions & 9 deletions scratchgpt/model_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import pickle

import torch
Expand All @@ -13,19 +14,19 @@ class ModelLoadFailed(Exception):
pass


def get_best_model_weights_path(exp_folder: str) -> str:
return os.path.join(exp_folder, "best_model_weights.pth")
def get_best_model_weights_path(exp_folder: Path) -> Path:
return exp_folder / "best_model_weights.pth"


def get_latest_model_weights_path(exp_folder: str) -> str:
return os.path.join(exp_folder, "latest_model_weights.pth")
def get_latest_model_weights_path(exp_folder: Path) -> Path:
return exp_folder / "latest_model_weights.pth"


def get_tokenizer_path(exp_folder: str) -> str:
return os.path.join(exp_folder, "tokenizer.pkl")
def get_tokenizer_path(exp_folder: Path) -> Path:
return exp_folder / "tokenizer.pkl"


def load_model(model_path: str, model: TransformerLanguageModel, device: torch.device) -> TransformerLanguageModel:
def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.device) -> TransformerLanguageModel:
model.to(device)
if os.path.exists(model_path):
try:
Expand All @@ -39,7 +40,7 @@ def load_model(model_path: str, model: TransformerLanguageModel, device: torch.d
return model


def get_tokenizer(exp_path: str) -> Tokenizer:
def get_tokenizer(exp_path: Path) -> Tokenizer:
tokenizer_path = get_tokenizer_path(exp_path)
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "rb") as f:
Expand All @@ -49,7 +50,7 @@ def get_tokenizer(exp_path: str) -> Tokenizer:
return tokenizer


def save_tokenizer(exp_path: str, tokenizer: Tokenizer) -> None:
def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None:
tokenizer_path = get_tokenizer_path(exp_path)
with open(tokenizer_path, "wb") as f:
pickle.dump(tokenizer, f)
Expand Down