# 诗经

In [3]:
text = """
诗经·国风·周南·关雎
关关雎鸠，在河之洲。窈窕淑女，君子好逑。
参差荇菜，左右流之。窈窕淑女，寤寐求之。
求之不得，寤寐思服。悠哉悠哉，辗转反侧。
参差荇菜，左右采之。窈窕淑女，琴瑟友之。
参差荇菜，左右芼之。窈窕淑女，钟鼓乐之。

诗经·国风·周南·葛覃
葛之覃兮，施于中谷，维叶萋萋。黄鸟于飞，集于灌木，其鸣喈喈。
葛之覃兮，施于中谷，维叶莫莫。是刈是濩，为绖为粥，效禄无疆。
葛之覃兮，施于中谷，维叶萋萋。黄鸟于飞，集于灌木，其鸣喈喈。

诗经·国风·周南·卷耳
采采卷耳，不盈顷筐。嗟我怀人，置彼周行。
陟彼崔嵬，我马虺虺。我姑酌彼金罍，维以不永怀。
陟彼高冈，我马玄黄。我姑酌彼兕觥，维以不永伤。
陟彼砠矣，我马瘏矣。我仆痡矣，云何吁矣。

诗经·国风·周南·樛木
南有樛木，葛藟累之。乐只君子，福履绥之。
南有樛木，葛藟荒之。乐只君子，福履将之。
南有樛木，葛藟萦之。乐只君子，福履成之。

诗经·国风·周南·螽斯
螽斯羽，诜诜兮。宜尔子孙，振振兮。
螽斯羽，薨薨兮。宜尔子孙，绳绳兮。
螽斯羽，揖揖兮。宜尔子孙，蛰蛰兮。
"""

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers, dropout=0.2):
        super(CharRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        # 添加dropout来防止过拟合
        self.rnn = nn.LSTM(input_size, hidden_size, n_layers, 
                          batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, hidden):
        output, hidden = self.rnn(x, hidden)
        output = self.dropout(output)
        output = self.fc(output)
        return output, hidden
    
    def init_hidden(self, batch_size):
        # LSTM需要两个隐藏状态
        return (torch.zeros(self.n_layers, batch_size, self.hidden_size),
                torch.zeros(self.n_layers, batch_size, self.hidden_size))

def process_data(text):
    # 去除空白字符
    text = ''.join(text.split())
    chars = sorted(list(set(text)))
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}
    return chars, char_to_idx, idx_to_char

def create_sequences(text, char_to_idx, seq_length):
    text = ''.join(text.split())
    x = []
    y = []
    
    for i in range(0, len(text) - seq_length, 3):  # 步长为3，增加数据多样性
        sequence = text[i:i+seq_length]
        target = text[i+1:i+seq_length+1]
        
        if len(sequence) == seq_length and len(target) == seq_length:
            x.append([char_to_idx[char] for char in sequence])
            y.append([char_to_idx[char] for char in target])
    
    x = torch.tensor(x)
    y = torch.tensor(y)
    return x, y

def train_model(model, data, targets, criterion, optimizer, vocab_size, batch_size=64):
    model.train()
    total_loss = 0
    
    n_batches = len(data) // batch_size
    if n_batches == 0:
        n_batches = 1
        batch_size = len(data)
    
    for i in range(n_batches):
        start = i * batch_size
        end = min(start + batch_size, len(data))
        batch_data = data[start:end]
        batch_targets = targets[start:end]
        
        current_batch_size = len(batch_data)
        
        x = torch.zeros(current_batch_size, len(batch_data[0]), vocab_size)
        for i, sequence in enumerate(batch_data):
            for t, char_idx in enumerate(sequence):
                x[i, t, char_idx] = 1
                
        hidden = model.init_hidden(current_batch_size)
        
        # 分离隐藏状态，防止梯度累积
        if isinstance(hidden, tuple):
            hidden = tuple([h.detach() for h in hidden])
        else:
            hidden = hidden.detach()
            
        output, hidden = model(x, hidden)
        loss = criterion(output.view(-1, vocab_size), batch_targets.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / n_batches

def generate_text(model, initial_str, char_to_idx, idx_to_char, vocab_size, 
                 pred_length=100, temperature=0.8):
    # 确保初始字符串中的所有字符都在词汇表中
    for char in initial_str:
        if char not in char_to_idx:
            raise ValueError(f"Character '{char}' not in vocabulary")
    model.eval()
    current_str = initial_str
    hidden = model.init_hidden(1)
    
    with torch.no_grad():
        for _ in range(pred_length):
            x = torch.zeros(1, 1, vocab_size)
            x[0, 0, char_to_idx[current_str[-1]]] = 1
            
            output, hidden = model(x, hidden)
            
            # 使用temperature参数来控制生成的随机性
            output_dist = output[0, -1].div(temperature).exp()
            probs = output_dist.div(torch.sum(output_dist))
            
            # 按概率采样并确保索引在有效范围内
            char_idx = torch.multinomial(probs, 1)[0].item()
            if char_idx >= len(idx_to_char):
                char_idx = char_idx % len(idx_to_char)
            
            # 添加标点符号的逻辑
            if len(current_str) % 7 == 0:  # 每7个字符加入逗号
                current_str += '，'
            elif len(current_str) % 15 == 0:  # 每15个字符加入句号
                current_str += '。'
            else:
                current_str += idx_to_char[char_idx]
            
    return current_str

def find_valid_starts(text, length=2):
    """找出文本中所有可用的起始短语"""
    text = ''.join(text.split())
    valid_starts = set()
    for i in range(len(text) - length + 1):
        valid_starts.add(text[i:i+length])
    return sorted(list(valid_starts))

if __name__ == "__main__":
    # 增加序列长度和隐藏层大小
    seq_length = 20
    hidden_size = 256
    n_layers = 2
    
    chars, char_to_idx, idx_to_char = process_data(text)
    x, y = create_sequences(text, char_to_idx, seq_length)
    
    input_size = len(chars)
    vocab_size = len(chars)
    
    # 创建模型
    model = CharRNN(input_size, hidden_size, vocab_size, n_layers, dropout=0.2)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.002)
    
    # 增加训练轮数
    n_epochs = 200
    
    print(f"开始训练，词汇表大小: {vocab_size}，序列长度: {seq_length}")
    
    for epoch in range(n_epochs):
        loss = train_model(model, x, y, criterion, optimizer, vocab_size)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss:.4f}")
            # 每20轮展示一次生成效果
            sample = generate_text(model, "关雎", char_to_idx, idx_to_char, 
                                 vocab_size, pred_length=50)
            print(f"Sample text:\n{sample}\n")
    
    print("\n训练完成！")
    
    # 找出可用的起始文本
    valid_starts = find_valid_starts(text)
    print("\n可用的起始短语示例（前10个）：")
    print(", ".join(valid_starts[:10]))
    
    print("\n生成示例：")
    
    # 从训练数据中选择起始文本
    initial_texts = ["南有", "采采", "君子"]
    for init_text in initial_texts:
        generated = generate_text(model, init_text, char_to_idx, idx_to_char, 
                                vocab_size, pred_length=100)
        print(f"\n起始：\"{init_text}\"\n生成：{generated}")

开始训练，词汇表大小: 143，序列长度: 20
Epoch 20/200, Loss: 4.3526
Sample text:
关雎灌差子求之，之·何之是经，。周。。·。，螽风于，，之，，。葛，窈嵬，履木·，风南，，，。采于卷，木·

Epoch 40/200, Loss: 2.8431
Sample text:
关雎以寤陟不永，效。陟，嵬姑，。酌侧效彼，，彼我崔我不疆，我。酌我，酌，，为绖我我永，葛虺。我我痡，陟以

Epoch 60/200, Loss: 1.3654
Sample text:
关雎关木，维维，之。。诗差淑，。左右采之。，只君，子福履，之。诗经国国，风····葛，葛之。兮，于，中，

Epoch 80/200, Loss: 0.6595
Sample text:
关雎飞，关中灌，，莫喈莫。黄，。于，集灌灌，莫喈喈喈。诗，·。·周周·，卷卷耳，耳不，顷。。嗟我怀，置彼

Epoch 100/200, Loss: 0.3881
Sample text:
关雎风·周周南，螽斯斯诜，宜，。。。我尔哉，，何思之。参，荇。，左右流，。窈窕悠女，，瑟友。。参差，菜，

Epoch 120/200, Loss: 0.2588
Sample text:
关雎木，葛藟荒，。乐只君子，，。成之。诗经，国风·周南·，卷。采采不不，，顷顷，。嗟，彼人。置周矣，陟彼

Epoch 140/200, Loss: 0.2064
Sample text:
关雎风·周南·，雎葛采覃覃兮，。君中谷，维，莫莫。是刈是，，。为为。。，高矣，我马瘏，。我。痡矣，，何吁

Epoch 160/200, Loss: 0.1503
Sample text:
关雎风·周南·，螽螽斯羽，诜，。洲宜我不，，云何吁。。诗，·。·周南·，耳采采卷耳，，盈顷。。嗟我，人，

Epoch 180/200, Loss: 0.1166
Sample text:
关雎木，葛藟萦，。乐只君子，，。成之。诗经，国风·周南·，耳。采卷耳，，盈顷筐。嗟我，人，。彼玄黄，我彼

Epoch 200/200, Loss: 0.0947
Sample text:
关雎矣，我马虺，。我姑酌彼金，。我不永怀。，彼高冈，我马，黄。我姑酌彼，觥维维以永伤，陟彼。矣，我，瘏矣


训练完成！

可用的起始短语示例（前10个）：
·关, 