In [3]:
from dataclasses import dataclass
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 32000
    hidden_dim: Optional[int] = None
    multiple_of: int = 256
    norm_eps: float = 1e-5
    max_seq_len: int = 2048
    dropout: float = 0.0

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional

class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size

        # Convert input token indices into dense vector representations
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)

        # Add transformer blocks here
        ...

        # Convert the final hidden state of the model back into a distribution over the vocabulary
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

        # Weight Tying: using the same weight matrix to reduce complexity
        self.tok_embeddings.weight = self.output.weight

        # Precompute positional embeddings
        self.freqs_cos, self.freqs_sin = ...

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None):
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)
        for layer in self.layers:
            h = layer(h, self.freqs_cos[:seqlen], self.freqs_sin[:seqlen])
        h = self.norm(h)

        if targets is not None: # training-stage
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else: # inference-stage: only select the hidden state of the last token in each sequence
            logits = self.output(h[:, [-1], :])
            self.last_loss = None

        return logits

In [5]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        ...
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(TransformerBlock(layer_id, args))
        
        # Normalizes the input to the attention and feed-forward layers.
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        ...

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout)
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [6]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        # QKV projections
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)

        # Final projection into the residual stream
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
      
        # Create a mask for causal attention
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        self.mask = torch.triu(mask, diagonal=1)
        ...

    def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
  
        # reorganize dimensions and apply relative positional embeddings to update xq, xk using freqs_cos, freqs_sin
        ...

        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = scores + self.mask[:, :, :seqlen, :seqlen] 
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, xv)
        output = self.wo(output)
        return output

In [7]:
class ModelArgs:
    dim: int = 288
    n_layers: int = 6
    n_heads: int = 6
    n_kv_heads: Optional[int] = 6
    vocab_size: int = 2048
    hidden_dim: Optional[int] = None
    multiple_of: int = 32
    norm_eps: float = 1e-5
    max_seq_len: int = 256
    dropout: float = 0.1

In [10]:
from utils import download_TinyStories
download_TinyStories(data_dir="demo_data")

Downloading https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz to demo_data/TinyStories_all_data.tar.gz...


demo_data/TinyStories_all_data.tar.gz: 100%|█| 1.50G/1.50G [00:36<00:00, 43.7MiB


Unpacking demo_data/TinyStories_all_data.tar.gz...
Download done.


In [11]:
import glob
import json
import os
from tqdm import tqdm
import sentencepiece as spm

def train_vocab(vocab_size, data_dir, dataset_name="TinyStories_all_data"):
    """
    Trains a custom sentencepiece tokenizer on the TinyStories dataset.
    It produces a file saved in "tok{vocab_size}" under the data_dir directory.
    """

    assert vocab_size > 0, "Vocab size must be positive"
    prefix = os.path.join(data_dir, f"tok{vocab_size}") #output file prefix
    
    # Export a number of shards into a text file for vocab training. The lower the more efficiency
    num_shards = 10
    temp_file = os.path.join(data_dir, "temp.txt")
    data_dir = os.path.join(data_dir, dataset_name)
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

    print(f"Writing temporary file {temp_file} with {num_shards} shards...")
    with open(temp_file, "w", encoding="utf-8") as of:
        for shard in tqdm(shard_filenames[:num_shards]):
            with open(shard, "r") as f:
                data = json.load(f)
            for example in data:
                text = example["story"]
                text = text.strip()
                of.write(text + "\n")
    print(f"Size: {os.path.getsize(temp_file) / 1024 / 1024:.2f} MB")
    print("Train the sentencepiece model ...")
    spm.SentencePieceTrainer.train(input=temp_file, model_prefix=prefix, model_type="bpe", vocab_size=vocab_size, split_digits=False, allow_whitespace_only_pieces=True, byte_fallback=True, unk_surface=r" \342\201\207 ", normalization_rule_name="identity")
    os.remove(temp_file)
    tokenizer_model = f"{prefix}.model"
    print(f"Trained tokenizer is in {tokenizer_model}")
    return tokenizer_model

tokenizer_model = train_vocab(2048, "demo_data", "TinyStories_all_data")

Writing temporary file demo_data/temp.txt with 10 shards...


100%|███████████████████████████████████████████| 10/10 [00:04<00:00,  2.22it/s]
sentencepiece_trainer.cc(78

Size: 739.57 MB
Train the sentencepiece model ...
Trained tokenizer is in demo_data/tok2048.model


) LOG(INFO) Starts training with : 
trainer_spec {
  input: demo_data/temp.txt
  input_format: 
  model_prefix: demo_data/tok2048
  model_type: BPE
  vocab_size: 2048
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 1
  required_chars: 
  byte_fallback: 1
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  \342\201\207 
  enable_differential_privacy: 0
  differenti

In [12]:
from tokenizer import pretokenize
output_bin_dir = pretokenize(data_dir="demo_data", dataset_name="TinyStories_all_data", tokenizer_model=tokenizer_model)
# This will create a directory called TinyStories_all_data_pretok under data_dir

#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2



  0%|          | 0/100000 [00:00<?, ?it/s][A
  1%|          | 666/100000 [00:00<00:14, 6655.59it/s][A
  1%|▏         | 1337/100000 [00:00<00:14, 6684.53it/s][A
  2%|▏         | 2016/100000 [00:00<00:14, 6730.06it/s][A
  3%|▎         | 2690/100000 [00:00<00:14, 6721.81it/s][A
  3%|▎         | 3365/100000 [00:00<00:14, 6729.32it/s][A
  4%|▍         | 4038/100000 [00:00<00:14, 6723.05it/s][A
  5%|▍         | 4711/100000 [00:00<00:14, 6664.72it/s][A
  5%|▌         | 5378/100000 [00:00<00:14, 6551.10it/s][A
  6%|▌         | 6038/100000 [00:00<00:14, 6562.19it/s][A
  7%|▋         | 6713/100000 [00:01<00:14, 6617.69it/s][A
  7%|▋         | 7382/100000 [00:01<00:13, 6639.07it/s][A
  8%|▊         | 8053/100000 [00:01<00:13, 6658.93it/s][A
  9%|▊         | 8720/100000 [00:01<00:13, 6646.69it/s][A
  9%|▉         | 9385/100000 [00:01<00:13, 6639.72it/s][A
 10%|█         | 10050/100000 [00:01<00:13, 6634.32it/s][A
 11%|█▏        | 11392/100000 [00:01<00:13, 6637.74it/s][A
 12%|█▏ 

Saved demo_data/TinyStories_all_data_pretok/data00.bin, average seqlen: 213.85
Saved demo_data/TinyStories_all_data_pretok/data01.bin, average seqlen: 213.54
#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2




  0%|          | 0/100000 [00:00<?, ?it/s][A[A


  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A

  1%|          | 668/100000 [00:00<00:14, 6678.84it/s][A[A


  1%|          | 657/100000 [00:00<00:15, 6566.70it/s][A[A[A

  1%|▏         | 1336/100000 [00:00<00:14, 6653.19it/s][A[A


  1%|▏         | 1314/100000 [00:00<00:15, 6566.56it/s][A[A[A

  2%|▏         | 2015/100000 [00:00<00:14, 6712.92it/s][A[A


  2%|▏         | 1992/100000 [00:00<00:14, 6662.96it/s][A[A[A

  3%|▎         | 2687/100000 [00:00<00:14, 6700.53it/s][A[A


  3%|▎         | 2666/100000 [00:00<00:14, 6687.84it/s][A[A[A

  3%|▎         | 3358/100000 [00:00<00:14, 6680.38it/s][A[A


  3%|▎         | 3352/100000 [00:00<00:14, 6748.18it/s][A[A[A

  4%|▍         | 4027/100000 [00:00<00:15, 6364.84it/s][A[A


  4%|▍         | 4027/100000 [00:00<00:15, 6355.53it/s][A[A[A

  5%|▍         | 4667/100000 [00:00<00:15, 6184.72it/s][A[A


  5%|▍         | 4667/100000 [00:00<00:15, 6194.29it

Saved demo_data/TinyStories_all_data_pretok/data02.bin, average seqlen: 214.08
Saved demo_data/TinyStories_all_data_pretok/data03.bin, average seqlen: 213.70
#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2






  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A




  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A[A



  1%|          | 671/100000 [00:00<00:14, 6705.05it/s][A[A[A[A




  1%|          | 680/100000 [00:00<00:14, 6798.11it/s][A[A[A[A[A



  1%|▏         | 1342/100000 [00:00<00:14, 6664.15it/s][A[A[A[A




  1%|▏         | 1360/100000 [00:00<00:14, 6773.37it/s][A[A[A[A[A



  2%|▏         | 2016/100000 [00:00<00:14, 6697.23it/s][A[A[A[A




  2%|▏         | 2042/100000 [00:00<00:14, 6788.93it/s][A[A[A[A[A



  3%|▎         | 2686/100000 [00:00<00:14, 6594.54it/s][A[A[A[A




  3%|▎         | 2721/100000 [00:00<00:14, 6622.94it/s][A[A[A[A[A



  3%|▎         | 3346/100000 [00:00<00:15, 6410.87it/s][A[A[A[A




  3%|▎         | 3384/100000 [00:00<00:14, 6467.84it/s][A[A[A[A[A



  4%|▍         | 4015/100000 [00:00<00:14, 6503.18it/s][A[A[A[A




  4%|▍         | 4034/100000 [00:00<00:14, 6476.40it/s][A[A[A[A[A



  

Saved demo_data/TinyStories_all_data_pretok/data04.bin, average seqlen: 213.78
Saved demo_data/TinyStories_all_data_pretok/data05.bin, average seqlen: 213.83
#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2








  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A[A[A






  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A[A[A[A





  1%|          | 677/100000 [00:00<00:14, 6765.20it/s][A[A[A[A[A[A






  1%|          | 661/100000 [00:00<00:15, 6603.08it/s][A[A[A[A[A[A[A





  1%|▏         | 1354/100000 [00:00<00:14, 6666.06it/s][A[A[A[A[A[A






  1%|▏         | 1322/100000 [00:00<00:14, 6579.00it/s][A[A[A[A[A[A[A





  2%|▏         | 2021/100000 [00:00<00:14, 6656.87it/s][A[A[A[A[A[A






  2%|▏         | 1986/100000 [00:00<00:14, 6604.34it/s][A[A[A[A[A[A[A





  3%|▎         | 2687/100000 [00:00<00:14, 6600.37it/s][A[A[A[A[A[A






  3%|▎         | 2648/100000 [00:00<00:14, 6610.33it/s][A[A[A[A[A[A[A





  3%|▎         | 3359/100000 [00:00<00:14, 6642.72it/s][A[A[A[A[A[A






  3%|▎         | 3316/100000 [00:00<00:14, 6634.13it/s][A[A[A[A[A[A[A





  4%|▍         | 4024/100000 [00:00<00:14, 6551.53

Saved demo_data/TinyStories_all_data_pretok/data06.bin, average seqlen: 213.68
Saved demo_data/TinyStories_all_data_pretok/data07.bin, average seqlen: 213.65
#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2










  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A








  0%|          | 0/100000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A







  1%|          | 673/100000 [00:00<00:14, 6721.32it/s][A[A[A[A[A[A[A[A








  1%|          | 664/100000 [00:00<00:14, 6636.34it/s][A[A[A[A[A[A[A[A[A







  1%|▏         | 1346/100000 [00:00<00:14, 6711.56it/s][A[A[A[A[A[A[A[A








  1%|▏         | 1332/100000 [00:00<00:14, 6656.49it/s][A[A[A[A[A[A[A[A[A







  2%|▏         | 2018/100000 [00:00<00:14, 6673.99it/s][A[A[A[A[A[A[A[A








  2%|▏         | 2000/100000 [00:00<00:14, 6664.82it/s][A[A[A[A[A[A[A[A[A







  3%|▎         | 2688/100000 [00:00<00:14, 6681.35it/s][A[A[A[A[A[A[A[A








  3%|▎         | 2674/100000 [00:00<00:14, 6693.29it/s][A[A[A[A[A[A[A[A[A







  3%|▎         | 3357/100000 [00:00<00:14, 6674.48it/s][A[A[A[A[A[A[A[A








  3%|▎         | 3344/100000 [00:00<00:14,

#words: 2048 - BOS ID: 1 - EOS ID: 2
#words: 2048 - BOS ID: 1 - EOS ID: 2


 12%|█▏        | 11958/100000 [00:01<00:13, 6601.68it/s]
 12%|█▏        | 12181/100000 [00:01<00:13, 6670.76it/s]


KeyboardInterrupt: 

In [None]:
from demo_training_config import Config
config = Config()
config_dict = config.to_dict()

# Load configuration parameters
pretok_bin_dir = config.pretok_bin_dir
model_out_dir = config.model_out_dir
...

In [None]:
from model import Transformer, ModelArgs
from tokenizer import BatchProcessor
import math
import os
import time
from contextlib import nullcontext
from functools import partial
import torch

# Set up mixed precision
ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

# Create batches using partial function
iter_batches = partial(BatchProcessor.iter_batches, batch_size=config.batch_size, device=config.device, ...)

# Initialize model and optimizer
model_args = ModelArgs(dim=config.dim, n_layers=config.n_layers, n_heads=config.n_heads, ...)
model = Transformer(model_args)
model.to(config.device)
optimizer = model.configure_optimizers(config.weight_decay, config.learning_rate, ...)

# Function to estimate loss on validation data
@torch.no_grad()
def estimate_loss():
    model.eval()
    # Evaluate loss...
    model.train()
    return train_loss, val_loss

# Training loop
train_batch_iter = iter_batches(split="train")
while iter_num <= config.max_iters:
    # Adjust learning rate and evaluate periodically
    if iter_num % config.eval_interval == 0:
        losses = estimate_loss()
        # Save checkpoints ...
    # Forward and backward pass with optional gradient accumulation
    for micro_step in range(config.gradient_accumulation_steps):
        with ctx:
            logits = model(X, Y)
            loss = model.last_loss / gradient_accumulation_steps
        # Fetch next batch and backpropagate
        X, Y = next(train_batch_iter)
        scaler.scale(loss).backward()

    # Optimizer step and update
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)