# 使用CVAE实现宋词续写

数据集: [chinese-poetry/chinese-poetry: The most comprehensive database of Chinese poetry ](https://github.com/chinese-poetry/chinese-poetry)


> 为什么会有这个项目：  
> 花了很长时间终于彻底理解了ACT的模型部分，正好又看到苏剑林大佬写的VAE文章，讲得非常透彻，让我信心大增，想试着整点活。最近事情也告一段落，于是趁机摸鱼。这个项目花了三个多小时写完，改了几次模型结构但是基本没调参。最后让AI帮忙优化了一下代码（要优雅！！），纯属娱乐，不要太在意效果（

In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

## 1. 数据加载与预处理

In [2]:

# 假设数据格式为 [{"rhythmic": "水调歌头", "lines": ["明月几时有", "把酒问青天", ...]}, ...]
# 读取所有ci.song开头的json文件
data_dir = './datas/宋词'  # 修改为你的数据文件夹路径
all_data = []
for fname in os.listdir(data_dir):
    if fname.startswith('ci.song') and fname.endswith('.json'):
        with open(os.path.join(data_dir, fname), encoding='utf-8') as f:
            all_data.extend(json.load(f))
            print(f"Loaded {len(all_data)} items from {fname}")

# 构建词表
lines = []
rhythmic = set()
for item in all_data:
    rhythmic.add(item['rhythmic'])
    for i in range(len(item['paragraphs']) - 1):
        lines.append((item['rhythmic'], item['paragraphs'][i], item['paragraphs'][i+1]))

# 字符级tokenizer
from collections import Counter
all_text = ''.join([l[1]+l[2] for l in lines])
char_count = Counter(all_text)
chars = ['<PAD>', '<BOS>', '<EOS>', '<UNK>'] + [c for c, _ in char_count.most_common()]
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}

# 词牌名编码
rhythmic2idx = {c: i for i, c in enumerate(sorted(list(rhythmic)))}
idx2rhythmic = {i: c for c, i in rhythmic2idx.items()}

def encode_line(line, max_len):
    ids = [char2idx.get('<BOS>')]
    for c in line:
        ids.append(char2idx.get(c, char2idx['<UNK>']))
    ids.append(char2idx.get('<EOS>'))
    if len(ids) < max_len:
        ids += [char2idx['<PAD>']] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return ids

max_len = max(max(len(l[1]), len(l[2])) for l in lines) + 2  # +2 for BOS/EOS

class SongLineDataset(Dataset):
    def __init__(self, lines):
        self.data = lines

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        rhythmic, prev, next_ = self.data[idx]
        return (
            rhythmic2idx[rhythmic],
            torch.tensor(encode_line(prev, max_len)),
            torch.tensor(encode_line(next_, max_len))
        )

dataset = SongLineDataset(lines)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Loaded 1000 items from ci.song.15000.json
Loaded 2000 items from ci.song.7000.json
Loaded 3000 items from ci.song.3000.json
Loaded 4000 items from ci.song.11000.json
Loaded 4050 items from ci.song.21000.json
Loaded 5050 items from ci.song.2000.json
Loaded 6050 items from ci.song.17000.json
Loaded 7050 items from ci.song.16000.json
Loaded 8050 items from ci.song.18000.json
Loaded 9050 items from ci.song.10000.json
Loaded 10050 items from ci.song.6000.json
Loaded 11050 items from ci.song.19000.json
Loaded 12050 items from ci.song.8000.json
Loaded 13050 items from ci.song.9000.json
Loaded 14050 items from ci.song.4000.json
Loaded 15050 items from ci.song.20000.json
Loaded 16050 items from ci.song.0.json
Loaded 17050 items from ci.song.13000.json
Loaded 18050 items from ci.song.12000.json
Loaded 19050 items from ci.song.5000.json
Loaded 19053 items from ci.song.2019y.json
Loaded 20053 items from ci.song.14000.json
Loaded 21053 items from ci.song.1000.json


## 2. CVAE模型定义


In [3]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim, cond_dim, hidden_dim, latent_dim):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size, emb_dim)
        self.rhythmic_emb = nn.Embedding(rhythmic_size, cond_dim)
        self.rnn = nn.GRU(emb_dim + cond_dim, hidden_dim, batch_first=True)
        self.linear_mu = nn.Linear(hidden_dim, latent_dim)
        self.linear_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, rhythmic):
        x_emb = self.char_emb(x)
        rhythmic_emb = self.rhythmic_emb(rhythmic).unsqueeze(1).expand(-1, x.size(1), -1)
        inp = torch.cat([x_emb, rhythmic_emb], dim=-1)
        _, h = self.rnn(inp)
        h = h.squeeze(0)
        mu = self.linear_mu(h)
        logvar = self.linear_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size, emb_dim)
        self.rhythmic_emb = nn.Embedding(rhythmic_size, cond_dim)
        self.latent2hidden = nn.Linear(latent_dim + cond_dim, hidden_dim)
        self.rnn = nn.GRU(emb_dim + cond_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, z, rhythmic):
        rhythmic_emb = self.rhythmic_emb(rhythmic)
        zc = torch.cat([z, rhythmic_emb], dim=-1)
        h0 = self.latent2hidden(zc).unsqueeze(0)
        x_emb = self.char_emb(x)
        rhythmic_emb_seq = rhythmic_emb.unsqueeze(1).expand(-1, x.size(1), -1)
        inp = torch.cat([x_emb, rhythmic_emb_seq], dim=-1)
        out, _ = self.rnn(inp, h0)
        logits = self.out(out)
        return logits

class CVAE(nn.Module):
    def __init__(self, vocab_size, rhythmic_size, emb_dim=128, cond_dim=32, hidden_dim=256, latent_dim=64):
        super().__init__()
        self.encoder = Encoder(vocab_size, rhythmic_size, emb_dim, cond_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(vocab_size, rhythmic_size, emb_dim, cond_dim, latent_dim, hidden_dim)

    def forward(self, prev_line, next_line, rhythmic):
        mu, logvar = self.encoder(prev_line, rhythmic)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        logits = self.decoder(next_line[:, :-1], z, rhythmic)
        return logits, mu, logvar

    def generate(self, prev_line, rhythmic, max_len=20, device='cpu'):
        self.eval()
        with torch.no_grad():
            mu, logvar = self.encoder(prev_line, rhythmic)
            z = mu  # 使用均值
            input_seq = torch.tensor([[char2idx['<BOS>']]] * prev_line.size(0), device=device)
            outputs = []
            h = None
            for _ in range(max_len):
                logits = self.decoder(input_seq, z, rhythmic)
                next_token = logits[:, -1, :].argmax(-1, keepdim=True)
                outputs.append(next_token)
                input_seq = torch.cat([input_seq, next_token], dim=1)
                if (next_token == char2idx['<EOS>']).all():
                    break
            outputs = torch.cat(outputs, dim=1)
        return outputs

## 3. 模型训练与推理

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CVAE(len(char2idx), len(rhythmic2idx)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=char2idx['<PAD>'])

In [4]:
def decode_line(ids):
    chars_ = []
    for i in ids:
        c = idx2char.get(i, '')
        if c == '<EOS>':
            break
        if c not in ['<PAD>', '<BOS>', '<EOS>']:
            chars_.append(c)
    return ''.join(chars_)

def predict_next_line(prev_line, rhythmic_name):
    model.eval()
    prev_ids = torch.tensor([encode_line(prev_line, max_len)], device=device)
    rhythmic_id = torch.tensor([rhythmic2idx[rhythmic_name]], device=device)
    out_ids = model.generate(prev_ids, rhythmic_id, max_len=max_len, device=device)
    return decode_line(out_ids[0].cpu().numpy())

In [None]:
def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()

epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for rhythmic, prev, next_ in tqdm(dataloader):
        rhythmic = rhythmic.to(device)
        prev = prev.to(device)
        next_ = next_.to(device)
        logits, mu, logvar = model(prev, next_, rhythmic)
        loss_rec = loss_fn(logits.reshape(-1, logits.size(-1)), next_[:,1:].reshape(-1))
        loss_kl = kl_loss(mu, logvar)
        loss = loss_rec + 0.1 * loss_kl
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
    print(f"明月几时有 (水调歌头) → {predict_next_line('明月几时有', '水调歌头')}")  

100%|██████████| 2157/2157 [00:45<00:00, 47.60it/s]


Epoch 1, Loss: 5.0727
明月几时有 (水调歌头) → 一笑一枝，一笑一枝春色，不见不知春。


100%|██████████| 2157/2157 [00:44<00:00, 48.17it/s]


Epoch 2, Loss: 4.3998
明月几时有 (水调歌头) → 一笑一枝头，看取次，一杯酒，一杯酒。


100%|██████████| 2157/2157 [00:44<00:00, 48.05it/s]


Epoch 3, Loss: 4.1749
明月几时有 (水调歌头) → 一笑一生一，一笑付与君。


100%|██████████| 2157/2157 [00:44<00:00, 47.99it/s]


Epoch 4, Loss: 4.0416
明月几时有 (水调歌头) → 一杯酒，一笑一，一杯酒。


100%|██████████| 2157/2157 [00:44<00:00, 47.96it/s]


Epoch 5, Loss: 3.9461
明月几时有 (水调歌头) → 天际水云乡。


100%|██████████| 2157/2157 [00:44<00:00, 47.98it/s]


Epoch 6, Loss: 3.8729
明月几时有 (水调歌头) → 不用登临临水，不用登临赋酒，不用写长安。


100%|██████████| 2157/2157 [00:44<00:00, 48.08it/s]


Epoch 7, Loss: 3.8148
明月几时有 (水调歌头) → 一笑天然无处，不见龙山飞舞，一笑俯天潢。


100%|██████████| 2157/2157 [00:44<00:00, 48.04it/s]


Epoch 8, Loss: 3.7657
明月几时有 (水调歌头) → 不知何处，有人知我亦何忧。


100%|██████████| 2157/2157 [00:44<00:00, 48.09it/s]


Epoch 9, Loss: 3.7241
明月几时有 (水调歌头) → 一笑嫣然一笑，不肯放春风力，不用污人愁。


100%|██████████| 2157/2157 [00:44<00:00, 48.11it/s]

Epoch 10, Loss: 3.6874
明月几时有 (水调歌头) → 一笑平生志，非雾非烟非。





In [6]:
tests = [
    ('明月几时有', '水调歌头'),
    ('了却君王天下事', '破阵子'),
    ('稻花香里说丰年', '西江月'),
    ('众里寻他千百度', '青玉案'),
    ('江晚正愁余', '菩萨蛮'),
    ('我想下班', '水调歌头'),
    ('怎么还有一个小时才下班', '破阵子'),
    ('怎么还有三天才周末', '西江月'),
    ('原来周二才是最难熬的', '青玉案')]


for prev_line, rhythmic_name in tests:
    print(f"{prev_line} ({rhythmic_name}) → {predict_next_line(prev_line, rhythmic_name)}")

明月几时有 (水调歌头) → 一笑平生志，非雾非烟非。
了却君王天下事 (破阵子) → 一曲清歌，一声声断云。
稻花香里说丰年 (西江月) → 不道春工不管。
众里寻他千百度 (青玉案) → 一年一度春风雨。
江晚正愁余 (菩萨蛮) → 一枝春色生春色。
我想下班 (水调歌头) → 一笑平生志，非雾非烟非。
怎么还有一个小时才下班 (破阵子) → 一曲清歌，一声声断云。
怎么还有三天才周末 (西江月) → 不道春工不管。
原来周二才是最难熬的 (青玉案) → 一年一度春风雨。
