Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
with:
version: "latest"
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- run: uv sync --group dev
- run: uv run ruff check .
- run: uv run mypy .
6 changes: 0 additions & 6 deletions main.py

This file was deleted.

63 changes: 35 additions & 28 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,51 @@ dependencies = [
[dependency-groups]
dev = [
"bandit>=1.8.6",
"black>=25.1.0",
"isort>=6.0.1",
"mypy>=1.17.1",
"pylint>=3.3.8",
"pytest>=8.4.1",
"ruff>=0.1.0",
]

[tool.isort]
profile = "black"
line_length = 120
force_sort_within_sections = true
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
default_section = "THIRDPARTY"
skip_glob = [".venv"]

[tool.pylint."MESSAGES CONTROL"]
disable = ["missing-module-docstring", "missing-class-docstring"]
extension-pkg-whitelist = "pydantic"

[tool.pylint.REPORTS]
output-format = "parseable"
[tool.pytest.ini_options]
asyncio_mode = "auto"

[tool.pylint.FORMAT]
max-line-length = 120
[tool.mypy]
python_version = "3.12"
warn_unused_configs = true
files = ["scratchgpt/"]
ignore_missing_imports = false
check_untyped_defs = true
explicit_package_bases = true
warn_unreachable = true
warn_redundant_casts = true
strict = true
exclude = [".venv"]

[tool.pylint.DESIGN]
max-args = 10
max-attributes = 10
[[tool.mypy.overrides]]
module = ["ptflops"]
ignore_missing_imports = true

[tool.black]
[tool.ruff]
line-length = 120
target-version = ['py312']
target-version = "py312"

[tool.pytest.ini_options]
asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"F", # Pyflakes
"E", # pycodestyle errors
"W", # pycodestyle warnings
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
"B", # flake8-bugbear (catches common bugs)
"C4", # flake8-comprehensions
"PIE", # flake8-pie
"SIM", # flake8-simplify
]
ignore = ["N812", "N806"]

[tool.mypy]
python_version = "3.12"
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"] # Allow unused imports in __init__.py

[build-system]
requires = ["hatchling"]
Expand Down
13 changes: 5 additions & 8 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
import os
from pathlib import Path
from typing import Literal, override

Expand All @@ -12,7 +11,7 @@

class TextProvider(ABC):
@abstractmethod
def get_text(self) -> Path:
def get_text(self) -> str:
"""This method fetches the text from the underlying storage"""


Expand All @@ -22,7 +21,7 @@ def __init__(self, file_path: Path) -> None:
raise ValueError(f"File path {file_path} does not exist")

self._data = ""
with open(file_path, "r") as f:
with open(file_path) as f:
self._data = f.read()

@override
Expand All @@ -40,12 +39,10 @@ def __init__(self, dir_path: Path) -> None:

self._data = ""
for file_path in dir_path.rglob("*"): # Recursively find all files
print(f"Loading data from {file_path}")
if file_path.is_file() and not file_path.name.startswith("."):
print(f"Loading data from {file_path.parent}")
with open(file_path, "r", encoding="utf-8") as f:
self._data += f.read() + "\n" # Concatenate with a
# newline between files, could be the place to add
# special tokens
with open(file_path, encoding="utf-8") as f:
self._data += f.read() + "\n"

@override
def get_text(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion scratchgpt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from scratchgpt.config import ScratchGPTConfig

from .main import TransformerLanguageModel
from .model.model import TransformerLanguageModel
from .model_io import get_best_model_weights_path, get_tokenizer, load_model


Expand Down
6 changes: 3 additions & 3 deletions scratchgpt/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import os
from pathlib import Path
import sys
from pathlib import Path
from typing import Literal

import torch
from pydantic_yaml import to_yaml_file, parse_yaml_file_as
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
from rich.pretty import pprint as rpprint
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -227,4 +227,4 @@ def main() -> None:


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion scratchgpt/metering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def value(self) -> tuple[float, float]:

class AverageValueMeter(Meter):
def __init__(self) -> None:
super(AverageValueMeter, self).__init__()
super().__init__()
self.reset()
self.val: float = 0.0

Expand Down
7 changes: 3 additions & 4 deletions scratchgpt/model/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
from typing import Any

from ptflops import get_model_complexity_info
import torch
from ptflops import get_model_complexity_info
from torch import Tensor, nn
from torch.nn import functional as F

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self._dropout = nn.Dropout(dropout_rate)

def forward(self, context: Tensor) -> Tensor:
out = torch.cat([head(context) for head in self._heads], dim=-1)
out: Tensor = torch.cat([head(context) for head in self._heads], dim=-1)
out = self._proj(out)
out = self._dropout(out)
return out
Expand Down Expand Up @@ -202,12 +202,11 @@ def input_constructor(input_shape: Any) -> Tensor:
flops, params = get_model_complexity_info(
model,
input_shape,
input_constructor=input_constructor, # type: ignore
input_constructor=input_constructor,
print_per_layer_stat=False,
as_strings=False,
)

print(f" FLOPs per forward pass: {flops:,}")
print(f"GFLOPs per forward pass: {flops / 1e9:.2f}")

print("=========================")
8 changes: 4 additions & 4 deletions scratchgpt/model_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
import pickle
from pathlib import Path

import torch

Expand All @@ -10,7 +10,7 @@
from .tokenizer.tiktoken import TiktokenWrapper


class ModelLoadFailed(Exception):
class ModelLoadFailedError(Exception):
pass


Expand All @@ -33,8 +33,8 @@ def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.
print(f"Loading weights from: {model_path}")
model_dict = torch.load(model_path, map_location=device)
model.load_state_dict(model_dict)
except Exception:
raise ModelLoadFailed(model_path)
except Exception as error:
raise ModelLoadFailedError(model_path) from error
else:
print("No saved model, starting from scratch...gpt, lol!")
return model
Expand Down
6 changes: 3 additions & 3 deletions scratchgpt/tokenizer/char_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def get_vocab(text: str) -> list[str]:
chars = sorted(list(set(text)))
chars = sorted(set(text))
return chars


Expand All @@ -13,7 +13,7 @@ def str_to_int(chars: list[str]) -> dict[str, int]:


def int_to_str(chars: list[str]) -> dict[int, str]:
return {idx: char for idx, char in enumerate(chars)}
return dict(enumerate(chars))


class CharTokenizer(Tokenizer):
Expand Down Expand Up @@ -48,7 +48,7 @@ def decode(
class Utf8Tokenizer(Tokenizer):

def __init__(self) -> None:
self._vocabulary = list(range(0, 256))
self._vocabulary = list(range(256))

@property
@override
Expand Down
2 changes: 1 addition & 1 deletion scratchgpt/tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def vocabulary(self) -> list[str]:
"""Return the learned vocabulary"""
return list(self._get_cached_vocabulary())

@lru_cache(maxsize=1)
@lru_cache(maxsize=1) # noqa: B019
def _get_cached_vocabulary(self) -> tuple[str, ...]:
"""
Cache and return the vocabulary as a tuple.
Expand Down
30 changes: 15 additions & 15 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from scratchgpt.tokenizer.char_tokenizer import CharTokenizer, Utf8Tokenizer


test_data = [
"Привет, как дела? 😊",
"Я люблю читать книги.",
"Москва - это красивый город.",
"안녕하세요 👋",
"나는 당신을 만나서 행복해요 😊",
"서울은 아름다운 도시입니다.",
"Ciao, come stai? 😊",
"Amo leggere libri.",
"Roma è una città bellissima.",
"Hello, how are you? 👋",
"I love to read books.",
"New York City is a bustling metropolis 🗽️"
"Привет, как дела? 😊",
"Я люблю читать книги.",
"Москва - это красивый город.",
"안녕하세요 👋",
"나는 당신을 만나서 행복해요 😊",
"서울은 아름다운 도시입니다.",
"Ciao, come stai? 😊",
"Amo leggere libri.",
"Roma è una città bellissima.",
"Hello, how are you? 👋",
"I love to read books.",
"New York City is a bustling metropolis 🗽️",
]

def test_basic_char_tokenizer():

def test_basic_char_tokenizer() -> None:
test_corpus = "".join(test_data)
tokenizer = CharTokenizer(test_corpus)

Expand All @@ -27,7 +27,7 @@ def test_basic_char_tokenizer():
assert test_sample == decoded, "Oh no, this thing failed"


def test_utf8_tokenizer():
def test_utf8_tokenizer() -> None:
tokenizer = Utf8Tokenizer()

for test_sample in test_data:
Expand Down
Loading