diff --git a/scratchgpt/main.py b/scratchgpt/main.py index a338fbc..bc28ef7 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -5,11 +5,13 @@ from typing import Literal import torch +from ptflops import get_model_complexity_info from torch import Tensor, nn from torch.nn import functional as F from torch.optim.adamw import AdamW from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader + from tqdm import tqdm from .dataloader import FileTextProvider, FolderTextProvider, TextDataset, TextProvider @@ -51,15 +53,32 @@ def parse_args() -> argparse.Namespace: def print_model_complexity(model: nn.Module) -> None: - """ - Helper function to report the complexity of the model - """ + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("=== MODEL COMPLEXITY ===") + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB (float32)") + input_shape = (BLOCK_SIZE,) - flops, params = 0, 0 # get_model_complexity_info(model, input_shape, print_per_layer_stat=True, as_strings=True) + def input_constructor(input_shape): + return torch.randint(0, model._token_embedding_table.num_embeddings, + (1,) + input_shape, device=DEVICE) + + flops, params = get_model_complexity_info( + model, + input_shape, + input_constructor=input_constructor, + print_per_layer_stat=False, + as_strings=False + ) + + print(f" FLOPs per forward pass: {flops:,}") + print(f"GFLOPs per forward pass: {flops / 1e9:.2f}") - print(flops) - print(params) + print("=========================") class Head(nn.Module):