# 中文模型预训练

## Step1 导包

In [1]:
from transformers import TrainingArguments , Trainer
from tokenizers import  ByteLevelBPETokenizer
from torch import nn
from torch.nn import functional as F
import torch
from datasets import load_dataset,load_from_disk
from tqdm.auto import tqdm
import os, re
from torch.utils.data import Dataset , DataLoader
import hashlib
import random
import datasets
from transformers import AutoTokenizer


## Step2 导入数据集

In [None]:
thu = load_dataset(
    "text",
    data_files=r"G:\Anaconda\Kaggle\gpt2\data\THUCNews1\THUCNews1\**\*.txt",
    split="train",
    cache_dir= r'G:\Anaconda\Kaggle\gpt2\data',
    num_proc = 8
)

In [None]:
thu = load_dataset(r'G:\Anaconda\Kaggle\gpt2\data\text')

In [None]:
import datasets
from tqdm.auto import tqdm

def iter_wiki_pages(
    xml_file,
    min_text_length=200,
    sample_ratio=0.3,
    max_keep=None,
    max_seen=None,
    skip_non_main=True,
    seed=42,
):
    try:
        from lxml import etree as ET
        use_lxml = True
    except Exception:
        import xml.etree.ElementTree as ET
        use_lxml = False

    import random
    rnd = random.Random(seed)

    context = ET.iterparse(xml_file, events=('start', 'end'))
    _, root = next(context)
    ns = ''
    if '}' in root.tag:
        ns = root.tag.split('}')[0] + '}'

    kept = 0
    seen = 0
    pbar = tqdm(desc='解析页面', unit='page')
    for event, elem in context:
        if event == 'end' and elem.tag == f'{ns}page':
            if max_seen is not None and seen >= max_seen:
                break
            seen += 1

            title_elem = elem.find(f'{ns}title')
            text_elem = elem.find(f'{ns}revision/{ns}text')
            title = title_elem.text or '' if title_elem is not None else ''
            text = text_elem.text or '' if text_elem is not None else ''

            if skip_non_main and ':' in title:
                pass
            elif len(text) >= min_text_length:
                if sample_ratio is None or rnd.random() < sample_ratio:
                    yield {'title': title, 'text': text}
                    kept += 1
                    if max_keep is not None and kept >= max_keep:
                        break

            if use_lxml:
                elem.clear()
                while elem.getprevious() is not None:
                    del elem.getparent()[0]
            else:
                root.clear()
            pbar.update(1)
    pbar.close()

def load_wiki_xml_as_dataset(xml_file, **kwargs):
    features = datasets.Features({'title': datasets.Value('string'), 'text': datasets.Value('string')})
    return datasets.Dataset.from_generator(iter_wiki_pages, gen_kwargs={'xml_file': xml_file, **kwargs}, features=features)

wiki_xml_path = r"G:\Anaconda\Kaggle\gpt2\data\zhwiki-20251120-pages-articles-multistream.xml\zhwiki-20251120-pages-articles-multistream.xml"

wiki_ds_30 = load_wiki_xml_as_dataset(
    wiki_xml_path,
    min_text_length=200,
    sample_ratio=0.3,
    max_keep=None,    # 可设置成比如 300000，进一步上限控制
    max_seen=None,    # 可设置成比如 1000000，只扫描前100万页
    skip_non_main=True,
    seed=42,
)


In [None]:
save_dir_arrow = r"G:\Anaconda\Kaggle\gpt2\data\wiki_ds"
wiki_ds_30.save_to_disk(save_dir_arrow)

In [None]:
save_dir_arrow = r"G:\Anaconda\Kaggle\gpt2\data\thu_ds"
thu.save_to_disk(save_dir_arrow)

In [None]:
thu

In [None]:
wiki_ds_30

In [None]:
thu['train'][0]

In [None]:
wiki_ds_30[0]

In [None]:
from datasets import load_from_disk


thu = load_from_disk(r"G:\Anaconda\Kaggle\gpt2\data\thu_ds")

In [None]:
cleaned_ds = load_from_disk(r"G:\Anaconda\Kaggle\gpt2\data\wiki")

In [None]:
def clean_batch(batch):
    import os, re
    def make_cleaner():
            comment = re.compile(r'<!--.*?-->', re.DOTALL)
            ref_tag = re.compile(r'<ref[^>]*>.*?</ref>', re.DOTALL)
            html_tag = re.compile(r'<[^>]+>')
            template = re.compile(r'\{\{[^{}]*\}\}')
            table = re.compile(r'\{\|[\s\S]*?\|\}', re.DOTALL)
            external_link = re.compile(r'\[https?:\/\/[^\s\]]+(?:\s+([^\]]+))?\]')
            internal_link = re.compile(r'\[\[([^|\]]+)(?:\|([^\]]+))?\]\]')
            category = re.compile(r'\[\[(?:Category|分类):[^\]]+\]\]', re.IGNORECASE)
            filelink = re.compile(r'\[\[(?:File|Image|文件|图像):[^\]]+\]\]', re.IGNORECASE)
            bolditalic = re.compile(r"'''''(.*?)'''''", re.DOTALL)
            bold = re.compile(r"'''(.*?)'''", re.DOTALL)
            italic = re.compile(r"''(.*?)''", re.DOTALL)
            list_marks = re.compile(r'^[*#;:]+\s*', re.MULTILINE)
            heading = re.compile(r'(==+)\s*(.*?)\s*\1')
            spaces = re.compile(r'[ \t]+')
            blanklines = re.compile(r'\n{3,}')
            def clean_text(text):
                text = comment.sub('', text)
                text = ref_tag.sub('', text)
                text = table.sub('', text)
                for _ in range(6):
                    new = template.sub('', text)
                    if new == text:
                        break
                    text = new
                text = external_link.sub(lambda m: m.group(1) or '', text)
                text = internal_link.sub(lambda m: (m.group(2) or m.group(1)), text)
                text = bolditalic.sub(r'\1', text)
                text = bold.sub(r'\1', text)
                text = italic.sub(r'\1', text)
                text = category.sub('', text)
                text = filelink.sub('', text)
                text = html_tag.sub('', text)
                text = list_marks.sub('', text)
                text = heading.sub(lambda m: m.group(2), text)
                text = spaces.sub(' ', text)
                text = re.sub(r' +\n', '\n', text)
                text = blanklines.sub('\n\n', text)
                return text.strip()
            return clean_text
    clean_one = make_cleaner()
    return {'text': [clean_one(t) for t in batch['text']]}

wiki_ds_30 = wiki_ds_30.select_columns(['text'])
cleaned_ds = wiki_ds_30.map(
    clean_batch,
    batched=True,
    batch_size=1000,
    num_proc=max(1, os.cpu_count() // 2),
    desc='清洗文本'
)
cleaned_ds

In [None]:
save_dir_arrow = r"G:\Anaconda\Kaggle\gpt2\data\wiki"
cleaned_ds.save_to_disk(save_dir_arrow)

In [None]:
cleaned_ds[1]

In [None]:
import os


def _convert_batch(batch):
    from opencc import OpenCC
    _cc = OpenCC('t2s')
    def _to_simplified(x):
        return _cc.convert(x)
    return {'text': [_to_simplified(t) for t in batch['text']]}

cleaned_ds_simp = cleaned_ds.map(
    _convert_batch,
    batched=True,
    batch_size=1000,
    num_proc=max(1, os.cpu_count() // 2),
    desc='繁转简'
)


In [None]:
cleaned_ds_simp.save_to_disk(r"G:\Anaconda\Kaggle\gpt2\data\wiki_simple")

In [None]:
cleaned_ds_simp[1]

In [None]:
def strip_batch(batch, _state={'rx': None, 'banned': None, 'en_heading': None}):
    import os, re
    if _state['rx'] is None:
        rx = {}
        rx["comment"] = re.compile(r'<!--.*?-->', re.DOTALL)
        rx["ref"] = re.compile(r'<ref[^>]*>.*?</ref>', re.DOTALL)
        rx["html"] = re.compile(r'<[^>]+>')
        rx["template"] = re.compile(r'\{\{[^{}]*\}\}')
        rx["table"] = re.compile(r'\{\|[\s\S]*?\|\}', re.DOTALL)
        rx["external"] = re.compile(r'\[https?:\/\/[^\s\]]+(?:\s+([^\]]+))?\]')
        rx["internal"] = re.compile(r'\[\[([^|\]]+)(?:\|([^\]]+))?\]\]')
        rx["category"] = re.compile(r'\[\[(?:Category|分类):[^\]]+\]\]', re.IGNORECASE)
        rx["filelink"] = re.compile(r'\[\[(?:File|Image|文件|图像):[^\]]+\]\]', re.IGNORECASE)
        rx["bolditalic"] = re.compile(r"'''''(.*?)'''''", re.DOTALL)
        rx["bold"] = re.compile(r"'''(.*?)'''", re.DOTALL)
        rx["italic"] = re.compile(r"''(.*?)''", re.DOTALL)
        rx["listmarks"] = re.compile(r'^[*#;:]+\s*', re.MULTILINE)
        rx["heading"] = re.compile(r'(==+)\s*(.*?)\s*\1')
        rx["spaces"] = re.compile(r'[ \t]+')
        rx["blanklines"] = re.compile(r'\n{3,}')
        _state['rx'] = rx
        _state['banned'] = {
            '参考文献','参考资料','延伸阅读','外部链接','外部连结','参见','注解',
            '参考','相关条目','外部资源','扩展阅读','入门','专题介绍','选集','参考著作'
        }
        _state['en_heading'] = {'References','External links','See also','Further reading','Overview','Introduction'}

    rx = _state['rx']
    banned = _state['banned']
    en_heading = _state['en_heading']

    def ascii_ratio(s):
        a = sum(1 for ch in s if ('A' <= ch <= 'Z') or ('a' <= ch <= 'z'))
        return a / max(1, len(s))

    def remove_ascii_parens(text):
        def repl_cn(m):
            inner = m.group(1)
            return '' if ascii_ratio(inner) >= 0.4 else m.group(0)
        def repl_en(m):
            inner = m.group(1)
            return '' if ascii_ratio(inner) >= 0.4 else m.group(0)
        text = re.sub(r'（([^）]+)）', repl_cn, text)
        text = re.sub(r'\(([^\)]+)\)', repl_en, text)
        return text

    def clean_text(text):
        text = rx["comment"].sub('', text)
        text = rx["ref"].sub('', text)
        text = rx["table"].sub('', text)
        for _ in range(6):
            new = rx["template"].sub('', text)
            if new == text:
                break
            text = new
        text = re.sub(r'\}\}+', '', text)
        text = rx["external"].sub(lambda m: m.group(1) or '', text)
        text = rx["internal"].sub(lambda m: (m.group(2) or m.group(1)), text)
        text = rx["bolditalic"].sub(r'\1', text)
        text = rx["bold"].sub(r'\1', text)
        text = rx["italic"].sub(r'\1', text)
        text = rx["category"].sub('', text)
        text = rx["filelink"].sub('', text)
        text = rx["html"].sub('', text)
        text = rx["listmarks"].sub('', text)
        text = rx["heading"].sub(lambda m: m.group(2), text)
        text = rx["spaces"].sub(' ', text)
        text = re.sub(r' +\n', '\n', text)
        text = rx["blanklines"].sub('\n\n', text)
        text = remove_ascii_parens(text)
        return text.strip()

    def strip_noise(text):
        out = []
        drop_rest = False
        for line in text.splitlines():
            s = line.strip()
            if s == '':
                if not drop_rest:
                    out.append('')
                continue
            if s in banned or s in en_heading:
                drop_rest = True
                break
            if s.startswith('Category:') or s.startswith('分类:'):
                continue
            if 'ISBN' in s or 'ISSN' in s:
                continue
            if ascii_ratio(s) >= 0.4:
                continue
            out.append(line)
        res = '\n'.join(out)
        res = rx["blanklines"].sub('\n\n', res).strip()
        return res

    return {'text': [strip_noise(clean_text(t)) for t in batch['text']]}
cleaned_ds_final = cleaned_ds_simp.map(
    strip_batch,
    batched=True,
    batch_size=1000,
    num_proc=max(1, os.cpu_count() // 2),
    desc='二次清洗'
)

In [None]:
cleaned_ds_final[0]

In [None]:
cleaned_ds_final.save_to_disk(r"G:\Anaconda\Kaggle\gpt2\data\wiki_dataset")

### 如果已经处理并保存了可以直接读取数据

In [None]:
thu = load_from_disk(r"G:\Anaconda\Kaggle\gpt2\data\thu_ds")
wiki_ds = load_from_disk(r'G:\Anaconda\Kaggle\gpt2\data\wiki_dataset')

## Step3 训练tokenizer

In [None]:
from tqdm.auto import tqdm
from tokenizers import ByteLevelBPETokenizer
from transformers import GPT2TokenizerFast

def iter_ds_text_with_progress(
    ds,
    sample_ratio=0.3,
    seed=42,
    max_examples=None,
    text_column="text",
    min_len=50,
    desc="遍历样本"
):
    import random
    rnd = random.Random(seed)
    try:
        total = len(ds)
    except Exception:
        total = None
    pbar = tqdm(total=total, desc=desc, unit="样本")
    count = 0
    for ex in ds:
        if total is not None:
            pbar.update(1)
        if isinstance(ex, str):
            t = ex
        else:
            t = ex.get(text_column, "") if isinstance(ex, dict) else ""
        if not t:
            continue
        if sample_ratio is None or rnd.random() < sample_ratio:
            for p in t.split("\n"):
                p = p.strip()
                if len(p) >= min_len:
                    yield p
                    count += 1
                    if max_examples and count >= max_examples:
                        pbar.close()
                        return
    pbar.close()

wiki_iter = iter_ds_text_with_progress(
    wiki_ds,
    sample_ratio=0.3,
    seed=42,
    max_examples=1_000_000,
    text_column="text",
    min_len=50,
    desc="Wiki"
)

thu_iter = iter_ds_text_with_progress(
    thu,
    sample_ratio=0.2,
    seed=42,
    max_examples=500_000,
    text_column="text",
    min_len=50,
    desc="THUCNews"
)

def corpus_iter_with_tqdm():
    pbar = tqdm(desc="训练语料段数", unit="段")
    for t in wiki_iter:
        pbar.update(1)
        yield t
    for t in thu_iter:
        pbar.update(1)
        yield t
    pbar.close()

tokenizer = ByteLevelBPETokenizer(lowercase=False)
tokenizer.train_from_iterator(
    corpus_iter_with_tqdm(),
    vocab_size=50000,
    min_frequency=2,
    special_tokens=["<|pad|>", "<|bos|>", "<|eos|>", "<|unk|>"]
)

In [None]:
out_dir = r"G:\Anaconda\Kaggle\gpt2\tokenizer\bytebpe_zh"
os.makedirs(out_dir, exist_ok=True)
tokenizer.save_model(out_dir)

## Step4 使用tokenizer处理数据集

In [None]:
from data_build import build_pack_from_arrow_buckets

build_pack_from_arrow_buckets(
  arrow_dir=r"G:\Anaconda\Kaggle\gpt2\data\wiki_dataset",
  tokenizer_dir=r"G:\Anaconda\Kaggle\gpt2\tokenizer\bytebpe_zh",
  out_prefix=r"G:\Anaconda\Kaggle\gpt2\data\packed\wiki_1b",
  ctx_len=1024,
  target_tokens=1_000_000_000,
  bucket_ratios=(0.15, 0.35, 0.35, 0.15),
  seed=42,
  min_chars=8,
  max_doc_chars=8000,
  batch_size=64
)

In [None]:
from data_build import build_pack_from_arrow_buckets_streaming

build_pack_from_arrow_buckets_streaming(
  arrow_dir=r"G:\Anaconda\Kaggle\gpt2\data\thu_ds",
  tokenizer_dir=r"G:\Anaconda\Kaggle\gpt2\tokenizer\bytebpe_zh",
  out_prefix=r"G:\Anaconda\Kaggle\gpt2\data\packed\thu_1b",
  ctx_len=1024,
  target_tokens=1_000_000_000,
  bucket_ratios=(0.15, 0.35, 0.35, 0.15),
  seed=42,
  min_chars=8,
  max_doc_chars=8000,
  batch_size=64
)

In [None]:
ds = load_from_disk(r"G:\Anaconda\Kaggle\gpt2\data\packed\wiki_1b.bin")

In [None]:
from data_build import mix_packed_bins

mix_packed_bins(
  prefixes=[
    r"G:\Anaconda\Kaggle\gpt2\data\packed\thu_1b",
    r"G:\Anaconda\Kaggle\gpt2\data\packed\wiki_1b",
  ],
  out_prefix=r"G:\Anaconda\Kaggle\gpt2\data\packed\mix_2b",
  shuffle=True,
  seed=42
)

## Step5 训练模型

In [10]:
from model import GPTModel

In [3]:

class LMDataset(Dataset):
    def __init__(self, data_dir):
        import os
        import glob
        import datasets
        self._pylist = None
        try:
            ds = datasets.load_from_disk(data_dir)
            if isinstance(ds, datasets.DatasetDict):
                ds = ds.get("train", list(ds.values())[0])
            self.ds = ds
        except Exception:
            pattern = os.path.join(data_dir, "data-*.arrow")
            files = sorted(glob.glob(pattern))
            if not files:
                raise
            try:
                ds = datasets.Dataset.from_file(files[0])
                self.ds = ds
            except Exception:
                import pyarrow.ipc as pa_ipc
                try:
                    reader = pa_ipc.open_file(files[0])
                    table = reader.read_all()
                except Exception:
                    reader = pa_ipc.open_stream(files[0])
                    table = reader.read_all()
                col = table.column("input_ids")
                self._pylist = col.to_pylist()
                self.ds = None
    def __len__(self):
        return len(self._pylist) if self._pylist is not None else len(self.ds)
    def __getitem__(self, idx):
        if self._pylist is not None:
            return {"input_ids": self._pylist[idx]}
        ex = self.ds[idx]
        return {"input_ids": ex["input_ids"]}

def load_tokenizer(tokenizer_dir):
    import os
    from transformers import AutoTokenizer, GPT2TokenizerFast
    try:
        tok = AutoTokenizer.from_pretrained(tokenizer_dir)
    except Exception:
        vocab = os.path.join(tokenizer_dir, "vocab.json")
        merges = os.path.join(tokenizer_dir, "merges.txt")
        tok = GPT2TokenizerFast(vocab_file=vocab, merges_file=merges)
    if tok.pad_token_id is None:
        try:
            tok.add_special_tokens({"pad_token": "<|pad|>"})
        except Exception:
            pass
    return tok

class LMDataCollator:
    def __init__(self, context_length, pad_id):
        self.context_length = context_length
        self.pad_id = pad_id
    def __call__(self, batch):
        ids = []
        for item in batch:
            x = torch.tensor(item["input_ids"], dtype=torch.long)
            if x.shape[0] >= self.context_length:
                x = x[: self.context_length]
            else:
                pad = torch.full((self.context_length - x.shape[0],), self.pad_id, dtype=torch.long)
                x = torch.cat([x, pad], dim=0)
            ids.append(x)
        input_ids = torch.stack(ids, dim=0)
        attention_mask = (input_ids != self.pad_id).long()
        labels = input_ids.roll(-1, dims=1)
        labels[:, -1] = -100
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def build_lm_dataloader(data_dir, tokenizer_dir, batch_size, context_length, shuffle=True, num_workers=0):
    tok = load_tokenizer(tokenizer_dir)
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else 0
    ds = LMDataset(data_dir)
    collate = LMDataCollator(context_length, pad_id)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate)

def get_vocab_size(tokenizer_dir):
    tok = load_tokenizer(tokenizer_dir)
    try:
        return int(len(tok))
    except Exception:
        return int(tok.vocab_size)

def compute_lm_loss(logits, labels):
    return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)

In [4]:
import os, numpy as np, torch
from torch.utils.data import Dataset, DataLoader

class PackedBinDataset(Dataset):
    def __init__(self, prefix):
        # .idx 是二进制索引：int64 offset_tokens, int32 length_tokens
        self.bin_path = prefix + ".bin"
        self.idx_path = prefix + ".idx"
        self.entries = []
        with open(self.idx_path, "rb") as f:
            while True:
                rec = f.read(12)
                if not rec:
                    break
                off = int.from_bytes(rec[:8], "little", signed=True)
                ln  = int.from_bytes(rec[8:], "little", signed=True)
                self.entries.append((off, ln))
        # 用 memmap 读 .bin（int32）
        self.mm = np.memmap(self.bin_path, dtype=np.int32, mode="r")

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        off, ln = self.entries[idx]
        arr = self.mm[off: off + ln]          # shape: (ctx_len,)
        input_ids = torch.from_numpy(arr.astype(np.int64))   # to int64 tensor
        return {"input_ids": input_ids}

def collate_packed(batch):
    # 固定 ctx_len，无 padding。attention_mask 全 1；labels 右移一位，最后一位 -100
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    labels = input_ids.roll(-1, dims=1)
    labels[:, -1] = -100
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [5]:
import os, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from model import GPTModel, compute_lm_loss, get_vocab_size, load_tokenizer

import os, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import math
from tqdm.auto import tqdm
from model import GPTModel, compute_lm_loss, get_vocab_size, load_tokenizer

class PackedBinDataset(Dataset):
    def __init__(self, prefix):
        self.bin_path = prefix + ".bin"
        self.idx_path = prefix + ".idx"
        self.entries = []
        with open(self.idx_path, "rb") as f:
            while True:
                rec = f.read(12)
                if not rec: break
                off = int.from_bytes(rec[:8], "little", signed=True)
                ln  = int.from_bytes(rec[8:], "little", signed=True)
                self.entries.append((off, ln))
        self.mm = np.memmap(self.bin_path, dtype=np.int32, mode="r")
    def __len__(self): return len(self.entries)
    def __getitem__(self, idx):
        off, ln = self.entries[idx]
        arr = self.mm[off: off + ln]
        return {"input_ids": torch.from_numpy(arr.astype(np.int64))}

def collate_packed(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    labels = input_ids.roll(-1, dims=1)
    labels[:, -1] = -100
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def notebook_train(
    bin_prefix,
    tokenizer_dir,
    ctx_len=1024,
    emb_dim=768, n_heads=12, n_layers=12, dropout=0.1, qkv_bias=True, tie_weights=True,
    batch_size=8, num_workers=0,
    lr=3e-4, weight_decay=0.1, warmup_steps=3000, max_steps=0, epochs=1, grad_accum=1,
    log_every=50, save_every=1000, save_dir=r"G:\Anaconda\Kaggle\gpt2\checkpoints", resume=""
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tok = load_tokenizer(tokenizer_dir)
    vocab_nominal = get_vocab_size(tokenizer_dir)

    ds = PackedBinDataset(bin_prefix)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_packed)

    b = next(iter(dl))
    max_id = int(b["input_ids"].max().item())
    vocab_eff = max(vocab_nominal, max_id + 1)

    cfg = {
        "vocab_size": vocab_eff,
        "context_length": ctx_len,
        "emb_dim": emb_dim,
        "dropout": dropout,
        "n_heads": n_heads,
        "qkv_bias": qkv_bias,
        "n_layers": n_layers,
    }
    model = GPTModel(cfg).to(device)
    if tie_weights:
        model.output_layer.weight = model.embedding_layer.embedding.weight

    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad: continue
        if n.endswith("bias") or ("ln" in n.lower()) or ("layernorm" in n.lower()):
            no_decay.append(p)
        else:
            decay.append(p)
    optim = AdamW([
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ], lr=lr)

    total_steps = max_steps if (max_steps and max_steps > 0) else (epochs * len(dl))
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        return max(0.0, (total_steps - step) / max(1, total_steps - warmup_steps))
    scheduler = LambdaLR(optim, lr_lambda)

    try:
        scaler = torch.amp.GradScaler('cuda' if device == 'cuda' else 'cpu')
    except Exception:
        scaler = GradScaler()

    start_step = 0
    if resume:
        ckpt = torch.load(resume, map_location="cpu")
        model.load_state_dict(ckpt.get("model", {}), strict=False)
        optim.load_state_dict(ckpt.get("optimizer", {}))
        try: scaler.load_state_dict(ckpt.get("scaler", {}))
        except: pass
        start_step = int(ckpt.get("step", 0))

    os.makedirs(save_dir, exist_ok=True)
    model.train()
    step = start_step
    running, count = 0.0, 0

    p_epoch = tqdm(total=total_steps, desc="训练", unit="step")
    for _ in range(epochs):
        for batch in dl:
            step += 1

            if device == "cuda":
                _cuda_ctx = None
                try:
                    from torch.amp import autocast as _ac
                    _cuda_ctx = _ac('cuda')
                except Exception:
                    _cuda_ctx = autocast()
                with _cuda_ctx:
                    logits = model(batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device))
                    loss = compute_lm_loss(logits, batch["labels"].to(device)) / grad_accum
            else:
                logits = model(batch["input_ids"], attention_mask=batch["attention_mask"])
                loss = compute_lm_loss(logits, batch["labels"]) / grad_accum

            scaler.scale(loss).backward()
            if step % grad_accum == 0:
                scaler.unscale_(optim)
                clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optim); scaler.update()
                optim.zero_grad(set_to_none=True)
                scheduler.step()

            running += loss.item()
            count += 1
            if log_every and step % log_every == 0:
                avg = running * grad_accum / max(1, count)
                ppl = math.exp(avg) if avg < 700 else float('inf')
                lr_now = scheduler.get_last_lr()[0]
                valid_tokens = int((batch["labels"] != -100).sum().item())
                p_epoch.set_postfix_str(f"loss {avg:.4f} | ppl {ppl:.2f} | lr {lr_now:.6f} | valid {valid_tokens}")
                running, count = 0.0, 0

            if save_every and step % save_every == 0:
                path = os.path.join(save_dir, f"ckpt_step_{step}.pt")
                _obj = {
                    "model": model.state_dict(),
                    "optimizer": optim.state_dict(),
                    "scaler": scaler.state_dict(),
                    "step": step,
                    "cfg": cfg,
                }
                try:
                    torch.save(_obj, path)
                except RuntimeError:
                    torch.save(_obj, path, _use_new_zipfile_serialization=False)

            p_epoch.update(1)
            if step >= total_steps:
                break
    p_epoch.close()
    return model

def save_final_model(model, save_dir=r'G:\Anaconda\Kaggle\gpt2\checkpoints', file_name="final_model.pt"):
    os.makedirs(save_dir, exist_ok=True)
    final_ckpt = os.path.join(save_dir, file_name)
    sd_cpu = {}
    for k, v in tqdm(model.state_dict().items(), desc="准备最终权重", unit="param"):
        sd_cpu[k] = v.detach().cpu()
    try:
        torch.save({'model': sd_cpu}, final_ckpt)
    except RuntimeError as e:
        torch.save({'model': sd_cpu}, final_ckpt, _use_new_zipfile_serialization=False)
    print("已保存:", final_ckpt)

In [6]:
model = notebook_train(
  bin_prefix=r"G:\Anaconda\Kaggle\gpt2\data\packed\mix_2b",
  tokenizer_dir=r"G:\Anaconda\Kaggle\gpt2\tokenizer\bytebpe_zh",
  ctx_len=1024,
  emb_dim=768, n_heads=12, n_layers=12, dropout=0.1, qkv_bias=False, tie_weights=True,
  batch_size=1, num_workers=0,
  lr=3e-4, weight_decay=0.1, warmup_steps=3000, max_steps=10000000000, grad_accum=1,
  log_every=50, save_every=1000, save_dir=r"G:\Anaconda\Kaggle\gpt2\checkpoints", resume=""
)
save_final_model(model, save_dir=r'G:\Anaconda\Kaggle\gpt2\checkpoints', file_name="final_model.pt")

  scaler = GradScaler()


训练:   0%|          | 0/10000000000 [00:00<?, ?step/s]

  with autocast():


KeyboardInterrupt: 

In [None]:
save_dir=r'G:\Anaconda\Kaggle\gpt2\checkpoints'
final_ckpt = os.path.join(save_dir, "final_model.pt")
torch.save({
        'model': model.state_dict(),
    }, final_ckpt)

## Step6 使用模型生成文本

In [10]:
import os, torch
from model import GPTModel, load_tokenizer

def load_model_from_ckpt(ckpt_path, device=None, dtype=None, cfg_override=None, tie_weights=True):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    cfg = cfg_override
    model = GPTModel(cfg)
    if tie_weights:
        try:
            model.output_layer.weight = model.embedding_layer.embedding.weight
        except Exception:
            pass
    try:
        model.load_state_dict(ckpt['model'], strict=False)
    except Exception:
        model.load_state_dict(ckpt['model'])
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if dtype is not None:
        model.to(dtype=dtype)
    model.to(device)
    model.eval()
    return model, cfg

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0):
    if top_k and top_k > 0:
        v, _ = torch.topk(logits, k=min(top_k, logits.size(-1)), dim=-1)
        kth = v[:, -1].unsqueeze(-1)
        logits = torch.where(logits < kth, torch.full_like(logits, -float('inf')), logits)
    if top_p and top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 0] = False
        indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, -float('inf'))
    return logits

def generate(prompt, ckpt_path, tokenizer_dir, max_new_tokens=200, temperature=1.0, top_k=50, top_p=0.95, stop_on_eos=True, device=None, use_half=True, cfg_override=None, tie_weights=True):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    tok = load_tokenizer(tokenizer_dir)
    eos_id = tok.eos_token_id
    dtype = torch.float16 if (use_half and device == 'cuda') else None
    model, cfg = load_model_from_ckpt(ckpt_path, device=device, dtype=dtype, cfg_override=cfg_override, tie_weights=tie_weights)
    ids = tok(prompt, add_special_tokens=False, return_attention_mask=False)['input_ids']
    if tok.bos_token_id is not None:
        ids = [tok.bos_token_id] + ids
    input_ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            x = input_ids[:, -cfg['context_length']:]
            attn = torch.ones_like(x, dtype=torch.long)
            logits = model(x, attention_mask=attn)
            logits = logits[:, -1, :] / max(temperature, 1e-5)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            nid = int(next_id.item())
            input_ids = torch.cat([input_ids, next_id], dim=1)
            if stop_on_eos and eos_id is not None and nid == eos_id:
                break
    return tok.decode(input_ids[0].tolist(), skip_special_tokens=True)

In [13]:
cfg = {
  "vocab_size": 50002,
  "context_length": 1024,
  "emb_dim": 768,
  "dropout": 0.1,
  "n_heads": 12,
  "qkv_bias": False,
  "n_layers": 12,
}
text = generate(
  prompt="东南大学",
  ckpt_path=r"G:\Anaconda\Kaggle\gpt2\checkpoints\final_model.pt",
  tokenizer_dir=r"G:\Anaconda\Kaggle\gpt2\tokenizer\bytebpe_zh",
  max_new_tokens=512,
  temperature=0.8,
  top_k=30, top_p=0.9,
  stop_on_eos=True,
  device='cuda',
  use_half=True,
  cfg_override=cfg,
  tie_weights=True
)
print(text)

东南大学大学是江苏省理工学院的招生专业之一，由学校教务处负责管理。目前，东南大学在实行平行志愿录取工作，并组织开展“补录”。
