In [1]:
# 1) Hyperparameters
N_STEPS = 10
NUM_EPOCHS = 80
BATCH_SIZE = 16
MAX_LEN = 256
PREFIX_LEN = 4
MODEL_DIR = "weights/roberta-diffusion-single-with-prefix"
SAVE_DIR = "weights/roberta"

# linearly spaced mask probabilities from 1/N_STEPS → 1.0
mask_probs = [(i + 1) / N_STEPS for i in range(N_STEPS - 1, -1, -1)]
mask_probs

[1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

In [2]:
# 0) Load dataset
from datasets import load_dataset
dataset = load_dataset("larryvrh/Chinese-Poems")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# filter to only 唐代 and 宋代 poems
dataset['train'] = dataset['train'].filter(
    lambda ex: ex["dynasty"] in ["唐代", "宋代"])  

In [4]:
dataset['train'].column_names

['dynasty', 'author', 'title', 'content']

In [5]:
from transformers import (
    RobertaTokenizerFast,
    RobertaForMaskedLM,
    Trainer,
    TrainingArguments,
)
# 2) Load model and tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_DIR, max_len=MAX_LEN)
tokenizer.model_max_length = MAX_LEN
model = RobertaForMaskedLM.from_pretrained(MODEL_DIR)

In [6]:
def tokenize_function(examples):
    return tokenizer(
        examples["content"],
        padding=False,
        truncation=True,
        max_length=MAX_LEN,
    )
tokenized = dataset.map(tokenize_function, batched=True, remove_columns=dataset['train'].column_names)

In [7]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= MAX_LEN:
        total_length = (total_length // MAX_LEN) * MAX_LEN
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + MAX_LEN] for i in range(0, total_length, MAX_LEN)]
        for k, t in concatenated_examples.items()
    }
    return result

In [8]:
tokenized = tokenized.map(group_texts, batched=True,remove_columns=tokenized['train'].column_names)

Map: 100%|██████████| 138771/138771 [01:55<00:00, 1201.67 examples/s]


In [9]:
import torch
def diffusion_collator(features):
    """features: list of dicts with 'input_ids' and 'attention_mask'.

    Returns a batch dict:
      - input_ids: (B, MAX_LEN) with some tokens replaced by <mask>
      - attention_mask: (B, MAX_LEN) unchanged
      - labels: (B, MAX_LEN) where unmasked = -100, masked = original token IDs
    """
    # Stack into CPU tensors
    batch_input_ids = torch.tensor([f['input_ids'] for f in features], dtype=torch.long) # (B, MAX_LEN)
    batch_attention = torch.tensor([f['attention_mask'] for f in features], dtype=torch.long) # (B, MAX_LEN)
    # Clone to be labels; we'll set unmasked → -100 later
    labels = batch_input_ids.clone()  # (B, MAX_LEN)
    # sample a mask probability for each example in the batch
    p = float(mask_probs[torch.randint(low=0, high=N_STEPS, size=(1,))])
    
    B,L = batch_input_ids.shape
    # 7b) Build a boolean mask “cannot_mask” for every position that must NOT be masked:
    #      - any special token (CLS, SEP, PAD, etc.)
    #      - any position < PREFIX_LEN
    special_ids = set(tokenizer.all_special_ids)
    is_special = torch.zeros_like(batch_input_ids, dtype=torch.bool)
    for sid in special_ids:
        is_special |= (batch_input_ids == sid)
    device = batch_input_ids.device
    pos_idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, L) # (B, L)
    is_prefix = pos_idxs < PREFIX_LEN  # (B, L)
    
    # Combine to get mask_candidates = everything that can be masked
    mask_candidates = (batch_attention == 1) & (~is_special) & (~is_prefix)  # (B, L) boolean
    # Draw random uniform [0,1) for each token
    rand = torch.rand_like(batch_input_ids, dtype=torch.float)  # (B, L) uniform [0,1)
    # Determine which tokens to mask
    to_mask = (rand < p) & mask_candidates  # (B, L) boolean
    
    # Apply the mask to the input IDs
    batch_input_ids[to_mask] = tokenizer.mask_token_id

    # For labels, only compute loss where mask_positions is True:
    labels[~to_mask] = -100  # unmasked positions get -100

    return {
        "input_ids": batch_input_ids,
        "attention_mask": batch_attention,
        "labels": labels,
    }

In [13]:
# train arguments
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    save_strategy="epoch",
    # save_total_limit=12,
    logging_steps=200,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized['train'],
    data_collator=diffusion_collator,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
