# 使用CVAE实现宋词续写

数据集: [chinese-poetry/chinese-poetry: The most comprehensive database of Chinese poetry ](https://github.com/chinese-poetry/chinese-poetry)


> 为什么会有这个项目：  
> 花了很长时间终于彻底理解了ACT的模型部分，正好又看到苏剑林大佬写的VAE文章，讲得非常透彻，让我信心大增，想试着整点活。最近事情也告一段落，于是趁机摸鱼。这个项目花了三个多小时写完，改了几次模型结构但是基本没调参。后续让AI帮忙优化了一下代码，纯属娱乐，不要太在意效果（

2025.7.30 改了采样的方式，现在不会全是“一”开头了

In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

## 1. 数据加载与预处理

In [2]:

# 假设数据格式为 [{"rhythmic": "水调歌头", "lines": ["明月几时有", "把酒问青天", ...]}, ...]
# 读取所有ci.song开头的json文件
data_dir = './datas/宋词'  # 修改为你的数据文件夹路径
all_data = []
for fname in os.listdir(data_dir):
    if fname.startswith('ci.song') and fname.endswith('.json'):
        with open(os.path.join(data_dir, fname), encoding='utf-8') as f:
            all_data.extend(json.load(f))
            print(f"Loaded {len(all_data)} items from {fname}")

# 构建词表
lines = []
rhythmic = set()
for item in all_data:
    rhythmic.add(item['rhythmic'])
    for i in range(len(item['paragraphs']) - 1):
        lines.append((item['rhythmic'], item['paragraphs'][i], item['paragraphs'][i+1]))

# 字符级tokenizer
from collections import Counter
all_text = ''.join([l[1]+l[2] for l in lines])
char_count = Counter(all_text)
chars = ['<PAD>', '<BOS>', '<EOS>', '<UNK>'] + [c for c, _ in char_count.most_common()]
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}

# 词牌名编码
rhythmic2idx = {c: i for i, c in enumerate(sorted(list(rhythmic)))}
idx2rhythmic = {i: c for c, i in rhythmic2idx.items()}

def encode_line(line, max_len):
    ids = [char2idx.get('<BOS>')]
    for c in line:
        ids.append(char2idx.get(c, char2idx['<UNK>']))
    ids.append(char2idx.get('<EOS>'))
    if len(ids) < max_len:
        ids += [char2idx['<PAD>']] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return ids

max_len = max(max(len(l[1]), len(l[2])) for l in lines) + 2  # +2 for BOS/EOS

class SongLineDataset(Dataset):
    def __init__(self, lines):
        self.data = lines

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        rhythmic, prev, next_ = self.data[idx]
        return (
            rhythmic2idx[rhythmic],
            torch.tensor(encode_line(prev, max_len)),
            torch.tensor(encode_line(next_, max_len))
        )

dataset = SongLineDataset(lines)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Loaded 1000 items from ci.song.15000.json
Loaded 2000 items from ci.song.7000.json
Loaded 3000 items from ci.song.3000.json
Loaded 4000 items from ci.song.11000.json
Loaded 4050 items from ci.song.21000.json
Loaded 5050 items from ci.song.2000.json
Loaded 6050 items from ci.song.17000.json
Loaded 7050 items from ci.song.16000.json
Loaded 8050 items from ci.song.18000.json
Loaded 9050 items from ci.song.10000.json
Loaded 10050 items from ci.song.6000.json
Loaded 11050 items from ci.song.19000.json
Loaded 12050 items from ci.song.8000.json
Loaded 13050 items from ci.song.9000.json
Loaded 14050 items from ci.song.4000.json
Loaded 15050 items from ci.song.20000.json
Loaded 16050 items from ci.song.0.json
Loaded 17050 items from ci.song.13000.json
Loaded 18050 items from ci.song.12000.json
Loaded 19050 items from ci.song.5000.json
Loaded 19053 items from ci.song.2019y.json
Loaded 20053 items from ci.song.14000.json
Loaded 21053 items from ci.song.1000.json


## 2. CVAE模型定义


In [3]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim, cond_dim, hidden_dim, latent_dim):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size, emb_dim)
        self.rhythmic_emb = nn.Embedding(rhythmic_size, cond_dim)
        self.rnn = nn.GRU(emb_dim + cond_dim, hidden_dim, batch_first=True)
        self.linear_mu = nn.Linear(hidden_dim, latent_dim)
        self.linear_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, rhythmic):
        x_emb = self.char_emb(x)
        rhythmic_emb = self.rhythmic_emb(rhythmic).unsqueeze(1).expand(-1, x.size(1), -1)
        inp = torch.cat([x_emb, rhythmic_emb], dim=-1)
        _, h = self.rnn(inp)
        h = h.squeeze(0)
        mu = self.linear_mu(h)
        logvar = self.linear_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size, emb_dim)
        self.rhythmic_emb = nn.Embedding(rhythmic_size, cond_dim)
        self.latent2hidden = nn.Linear(latent_dim + cond_dim, hidden_dim)
        self.rnn = nn.GRU(emb_dim + cond_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, z, rhythmic):
        rhythmic_emb = self.rhythmic_emb(rhythmic)
        zc = torch.cat([z, rhythmic_emb], dim=-1)
        h0 = self.latent2hidden(zc).unsqueeze(0)
        x_emb = self.char_emb(x)
        rhythmic_emb_seq = rhythmic_emb.unsqueeze(1).expand(-1, x.size(1), -1)
        inp = torch.cat([x_emb, rhythmic_emb_seq], dim=-1)
        out, _ = self.rnn(inp, h0)
        logits = self.out(out)
        return logits

class CVAE(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim=128, cond_dim=32, hidden_dim=256, latent_dim=64):
        super().__init__()
        self.encoder = Encoder(vocab_size, rhythmic_size, emb_dim, cond_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(vocab_size, rhythmic_size, emb_dim, cond_dim, latent_dim, hidden_dim)

    def forward(self, prev_line, next_line, rhythmic):
        mu, logvar = self.encoder(prev_line, rhythmic)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        logits = self.decoder(next_line[:, :-1], z, rhythmic)
        return logits, mu, logvar

    def generate(self, prev_line, rhythmic, max_len=20, device='cpu'):
        self.eval()
        with torch.no_grad():
            mu, logvar = self.encoder(prev_line, rhythmic)
            z = mu  # 使用均值
            input_seq = torch.tensor([[char2idx['<BOS>']]] * prev_line.size(0), device=device)
            outputs = []
            h = None
            for _ in range(max_len):
                logits = self.decoder(input_seq, z, rhythmic)
                probs = torch.softmax(logits[:, -1, :], dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                outputs.append(next_token)
                input_seq = torch.cat([input_seq, next_token], dim=1)
                if (next_token == char2idx['<EOS>']).all():
                    break
            outputs = torch.cat(outputs, dim=1)
        return outputs

## 3. 模型训练与推理

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CVAE(len(char2idx), len(rhythmic2idx)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=char2idx['<PAD>'])

In [5]:
def decode_line(ids):
    chars_ = []
    for i in ids:
        c = idx2char.get(i, '')
        if c == '<EOS>':
            break
        if c not in ['<PAD>', '<BOS>', '<EOS>']:
            chars_.append(c)
    return ''.join(chars_)

def predict_next_line(prev_line, rhythmic_name):
    model.eval()
    prev_ids = torch.tensor([encode_line(prev_line, max_len)], device=device)
    rhythmic_id = torch.tensor([rhythmic2idx[rhythmic_name]], device=device)
    out_ids = model.generate(prev_ids, rhythmic_id, max_len=max_len, device=device)
    return decode_line(out_ids[0].cpu().numpy())

In [6]:
def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()

epochs = 20
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for rhythmic, prev, next_ in tqdm(dataloader):
        rhythmic = rhythmic.to(device)
        prev = prev.to(device)
        next_ = next_.to(device)
        logits, mu, logvar = model(prev, next_, rhythmic)
        loss_rec = loss_fn(logits.reshape(-1, logits.size(-1)), next_[:,1:].reshape(-1))
        loss_kl = kl_loss(mu, logvar)
        loss = loss_rec + 0.1 * loss_kl
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, loss_rec: {loss_rec.item():.4f}, loss_kl: {loss_kl.item():.4f}")
    print(f"人有悲欢离合 (水调歌头) → {predict_next_line('人有悲欢离合', '水调歌头')}")  

  0%|          | 0/2157 [00:00<?, ?it/s]

100%|██████████| 2157/2157 [00:45<00:00, 47.23it/s]


Epoch 1, Loss: 5.0770, loss_rec: 4.5186, loss_kl: 0.0027
人有悲欢离合 (水调歌头) → 当戴韵，追下发，歌曲上，云岫脍，竟兰妩。


100%|██████████| 2157/2157 [00:45<00:00, 47.86it/s]


Epoch 2, Loss: 4.3971, loss_rec: 4.4270, loss_kl: 0.0006
人有悲欢离合 (水调歌头) → 直到朝兵萍漠，应是沧溟鸥节，昭汊亦康朱。


100%|██████████| 2157/2157 [00:45<00:00, 47.85it/s]


Epoch 3, Loss: 4.1701, loss_rec: 4.2030, loss_kl: 0.0009
人有悲欢离合 (水调歌头) → 自然素被，幕长一榻弄清香。


100%|██████████| 2157/2157 [00:44<00:00, 47.96it/s]


Epoch 4, Loss: 4.0371, loss_rec: 4.2336, loss_kl: 0.0006
人有悲欢离合 (水调歌头) → 玉楼归觐浦，万斛丹青愧，人协俱怀。


100%|██████████| 2157/2157 [00:44<00:00, 48.00it/s]


Epoch 5, Loss: 3.9425, loss_rec: 3.9783, loss_kl: 0.0008
人有悲欢离合 (水调歌头) → 门外静中夜子，料想云山如故，岁作夜深浮。


100%|██████████| 2157/2157 [00:44<00:00, 47.98it/s]


Epoch 6, Loss: 3.8702, loss_rec: 3.8406, loss_kl: 0.0007
人有悲欢离合 (水调歌头) → 寿比新年继，晋宋与樽颜。


100%|██████████| 2157/2157 [00:44<00:00, 48.01it/s]


Epoch 7, Loss: 3.8120, loss_rec: 4.0450, loss_kl: 0.0011
人有悲欢离合 (水调歌头) → 南舟寄远，奈命碧江濯发，才一步酸鸣。


100%|██████████| 2157/2157 [00:44<00:00, 48.07it/s]


Epoch 8, Loss: 3.7630, loss_rec: 4.0155, loss_kl: 0.0005
人有悲欢离合 (水调歌头) → 唯有兰堂旧侣，临合梦中偏。


100%|██████████| 2157/2157 [00:44<00:00, 48.10it/s]


Epoch 9, Loss: 3.7215, loss_rec: 3.8232, loss_kl: 0.0009
人有悲欢离合 (水调歌头) → 昼生虚席，知华独领即非贤。


100%|██████████| 2157/2157 [00:44<00:00, 48.12it/s]


Epoch 10, Loss: 3.6852, loss_rec: 3.7375, loss_kl: 0.0008
人有悲欢离合 (水调歌头) → 堪笑蝉莼衮冕，堪叹忧谟战在，景致总关弥。


100%|██████████| 2157/2157 [00:44<00:00, 48.10it/s]


Epoch 11, Loss: 3.6531, loss_rec: 3.6829, loss_kl: 0.0006
人有悲欢离合 (水调歌头) → 玉树教乌鹤，风露凝情溢。


100%|██████████| 2157/2157 [00:44<00:00, 48.14it/s]


Epoch 12, Loss: 3.6249, loss_rec: 3.6457, loss_kl: 0.0010
人有悲欢离合 (水调歌头) → 淡月黄昏千仞，皎薄初乾三杰，欲向九霄丹。


100%|██████████| 2157/2157 [00:44<00:00, 48.06it/s]


Epoch 13, Loss: 3.5990, loss_rec: 3.5745, loss_kl: 0.0007
人有悲欢离合 (水调歌头) → 银蟾净映立粲，翠冷断魂流。


100%|██████████| 2157/2157 [00:44<00:00, 48.06it/s]


Epoch 14, Loss: 3.5763, loss_rec: 3.9578, loss_kl: 0.0032
人有悲欢离合 (水调歌头) → 梅冠紫芝田，气倒晚寒生。


100%|██████████| 2157/2157 [00:44<00:00, 48.07it/s]


Epoch 15, Loss: 3.5579, loss_rec: 3.9676, loss_kl: 0.0177
人有悲欢离合 (水调歌头) → 滚滚秦淮月白，灯影漾流光。


100%|██████████| 2157/2157 [00:44<00:00, 48.06it/s]


Epoch 16, Loss: 3.5368, loss_rec: 3.5377, loss_kl: 0.0151
人有悲欢离合 (水调歌头) → 似铁滩皮石，阴影几曾哀。


100%|██████████| 2157/2157 [00:44<00:00, 48.01it/s]


Epoch 17, Loss: 3.5193, loss_rec: 3.5866, loss_kl: 0.0176
人有悲欢离合 (水调歌头) → 重游得丧，举手膝下应相贤。


100%|██████████| 2157/2157 [00:44<00:00, 48.12it/s]


Epoch 18, Loss: 3.5033, loss_rec: 3.5097, loss_kl: 0.0126
人有悲欢离合 (水调歌头) → 谁经重拜，迂岁朝觐照崔嵬。


100%|██████████| 2157/2157 [00:44<00:00, 48.02it/s]


Epoch 19, Loss: 3.4876, loss_rec: 3.7436, loss_kl: 0.0115
人有悲欢离合 (水调歌头) → 不知天不尘，清夜永，庭青渚。


100%|██████████| 2157/2157 [00:44<00:00, 48.03it/s]

Epoch 20, Loss: 3.4744, loss_rec: 3.4385, loss_kl: 0.0167
人有悲欢离合 (水调歌头) → 醉眠狂客醒，满酌佩红扬。





In [7]:
tests = [
    ('人有悲欢离合', '水调歌头'),
    ('了却君王天下事', '破阵子'),
    ('稻花香里说丰年', '西江月'),
    ('众里寻他千百度', '青玉案'),
    ('江晚正愁余', '菩萨蛮'),
    ('争渡争渡', '如梦令'),
    ('我想下班', '水调歌头'),
    ('怎么还有一个小时才下班', '破阵子'),
    ('怎么还有三天才周末', '西江月'),
    ('原来周二才是最难熬的', '青玉案')]


for prev_line, rhythmic_name in tests:
    print(f"{prev_line} ({rhythmic_name}) → {predict_next_line(prev_line, rhythmic_name)}")

人有悲欢离合 (水调歌头) → 愁恐醒时知减，成后扫春休。
了却君王天下事 (破阵子) → 不似鱼金累累累。
稻花香里说丰年 (西江月) → 高致宜怀旧处，得烦儿孙戏剧。
众里寻他千百度 (青玉案) → 冠儿茉莉，芙蓉戏彩，唾香沾雨。
江晚正愁余 (菩萨蛮) → 月寒多少情。
争渡争渡 (如梦令) → 休恼。
我想下班 (水调歌头) → 箫鼓地横水，佳处甚云踪。
怎么还有一个小时才下班 (破阵子) → 亲将前后金銮句，应寄南窗写一鞭。
怎么还有三天才周末 (西江月) → 我亦何须语后，何妨到此清闲。
原来周二才是最难熬的 (青玉案) → 夜来何事，早翻马、先成偶。
