In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

In [2]:
class DyT(nn.Module):
    def __init__(self, dims, init_alpha=0.5, **kwargs):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.beta = nn.Parameter(torch.zeros(dims))
        self.gamma = nn.Parameter(torch.ones(dims))

    def forward(self, x):
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

In [7]:
class SwiGLU(nn.Module):
    def __init__(self, dim, scale=0.66, **kwargs):
        super().__init__()
        self.expansion_dim = int(dim * scale)
        self.l1 = nn.Linear(dim, self.expansion_dim, bias=False)
        self.l2 = nn.Linear(dim, self.expansion_dim, bias=False)
        self.l3 = nn.Linear(self.expansion_dim, dim, bias=False)

    def forward(self, x):
        t1 = self.l1(x)
        t2 = self.l2(x)
        swisht1 = t1 * torch.sigmoid(t1)
        return self.l3(swisht1 * t2)

In [10]:
summary(SwiGLU(768), (512, 768), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1             [-1, 512, 506]         388,608
            Linear-2             [-1, 512, 506]         388,608
            Linear-3             [-1, 512, 768]         388,608
Total params: 1,165,824
Trainable params: 1,165,824
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.50
Forward/backward pass size (MB): 6.95
Params size (MB): 4.45
Estimated Total Size (MB): 12.90
----------------------------------------------------------------


In [52]:
class Encoder_Block(nn.Module): ## prenorm support
    def __init__(self, dim=768, num_heads=12, d_mha=0.1, d_ff=0.1, d_res=0.1, prenorm=True, **kwargs):
        super().__init__()
        ff_hidden_dim=dim*4
        self.prenorm = prenorm ## BERT uses postnorm

        self.mha = nn.MultiheadAttention(dim, num_heads, d_mha, batch_first=True)
        self.norm1 = DyT(dim, **kwargs)

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, dim),
            nn.Dropout(d_ff)
        )

        self.norm2 = DyT(dim, **kwargs)
        self.dropout = nn.Dropout(d_res)

    def forward(self, x, pad_mask=None):
        if self.prenorm:
            x = self.norm1(x)

        attn_output, _ = self.mha(x, x, x, key_padding_mask=pad_mask)
        x = x + self.dropout(attn_output)

        if not self.prenorm:
            x = self.norm1(x)

        if self.prenorm:
            x = self.norm2(x)

        ff_output = self.ff(x)
        x = x + self.dropout(ff_output)

        if not self.prenorm:
            x = self.norm2(x)

        return x

In [42]:
class Encoder(nn.Module):
    def __init__(self, num_layers=12, dim=768, prenorm=True, **kwargs):
        super().__init__()

        self.layers = nn.ModuleList([
            Encoder_Block(dim, prenorm=prenorm, **kwargs) for _ in range(num_layers)
        ])

        self.norm = DyT(dim)
        self.prenorm = prenorm

    def forward(self, x, pad_mask=None):
        for block in self.layers:
            x = block(x, pad_mask)
        x = self.norm(x)

        return x

In [43]:
summary(Encoder(num_layers=2).to('cuda'), (512, 768), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
               DyT-1             [-1, 512, 768]               0
MultiheadAttention-2  [[-1, 512, 768], [-1, 512, 512]]               0
           Dropout-3             [-1, 512, 768]               0
               DyT-4             [-1, 512, 768]               0
            Linear-5            [-1, 512, 3072]       2,362,368
              ReLU-6            [-1, 512, 3072]               0
            Linear-7             [-1, 512, 768]       2,360,064
           Dropout-8             [-1, 512, 768]               0
           Dropout-9             [-1, 512, 768]               0
    Encoder_Block-10             [-1, 512, 768]               0
              DyT-11             [-1, 512, 768]               0
MultiheadAttention-12  [[-1, 512, 768], [-1, 512, 512]]               0
          Dropout-13             [-1, 512, 768]               0
              DyT-14    

In [48]:
class MyBERT(nn.Module):
    def __init__(self, vocab_size, seq_len=512, dim=768, **kwargs):
        super().__init__()
        self.encoder = Encoder(dim=dim,**kwargs)
        self.token_embeddings = nn.Embedding(vocab_size, dim)
        self.segment_embeddings = nn.Embedding(2, dim)
        self.positional_embeddings = nn.Embedding(seq_len, dim)
        self.register_buffer("position_ids", torch.arange(seq_len).unsqueeze(0))
        
    def forward(self, x, segment, pad_mask=None):
        batch_size, seq_len = x.shape
        position_ids = self.position_ids.expand(batch_size, seq_len)
        x = self.token_embeddings(x) + self.segment_embeddings(segment) + self.positional_embeddings(position_ids)

        x = self.encoder(x, pad_mask)

        return x

In [54]:
bert = MyBERT(1000, dim=768, num_layers=4)

In [55]:
batch_size = 2
seq_len = 512

# Random input IDs between 0 and 999
x = torch.randint(0, 1000, (batch_size, seq_len))

# Dummy segment IDs (all 0s = single segment for now)
segment = torch.zeros_like(x, dtype=torch.long)

# Optional padding mask (1 = real token, 0 = pad)
pad_mask = torch.ones_like(x, dtype=torch.bool)  # full attention for now

# Forward pass
output = bert(x, segment, pad_mask)
print(output.shape)  # should be [2, 16, 768]

torch.Size([2, 512, 768])


## Pretraining work

In [None]:
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

"""
Script that processes the Project Gutenberg files into fewer larger files.
"""

import argparse
import os
import re
from tqdm import tqdm
from gutenberg.src.cleanup import strip_headers


def is_english(text, threshold=0.9):
    ascii_chars = sum(1 for c in text if ord(c) < 128)
    return ascii_chars / len(text) > threshold


def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"):
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    current_content = []
    current_size = 0
    file_counter = 1

    for file_path in tqdm(file_paths):
        try:
            with open(file_path, "r", encoding="utf-8") as file:
                content = file.read()
        except UnicodeDecodeError:
            # Attempt to read the file with a fallback encoding
            tqdm.write(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}")
            with open(file_path, "r", encoding=fallback_encoding) as file:
                content = file.read()

        if not is_english(content):
            tqdm.write(f"Skipping {file_path} as it does not contain primarily English text.")
            continue
        content = strip_headers(content)

        # Regular expression to replace multiple blank lines with a single blank line
        content = re.sub(r'\n\s*\n', '\n\n', content)
        estimated_size = len(content.encode("utf-8"))

        if current_size + estimated_size > max_size_mb * 1024 * 1024:
            target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
            with open(target_file_path, "w", encoding="utf-8") as target_file:
                target_file.write(separator.join(current_content))
            file_counter += 1
            current_content = [content]
            current_size = estimated_size
        else:
            current_content.append(content)
            current_size += estimated_size

    if current_content:
        target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
        with open(target_file_path, "w", encoding="utf-8") as target_file:
            target_file.write(separator.join(current_content))
    return file_counter


# if __name__ == "__main__":

#     parser = argparse.ArgumentParser(description="Preprocess and combine text files for pretraining")

#     parser.add_argument("--data_dir", type=str, default="gutenberg/data/raw",
#                         help="Directory containing the downloaded raw training data")
#     parser.add_argument("--max_size_mb", type=int, default=500,
#                         help="The maximum file size for each concatenated file in megabytes")
#     parser.add_argument("--output_dir", type=str, default="gutenberg_preprocessed",
#                         help="Directory where the preprocessed data will be saved")

#     args = parser.parse_args()

#     all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir)
#                  for name in files if name.endswith((".txt", ".txt.utf8"))]

#     print(f"{len(all_files)} file(s) to process.")
#     file_counter = combine_files(all_files, args.output_dir, max_size_mb=args.max_size_mb)
#     print(f"{file_counter} file(s) saved in {os.path.abspath(args.output_dir)}")