# Pre-training Model Notebook

This notebook migrates the code from pre_train.py, maintaining the original logic and structure completely unchanged.

## Third-party Library Dependencies

In [1]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import tokenizers
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import random
import gc
import json

from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any

## Project Internal Dependencies

The following code is migrated from utils.py, dataset.py, and models.py files

In [2]:
# TextGenerator from utils.py
import time

class TextGenerator:
    def __init__(
        self,
        model: nn.Module,
        tokenizer: tokenizers.Tokenizer,
        device,
        padding_side="right",
    ) -> None:
        self.tokenizer = tokenizer
        self.device = device
        if isinstance(model, nn.DataParallel):
            print("DataParallel is used")
            self.model = model.module
        else:
            self.model = model
        self.seq_max_len = self.model.args.seq_max_len
        self.padding_side = padding_side
        self.tokenizer.enable_padding(direction=padding_side, length=self.seq_max_len)
        self.tokenizer.enable_truncation(max_length=self.seq_max_len, direction=padding_side)

    def generate(
    self,
    start_token: str,
    gen_seq_len=30,
    temperature=0.7,
    frequency_penalty=0.1,
    top_k=20,
    print_out=True,
    ):
        with torch.no_grad():
            self.model.eval()
            tokens = [start_token]
            all_token_ids = self.tokenizer.encode(start_token).ids

            for i in range(gen_seq_len):
                all_token_ids = self.tokenizer.encode(''.join(tokens)).ids
                input_tensor = torch.tensor(all_token_ids).int().unsqueeze(0).to(self.device)
                out = self.model(input_tensor)
                
                if self.padding_side == "right":
                    logits = out[0, len(tokens)-1, :]  
                elif self.padding_side == "left":
                    logits = out[0, -1, :]
                else:
                    raise ValueError("padding_side must be 'right' or 'left'")
                
                if frequency_penalty != 0:
                    tokens_tensor = torch.tensor(all_token_ids, device=self.device)
                    unique, counts = torch.unique(tokens_tensor, return_counts=True)
                    
                    penalty = torch.zeros_like(logits)
                    penalty[unique] = counts.float() * frequency_penalty
                    logits = logits - penalty
                
                if top_k is not None and top_k > 0:
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                    logits[indices_to_remove] = -float('Inf')
                
                probabilities = F.softmax(logits / temperature, dim=-1)
                next_token_id = probabilities.multinomial(num_samples=1).item()
                
                tokens.append(self.tokenizer.decode([next_token_id], skip_special_tokens=False))
                if print_out:
                    print(tokens[-1], end=" ", flush=True)
            
            return tokens


class DebugTimer:
    def __init__(self, name=None):
        self.start_time = None
        self.name = name

    def __call__(self, func):
        def wrapper(*args, **kwargs):
            self.timer_start(self.name)
            result = func(*args, **kwargs)
            self.timer_stop()
            return result

        return wrapper

    def timer_start(self, name=None):
        self.start_time = time.perf_counter()
        if name is not None:
            self.name = name
        print(f"{self.name}:", end="")

    def timer_stop(self):
        elapsed_time = round(time.perf_counter() - self.start_time, 4)
        print(f"{elapsed_time}s")

def _format_string(s, length, fill_char=" "):
    return s.ljust(length, fill_char)

def model_structure(model):
    print("-" * 90)
    print(
        "|"
        + _format_string("weight name", 31)
        + "|"
        + _format_string("weight shape", 42)
        + "|"
        + _format_string("number", 13)
        + "|"
    )
    print("-" * 90)

    total_params = 0
    type_size = 1

    for key, param in model.named_parameters():
        formatted_key = _format_string(key, 30)
        shape_str = _format_string(str(param.shape), 40)
        param_count = param.numel()
        formatted_count = _format_string(str(param_count), 10)

        print(f"| {formatted_key} | {shape_str} | {formatted_count} |")
        total_params += param_count

    print("-" * 90)
    print(f"The total number of parameters: {total_params}")
    print(
        f"The parameters of Model {model._get_name()}: {total_params * type_size / 1e6:.4f}M"
    )
    print("-" * 90)
    return total_params


class WarmUpCosineLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, total_epochs, warmup_epochs, min_lr=0, last_epoch=-1):
        self.last_epoch = last_epoch
        self.total_epochs = total_epochs
        self.warmup_epochs = warmup_epochs
        self.min_lr = min_lr
        super(WarmUpCosineLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            return [
                base_lr * (self.last_epoch + 1) / self.warmup_epochs
                for base_lr in self.base_lrs
            ]
        else:
            current_epoch = self.last_epoch - self.warmup_epochs
            total_cosine_epochs = self.total_epochs - self.warmup_epochs
            return [
                self.min_lr
                + (base_lr - self.min_lr)
                * (
                    1
                    + torch.cos(
                        torch.tensor(current_epoch / total_cosine_epochs * torch.pi)
                    )
                )
                / 2
                for base_lr in self.base_lrs
            ]

In [3]:
# Dataset classes from dataset.py
import collections
import sys
import json
import random

class StreamingTextDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_dir: str,
        tokenizer: tokenizers.Tokenizer,
        seq_max_len: int = 192,
        downsample: int = 1,
        batch: bool = None,
        re_tokenize: bool = False,
        padding_side: str = "right",
    ):
        super().__init__()
        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.seq_max_len = seq_max_len
        self.re_tokenize = re_tokenize
        self.padding_side = padding_side

        self.line_offsets = []
        self._build_line_index(downsample)

    def _build_line_index(self, downsample: int):
        with open(self.data_dir, "rb") as f:
            offset = 0
            line_count = 0
            while True:
                self.line_offsets.append(offset)
                line = f.readline()
                if not line:
                    break

                offset += len(line)
                line_count += 1
            self.line_offsets = random.sample(
                self.line_offsets, k=int(len(self.line_offsets) * downsample)
            )

    def pad_seq(
        self,
        seq: list[int],
        max_len: int,
        truncation=True,
        padding_value=0,
        padding_side="left",
    ):
        if truncation:
            if padding_side == "right":
                seq = seq[:max_len]
            elif padding_side == "left":
                seq = seq[-max_len:]
            else:
                raise ValueError("padding_side must be 'left' or 'right'")

        if len(seq) < max_len:
            if padding_side == "left":
                seq = [padding_value] * (max_len - len(seq)) + seq
            elif padding_side == "right":
                seq = seq + [padding_value] * (max_len - len(seq))
            else:
                raise ValueError("padding_side must be 'left' or 'right'")

        return seq

    def __len__(self):
        return len(self.line_offsets)

    def __getitem__(self, index):
        with open(self.data_dir, "r", encoding="utf-8") as f:
            f.seek(self.line_offsets[index])
            line = f.readline().strip()

        if self.re_tokenize:
            raw = self.tokenizer.encode(line).ids
        else:
            raw = self.tokenizer.encode(line.split(" "), is_pretokenized=True).ids

        raw = self.pad_seq(
            raw,
            max_len=self.seq_max_len,
            truncation=True,
            padding_value=0,
            padding_side=self.padding_side,
        )

        raw_tensor = torch.tensor(raw, dtype=torch.long)

        return (raw_tensor[:-1].contiguous(), raw_tensor[1:].contiguous())

class Vocab:
    def __init__(self, word_counts, specials=["<PAD>", "<UNK>"]):
        self.stoi = {}
        self.itos = []
        for special in specials:
            self.stoi[special] = len(self.stoi)

        for word, _ in word_counts:
            self.stoi.setdefault(word, len(self.stoi))

        self.itos = list(self.stoi.keys())
        self.default_index = self.stoi["<UNK>"]

    def __call__(self, word):
        return self.stoi.get(word, self.default_index)

    def __len__(self):
        return len(self.stoi)

    def build_from_dict(self, word_dict):
        self.stoi = word_dict
        self.itos = list(self.stoi.keys())
        self.default_index = self.stoi["<UNK>"]

In [4]:
# Model classes from models.py
import math

@dataclass
class MyLMArgs:
    d_model: int
    d_inner: int
    n_layers: int
    vocab_size: int
    seq_max_len: int
    use_moe: bool = False
    n_heads: int = None
    n_experts: int = 4
    n_experts_per_tok: int = 2
    d_conv: int = 3
    conv_bias: bool = True
    ffn_bias: bool = False
    attn_bias: bool = False
    d_head: int = 64
    dropout: float = 0.1
    init_std: float = 0.25
    resid_pdrop: float = 0.1
    resid_scale: float = 1.0
    layer_scale: float = 1.0
    use_deepnet_scaling: bool = True


class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        self._reset_parameters()
    
    def _reset_parameters(self, init_std=0.02):
        with torch.no_grad():
            self.weight.fill_(1.0)
    
    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * x.to(input_dtype)

class CausalConv1d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=self.pad,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

    def forward(self, input):
        return self.conv(input)[:, :, : -self.pad]

class GPT2PositionEmbedding(nn.Module):
    def __init__(self, seq_max_len, d_model):
        super().__init__()
        self.pos_emb = nn.Embedding(seq_max_len, d_model)
        self._reset_parameters()

    def _reset_parameters(self, init_std=0.02):
        nn.init.normal_(self.pos_emb.weight, std=init_std)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        assert (
            seq_len <= self.pos_emb.num_embeddings
        ), f"Sequence length {seq_len} exceeds max length {self.pos_emb.num_embeddings}"
        pos = torch.arange(seq_len).to(x.device)
        pos_emb = self.pos_emb(pos)
        pos_emb = pos_emb.unsqueeze(0)
        return x + pos_emb

class MyPositionEmbedding(nn.Module):
    def __init__(self, d_model, d_out, d_inner=None):
        super().__init__()
        if d_inner is None:
            d_inner = d_model // 16
        self.d_inner = d_inner
        self.d_model = d_model
        self.gru = nn.GRU(d_inner, d_inner, bias=False, batch_first=True)
        self.pos_conv = nn.Conv1d(d_model, d_inner, 1)
        self.proj_conv = nn.Conv1d(d_model, d_model - d_inner, 1)
        self.linear = nn.Linear(d_model, d_out, bias=False)

    def _reset_parameters(self, init_std=0.02):
        torch.nn.init.normal_(self.linear.weight, std=init_std)
        if hasattr(self, 'pos_conv') and self.pos_conv.weight is not None:
            torch.nn.init.normal_(self.pos_conv.weight, std=init_std)
        if hasattr(self, 'proj_conv') and self.proj_conv.weight is not None:
            torch.nn.init.normal_(self.proj_conv.weight, std=init_std)
        if hasattr(self, 'gru'):
            for name, param in self.gru.named_parameters():
                if 'weight' in name:
                    torch.nn.init.normal_(param, std=init_std)
                elif 'bias' in name:
                    torch.nn.init.zeros_(param)

    def forward(self, x):
        res = x
        x = x.transpose(1, 2)
        x_proj = self.proj_conv(x).transpose(1, 2)
        pos = self.pos_conv(x).transpose(1, 2)
        pos, _ = self.gru(pos)
        x = self.linear(torch.cat([pos, x_proj], dim=-1))

        return x + res

class ALiBi(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        slopes = torch.Tensor(self._get_slopes(num_heads))
        self.register_buffer("slopes", slopes)

    def _get_slopes(self, n):
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + self._get_slopes(2 * closest_power_of_2)[0::2][
                    : n - closest_power_of_2
                ]
            )

    def forward(self, seq_len, batch_size, device):
        context_position = torch.arange(seq_len, device=device)[:, None]
        memory_position = torch.arange(seq_len, device=device)[None, :]
        relative_position = torch.abs(
            context_position - memory_position
        )

        bias = relative_position[None, ...] * self.slopes[:, None, None]
        bias = -bias

        bias = bias.repeat(batch_size, 1, 1)
        return bias

class Attention(nn.Module):
    def __init__(self, args: MyLMArgs):
        super().__init__()
        self.d_model = args.d_model
        self.n_heads = args.n_heads or (args.d_model // args.d_head)
        self.d_head = args.d_head
        self.seq_max_len = args.seq_max_len

        self.q_proj = nn.Linear(args.d_model, args.d_model)
        self.k_proj = nn.Linear(args.d_model, args.d_model)
        self.v_proj = nn.Linear(args.d_model, args.d_model)
        self.o_proj = nn.Linear(args.d_model, args.d_model)

        self.register_buffer("cos_cached", torch.zeros(1, 1, args.seq_max_len, args.d_head))
        self.register_buffer("sin_cached", torch.zeros(1, 1, args.seq_max_len, args.d_head))

        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        
        self._init_rope()
        self._reset_parameters()

    def _init_rope(self):
        d_head_half = self.d_head // 2
        inv_freq = 1.0 / (
            100 ** (torch.arange(0, d_head_half, dtype=torch.float) / d_head_half)
        )

        t = torch.arange(self.seq_max_len, dtype=torch.float)
        freqs = torch.einsum("i,j->ij", t, inv_freq)

        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0))
        self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0))

    def _reset_parameters(self, init_std=0.02):
        torch.nn.init.normal_(self.q_proj.weight, std=init_std)
        torch.nn.init.normal_(self.k_proj.weight, std=init_std)
        torch.nn.init.normal_(self.v_proj.weight, std=init_std)
        torch.nn.init.normal_(self.o_proj.weight, std=init_std)

    def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_rotary_pos_emb(
        self, q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ):
        cos = cos[:, :, : q.size(2), :]
        sin = sin[:, :, : q.size(2), :]

        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, x: torch.Tensor, token_ids=None) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(
            1, 2
        )
        k = k.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(
            1, 2
        )
        v = v.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(
            1, 2
        )

        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]
        q, k = self._apply_rotary_pos_emb(q, k, cos, sin)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=att.device)).view(1, 1, seq_len, seq_len)
        att = att.masked_fill(causal_mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        y = att @ v
        y = (
            y.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        )

        y = self.resid_dropout(self.o_proj(y))
        return y


class FFN(nn.Module):
    def __init__(self, args: MyLMArgs):
        super().__init__()
        self.args = args
        self.gate_proj = nn.Linear(args.d_model, args.d_inner, bias=False)
        self.up_proj = nn.Linear(args.d_model, args.d_inner, bias=False)
        self.down_proj = nn.Linear(args.d_inner, args.d_model, bias=False)
        self.act_fn = nn.SiLU()
        self._reset_parameters()

    def _reset_parameters(self, init_std=0.02):
        print('on call _reset_parameters')
        torch.nn.init.normal_(self.gate_proj.weight, std=init_std)
        torch.nn.init.normal_(self.up_proj.weight, std=init_std)
        torch.nn.init.normal_(self.down_proj.weight, std=init_std)

    def forward(self, x: torch.Tensor, token_ids=None) -> torch.Tensor:
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

class MoEFFN(nn.Module):
    def __init__(self, args: MyLMArgs):
        super().__init__()
        self.args = args
        self.router = nn.Linear(args.d_model, args.n_experts)
        self.experts = nn.ModuleList([FFN(args) for _ in range(args.n_experts)])

    def _reset_parameters(self, init_std=0.02, ffn_scale=1.0, resid_scale=1.0, n_layers=12):
        torch.nn.init.normal_(self.router.weight, std=init_std * 0.1)
        if self.router.bias is not None:
            nn.init.zeros_(self.router.bias)
        for expert in self.experts:
            expert._reset_parameters(init_std, ffn_scale, resid_scale, n_layers)

    def forward(self, x, token_ids=None):
        probs = F.softmax(self.router(x), dim=-1)
        top_k_probs, top_k_indices = torch.topk(
            probs, self.args.n_experts_per_tok, dim=-1
        )
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        expert_inputs = x.view(
            -1, x.shape[-1]
        )
        expert_inputs = expert_inputs.repeat_interleave(
            self.args.n_experts_per_tok, dim=0
        )
        flat_top_k_idx = top_k_indices.view(-1)
        expert_outputs = torch.zeros_like(expert_inputs)

        for expert_idx, expert in enumerate(self.experts):
            mask = flat_top_k_idx == expert_idx
            if mask.any():
                expert_outputs[mask] = expert(expert_inputs[mask], token_ids=token_ids)
        expert_outputs = (
            expert_outputs.view(*top_k_probs.shape, -1) * top_k_probs.unsqueeze(-1)
        ).sum(
            dim=2
        )

        return expert_outputs

class MyLMDecoderLayer(nn.Module):
    def __init__(self, args: MyLMArgs):
        super().__init__()
        self.args = args
        self.attn = Attention(args)
        self.mlp = MoEFFN(args) if args.use_moe else FFN(args)
        self.input_layernorm = RMSNorm(args.d_model)
        self.post_attention_layernorm = RMSNorm(args.d_model)
        

    def forward(self, x: torch.Tensor, token_ids=None) -> torch.Tensor:
        residual = x
        x = self.input_layernorm(x)
        x = self.attn(x, token_ids=token_ids)
        x = residual + x
        

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x, token_ids=token_ids)
        x = residual + x

        return x


class MyLM(nn.Module):
    def __init__(self, args: MyLMArgs):
        super().__init__()
        self.args = args

        self.token_embedding = nn.Embedding(args.vocab_size, args.d_model)

        self.blocks = nn.ModuleList(
            [MyLMDecoderLayer(args) for _ in range(args.n_layers)]
        )

        self.norm = RMSNorm(args.d_model)
        self.head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self._reset_parameters()

    def _reset_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor, token_ids=None) -> torch.Tensor:
        x = self.token_embedding(x)

        for block in self.blocks:
            x = block(x, token_ids=token_ids)

        x = self.norm(x)
        logits = self.head(x)

        return logits

## Pre-training Main Code

The following is the core training code migrated from pre_train.py

In [5]:
t = DebugTimer()

@dataclass
class TrainingConfig:
    data_dir: str = r"data_large_ChatML.txt"
    tokenizer_dir: str = r"bpe_tokenizer_6k_0724_ChatML.json"
    model_save_dir: str = r"model\model_state.pth"
    ckpt_save_dir: str = r"ckpt\ckpt.pth"
    config_save_dir: str = r"config.json"
    log_dir: str = r"logs"
    padding_side = "left"

    seed: int = 42
    epochs: int = 5
    batch_size: int = 32
    batch_acceleration: int = 4
    dataset_downsample: int = 0.008
    valset_rate: float = 0.01
    val_interval_step: int = 100
    seq_max_len=192

    learning_rate: float = 5e-3
    min_learning_rate: float = 5e-4
    warmup_steps: int = 1
    use_amp: bool = False

    model_args = MyLMArgs(
        d_model=256,
        d_inner=int(((256 * (8 / 3)) // 64) * 64),
        d_head=64,
        n_heads=None,
        n_layers=1,
        vocab_size=None,
        seq_max_len=seq_max_len,
        use_moe=False,
        n_experts=None,
        n_experts_per_tok=None,
        d_conv = None,
        conv_bias = None,
        ffn_bias = False,
        attn_bias = True,
        dropout = 0.1
    )

    ckpt_interval_step: int = 1000
    resume_from: Optional[str] = None

class PreTrainer:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self._set_seed()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = tokenizers.Tokenizer.from_file(config.tokenizer_dir)
        self.config.model_args.vocab_size = len(self.tokenizer.get_vocab())
        self.train_loader, self.val_loader = self._build_dataloader()
        self.model = self._build_model().to(self.device)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.optimizer, self.scheduler = self._build_optimizer()
        self.scaler = torch.GradScaler(self.device, enabled=config.use_amp)
        self.generator = TextGenerator(
            self.model, self.tokenizer, self.device, padding_side=config.padding_side
        )

        self.current_epoch = 0
        self.global_step = 0
        self.start_epoch = 0
        self.current_step = 0
        self.start_step = 0
        self.train_loss_log = []
        self.val_loss_log = []
        self.lr_log = []

        if config.resume_from is not None:
            self.load_checkpoint(config.resume_from)

    def _set_seed(self):
        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        torch.cuda.manual_seed(self.config.seed)
        torch.cuda.manual_seed_all(self.config.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

    def _build_model(self):
        model = MyLM(self.config.model_args)
        return model

    def _build_dataloader(self):
        dataset = StreamingTextDataset(
            self.config.data_dir,
            downsample=self.config.dataset_downsample,
            seq_max_len=self.config.seq_max_len,
            tokenizer=self.tokenizer,
            re_tokenize=False,
            batch=False,
            padding_side=self.config.padding_side,
        )
        val_dataset_len = int(len(dataset) * self.config.valset_rate)
        train_dataset_len = len(dataset) - val_dataset_len
        train_dataset, val_dataset = torch.utils.data.random_split(
            dataset, [train_dataset_len, val_dataset_len]
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            pin_memory=False,
            num_workers=4
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            pin_memory=False,
            num_workers=1
        )

        return train_loader, val_loader

    def _build_optimizer(self):
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            amsgrad=False,
            betas=(0.85, 0.99),
            eps=1e-6,
            weight_decay=0.005,
        )

        scheduler = WarmUpCosineLR(
            optimizer,
            total_epochs=(
                self.config.epochs
                * (len(self.train_loader) // self.config.batch_acceleration + 1)
            )
            + 1,
            warmup_epochs=self.config.warmup_steps,
            min_lr=self.config.min_learning_rate,
        )
        return optimizer, scheduler

    def save_checkpoint(self, path: str, is_final: bool = False):
        if isinstance(self.model, nn.DataParallel):
            model_state_dict = self.model.module.state_dict()
        else:
            model_state_dict = self.model.state_dict()
        state = {
            "epoch": self.current_epoch,
            "global_step": self.global_step,
            "current_step": self.current_step,
            "model_state_dict": model_state_dict,
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "train_loss": self.train_loss_log[-1] if self.train_loss_log else None,
            "rng_states": {
                "torch": torch.get_rng_state(),
                "cuda": (
                    torch.cuda.get_rng_state_all()
                    if torch.cuda.is_available()
                    else None
                ),
                "random": random.getstate(),
                "numpy": np.random.get_state(),
            },
        }
        torch.save(state, path)
        if is_final:
            torch.save(model_state_dict, self.config.model_save_dir)

    def load_checkpoint(self, checkpoint_path: str):
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, weights_only=False)

        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

        self.current_epoch = checkpoint["epoch"]
        self.global_step = checkpoint["global_step"]
        self.start_epoch = checkpoint["epoch"]
        self.start_step = checkpoint["current_step"]

        rng_states = checkpoint["rng_states"]
        torch.set_rng_state(rng_states["torch"])
        if rng_states["cuda"] and torch.cuda.is_available():
            torch.cuda.set_rng_state_all(rng_states["cuda"])
        random.setstate(rng_states["random"])
        np.random.set_state(rng_states["numpy"])

    def _train_step(self, inputs, targets):
        self.model.train()
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)

        with torch.autocast(str(self.device), enabled=self.config.use_amp):
            output = self.model(inputs)
            loss = self.criterion(
                output.view(-1, self.config.model_args.vocab_size), targets.view(-1)
            )

        loss = loss / self.config.batch_acceleration

        self.scaler.scale(loss).backward()

        if ((self.current_step + 1) % self.config.batch_acceleration == 0) or (
            self.current_step + 1 == len(self.train_loader)
        ):
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad(set_to_none=True)
            self.scheduler.step()

        return loss.item() * self.config.batch_acceleration

    def log(self):
        total_params = model_structure(self.model)
        print(f"Training parameters:")
        print(f"Vocabulary size: {self.tokenizer.get_vocab_size()}")
        val_dataset_len, train_dataset_len = len(self.val_loader.dataset), len(
            self.train_loader.dataset
        )
        print(f"Context length: {self.config.model_args.seq_max_len}")
        print(f"Dataset size: {val_dataset_len+train_dataset_len}")
        print(f"Training set size: {train_dataset_len}")
        print(f"Validation set size: {val_dataset_len}")
        nums_token = self.config.model_args.seq_max_len * train_dataset_len
        print(f"Tokens: {nums_token/1e6:.3f}M")
        print(f"Model parameters: {total_params/1e6:.3f}M")
        print(
            f"Computational complexity: {(nums_token * total_params * 6)/1e12:.2f}TFLOPs * {self.config.epochs} = {(nums_token * total_params * 6 * self.config.epochs)/1e12:.2f}TFLOPs"
        )

    def train(self):
        gc.collect()
        writer = SummaryWriter(log_dir=self.config.log_dir)
        print("~~~Training started~~~")

        if torch.cuda.device_count() > 1:
            print(f"Multi-GPU training: {torch.cuda.device_count()} GPUs")
            self.model = nn.DataParallel(self.model)
        val_loss = 0
        for epoch in range(self.config.epochs):
            bar = tqdm(self.train_loader, unit="step")
            if epoch < self.start_epoch:
                print(f"Skipping already trained epoch: {epoch}")
                continue
            elif epoch == self.start_epoch:
                bar.update(self.start_step)

            self.current_epoch = epoch
            train_loss_sum = 0

            torch.cuda.empty_cache()

            for i, (train_inputs, train_targets) in enumerate(self.train_loader):
                if i <= self.start_step and epoch == self.start_epoch:
                    continue
                self.current_step = i
                loss = self._train_step(train_inputs, train_targets)

                train_loss_sum += loss
                self.train_loss_log.append((self.global_step, loss))
                writer.add_scalar("Loss/train", loss, self.global_step)
                self.lr_log.append(
                    (self.global_step, float(self.scheduler.get_last_lr()[0]))
                )
                self.global_step += 1
                writer.add_scalar(
                    "LearningRate",
                    float(self.scheduler.get_last_lr()[0]),
                    self.global_step,
                )

                if i % self.config.val_interval_step == 0:
                    val_loss = self.validate()
                    self.generate_test("AI is")
                    self.val_loss_log.append((self.global_step, val_loss))
                    writer.add_scalar("Loss/val", val_loss, self.global_step)

                bar.update(1)
                bar.postfix = f"train_loss: {loss:.2f} test_loss: {val_loss:.2f} lr: {self.scheduler.get_last_lr()[0]:.2e}"

                if self.global_step % self.config.ckpt_interval_step == 0:
                    val_loss = self.validate()
                    self.val_loss_log.append((self.global_step, val_loss))
                    ckpt_path = f"{self.config.ckpt_save_dir.rsplit('.', 1)[0]}_epoch_{self.current_epoch}_step_{self.global_step}.pth"
                    self.save_checkpoint(ckpt_path, is_final=False)

            bar.close()

            val_loss = self.validate()
            self.val_loss_log.append((self.global_step, val_loss))

            self.generate_test()

            self.save_checkpoint(
                f"{self.config.ckpt_save_dir.rsplit('.', 1)[0]}_epoch_{epoch}.pth",
                is_final=True,
            )

            test_text = self.generate_test(gen_len=10)
            writer.add_text(
                "GeneratedText", f"epoch_{epoch}: {test_text}", self.global_step
            )

            print(f"Learning rate {self.scheduler.get_last_lr()}")
            print(
                f"Epoch: {epoch+1}/{self.config.epochs}, avg_train_loss: {train_loss_sum/len(self.train_loader)}, avg_test_loss: {val_loss}"
            )

        self.save_checkpoint(self.config.model_save_dir, is_final=True)

    def validate(self) -> float:
        self.model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for val_inputs, val_targets in self.val_loader:
                val_inputs = val_inputs.to(self.device)
                val_targets = val_targets.to(self.device)
                with torch.autocast(str(self.device), enabled=self.config.use_amp):
                    val_output = self.model(val_inputs)
                    loss = self.criterion(
                        val_output.view(-1, self.config.model_args.vocab_size),
                        val_targets.view(-1),
                    )
                val_loss_sum += loss.item()
        return val_loss_sum / len(self.val_loader)

    def generate_test(self, start: str = "I", gen_len: int = 25):
        self.model.eval()
        ans = self.generator.generate(
            start_token=start, gen_seq_len=gen_len, print_out=False
        )
        ans = ans[len(start) :]
        result = "".join(ans)
        print(f"(input){start}-> {result}")
        return result

    def plot_losses(self):
        fig, ax1 = plt.subplots(figsize=(16, 10))

        train_steps = [step for step, loss in self.train_loss_log]
        train_losses = [loss for step, loss in self.train_loss_log]

        val_steps = [step for step, loss in self.val_loss_log]
        val_losses = [loss for step, loss in self.val_loss_log]

        ax1.plot(train_steps, train_losses, label="Train Loss")
        ax1.plot(val_steps, val_losses, "o-", label="Test Loss")
        ax1.set_xlabel("Steps")
        ax1.set_ylabel("Loss")
        ax1.legend(loc="upper left")

        ax2 = ax1.twinx()
        ax2.plot(
            [step for step, _ in self.lr_log],
            [float(value) for _, value in self.lr_log],
            label="Learning Rate",
            color="c",
            linestyle="--",
        )
        ax2.set_ylabel("Learning Rate")
        ax2.tick_params(axis="y")
        ax2.legend(loc="upper right")

        plt.title("Train and Test Loss Curves with Learning Rate")
        plt.show()

# The main execution code is commented out to avoid running when importing in notebook
# if __name__ == "__main__":
#     config = TrainingConfig()
#     config_dict = asdict(config.model_args)
#     with open(config.config_save_dir, "w") as f:
#         json.dump(config_dict, f, indent=4)
#     trainer = PreTrainer(config)
#     trainer.log()
#     trainer.train()
#     trainer.plot_losses()
#
#     MAX_LEN = 10
#     T = 0.8
#     while True:
#         start = input("In>>")
#         if start[:2] == "T=":
#             T = float(start[2:])
#             print(f"T={T}")
#         else:
#             print(
#                 f"T={T}\n"
#                 + "".join(
#                     trainer.generator.generate(
#                         start_token=start,
#                         gen_seq_len=MAX_LEN,
#                         temperature=T,
#                         frequency_penalty=10,
#                         print_out=False,
#                     )
#                 )
#             )

## Test Code

The following code is used to test if the notebook works properly

In [6]:
# Test the configuration creation
config = TrainingConfig()
print("Configuration created successfully")
print(f"Model parameters: d_model={config.model_args.d_model}, n_layers={config.model_args.n_layers}")

# Print config dictionary
config_dict = asdict(config.model_args)
print("Model arguments dictionary:", config_dict)

# Print model structure
model = MyLM(config.model_args)
print("Model created successfully")

Configuration created successfully
Model parameters: d_model=256, n_layers=1
Model arguments dictionary: {'d_model': 256, 'd_inner': 640, 'n_layers': 1, 'vocab_size': None, 'seq_max_len': 192, 'use_moe': False, 'n_heads': None, 'n_experts': None, 'n_experts_per_tok': None, 'd_conv': None, 'conv_bias': None, 'ffn_bias': False, 'attn_bias': True, 'd_head': 64, 'dropout': 0.1, 'init_std': 0.25, 'resid_pdrop': 0.1, 'resid_scale': 1.0, 'layer_scale': 1.0, 'use_deepnet_scaling': True}


TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
