In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
from pathlib import Path
import json
import re

# import tqdm
import matplotlib.pyplot as plt

In [None]:
batch_size = 128
num_epochs = 30
context_len = 64
initial_lr = 0.001
data_path = "./data/chinese-poetry/唐诗"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("now using", device)

## Self-Attention


In [None]:
x = torch.rand(1, 3, 5)
print(x)
raw_weights = torch.bmm(x, x.transpose(1, 2))
print(raw_weights)
weights = F.softmax(raw_weights, dim=2)
print(weights)
y = torch.bmm(weights, x)
print(y)

In [None]:
# converter = opencc.OpenCC("t2s")


def sentenceParse(para):
    para = re.sub(r"（.*?）", "", para)
    para = re.sub(r"\(.*?\)", "", para)
    para = re.sub(r"{.*?}", "", para)
    para = re.sub(r"《.*?》", "", para)
    para = re.sub(r"[\[\]]", "", para)
    para = "".join([s for s in para if s not in "0123456789-"])
    para = re.sub(r"。。", "。", para)
    para = re.sub(r"？。", "。", para)
    para = re.sub(r"？", "。", para)
    # para = converter.convert(para)
    if "𫗋" in para or len(para) > context_len or len(para) < 24:
        return ""
    return para


def parseRawData(author=None, constrain=None):
    def handleJson(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)
        rst = []
        for poetry in data:
            if author and poetry.get("author") != author:
                continue

            # paragraphs = poetry.get("paragraphs")
            paragraphs = poetry.get("content")

            if any(
                len(tr) != constrain and len(tr) != 0
                for s in paragraphs
                for tr in re.split("[，！。]", s)
                if constrain is not None
            ):
                continue

            pdata = "".join(paragraphs)
            pdata = sentenceParse(pdata)
            segments = [segment for segment in re.split(r"[，。]", pdata) if segment]

            # 仅训练五言绝句和七言律诗
            if any(len(segment) not in [5, 7] for segment in segments):
                continue
            # 去除含有错误字符的诗句
            if '□' in pdata:
                continue
            if pdata:
                rst.append(pdata)
        return rst

    poems = []
    src_path = Path(data_path)
    for file_path in src_path.glob("data3*"):
        poems.extend(handleJson(file_path))
    return poems


poems = parseRawData()
poems = set(poems)
poems = list(poems)

In [None]:
# 构建词汇表
word_to_index = {}
for poem in poems:
    for word in poem:
        if word not in word_to_index:
            word_to_index[word] = len(word_to_index)
word_to_index["<EOP>"] = len(word_to_index)  # End Of Poem token
word_to_index["<START>"] = len(word_to_index)  # Start token
index_to_word = {index: word for word, index in word_to_index.items()}

vocab_size = len(word_to_index)

print("VOCAB_SIZE:", vocab_size)
print("data_size", len(poems))


# 将句子转换为列表形式，并添加结束符
def sentence_to_list(sentence):
    return list(sentence) + ["<EOP>"]


poems = [sentence_to_list(poem) for poem in poems]


# 创建单词索引张量
def create_index_tensor(word, word_to_index):
    return torch.LongTensor([word_to_index[word]])


# 创建真正的 one-hot 编码向量
def create_one_hot_vector(word, word_to_index, vocab_size):
    vector = torch.zeros(vocab_size)
    vector[word_to_index[word]] = 1
    return vector


index_tensors = {
    word: create_index_tensor(word, word_to_index) for word in word_to_index
}

# 如果需要，可以这样创建 one-hot 编码向量
# one_hot_vectors = {
#     word: create_one_hot_vector(word, word_to_index, vocab_size) for word in word_to_index
# }

In [None]:
def generate_sample(sequence):

    inputs = [index_tensors[sequence[i - 1]] for i in range(1, len(sequence))]
    outputs = [index_tensors[sequence[i]] for i in range(1, len(sequence))]

    # 将输入和输出列表合并为张量
    encoded_inputs = torch.cat(inputs)
    encoded_outputs = torch.cat(outputs)

    return encoded_inputs, encoded_outputs


# generate_sample(poems[0], one_hot_vectors)


class PoetryDataset(Dataset):
    def __init__(self, poems, transform=None):
        self.poems = poems
        self.transform = transform

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, index):
        poem = self.poems[index]
        input_data, output_data = generate_sample(poem)
        if self.transform:
            input_data = self.transform(input_data)
        return input_data, output_data


def custom_collate_fn(batch):
    sequences, targets = zip(*batch)
    # 统一长度以进行批处理
    padded_sequences = nn.utils.rnn.pad_sequence(
        sequences, batch_first=True, padding_value=word_to_index["<START>"]
    )
    padded_targets = nn.utils.rnn.pad_sequence(
        targets, batch_first=True, padding_value=word_to_index["<START>"]
    )
    return padded_sequences, padded_targets


dataset = PoetryDataset(poems)
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
)

![](../multi-head.png)

![](../self-attention.png)


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads=4, mask=False):
        super(SelfAttention, self).__init__()

        assert embed_size % num_heads == 0, "Embedding size 必须是 heads 的整数倍"
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        # 计算所有 heads 的 query, key 和 value
        self.query_projection = nn.Linear(embed_size, embed_size, bias=False)
        self.key_projection = nn.Linear(embed_size, embed_size, bias=False)
        self.value_projection = nn.Linear(embed_size, embed_size, bias=False)

        # 在 multi-head self-attention 操作后应用
        self.fc_out = nn.Linear(embed_size, embed_size)
        self.mask = mask

    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()

        # 计算所有 heads 的 query, key 和 value
        queries = self.query_projection(x).view(
            batch_size, seq_length, self.num_heads, self.head_dim
        )
        keys = self.key_projection(x).view(
            batch_size, seq_length, self.num_heads, self.head_dim
        )
        values = self.value_projection(x).view(
            batch_size, seq_length, self.num_heads, self.head_dim
        )

        # 将 tensor 重新排列，以适应 multi-head attention
        queries = (
            queries.transpose(1, 2)
            .contiguous()
            .view(batch_size * self.num_heads, seq_length, self.head_dim)
        )
        keys = (
            keys.transpose(1, 2)
            .contiguous()
            .view(batch_size * self.num_heads, seq_length, self.head_dim)
        )
        values = (
            values.transpose(1, 2)
            .contiguous()
            .view(batch_size * self.num_heads, seq_length, self.head_dim)
        )

        # 计算 Scaled dot-product attention
        dot_product = torch.bmm(queries, keys.transpose(1, 2))
        scaled_dot_product = dot_product / (self.embed_size**0.5)

        # 如果启用了 mask，则对未来的 token 进行屏蔽
        if self.mask:
            mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()
            mask = mask.to(device)
            scaled_dot_product.masked_fill_(mask, float("-inf"))

        attention = F.softmax(scaled_dot_product, dim=2)

        # 将 self-attention 应用于 values
        out = torch.bmm(attention, values).view(
            batch_size, self.num_heads, seq_length, self.head_dim
        )
        out = (
            out.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_length, self.embed_size)
        )

        return self.fc_out(out)

![](../transformer-architecture.png)


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads, mask=False):
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_size, num_heads=num_heads, mask=mask)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x):
        # Self-attention 和残差连接
        attended = self.attention(x)
        x = self.norm1(attended + x)

        # 前馈神经网络和残差连接
        fedforward = self.ff(x)
        return self.norm2(fedforward + x)

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        embed_size,
        num_heads,
        num_layers,
        seq_length,
        num_tokens,
        num_classes,
        mask=False,
    ):
        super(Transformer, self).__init__()

        self.token_emb = nn.Embedding(num_tokens, embed_size)
        self.pos_emb = nn.Embedding(seq_length, embed_size)

        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, num_heads, mask) for _ in range(num_layers)]
        )

        self.fc_out = nn.Linear(embed_size, num_classes)

    def forward(self, x):
        """
        :param x: A (batch_size, seq_length) tensor of integer values representing words (in some predetermined vocabulary).
        :return: A (batch_size, num_classes) tensor of log-probabilities over the classes.
        """
        batch_size, seq_length = x.size()

        # 生成 token 嵌入
        tokens = self.token_emb(x)

        # 生成位置嵌入
        positions = torch.arange(seq_length).to(x.device)
        positions = self.pos_emb(positions).expand(batch_size, seq_length, -1)

        # 将 token 嵌入和位置嵌入相加
        x = tokens + positions

        # 通过所有 Transformer 层
        for layer in self.layers:
            x = layer(x)

        # 最后映射到类概率
        x = self.fc_out(x)

        return x

In [None]:
mask = True
model = Transformer(
    embed_size=256,
    num_heads=4,
    num_layers=8,
    seq_length=128,
    num_tokens=vocab_size,
    num_classes=vocab_size,
    mask=True,
)

optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=9, verbose=True
)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_index["<START>"])

In [None]:
def train(model, data_loader, num_epochs, device, optimizer, criterion, scheduler):
    log_dict = {
        "train_loss_per_epoch": [],
        "train_perplexity_per_epoch": [],
    }
    start_time = time.time()
    model = model.to(device)
    for epoch in range(num_epochs):
        current_lr = optimizer.param_groups[0]["lr"]
        print(
            f"Epoch: {epoch+1:03d}/{num_epochs:03d} | Current Learning Rate: {current_lr:.6f}"
        )
        total_loss = 0
        model.train()
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)

            if not batch_idx % 100:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | Batch {batch_idx:04d}/{len(data_loader):04d} | Loss: {loss:.6f}"
                )
        torch.save(model, "model.pth")

        avg_loss = total_loss / len(data_loader.dataset)
        scheduler.step(avg_loss)
        perplexity = torch.exp(torch.tensor(avg_loss))
        log_dict["train_loss_per_epoch"].append(avg_loss)
        log_dict["train_perplexity_per_epoch"].append(perplexity)

        print(f"Time elapsed: {(time.time() - start_time) / 60:.2f} min")

    print(f"Total Training Time: {(time.time() - start_time)/ 60:.2f} min")
    return log_dict

In [None]:
log_dict = train(
    model, data_loader, num_epochs, device, optimizer, criterion, scheduler
)
# model = torch.load("model-full.pth")

In [None]:
def plot_training_stats(log_dict):
    plt.figure(figsize=(10, 6))
    plt.subplot(1, 2, 1)
    plt.plot(log_dict["train_loss_per_epoch"], label="Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(log_dict["train_perplexity_per_epoch"], label="Training Perplexity")
    plt.xlabel("Epoch")
    plt.ylabel("Perplexity")
    plt.grid(True)
    # plt.yscale("log")
    plt.title("Training Perplexity")
    plt.savefig("training_stats.svg")
    plt.show()


plot_training_stats(log_dict)

In [None]:
def generate_text(start_word="<START>", top_k=1, temperature=0.6, log=False):
    generated_text = ""
    words = []
    for word in start_word:
        words += [word]
    # print(words)
    with torch.no_grad():
        index_tensors_list = []
        for word in words:
            index_tensors_list.append(
                torch.LongTensor([word_to_index[word]]).unsqueeze(0)
            )
            generated_text += word

        input_tensor = torch.cat(index_tensors_list, dim=1)

        for _ in range(context_len - len(generated_text)):

            output = model(input_tensor.to(device))

            last_word = output[:, -1, :]
            last_word = last_word.view(-1)

            # 调整温度
            # softmax 函数倾向于增强输入向量中最大值的影响
            scaled_logits = last_word / temperature
            probabilities = F.softmax(scaled_logits, dim=-1)

            probabilities, top_indices = probabilities.data.topk(top_k)
            top_words = [index_to_word[index.item()] for index in top_indices]
            probabilities = probabilities / torch.sum(probabilities)

            probabilities_np = probabilities.cpu().detach().numpy()
            indices_np = top_indices.cpu().detach().numpy()
            if log:
                for word, prob in zip(top_words, probabilities_np):
                    print(f"{word}: {prob:.4f}")

            selected_index = np.random.choice(indices_np, p=probabilities_np)

            next_word = index_to_word[selected_index]
            if next_word == "<EOP>":
                break
            generated_text += next_word
            if log:
                print(generated_text)
            # * 需要升一个维, 因为模型的输入以 batch 为单位
            input_tensor = torch.cat(
                [input_tensor, torch.LongTensor([selected_index]).unsqueeze(0)], dim=1
            )

    return generated_text.strip()

In [None]:
print(generate_text("故人西辞黄鹤楼", top_k=1))
print(generate_text("长安一片月", top_k=1))
for i in range(10):
    print(generate_text("月", top_k=9))
for i in range(10):
    print(generate_text("海棠", top_k=3))
print(generate_text("风", top_k=3, log=True))

In [None]:
def generate_acrostic(start_chars, top_k=1, log=False):
    generated_text = ""
    for start_char in start_chars:
        words = [start_char]
        with torch.no_grad():
            vectors_list = []
            for word in words:
                word_vector = torch.LongTensor([word_to_index[word]]).unsqueeze(0)
                vectors_list.append(word_vector)
                generated_text += word

            input_vector = torch.cat(vectors_list, dim=1)

            for _ in range(context_len - len(words)):
                output = model(input_vector.to(device))

                last_word_logits = output[:, -1, :]
                last_word_logits = last_word_logits.view(-1)

                # Adjust temperature
                temperature = 0.7
                scaled_logits = last_word_logits / temperature
                probabilities = F.softmax(scaled_logits, dim=-1)

                probabilities, top_indices = probabilities.data.topk(top_k)
                top_words = [index_to_word[index.item()] for index in top_indices]
                probabilities = probabilities / torch.sum(probabilities)

                probabilities_np = probabilities.cpu().detach().numpy()
                indices_np = top_indices.cpu().detach().numpy()
                if log:
                    for word, prob in zip(top_words, probabilities_np):
                        print(f"{word}: {prob:.4f}")

                selected_index = np.random.choice(indices_np, p=probabilities_np)

                next_word = index_to_word[selected_index]
                if next_word == "<EOP>":
                    break
                generated_text += next_word
                if log:
                    print(generated_text)
                # Need to add a dimension because model input is in batches
                input_vector = torch.cat(
                    [input_vector, torch.LongTensor([selected_index]).unsqueeze(0)],
                    dim=1,
                )

            generated_text += "\n"  # Add a newline after each line of the poem

    return generated_text


print(generate_acrostic("南昌航空大学", top_k=3))