# BERT + Discrete (Masked) Diffusion
A minimal training pipeline (LLaDA / SMDM‑style) implemented with **PyTorch Lightning**.

*Generated automatically on 2025-08-07 11:26:09.*

## 1  Environment & installs
Run the following cell **once** (e.g. on Colab) to install required libraries.

In [1]:
# !pip install -U torch pytorch-lightning transformers datasets accelerate sentencepiece

In [2]:
# 顶格放这段，后面再 import transformers
import os
import warnings
import logging
import sys
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"   # HF 4.42+ 支持
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message="The current process just got forked")  # 可选
warnings.filterwarnings("ignore", message="A parameter name that contains")
logging.getLogger("transformers").setLevel(logging.ERROR)

home_path = os.getcwd().split("DialFill-DM")[0]
target_path = home_path + "DialFill-DM"
sys.path.append(target_path)

## 2  Imports / basic config

In [3]:
import math, random, torch, torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
import pytorch_lightning as pl
from datasets import load_dataset
from transformers import AutoTokenizer, BertForMaskedLM

pl.seed_everything(42)

@dataclass
class TrainConfig:
    model_name: str = "bert-base-uncased"
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    text_column: str = "text"
    max_length: int = 128
    batch_size: int = 16
    num_workers: int = 2
    lr: float = 3e-5
    weight_decay: float = 0.01
    max_steps: int = 3000
    warmup_steps: int = 100
    val_check_interval: int = 500
    log_every_n_steps: int = 20
    T: int = 8
    mask_token_mode: str = "mask"
    use_time_embed: bool = False
    random_replace_prob: float = 0.0
    sampling_topk: int = 50
    sampling_temperature: float = 1.0
    gradient_clip_val: float = 1.0
    devices = 'cuda:1'
    precision: str = "bf16-mixed" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "16-mixed"

cfg = TrainConfig()

Seed set to 42


## 3  Lightning DataModule

In [4]:
from dataset_process.base_conv import BaseConvDataset, Collator
from transformers import AutoTokenizer

class ConvDataModule(pl.LightningDataModule):
    """
    DataModule that wraps your custom conversational dataset defined in base_conv.py.
    Adjust `dataset` to point at the JSONL (or other) file
    containing your dialogues + knowledge as expected by `BaseConvDataset`.
    """
    def __init__(self, cfg: TrainConfig, dataset: str):
        super().__init__()
        self.cfg = cfg
        self.dataset = dataset
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.pad_token_id = self.tokenizer.pad_token_id

    def setup(self, stage=None):
        self.train_ds = BaseConvDataset(
            data=self.dataset,
            tokenizer=self.tokenizer,
            max_history=3,
            max_seq_length=self.cfg.max_length,
            is_generation=False,
            include_triples=True,
        )
        self.val_ds = self.train_ds  # ← replace if you have a separate validation file

    def collate(self, batch):
        collator = Collator(
            pad=self.pad_token_id,
            padding_side='right',
        )
        return collator(batch)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_ds,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            collate_fn=self.collate,
            pin_memory=True,
        )

    def val_dataloader(self):
            return torch.utils.data.DataLoader(
                self.val_ds,
                batch_size=self.cfg.batch_size,
                shuffle=False,
                num_workers=self.cfg.num_workers,
                collate_fn=self.collate,
                pin_memory=True,
            )


## 4  Discrete mask scheduler

In [5]:
class DiscreteMaskScheduler(torch.nn.Module):
    def __init__(self, tokenizer, T: int, max_length: int,
                 random_replace_prob: float = 0.0, schedule: str = "cosine"):
        super().__init__()
        self.T = T
        self.mask_id = tokenizer.mask_token_id
        self.pad_id  = tokenizer.pad_token_id
        self.vocab_size = tokenizer.vocab_size
        self.random_replace_prob = random_replace_prob

        ts = torch.arange(1, T + 1, dtype=torch.float)
        if schedule == "linear":
            m = ts / T
        else:
            m = torch.sin((ts / T) * math.pi / 2.0)

        # **把 m_table 注册成 buffer，随模型一起搬到对应 device**
        self.register_buffer("m_table", torch.clamp(m, 1e-4, 0.9999))

    @torch.no_grad()
    def q_sample(self, x0_ids: torch.LongTensor, t: torch.LongTensor):
        B, L = x0_ids.shape
        # m_t = self.m_table[t-1].view(B, 1).to(x0_ids.device)
        m_t = self.m_table.to(x0_ids.device)[t.cpu() - 1].view(B, 1)
        is_pad = x0_ids.eq(self.pad_id)
        mask_draw = torch.rand(B, L, device=x0_ids.device)
        to_mask = (mask_draw < m_t) & (~is_pad)
        x_t = x0_ids.clone()
        x_t[to_mask] = self.mask_id

        if self.random_replace_prob > 0:
            rnd_draw = torch.rand(B, L, device=x0_ids.device)
            do_replace = (rnd_draw < self.random_replace_prob) & (~to_mask) & (~is_pad)
            rand_ids = torch.randint(0, self.vocab_size, (B, L), device=x0_ids.device)
            rand_ids = torch.where(rand_ids.eq(self.pad_id)|rand_ids.eq(self.mask_id),
                                   (rand_ids+1) % self.vocab_size, rand_ids)
            x_t[do_replace] = rand_ids[do_replace]
        return x_t, to_mask

## 5  LightningModule (BERT denoiser)

In [6]:
class DiffusionBert(pl.LightningModule):
    def __init__(self, cfg: TrainConfig, tokenizer):
        super().__init__()
        self.save_hyperparameters()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.bert = BertForMaskedLM.from_pretrained(cfg.model_name)
        self.bert.resize_token_embeddings(len(tokenizer))
        self.scheduler = DiscreteMaskScheduler(tokenizer, cfg.T, cfg.max_length,
                                               cfg.random_replace_prob)
        if cfg.use_time_embed:
            self.time_embed = torch.nn.Embedding(cfg.T+1, self.bert.config.hidden_size)
        else:
            self.time_embed = None

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        pg = [
            {'params': [p for n,p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.cfg.weight_decay},
            {'params': [p for n,p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0},
        ]
        opt = torch.optim.AdamW(pg, lr=self.cfg.lr)
        sched = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=self.cfg.warmup_steps)
        return {'optimizer': opt, 'lr_scheduler': {'scheduler': sched, 'interval':'step'}}

    def _add_time(self, emb, t):
        if self.time_embed is None:
            return emb
        return emb + self.time_embed(t.clamp(0,self.cfg.T)).unsqueeze(1)

    def training_step(self, batch, _):
        x0, attn = batch['input_ids'], batch['attention_mask']
        t = torch.randint(1, self.cfg.T+1, (x0.size(0),), device=self.device)
        xt, to_mask = self.scheduler.q_sample(x0, t)
        if self.time_embed is not None:
            emb = self.bert.get_input_embeddings()(xt)
            logits = self.bert(inputs_embeds=self._add_time(emb,t),
                               attention_mask=attn,return_dict=True).logits
        else:
            logits = self.bert(input_ids=xt, attention_mask=attn, return_dict=True).logits
        loss = F.cross_entropy(logits[to_mask], x0[to_mask]) if to_mask.any() else logits.new_zeros(())
        self.log('train/loss', loss, prog_bar=True, on_step=True)
        return loss

    def validation_step(self, batch, _):
        x0, attn = batch['input_ids'], batch['attention_mask']
        t = torch.full((x0.size(0),), math.ceil(self.cfg.T/2), device=self.device)
        xt, to_mask = self.scheduler.q_sample(x0, t)
        if self.time_embed is not None:
            emb = self.bert.get_input_embeddings()(xt)
            logits = self.bert(inputs_embeds=self._add_time(emb,t),
                               attention_mask=attn,return_dict=True).logits
        else:
            logits = self.bert(input_ids=xt, attention_mask=attn, return_dict=True).logits
        loss = F.cross_entropy(logits[to_mask], x0[to_mask]) if to_mask.any() else logits.new_zeros(())
        self.log('val/loss', loss, prog_bar=True, on_epoch=True)

    @torch.no_grad()
    def sample(self, prompts, num_steps=None, start_mask_ratio=1.0):
        self.eval()
        num_steps = num_steps or self.cfg.T
        enc = self.tokenizer(prompts, return_tensors='pt', padding='max_length',
                             truncation=True, max_length=self.cfg.max_length).to(self.device)
        x = enc.input_ids.clone()
        attn = enc.attention_mask
        mask_id = self.tokenizer.mask_token_id
        # initial masking
        rnd = torch.rand_like(x, dtype=torch.float)
        mask_init = (rnd < start_mask_ratio) & attn.bool()
        x[mask_init] = mask_id
        B = x.size(0)
        for step in range(num_steps,0,-1):
            t = torch.full((B,), step, device=self.device, dtype=torch.long)
            if self.time_embed is not None:
                emb = self.bert.get_input_embeddings()(x)
                logits = self.bert(inputs_embeds=self._add_time(emb,t),
                                   attention_mask=attn,return_dict=True).logits
            else:
                logits = self.bert(input_ids=x, attention_mask=attn, return_dict=True).logits
            probs = torch.softmax(logits/self.cfg.sampling_temperature, dim=-1)
            topk = min(self.cfg.sampling_topk, probs.size(-1))
            if topk < probs.size(-1):
                topk_probs, topk_ids = torch.topk(probs, k=topk, dim=-1)
                topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
                idx = torch.distributions.Categorical(topk_probs).sample()
                sampled = topk_ids.gather(-1, idx.unsqueeze(-1)).squeeze(-1)
            else:
                sampled = torch.distributions.Categorical(probs).sample()
            fill = x.eq(mask_id) & attn.bool()
            x[fill] = sampled[fill]
        return self.tokenizer.batch_decode(x, skip_special_tokens=True)

## 6  Train

In [7]:
from huggingface_hub import login
from datasets import load_dataset
login(token="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")
# --- Use your dataset ------------------------------------------------------------
dataset = load_dataset("asnower/opendialkg")
dm = ConvDataModule(cfg, dataset['train'])

Using the latest cached version of the dataset since asnower/opendialkg couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/xueqiang/.cache/huggingface/datasets/asnower___opendialkg/default/0.0.0/309a8fc4ba0488c87e42e5a5ef2648b6c3144f0f (last modified on Thu Aug  7 21:22:25 2025).


In [8]:
# --- model & trainer ----------------------------------------------------------------
model = DiffusionBert(cfg, dm.tokenizer)  # Assuming DiffusionBert class defined earlier
trainer = pl.Trainer(
    max_steps=cfg.max_steps,
    gradient_clip_val=cfg.gradient_clip_val,
    devices='auto',
    accelerator='auto',
    precision=cfg.precision,
    log_every_n_steps=cfg.log_every_n_steps,
    val_check_interval=cfg.val_check_interval,
)

trainer.fit(model, dm)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Preci

Sanity Checking: |                                                                                  | 0/? [00:…

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/xueqiang/anaconda3/envs/DialFill_DM/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/xueqiang/anaconda3/envs/DialFill_DM/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/xueqiang/anaconda3/envs/DialFill_DM/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/xueqiang/DialFill-DM/dataset_process/base_conv.py", line 671, in __getitem__
    item_dict = process_method(
  File "/home/xueqiang/DialFill-DM/dataset_process/base_conv.py", line 711, in _build_from_segments
    text_input = self.tokenizer.apply_chat_template(
  File "/home/xueqiang/anaconda3/envs/DialFill_DM/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 1786, in apply_chat_template
    chat_template = self.get_chat_template(chat_template, tools)
  File "/home/xueqiang/anaconda3/envs/DialFill_DM/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 2025, in get_chat_template
    raise ValueError(
ValueError: Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating


## 7  Sampling demo

In [None]:
prompts = [
    "In recent years, diffusion models for language have",
    "The quick brown fox"
]
samples = model.sample(
    prompts,
    num_steps=cfg.T,
    start_mask_ratio=0.5,   # 更少初始掩码
    )                       # 在模块里把 topk=10, temperature=0.7

for p, s in zip(prompts, samples):
    print('\nPrompt:', p)
    print('Output:', s)