Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def __init__(self, file_path: Path) -> None:
raise ValueError(f"File path {file_path} does not exist")

self._data = ""
print(f"Loading data from {file_path}")
with open(file_path) as f:
self._data = f.read()
print("Data Loaded")

@override
def get_text(self) -> str:
Expand All @@ -38,12 +40,19 @@ def __init__(self, dir_path: Path) -> None:
raise ValueError(f"Directory path {dir_path} is not a directory")

self._data = ""
for file_path in dir_path.rglob("*"): # Recursively find all files
print(f"Loading data from {file_path}")
print(f"Loading data from {dir_path}")
total_read: int = 0
for idx, file_path in enumerate(dir_path.rglob("*")):
if file_path.is_file() and not file_path.name.startswith("."):
with open(file_path, encoding="utf-8") as f:
self._data += f.read() + "\n"

if idx % 500 == 1:
total_read += 500
print(f"Read {total_read} files")

print("Data Loaded")

@override
def get_text(self) -> str:
return self._data
Expand Down
14 changes: 12 additions & 2 deletions scratchgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
from rich.pretty import pprint as rpprint
from torch.nn import functional as F
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.types import Tensor
Expand Down Expand Up @@ -106,10 +107,16 @@ def run_epoch(
if is_train and optimizer is not None:
optimizer.zero_grad(set_to_none=True)

logits, loss = model(batch, targets)
logits = model(batch)

B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)

loss: Tensor = F.cross_entropy(logits, targets)

if is_train and optimizer is not None:
loss.backward()
loss.backward() # type: ignore[no-untyped-call]
optimizer.step()

average_loss.add(loss.item())
Expand Down Expand Up @@ -148,6 +155,7 @@ def main() -> None:
train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9)
val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1)

print("Loading train and validation loaders")
cpu_count = os.cpu_count() or 4
train_dataloader = DataLoader(
train_dataset,
Expand All @@ -165,6 +173,8 @@ def main() -> None:
shuffle=False,
)

print("Loaders initialized")

best_model_path = get_best_model_weights_path(args.experiment)
latest_model_path = get_latest_model_weights_path(args.experiment)

Expand Down
18 changes: 5 additions & 13 deletions scratchgpt/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size)
self._device = device

def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tensor, Tensor]:
def forward(self, context: Tensor) -> Tensor:
context = context.long()
B, T = context.shape

Expand All @@ -157,22 +157,13 @@ def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tenso
x = tok_emb + pos_emb # B, T, C
x = self._blocks(x)
x = self._block_norm(x)
logits = self._lm_head(x) # (B, T, vocab_size)

if targets is None:
loss = torch.empty(0)
else:
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)

return logits, loss
logits: Tensor = self._lm_head(x) # (B, T, vocab_size)
return logits

def generate(self, context: Tensor, max_new_tokens: int) -> Tensor:
for _ in range(max_new_tokens):
cropped_context = context[:, -self._block_size :]
logits, _loss = self(cropped_context)
logits = self(cropped_context)
logits = logits[:, -1, :] # becomes (B, C)
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
Expand Down Expand Up @@ -208,5 +199,6 @@ def input_constructor(input_shape: Any) -> Tensor:
)

print(f" FLOPs per forward pass: {flops:,}")
print(f" Params: {params}")

print("=========================")