In [6]:
# 1) Hyperparameters
N_STEPS = 10
NUM_EPOCHS = 80
BATCH_SIZE = 16
MAX_LEN = 256
PREFIX_LEN = 4
MODEL_DIR = "weights/roberta-diffusion-single-with-prefix"
SAVE_DIR = "weights/roberta"

# linearly spaced mask probabilities from 1/N_STEPS → 1.0
mask_probs = [(i + 1) / N_STEPS for i in range(N_STEPS - 1, -1, -1)]
mask_probs

[1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

In [None]:
# 0) Load dataset
from datasets import load_dataset
dataset = load_dataset("larryvrh/Chinese-Poems")


In [None]:
# filter to only 唐代 and 宋代 poems
dataset['train'] = dataset['train'].filter(
    lambda ex: ex["dynasty"] in ["唐代", "宋代"])  

In [24]:
dataset['train'].column_names

['dynasty', 'author', 'title', 'content']

In [None]:
from transformers import (
    RobertaTokenizerFast,
    RobertaForMaskedLM,
    Trainer,
    TrainingArguments,
)
# 2) Load model and tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_DIR, max_len=MAX_LEN)
tokenizer.model_max_length = MAX_LEN
model = RobertaForMaskedLM.from_pretrained(MODEL_DIR)

In [25]:
def tokenize_function(examples):
    return tokenizer(
        examples["content"],
        padding=False,
        truncation=True,
        max_length=MAX_LEN,
    )
tokenized = dataset.map(tokenize_function, batched=True, remove_columns=dataset['train'].column_names)

Map: 100%|██████████| 138771/138771 [00:16<00:00, 8597.46 examples/s]


In [None]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= MAX_LEN:
        total_length = (total_length // MAX_LEN) * MAX_LEN
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + MAX_LEN] for i in range(0, total_length, MAX_LEN)]
        for k, t in concatenated_examples.items()
    }
    return result

Sample 0: [0, 41907, 46, 10659, 41907, 17772, 13859, 36484, 18537, 9357, 47842, 10809, 49075, 3602, 41907, 13859, 15375, 42393, 15389, 15389, 43251, 4394, 14285, 42393, 7471, 4333, 37127, 17772, 15722, 42393, 862, 6248, 47876, 9470, 36484, 16948, 10809, 36714, 5782, 6248, 49117, 20024, 45682, 50118, 41907, 48894, 42393, 15389, 13859, 48635, 12410, 36484, 27, 15113, 47504, 11582, 48820, 23171, 48128, 12410, 43251, 4394, 14285, 42393, 14292, 12410, 46499, 3602, 48186, 10674, 48341, 4394, 47876, 3602, 36484, 3070, 4958, 49117, 14292, 45682, 2]
Length: 81
---
Sample 1: [0, 49874, 4333, 36484, 19002, 3070, 47856, 27819, 36484, 6248, 10172, 36714, 7258, 3602, 36484, 2840, 9253, 36714, 4394, 3070, 43251, 4394, 14285, 41907, 15375, 11582, 41907, 11582, 2469, 47089, 23171, 36714, 4333, 8210, 36484, 3602, 11582, 47954, 6248, 47240, 7487, 45682, 50118, 48617, 15264, 36484, 6248, 10172, 36484, 13859, 8210, 48607, 25448, 37127, 15389, 9264, 49212, 4333, 41907, 49794, 43251, 4394, 14285, 46499, 1816