diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index ddc6036..7d1573e 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -21,8 +21,10 @@ def __init__(self, file_path: Path) -> None: raise ValueError(f"File path {file_path} does not exist") self._data = "" + print(f"Loading data from {file_path}") with open(file_path) as f: self._data = f.read() + print("Data Loaded") @override def get_text(self) -> str: @@ -38,12 +40,19 @@ def __init__(self, dir_path: Path) -> None: raise ValueError(f"Directory path {dir_path} is not a directory") self._data = "" - for file_path in dir_path.rglob("*"): # Recursively find all files - print(f"Loading data from {file_path}") + print(f"Loading data from {dir_path}") + total_read: int = 0 + for idx, file_path in enumerate(dir_path.rglob("*")): if file_path.is_file() and not file_path.name.startswith("."): with open(file_path, encoding="utf-8") as f: self._data += f.read() + "\n" + if idx % 500 == 1: + total_read += 500 + print(f"Read {total_read} files") + + print("Data Loaded") + @override def get_text(self) -> str: return self._data diff --git a/scratchgpt/main.py b/scratchgpt/main.py index b12ab2e..e97d29c 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -7,6 +7,7 @@ import torch from pydantic_yaml import parse_yaml_file_as, to_yaml_file from rich.pretty import pprint as rpprint +from torch.nn import functional as F from torch.optim.adamw import AdamW from torch.optim.optimizer import Optimizer from torch.types import Tensor @@ -106,10 +107,16 @@ def run_epoch( if is_train and optimizer is not None: optimizer.zero_grad(set_to_none=True) - logits, loss = model(batch, targets) + logits = model(batch) + + B, T, C = logits.shape + logits = logits.view(B * T, C) + targets = targets.view(B * T) + + loss: Tensor = F.cross_entropy(logits, targets) if is_train and optimizer is not None: - loss.backward() + loss.backward() # type: ignore[no-untyped-call] optimizer.step() average_loss.add(loss.item()) @@ -148,6 +155,7 @@ def main() -> None: train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9) val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1) + print("Loading train and validation loaders") cpu_count = os.cpu_count() or 4 train_dataloader = DataLoader( train_dataset, @@ -165,6 +173,8 @@ def main() -> None: shuffle=False, ) + print("Loaders initialized") + best_model_path = get_best_model_weights_path(args.experiment) latest_model_path = get_latest_model_weights_path(args.experiment) diff --git a/scratchgpt/model/model.py b/scratchgpt/model/model.py index 2e1f7b0..0ea227e 100644 --- a/scratchgpt/model/model.py +++ b/scratchgpt/model/model.py @@ -148,7 +148,7 @@ def __init__( self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size) self._device = device - def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tensor, Tensor]: + def forward(self, context: Tensor) -> Tensor: context = context.long() B, T = context.shape @@ -157,22 +157,13 @@ def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tenso x = tok_emb + pos_emb # B, T, C x = self._blocks(x) x = self._block_norm(x) - logits = self._lm_head(x) # (B, T, vocab_size) - - if targets is None: - loss = torch.empty(0) - else: - B, T, C = logits.shape - logits = logits.view(B * T, C) - targets = targets.view(B * T) - loss = F.cross_entropy(logits, targets) - - return logits, loss + logits: Tensor = self._lm_head(x) # (B, T, vocab_size) + return logits def generate(self, context: Tensor, max_new_tokens: int) -> Tensor: for _ in range(max_new_tokens): cropped_context = context[:, -self._block_size :] - logits, _loss = self(cropped_context) + logits = self(cropped_context) logits = logits[:, -1, :] # becomes (B, C) probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) @@ -208,5 +199,6 @@ def input_constructor(input_shape: Any) -> Tensor: ) print(f" FLOPs per forward pass: {flops:,}") + print(f" Params: {params}") print("=========================")