In [1]:
# from torchtext.datasets import WikiText2 #导入 WikiText2 数据集
from torchtext.data.utils import get_tokenizer #导入分词器
from torchtext.vocab import build_vocab_from_iterator #导入vocabulary工具，用于从一个迭代器构建一个词汇表（Vocabulary），迭代器中包含了分词后的文本数据

In [30]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import math
max_seq_len = 256

In [7]:
torch.cuda.is_available()

True

In [4]:
#获取分词器
tokenizer = get_tokenizer('basic_english')

In [5]:
def load_local_wikitext2(split='train'):
    # 定义文件路径
    file_path = f'../data/traindata/wikitext-2/wiki.{split}.tokens'
    # 读取文件内容
    with open(file_path, 'r', encoding='utf-8') as file:
        data = file.readlines()
    return data

In [8]:
train_iter = load_local_wikitext2(split='train')
valid_iter = load_local_wikitext2(split='valid')

In [9]:
# 定义一个生成器函数，用于将数据集中的文本转换为tokens
def yield_tokens(data_iter):
    for item in data_iter:
        yield tokenizer(item)

# 创建词汇表，包括特殊tokens："<pad>", "<sos>", "<eos>"
#specials:一个包含特殊符号的列表，如<pad>（填充符），<unk>（未知词标记）等。这些特殊符号会被添加到词汇表的开始位置。
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<pad>", "<sos>", "<eos>"])
# 设置当查询的词汇项不在词汇表中时返回的默认索引值。当查询的词汇不在词汇表中默认返回<pad>的值
vocab.set_default_index(vocab["<pad>"])
'''
vocab的几个方法：
__getitem__(self, token): 返回给定词汇项的索引。
__len__(self): 返回词汇表中词汇项的数量。
get_itos(self): 返回一个列表，其中包含词汇表中所有词汇项，索引即为它们在词汇表中的位置。
get_stoi(self): 返回一个字典，键为词汇项，值为它们在词汇表中的索引。
set_default_index(self, index): 设置当查询的词汇项不在词汇表中时返回的默认索引值。
'''

# 打印词汇表信息
print("词汇表大小:", len(vocab))
print("词汇示例(word to index):", {word: vocab[word] for word in ["<pad>", "<sos>", "<eos>", "the", "apple"]})

词汇表大小: 28785
词汇示例(word to index): {'<pad>': 0, '<sos>': 1, '<eos>': 2, 'the': 3, 'apple': 11505}


In [11]:
class WikiDataset(Dataset):
    def __init__(self, data_iter, vocab, max_len=max_seq_len):
        super(WikiDataset, self).__init__()
        self.data = []
        for sentence in data_iter:
            tokens = tokenizer(sentence)[:max_len-2]
            tokens = [vocab['<sos>']] + vocab(tokens) + [vocab['<eos>']]
            self.data.append(tokens)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        source = self.data[idx][:-1]
        target = self.data[idx][1:]
        return torch.LongTensor(source), torch.LongTensor(target)

In [12]:
def attention(query, key, value, mask):
    '''
    query,key,value:[batch,head,seq_len,d_k]
    '''
    d_k = query.size(-1)
    # [batch,head,seq_len,d_k]*[batch,head,d_k,seq_len]=[batch,head,d_k,seq_len,seq_len]
    # QK^T / srqt(d_k)
    score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
    score = score.masked_fill_(mask, -1e9)
    score = torch.softmax(score, dim=-1)
    return torch.matmul(score, value)

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, head, d_embedding):
        super(MultiHeadAttention, self).__init__()
        self.head = head
        self.d_k = d_embedding // head
        self.queryLinear = nn.Linear(d_embedding, d_embedding)
        self.keyLinear = nn.Linear(d_embedding, d_embedding)
        self.valueLinear = nn.Linear(d_embedding, d_embedding)
        self.fc_out = nn.Linear(d_embedding, d_embedding)
        
        self.lary_norm = nn.LayerNorm(normalized_shape=d_embedding)
        
    def forward(self, query, key, value, att_mask):
        '''
        query, key, value：【batch,seq_len, d_embedding】
        '''
        batch = query.size(0)
        query_clone = query.clone()
        # [batch,seq_len,d_embedding]->[batch,seq_len,head,d_k]->[batch,head,seq_len,d_k]
        query = self.queryLinear(query).view(batch,-1,self.head, self.d_k).transpose(1,2)
        key = self.queryLinear(key).view(batch,-1,self.head, self.d_k).transpose(1,2)
        value = self.queryLinear(value).view(batch,-1,self.head, self.d_k).transpose(1,2)
        atten = attention(query, key, value, att_mask)
        #[batch,head,seq_len,d_k]->[batch,seq_len,head,d_k]-> [batch,seq_len,d_embedding]
        atten = atten.transpose(1,2).contiguous().view(batch,-1,self.head*self.d_k)
        Atten = self.fc_out(atten)
        return self.lary_norm(Atten + query_clone) 

![image](../data/image/GPT-FFN.png)

In [14]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_embedding, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.ForwardNetwork = nn.Sequential(
            nn.Linear(d_embedding, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_embedding)
        )
        self.lary_norm = nn.LayerNorm(normalized_shape=d_embedding)
    
    def forward(self, x):
        x_clone = x.clone()
        ForwardNetwork = self.ForwardNetwork(x)
        return self.lary_norm(x_clone + ForwardNetwork)

![image](../data/image/GPT-DecoderLayer.png)

In [15]:
class GPT_DecoderLayer(nn.Module):
    def __init__(self,d_embedding, head, d_ff):
        super(GPT_DecoderLayer, self).__init__()
        self.Multi = MultiHeadAttention(head, d_embedding)
        self.ffn = FeedForwardNetwork(d_embedding, d_ff)
        
    def forward(self, x, mask):
        atten  = self.Multi(x, x, x, mask)
        ffn = self.ffn(atten)
        return ffn

In [16]:
class GPT_Decoders(nn.Module):
    def __init__(self, n_layers,d_embedding, head,  d_ff):
        # GPT_Decoders(n_layers, d_embedding, 4, 1024)
        super(GPT_Decoders, self).__init__()
        self.layers = nn.ModuleList(
            [
                GPT_DecoderLayer(d_embedding, head, d_ff) for _ in range(n_layers)
            ]
        )

    def forward(self, src, src_mask):
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

In [17]:
def create_subsequent_mask(seq_length):
    '''
    seq_length:词的长度
    rerturn [seq_length, seq_length]
    '''
    # 创建后续掩码，上三角矩阵，保留对角线及以下元素，其余置为True
    subsequent_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()
    
    return subsequent_mask

In [18]:
class GPT_1(nn.Module):
    def __init__(self, n_layers ,vocab_size, d_embedding, seq_len):
        # model = GPT_1(6, len(vocab), 512, max_seq_len).to(device)
        super(GPT_1, self).__init__()
        self.src_emb = nn.Embedding(vocab_size, d_embedding)
        self.pos_emb = nn.Embedding(seq_len, d_embedding)
        self.decoder = GPT_Decoders(n_layers, d_embedding, 4, 1024)
        self.projection = nn.Linear(d_embedding, vocab_size)
        
    def forward(self, x, device):
        position = torch.arange(x.size(0), device=device).unsqueeze(-1)
        inputs_embedding = self.src_emb(x) + self.pos_emb(position)
        # [batch.seq_len, d_model]
        attn_mask = create_subsequent_mask(inputs_embedding.size(1)).to(device)
        dec_outputs = self.decoder(inputs_embedding, attn_mask)
        # 传递给全连接层以生成预测
        logits = self.projection(dec_outputs)
        return logits   

In [19]:
def pad_sequence(sequences, padding_value=0, length=None):
    '''
    进行批量训练时，将序列长度补齐
    '''
    max_length = max(len(seq) for seq in sequences) if length is None else length
    # 创建一个具有适当形状的全零张量，用于存储补齐后的序列
    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)
    # 遍历序列，将每个序列的内容复制到结果张量中
    for i, seq in enumerate(sequences):
        end = len(seq)
        result[i, :end] = seq[:end]
    return result

In [20]:
# 定义collate_fn函数，用于将一个批次的数据整理成适当的形状
def collate_fn(batch):
    # 从批次中分离源序列和目标序列
    sources, targets = zip(*batch)    
    # 计算批次中的最大序列长度
    max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))    
    # 使用pad_sequence函数补齐源序列和目标序列
    sources = pad_sequence(sources, padding_value=vocab["<pad>"], length=max_length)
    targets = pad_sequence(targets, padding_value=vocab["<pad>"], length=max_length)    
    # 返回补齐后的源序列和目标序列
    return sources, targets

In [21]:
train_dataset = WikiDataset(train_iter, vocab) # 创建训练数据集
valid_dataset = WikiDataset(valid_iter, vocab) # 创建验证数据集

In [22]:
batch_size = 3

In [23]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [27]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT_1(6, len(vocab), 512, max_seq_len).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # 优化器
epochs = 2  # 训练轮次

In [31]:
for epoch in range(epochs):
    epoch_loss = 0
    for batch_idx, (source, target) in enumerate(train_dataloader): # 用Dataloader加载数据
        inputs, targets = source.to(device), target.to(device)
        optimizer.zero_grad()  # 梯度清零
        outputs = model(inputs, device)  # 获取模型输出
        loss = criterion(outputs.view(-1, len(vocab)), targets.view(-1))  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        epoch_loss += loss.item()        
        if (batch_idx + 1) % 500 == 0: # 每500个批次打印一次损失
            print(f"Batch {batch_idx + 1}/{len(train_dataloader)}, Loss: {loss.item()}")    
    epoch_loss /= len(train_dataloader) # 每轮打印一次损失
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {epoch_loss}")

Batch 500/12240, Loss: 7.0839715003967285
Batch 1000/12240, Loss: 5.8813700675964355
Batch 1500/12240, Loss: 6.6938605308532715
Batch 2000/12240, Loss: 6.501121520996094
Batch 2500/12240, Loss: 6.339188575744629
Batch 3000/12240, Loss: 5.825780391693115
Batch 3500/12240, Loss: 5.920065402984619
Batch 4000/12240, Loss: 6.576527118682861
Batch 4500/12240, Loss: 5.715553283691406
Batch 5000/12240, Loss: 6.186413764953613
Batch 5500/12240, Loss: 5.766402721405029
Batch 6000/12240, Loss: 6.603525638580322
Batch 6500/12240, Loss: 5.811763763427734
Batch 7000/12240, Loss: 6.217658996582031
Batch 7500/12240, Loss: 5.253559112548828
Batch 8000/12240, Loss: 0.3145367205142975
Batch 8500/12240, Loss: 6.071967124938965
Batch 9000/12240, Loss: 5.8089599609375
Batch 9500/12240, Loss: 6.809699535369873
Batch 10000/12240, Loss: 2.0836989879608154
Batch 10500/12240, Loss: 5.400591850280762
Batch 11000/12240, Loss: 5.6525492668151855
Batch 11500/12240, Loss: 5.002734184265137
Batch 12000/12240, Loss: 4.

In [43]:
def generate_text_beam_search(model, input_str, device,max_len=50, beam_width=5):
    model.eval()  # 将模型设置为评估（测试）模式，关闭dropout和batch normalization等训练相关的层
    # 将输入字符串中的每个token 转换为其在词汇表中的索引
    input_tokens = [vocab[token] for token in input_str.split()]
    # 创建一个列表，用于存储候选序列
    candidates = [(input_tokens, 0.0)]
    with torch.no_grad():  # 禁用梯度计算，以节省内存并加速测试过程
        for _ in range(max_len):  # 生成最多max_len个tokens
            new_candidates = []
            for candidate, candidate_score in candidates:
                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)
                outputs = model(inputs,device) # 输出 logits形状为[1, len(output_tokens), vocab_size]
                logits = outputs[:, -1, :] # 只关心最后一个时间步（即最新生成的token）的logits
                # 找到具有最高分数的前beam_width个tokens
                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)
                final_results = [] # 初始化输出序列
                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):
                    new_candidate = candidate + [next_token.item()]
                    new_score = candidate_score - score.item()  # 使用负数，因为我们需要降序排列
                    if next_token.item() == vocab["<eos>"]:
                        # 如果生成的token是EOS（结束符），将其添加到最终结果中
                        final_results.append((new_candidate, new_score))
                    else:
                        # 将新生成的候选序列添加到新候选列表中
                        new_candidates.append((new_candidate, new_score))
            # 从新候选列表中选择得分最高的beam_width个序列
            candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]
    # 选择得分最高的候选序列
    best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]
    # 将输出 token 转换回文本字符串
    output_str = " ".join([vocab.get_itos()[token] for token in best_candidate if vocab.get_itos()[token] != "<pad>"])
    return output_str


In [46]:
input_str = "my name"  # 输入几个词
generated_text = generate_text_beam_search(model, input_str, device)  # 模型跟着这些词生成后续文本
print("生成的文本:", generated_text)  # 打印生成的文本

生成的文本: my name <unk> ( <unk> ) , <unk> ( <unk> ( <unk> ) , <unk> ( <unk> ) , <unk> ( <unk> ( <unk> ) , <unk> ( <unk> ( <unk> ) , <unk> ( <unk> ( <unk> ) , <unk> ( <unk> ( <unk> ) , <unk> ( <unk> ) ,
