In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import opencc
from pathlib import Path
import json
import re

In [131]:
batch_size = 12
num_epochs = 10
context_len = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("now using", device)

now using cuda


## Self-Attention

In [132]:
x = torch.rand(1,3,5)
print(x)
raw_weights = torch.bmm(x, x.transpose(1, 2))
print(raw_weights)
weights = F.softmax(raw_weights, dim=2)
print(weights)
y = torch.bmm(weights, x)
print(y)

tensor([[[0.5865, 0.4071, 0.3756, 0.9656, 0.6540],
         [0.5460, 0.8481, 0.8738, 0.1942, 0.3260],
         [0.9396, 0.4874, 0.4888, 0.8993, 0.2385]]])
tensor([[[2.0109, 1.3946, 1.9574],
         [1.3946, 1.9251, 1.6060],
         [1.9574, 1.6060, 2.2249]]])
tensor([[[0.4020, 0.2170, 0.3810],
         [0.2541, 0.4320, 0.3139],
         [0.3322, 0.2337, 0.4341]]])
tensor([[[0.7123, 0.5334, 0.5269, 0.7729, 0.4245],
         [0.6799, 0.6228, 0.6264, 0.6116, 0.3819],
         [0.7303, 0.5450, 0.5412, 0.7565, 0.3970]]])


In [133]:
converter = opencc.OpenCC("t2s")


def sentenceParse(para):
    para = re.sub(r"（.*?）", "", para)
    para = re.sub(r"{.*?}", "", para)
    para = re.sub(r"《.*?》", "", para)
    para = re.sub(r"[\[\]]", "", para)
    para = "".join([s for s in para if s not in "0123456789-"])
    para = re.sub(r"。。", "。", para)
    para = converter.convert(para)
    if "𫗋" in para or len(para) > context_len:
        return ""
    return para


def parseRawData(author=None, constrain=None):
    def handleJson(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)

        rst = []
        for poetry in data:
            if author and poetry.get("author") != author:
                continue

            paragraphs = poetry.get("paragraphs")
            if any(
                len(tr) != constrain and len(tr) != 0
                for s in paragraphs
                for tr in re.split("[，！。]", s)
                if constrain is not None
            ):
                continue

            pdata = "".join(paragraphs)
            pdata = sentenceParse(pdata)
            if pdata:
                rst.append(pdata)
        return rst

    poems = []
    src_path = Path("./data/全唐诗/")
    for file_path in src_path.glob("poet.tang*"):
        poems.extend(handleJson(file_path))
    # for file_path in src_path.glob("poet.song*"):
    # data.extend(handleJson(file_path))
    return poems

In [134]:
poems = parseRawData() 

In [135]:
# 构建词汇表
word_to_index = {}
for poem in poems:
    for word in poem:
        if word not in word_to_index:
            word_to_index[word] = len(word_to_index)
word_to_index["<EOP>"] = len(word_to_index)  # End Of Poem token
word_to_index["<START>"] = len(word_to_index)  # Start token
index_to_word = {index: word for word, index in word_to_index.items()}

vocab_size = len(word_to_index)

print("VOCAB_SIZE:", vocab_size)
print("data_size", len(poems))


# 将句子转换为列表形式，并添加结束符
def sentence_to_list(sentence):
    return list(sentence) + ["<EOP>"]


poems = [sentence_to_list(poem) for poem in poems]


# 创建单词到one-hot向量的映射
def create_one_hot_vector(word, word_to_index):
    return torch.LongTensor([word_to_index[word]])


one_hot_vectors = {
    word: create_one_hot_vector(word, word_to_index) for word in word_to_index
}
one_hot_vectors

VOCAB_SIZE: 7237
data_size 46969


{'秦': tensor([0]),
 '川': tensor([1]),
 '雄': tensor([2]),
 '帝': tensor([3]),
 '宅': tensor([4]),
 '，': tensor([5]),
 '函': tensor([6]),
 '谷': tensor([7]),
 '壮': tensor([8]),
 '皇': tensor([9]),
 '居': tensor([10]),
 '。': tensor([11]),
 '绮': tensor([12]),
 '殿': tensor([13]),
 '千': tensor([14]),
 '寻': tensor([15]),
 '起': tensor([16]),
 '离': tensor([17]),
 '宫': tensor([18]),
 '百': tensor([19]),
 '雉': tensor([20]),
 '余': tensor([21]),
 '连': tensor([22]),
 '甍': tensor([23]),
 '遥': tensor([24]),
 '接': tensor([25]),
 '汉': tensor([26]),
 '飞': tensor([27]),
 '观': tensor([28]),
 '迥': tensor([29]),
 '凌': tensor([30]),
 '虚': tensor([31]),
 '云': tensor([32]),
 '日': tensor([33]),
 '隐': tensor([34]),
 '层': tensor([35]),
 '阙': tensor([36]),
 '风': tensor([37]),
 '烟': tensor([38]),
 '出': tensor([39]),
 '疎': tensor([40]),
 '岩': tensor([41]),
 '廊': tensor([42]),
 '罢': tensor([43]),
 '机': tensor([44]),
 '务': tensor([45]),
 '崇': tensor([46]),
 '文': tensor([47]),
 '聊': tensor([48]),
 '驻': tensor([49]),
 '辇': tens

In [136]:
def generate_sample(sequence, one_hot_encoding):
    # 打印原始序列（可选）
    # print(sequence)

    # 使用列表推导式生成输入和输出的 one-hot 编码
    inputs = [one_hot_encoding[sequence[i - 1]] for i in range(1, len(sequence))]
    outputs = [one_hot_encoding[sequence[i]] for i in range(1, len(sequence))]

    # 将输入和输出列表合并为张量
    encoded_inputs = torch.cat(inputs)
    encoded_outputs = torch.cat(outputs)

    return encoded_inputs, encoded_outputs


# generate_sample(poems[0], one_hot_vectors)


class PoetryDataset(Dataset):
    def __init__(self, poems, transform=None):
        self.poems = poems
        self.transform = transform

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, index):
        poem = self.poems[index]
        input_data, output_data = generate_sample(poem, one_hot_vectors)
        if self.transform:
            input_data = self.transform(input_data)
        return input_data, output_data


def custom_collate_fn(batch):
    sequences, targets = zip(*batch)
    # 统一长度以进行批处理
    padded_sequences = nn.utils.rnn.pad_sequence(
        sequences, batch_first=True, padding_value=word_to_index["<START>"]
    )
    padded_targets = nn.utils.rnn.pad_sequence(
        targets, batch_first=True, padding_value=word_to_index["<START>"]
    )
    return padded_sequences, padded_targets


dataset = PoetryDataset(poems)
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
)

In [137]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False):
        super().__init__()

        assert k % heads == 0 # input vector size 必须是 heads 的整数倍
        self.k, self.heads = k, heads

        # Compute the queries, keys and values for all heads
        self.tokeys = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues = nn.Linear(k, k, bias=False)

        # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):
        b, t, k = x.size()
        h = self.heads

        # 首先，为所有 heads 计算 query/key/value，得到的是完整嵌入维度的 k*k 矩阵
        queries = self.toqueries(x)
        keys    = self.tokeys(x)
        values  = self.tovalues(x)

        # 接下来将 queries/keys/values 切块（降维），分别送到不同的 head
        s = k // h
        keys    = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values  = values.view(b, t, h, s)

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # Get dot product of queries and keys, and scale
        dot = torch.bmm(
            queries, keys.transpose(1, 2)
        )  # -- dot has size (b*h, t, t) containing raw weights
        dot = dot / (k ** (1 / 2))  # scale the dot product
        
        # masking 操作，禁用矩阵对角线及其以下的元素，确保在预测时不会看到未来的信息
        dot = torch.bmm(queries, keys.transpose(1, 2))

        indices = torch.triu_indices(t, t, offset=1)
        dot[:, indices[0], indices[1]] = float('-inf')

        dot = F.softmax(dot, dim=2)
        
        dot = F.softmax(
            dot, dim=2
        )  # normalize, dot now contains row-wise normalized weights

        out = torch.bmm(dot, values).view(b, h, t, s) # apply the self attention to the values

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)

        return self.unifyheads(out)

In [138]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()

        self.attention = SelfAttention(k, heads=heads)

        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)

        self.ff = nn.Sequential(nn.Linear(k, 4 * k), nn.ReLU(), nn.Linear(4 * k, k))

    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)

        fedforward = self.ff(x)
        return self.norm2(fedforward + x)

In [139]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)

        # The sequence of transformer blocks that does all the heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

        # Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the classes (where c is the nr. of classes).
        """
        # generate token embeddings
        tokens = self.token_emb(x)
        b, t, k = tokens.size()

        # generate position embeddings
        positions = torch.arange(t)
        positions = positions.to(device)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)

        x = (
            tokens + positions
        )  # 为什么文本嵌入和位置嵌入相加，没有理论，可能就是实验下来效果不错。
        # https://writings.stephenwolfram.com/2023/02/what-is-chatgpt-doing-and-why-does-it-work/
        x = self.tblocks(x)

        # probabilities
        x = self.toprobs(x)
        return F.log_softmax(x, dim=1)

In [140]:
model = Transformer(
    k=256, heads=4, depth=8, seq_length=512, num_tokens=vocab_size, num_classes=vocab_size
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_index["<START>"], reduction="mean")

In [141]:
def train(model, data_loader, num_epochs, device, optimizer, criterion):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            loss.backward()
            optimizer.step()

            if not batch_idx % 100:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | Batch {batch_idx:05d}/{len(data_loader):05d} | Loss: {loss:.4f}"
                )
        torch.save(model, "model.pth")

In [142]:
train(model, data_loader, num_epochs, device, optimizer, criterion)

Epoch: 001/010 | Batch 00000/03915 | Loss: 9.0253
Epoch: 001/010 | Batch 00100/03915 | Loss: 7.3220
Epoch: 001/010 | Batch 00200/03915 | Loss: 6.9916
Epoch: 001/010 | Batch 00300/03915 | Loss: 6.7564
Epoch: 001/010 | Batch 00400/03915 | Loss: 6.5544
Epoch: 001/010 | Batch 00500/03915 | Loss: 6.3860
Epoch: 001/010 | Batch 00600/03915 | Loss: 6.1584
Epoch: 001/010 | Batch 00700/03915 | Loss: 6.1089
Epoch: 001/010 | Batch 00800/03915 | Loss: 6.0879
Epoch: 001/010 | Batch 00900/03915 | Loss: 5.8858
Epoch: 001/010 | Batch 01000/03915 | Loss: 5.5657
Epoch: 001/010 | Batch 01100/03915 | Loss: 5.6316
Epoch: 001/010 | Batch 01200/03915 | Loss: 5.6627
Epoch: 001/010 | Batch 01300/03915 | Loss: 5.5743
Epoch: 001/010 | Batch 01400/03915 | Loss: 5.7914
Epoch: 001/010 | Batch 01500/03915 | Loss: 5.8535
Epoch: 001/010 | Batch 01600/03915 | Loss: 5.7575
Epoch: 001/010 | Batch 01700/03915 | Loss: 5.5250
Epoch: 001/010 | Batch 01800/03915 | Loss: 5.7871
Epoch: 001/010 | Batch 01900/03915 | Loss: 5.6664


In [2]:
def generate_text(start_word="<START>", top_k=1, log=False):
    generated_text = ""
    words = []
    for word in start_word:
        words += [word]
    print(words)
    with torch.no_grad():
        vectors_list = []
        for word in words:
            word_vector = torch.LongTensor([word_to_index[word]]).unsqueeze(0)
            vectors_list.append(word_vector)
            generated_text += word

        input_vector = torch.cat(vectors_list, dim=1)
        
        for _ in range(context_len - len(words)):
            
            output = model(input_vector.to(device))
            
            last_word = output[:, -1, :]
            last_word = last_word.view(-1)
            top_values, top_indices = last_word.data.topk(top_k)

            probabilities = torch.exp(top_values)
            top_words = [index_to_word[index.item()] for index in top_indices]

            probabilities_np = probabilities.cpu().detach().numpy()
            probabilities_np = probabilities_np / probabilities_np.sum()
            indices_np = top_indices.cpu().detach().numpy()
            if log:
                for word, prob in zip(top_words, probabilities_np):
                    print(f"{word}: {prob:.4f}")

            selected_index = np.random.choice(indices_np, p=probabilities_np)

            next_word = index_to_word[selected_index]
            if next_word == "<EOP>":
                break
            generated_text += next_word
            if log:
                print(generated_text)
            # * 需要升一个维, 因为模型的输入以 batch 为单位
            input_vector = torch.cat([input_vector, torch.LongTensor([selected_index]).unsqueeze(0)], dim=1)

    return generated_text.strip()


print(generate_text("原", top_k=1))
print(generate_text("月", top_k=3))
print(generate_text("泉", top_k=3))
print(generate_text("日", top_k=30))
print(generate_text("沾衣欲湿杏花雨", top_k=3))
print(generate_text("风", top_k=3, log=True))

['原']


NameError: name 'torch' is not defined