# BERT + Discrete (Masked) Diffusion
A minimal training pipeline (LLaDA / SMDM‑style) implemented with **PyTorch Lightning**.

*Generated automatically on 2025-08-07 11:26:09.*

## 1  Environment & installs
Run the following cell **once** (e.g. on Colab) to install required libraries.

In [1]:
# !pip install -U torch pytorch-lightning transformers datasets accelerate sentencepiece

In [None]:
# 顶格放这段，后面再 import transformers
import os, warnings, logging
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"   # HF 4.42+ 支持
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message="The current process just got forked")  # 可选
warnings.filterwarnings("ignore", message="A parameter name that contains")
logging.getLogger("transformers").setLevel(logging.ERROR)               # 可选


## 2  Imports / basic config

In [2]:
import math, random, torch, torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
import pytorch_lightning as pl
from datasets import load_dataset
from transformers import AutoTokenizer, BertForMaskedLM

pl.seed_everything(42)

@dataclass
class TrainConfig:
    model_name: str = "bert-base-uncased"
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    text_column: str = "text"
    max_length: int = 128
    batch_size: int = 16
    num_workers: int = 2
    lr: float = 3e-5
    weight_decay: float = 0.01
    max_steps: int = 3000
    warmup_steps: int = 100
    val_check_interval: int = 500
    log_every_n_steps: int = 20
    T: int = 8
    mask_token_mode: str = "mask"
    use_time_embed: bool = False
    random_replace_prob: float = 0.0
    sampling_topk: int = 50
    sampling_temperature: float = 1.0
    gradient_clip_val: float = 1.0
    precision: str = "bf16-mixed" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "16-mixed"

cfg = TrainConfig()

Seed set to 42


## 3  Lightning DataModule

In [3]:
class TextDataModule(pl.LightningDataModule):
    def __init__(self, cfg: TrainConfig):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
        if self.tokenizer.mask_token is None:
            self.tokenizer.add_special_tokens({"mask_token": "[MASK]"})
        self.pad_token_id = self.tokenizer.pad_token_id

    def prepare_data(self):
        load_dataset(self.cfg.dataset_name, self.cfg.dataset_config)

    def setup(self, stage=None):
        ds = load_dataset(self.cfg.dataset_name, self.cfg.dataset_config)
        def _preprocess(batch):
            txts = [x for x in batch[self.cfg.text_column] if x and not x.isspace()]
            enc = self.tokenizer(
                txts,
                truncation=True,
                max_length=self.cfg.max_length,
                padding='max_length',
                return_attention_mask=True
            )
            return enc
        self.train_ds = ds['train'].map(_preprocess, batched=True,
                                        remove_columns=ds['train'].column_names)
        self.val_ds = ds['validation'].map(_preprocess, batched=True,
                                           remove_columns=ds['validation'].column_names)

    def collate(self, batch):
        input_ids = torch.tensor([x['input_ids'] for x in batch], dtype=torch.long)
        attention_mask = torch.tensor([x['attention_mask'] for x in batch], dtype=torch.long)
        return {'input_ids': input_ids, 'attention_mask': attention_mask}

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.cfg.batch_size,
                                           shuffle=True, num_workers=self.cfg.num_workers,
                                           collate_fn=self.collate, pin_memory=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, batch_size=self.cfg.batch_size,
                                           shuffle=False, num_workers=self.cfg.num_workers,
                                           collate_fn=self.collate, pin_memory=True)

## 4  Discrete mask scheduler

In [4]:
class DiscreteMaskScheduler(torch.nn.Module):
    def __init__(self, tokenizer, T: int, max_length: int,
                 random_replace_prob: float = 0.0, schedule: str = "cosine"):
        super().__init__()
        self.T = T
        self.mask_id = tokenizer.mask_token_id
        self.pad_id  = tokenizer.pad_token_id
        self.vocab_size = tokenizer.vocab_size
        self.random_replace_prob = random_replace_prob

        ts = torch.arange(1, T + 1, dtype=torch.float)
        if schedule == "linear":
            m = ts / T
        else:
            m = torch.sin((ts / T) * math.pi / 2.0)

        # **把 m_table 注册成 buffer，随模型一起搬到对应 device**
        self.register_buffer("m_table", torch.clamp(m, 1e-4, 0.9999))

    @torch.no_grad()
    def q_sample(self, x0_ids: torch.LongTensor, t: torch.LongTensor):
        B, L = x0_ids.shape
        # m_t = self.m_table[t-1].view(B, 1).to(x0_ids.device)
        m_t = self.m_table.to(x0_ids.device)[t.cpu() - 1].view(B, 1)
        is_pad = x0_ids.eq(self.pad_id)
        mask_draw = torch.rand(B, L, device=x0_ids.device)
        to_mask = (mask_draw < m_t) & (~is_pad)
        x_t = x0_ids.clone()
        x_t[to_mask] = self.mask_id

        if self.random_replace_prob > 0:
            rnd_draw = torch.rand(B, L, device=x0_ids.device)
            do_replace = (rnd_draw < self.random_replace_prob) & (~to_mask) & (~is_pad)
            rand_ids = torch.randint(0, self.vocab_size, (B, L), device=x0_ids.device)
            rand_ids = torch.where(rand_ids.eq(self.pad_id)|rand_ids.eq(self.mask_id),
                                   (rand_ids+1) % self.vocab_size, rand_ids)
            x_t[do_replace] = rand_ids[do_replace]
        return x_t, to_mask

## 5  LightningModule (BERT denoiser)

In [5]:
class BertDiscreteDiffusion(pl.LightningModule):
    def __init__(self, cfg: TrainConfig, tokenizer):
        super().__init__()
        self.save_hyperparameters()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.bert = BertForMaskedLM.from_pretrained(cfg.model_name)
        self.bert.resize_token_embeddings(len(tokenizer))
        self.scheduler = DiscreteMaskScheduler(tokenizer, cfg.T, cfg.max_length,
                                               cfg.random_replace_prob)
        if cfg.use_time_embed:
            self.time_embed = torch.nn.Embedding(cfg.T+1, self.bert.config.hidden_size)
        else:
            self.time_embed = None

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        pg = [
            {'params': [p for n,p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.cfg.weight_decay},
            {'params': [p for n,p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0},
        ]
        opt = torch.optim.AdamW(pg, lr=self.cfg.lr)
        sched = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=self.cfg.warmup_steps)
        return {'optimizer': opt, 'lr_scheduler': {'scheduler': sched, 'interval':'step'}}

    def _add_time(self, emb, t):
        if self.time_embed is None:
            return emb
        return emb + self.time_embed(t.clamp(0,self.cfg.T)).unsqueeze(1)

    def training_step(self, batch, _):
        x0, attn = batch['input_ids'], batch['attention_mask']
        t = torch.randint(1, self.cfg.T+1, (x0.size(0),), device=self.device)
        xt, to_mask = self.scheduler.q_sample(x0, t)
        if self.time_embed is not None:
            emb = self.bert.get_input_embeddings()(xt)
            logits = self.bert(inputs_embeds=self._add_time(emb,t),
                               attention_mask=attn,return_dict=True).logits
        else:
            logits = self.bert(input_ids=xt, attention_mask=attn, return_dict=True).logits
        loss = F.cross_entropy(logits[to_mask], x0[to_mask]) if to_mask.any() else logits.new_zeros(())
        self.log('train/loss', loss, prog_bar=True, on_step=True)
        return loss

    def validation_step(self, batch, _):
        x0, attn = batch['input_ids'], batch['attention_mask']
        t = torch.full((x0.size(0),), math.ceil(self.cfg.T/2), device=self.device)
        xt, to_mask = self.scheduler.q_sample(x0, t)
        if self.time_embed is not None:
            emb = self.bert.get_input_embeddings()(xt)
            logits = self.bert(inputs_embeds=self._add_time(emb,t),
                               attention_mask=attn,return_dict=True).logits
        else:
            logits = self.bert(input_ids=xt, attention_mask=attn, return_dict=True).logits
        loss = F.cross_entropy(logits[to_mask], x0[to_mask]) if to_mask.any() else logits.new_zeros(())
        self.log('val/loss', loss, prog_bar=True, on_epoch=True)

    @torch.no_grad()
    def sample(self, prompts, num_steps=None, start_mask_ratio=1.0):
        self.eval()
        num_steps = num_steps or self.cfg.T
        enc = self.tokenizer(prompts, return_tensors='pt', padding='max_length',
                             truncation=True, max_length=self.cfg.max_length).to(self.device)
        x = enc.input_ids.clone()
        attn = enc.attention_mask
        mask_id = self.tokenizer.mask_token_id
        # initial masking
        rnd = torch.rand_like(x, dtype=torch.float)
        mask_init = (rnd < start_mask_ratio) & attn.bool()
        x[mask_init] = mask_id
        B = x.size(0)
        for step in range(num_steps,0,-1):
            t = torch.full((B,), step, device=self.device, dtype=torch.long)
            if self.time_embed is not None:
                emb = self.bert.get_input_embeddings()(x)
                logits = self.bert(inputs_embeds=self._add_time(emb,t),
                                   attention_mask=attn,return_dict=True).logits
            else:
                logits = self.bert(input_ids=x, attention_mask=attn, return_dict=True).logits
            probs = torch.softmax(logits/self.cfg.sampling_temperature, dim=-1)
            topk = min(self.cfg.sampling_topk, probs.size(-1))
            if topk < probs.size(-1):
                topk_probs, topk_ids = torch.topk(probs, k=topk, dim=-1)
                topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
                idx = torch.distributions.Categorical(topk_probs).sample()
                sampled = topk_ids.gather(-1, idx.unsqueeze(-1)).squeeze(-1)
            else:
                sampled = torch.distributions.Categorical(probs).sample()
            fill = x.eq(mask_id) & attn.bool()
            x[fill] = sampled[fill]
        return self.tokenizer.batch_decode(x, skip_special_tokens=True)

## 6  Train

In [6]:
dm = TextDataModule(cfg)
dm.prepare_data()
dm.setup()
model = BertDiscreteDiffusion(cfg, dm.tokenizer)
trainer = pl.Trainer(max_steps=cfg.max_steps,
                     val_check_interval=cfg.val_check_interval,
                     gradient_clip_val=cfg.gradient_clip_val,
                     precision=cfg.precision,
                     accelerator='auto', devices='auto',
                     log_every_n_steps=cfg.log_every_n_steps,
                     callbacks=[pl.callbacks.ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=1),
                                pl.callbacks.LearningRateMonitor(logging_interval='step')])
trainer.fit(model, dm)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Preci

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type                  | Params
----------------------------------------------------
0 | bert      | BertForMaskedLM       | 109 M 
1 | scheduler | DiscreteMaskScheduler | 0     
----------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
438.057   Total estimated model params size (MB)


Sanity Checking: |                                                                                  | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Training: |                                                                                         | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Validation: |                                                                                       | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Validation: |                                                                                       | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Validation: |                                                                                       | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Validation: |                                                                                       | 0/? [00:…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
`Trainer.fit` stopped: `max_steps=3000` reached.


## 7  Sampling demo

In [8]:
prompts = [
    "In recent years, diffusion models for language have",
    "The quick brown fox"
]
samples = model.sample(
    prompts,
    num_steps=cfg.T,
    start_mask_ratio=0.5,   # 更少初始掩码
    )                       # 在模块里把 topk=10, temperature=0.7

for p, s in zip(prompts, samples):
    print('\nPrompt:', p)
    print('Output:', s)


Prompt: In recent years, diffusion models for language have
Output: in recent years, diffusion models in and have

Prompt: The quick brown fox
Output: 3 of & fox
