In [1]:
import numpy as np
import torch
import torch.nn as nn

d_k = 64 # Q的维度
d_v = 64 # V的维度
d_embedding = 512 # embedding的维度
n_heads = 8 # 多头注意力的个数
batch_size = 10
n_layers = 6 # 解码器的层数

In [2]:
class ScaledDotProductAttention(nn.Module):
    """
    缩放点积注意力
    简单理解 ScaledDotProductAttention，目的是计算Query和Key的相似权重，作用于Value
    结果是
    Query1: {Value1: w11, Value2: w12, Value3: w13}
    Query2: {Value1: w21, Value2: w22, Value3: w23}
    """
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        # 维度信息
        # Q: [batch_size, n_heads, len_q, d_k]
        # K: [batch_size, n_heads, len_k, d_k]
        # V: [batch_size, n_heads, len_v(=len_k), d_v]
        # attn_mask: [batch_size, n_heads, len_q, len_k]
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        # scores: [batch_size, n_heads, len_q, len_k]
        # 加上注意力掩码, 将attn_mask中为True的位置的分数设置为极小值
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        # softmax归一化 => 注意力权重
        weights = nn.Softmax(dim=-1)(scores)
        # weights: [batch_size, n_heads, len_q, len_k]
        context = torch.matmul(weights, V) 
        # context: [batch_size, n_heads, len_q, d_v]
        return context, weights # 返回上下文变量 和 注意力分数

In [3]:
class MultiHeadAttention(nn.Module):
    """
    多头注意力
    简单理解，先放大维度，提取Q、K、V的各个维度的信息，再缩小维度，得到最终的结果
    黑盒的看是 (Q、K、V) -> Q
    """
    def __init__(self, d_embedding=d_embedding, n_heads=n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_embedding = d_embedding
        self.n_heads = n_heads

        self.W_Q = nn.Linear(d_embedding, n_heads * d_k)
        self.W_K = nn.Linear(d_embedding, n_heads * d_k)
        self.W_V = nn.Linear(d_embedding, n_heads * d_v)
        self.linear = nn.Linear(n_heads * d_v, d_embedding)
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, Q, K, V, attn_mask):
        # 维度信息
        # Q: [batch_size, len_q, d_embedding]
        # K: [batch_size, len_k, d_embedding]
        # V: [batch_size, len_v(=len_k), d_embedding]
        # attn_mask: [batch_size, len_q, len_k]
        
        residual, batch_size = Q, Q.size(0)
        # 线性层，维度提升，为了捕捉更多信息
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) 
        # q_s: [batch_size, n_heads, len_q, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
        # k_s: [batch_size, n_heads, len_k, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
        # v_s: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        # attn_mask: [batch_size, n_heads, len_q, len_k]

        # 点积缩放注意力
        context, weights = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        # context: [batch_size, n_heads, len_q, d_v]
        # weights: [batch_size, n_heads, len_q, len_k]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
        # context: [batch_size, len_q, n_heads * d_v]

        # 线性层，降维成 Q 原始的维度
        output = self.linear(context) 
        # output: [batch_size, len_q, d_embedding]
        
        # 残差连接，并做归一化（方便将当前Q往下层传递，所以做了残差）
        output = self.layer_norm(output + residual) 
        # output: [batch_size, len_q, d_embedding]
        return output, weights

In [4]:
class PoswiseFeedForwardNet(nn.Module):
    """
    前馈神经网络，目标是优化每个标记（单词）的表征
    对每个位置的d_embedding维度进行升维 => 降维 => 残差归一化
    """
    def __init__(self, d_ff=2048):
        super(PoswiseFeedForwardNet, self).__init__()
        # 输入升维
        self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
        # 输入降维
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
        # 定义 归一化
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, inputs):
        # inputs [batch_size, len_q, d_embedding]
        residual = inputs

        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
        # [batch_size, d_ff, len_q]

        output = self.conv2(output).transpose(1, 2)
        # [batch_size, len_q, d_embedding]
        
        output = self.layer_norm(output + residual)
        # [batch_size, len_q, d_embedding]
        return output

In [5]:
def get_pos_enc_table(n_position, embedding_dim):
    # 位置编码表：目的是让模型知道输入序列中单词的位置信息
    # 也可以用自然序列(1,2,3)作为位置编码，但正余弦能更好表达位置信息
    # 维度信息
    # n_position: 输入序列最大长度
    # embedding_dim: 词向量维度

    pos_table = np.zeros((n_position, embedding_dim), dtype=np.float32)
    for pos_i in range(n_position):
        for idx in range(embedding_dim):
            angle = pos_i / np.power(10000, 2 * (idx // 2) / embedding_dim)
            pos_table[pos_i, idx] = angle
    
    pos_table[:, 0::2] = np.sin(pos_table[:, 0::2]) # dim 2i偶数维
    pos_table[:, 1::2] = np.cos(pos_table[:,1::2]) # dim 2i+1奇数维
    # pos_table: [n_position, embedding_dim]
    return torch.FloatTensor(pos_table)

In [6]:
def get_attn_pad_mask(seq_q, seq_k):
    # 填充注意力掩码
    # seq_q: [batch_size, len_q]
    # seq_k: [batch_size, len_k]
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()

    # =0的位置会变成True,其他是False
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) 
    # [batch_size, 1, len_k]

    pad_aatn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
    # [batch_size, len_q, len_k]
    return pad_attn_mask

In [7]:
def get_attn_subsequent_mask(seq):
    # 注意力掩码，屏蔽未来的信息
    # seq: [batch_size, seq_len(Q)=seq_len(K)]
    
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    # attn_shape: [batch_size, seq_len, seq_len]

    # triu triangle upper
    subsequent_mask = np.triu(np.ones(attn_shape), k=1)
    # subsequent_mask: [batch_size, seq_len, seq_len]

    subsequent_mask = torch.from_numpy(subsequent_mask).byte()
    # subsequent_mask: [batch_size, seq_len, seq_len]
    return subsequent_mask

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()

        self.self_attn = MultiHeadAttention()
        self.feed_forward = PoswiseFeedForwardNet()
        self.norm1 = nn.LayerNorm(d_embedding)
        self.norm2 = nn.LayerNorm(d_embedding)

    def forward(self, dec_inputs, attn_mask):
        # dec_inputs: [batch_size, seq_len, d_embedding]
        attn_outputs, _ = self.self_attn(dec_inputs, dec_inputs, dec_inputs, attn_mask)
        # attn_outputs: [batch_size, seq_len, d_embedding]

        # 残差连接 + 归一化
        norm1_outputs = self.norm1(dec_inputs + attn_outputs)
        # norm1_outputs: [batch_size, seq_len, d_embedding]

        ff_outputs = self.feed_forward(norm1_outputs)
        # ff_outputs: [batch_size, seq_len, d_embedding]
        dec_outputs = self.norm2(norm1_outputs + ff_outputs)
        # dec_outputs: [batch_size, seq_len, d_embedding]
        return dec_outputs

In [9]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_seq_len):
        super(Decoder, self).__init__()
        # 词典维度
        self.src_emb = nn.Embedding(vocab_size, d_embedding)
        # 位置编码
        self.pos_emb = nn.Embedding(max_seq_len, d_embedding)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs):
        # dec_inputs: [batch_size, seq_len]

        # 创建位置编码
        positions = torch.arange(len(dec_inputs), device=dec_inputs.device).unsqueeze(-1)
        # positions: [batch_size, seq_len, 1]
        inputs_embedding = self.src_emb(dec_inputs) + self.pos_emb(positions)
        # inputs_embedding: [batch_size, seq_len, d_embedding]

        # 注意力掩码，屏蔽未来的信息
        attn_mask = get_attn_subsequent_mask(inputs_embedding).to(device)
        attn_mask = torch.gt(attn_mask, 0)
        # print(attn_mask.shape)
        # print(attn_mask.dtype)
        # attn_mask: [batch_size, seq_len, seq_len]

        dec_outputs = inputs_embedding
        for layer in self.layers:
            dec_outputs = layer(dec_outputs, attn_mask)
        return dec_outputs

In [10]:
class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len):
        super(GPT, self).__init__()

        self.decoder = Decoder(vocab_size, max_seq_len) # 解码器
        self.projection = nn.Linear(d_embedding, vocab_size) # 输出结果

    def forward(self, dec_inputs):
        dec_outputs = self.decoder(dec_inputs)
        # dec_outputs: [batch_size, tgt_len, embedding_dim]
        # 预测结果
        dec_outputs = self.projection(dec_outputs)
        # dec_outputs: [batch_size, tgt_len, vocab_size]
        return dec_outputs

In [11]:
from collections import Counter
class LanguageCorpus:
    def __init__(self, sentences):
        self.sentences = sentences

        self.seq_len = max([len(sentence.split()) for sentence in sentences]) + 2
        self.vocab = self.create_vocab()
        self.idx2word = {v: k for k, v in self.vocab.items()}

    def create_vocab(self):
        vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2}
        word_counts = Counter()
        for sentence in self.sentences:
            word_counts.update(sentence.split())
        for word in word_counts:
            if word not in vocab:
                vocab[word] = len(vocab)
        return vocab

    def make_batch(self, batch_size, test_batch=False):
        input_batch, output_batch = [], []
        # 取batch_size个句子
        sentence_idxs = torch.randperm(len(self.sentences))[:batch_size]

        for idx in sentence_idxs:
            sentence = self.sentences[idx]
            # 完整seq拼接 <sos> + 句子内容 + <eos>
            seq = [self.vocab['<sos>']] + [self.vocab[word] for word in sentence.split()] + [self.vocab['<eos>']]
            # 序列填充到seq_len长度
            seq += [self.vocab['<pad>']] * (self.seq_len - len(seq))
            input_batch.append(seq[:-1])
            output_batch.append(seq[1:])
        return torch.LongTensor(input_batch), torch.LongTensor(output_batch)        

In [12]:
sentences = []
with open("lang.txt", "r") as f:
    sentences = [line.strip() for line in f.readlines()]
corpus = LanguageCorpus(sentences)
vocab_size = len(corpus.vocab)
max_seq_len = corpus.seq_len

print(f"词汇表大小 {vocab_size}")
print(f"最长句子长度 {max_seq_len}")

词汇表大小 133
最长句子长度 17


In [13]:
learning_rate = 0.0001
epoches = 500

In [14]:
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

model = GPT(vocab_size, max_seq_len).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epoches):
    optimizer.zero_grad()
    inputs, targets = corpus.make_batch(batch_size=batch_size)
    inputs = inputs.to(device)
    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_fn(outputs.view(-1, vocab_size), targets.view(-1))

    if (epoch+1) % 100 == 0:
        print(f"epoch: {epoch+1: 04d} cost={loss:6f}")
    loss.backward()
    optimizer.step()

epoch:  100 cost=0.364337
epoch:  200 cost=0.226223
epoch:  300 cost=0.201120
epoch:  400 cost=0.211093
epoch:  500 cost=0.183992


In [15]:
def generate_text(model, input_str, max_len=50, debug=False):
    model.eval()

    input_tokens = [corpus.vocab[token] for token in input_str]
    output_tokens = input_tokens.copy()
    with torch.no_grad():
        for _ in range(max_len):
            input_tensor = torch.tensor([output_tokens]).to(device)
            # input_tensor: [1, seq_len]
            output = model(input_tensor)
            # output: [1, seq_len, vocab_size]
            next_token = torch.argmax(output[:, -1, :], dim=-1).item()
            if next_token == corpus.vocab["<eos>"]:
                break
            output_tokens.append(next_token)
    output_str = " ".join([corpus.idx2word[token] for token in output_tokens])            
    return output_str

In [16]:
intput_list = [["I", "am"], ["Python"]]

for input_str in intput_list:
    gen_text = generate_text(model, input_str)
    print("input_str: ", input_str)
    print("gen_text: ", gen_text)

input_str:  ['I', 'am']
gen_text:  I am excited to see how AI will continue to develop and change the world.
input_str:  ['Python']
gen_text:  Python is a popular programming language.


In [17]:
import pandas as pd

data_types = ["train", "test"]
wiki_datas = {}

for data_type in data_types:
    df = pd.read_csv(f"{data_type}.csv")
    print(len(df))
    df.columns = ["idx", "text"]
    res = []
    for i, item in df.iterrows():
        text = item.to_dict()["text"]
        res.append(text)
    wiki_datas[data_type] = res

def read_wikitext(data_type):
    return wiki_datas[data_type]

17514
2181


In [18]:
# from datasets import load_dataset

# wiki_dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")
# def read_wikitext(data_type, select_sentences=0):
#     # select_sentences 选用的句子数，现在数量太大跑不动..
#     ds = wiki_dataset[data_type]
#     res = []
#     idx = 0
#     for x in ds:
#         x = x['text'].strip()
#         if x == "" or len(x) < 3:
#             continue
#         if x.startswith("="):
#             continue
#         res.append(x)
#         idx += 1
#         if select_sentences > 0 and idx >= select_sentences:
#             break
#     print(f"wikitext: {data_type} has {len(res)} sentences")
#     return res
    

In [19]:
# train_iter = WikiText2(split='train') # 加载训练部分
# wiki_data_path = "../../../datas/wikitext-103/"
# def read_wikitext(file_path, ):
#     # select_sentences 选用的句子数，现在数量太大跑不动..
#     res = []
#     idx = 0
#     with open(file_path, 'r', encoding='utf-8') as f:
#         for line in f:
#             line = line.strip()
#             if line == "":
#                 continue
#             if line.startswith("="):
#                 continue
#             res.append(line)
#             idx += 1
#             if idx >= select_sentences:
#                 break
#         return res

In [22]:
import torch
import os
from torch.utils.data import DataLoader, Dataset
from utils import get_tokenizer
from vocab import build_vocab_from_iterator
# from torchtext.data.utils import get_tokenizer
# from torchtext.vocab import build_vocab_from_iterator 


tokenizer = get_tokenizer('basic_english')
train_iter = read_wikitext("train")

def yield_tokens(data_iter):
    for sentence in data_iter:
        yield tokenizer(sentence)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<pad>", "<sos>", "<eos>"])
vocab.set_default_index(vocab['<pad>'])

print("词汇表大小",     len(vocab))
print("词汇示例(word2idx)", vocab["<eos>"])

词汇表大小 65987
词汇示例(word2idx) 2


In [24]:
max_seq_len = 256

class WikiDataset(Dataset):
    def __init__(self, data_iter, vocab, max_len=max_seq_len):
        self.data = []

        for sentence in data_iter:
            # 每个句子进行tokenize, 为<sos>和<eos>留空间
            tokens = tokenizer(sentence)[:max_len-2]
            origin_chars = ["<sos>"] + tokens + ["<eos>"]
            dest_tokens = []
            for ch in origin_chars:
                dest_tokens.append(vocab[ch])
            # tokens = [vocab["<sos>"]] + vocab() + [vocab["<eos>"]]
            self.data.append(dest_tokens)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        source = self.data[idx][:-1]
        target = self.data[idx][1:]
        return torch.tensor(source), torch.tensor(target)

train_dataset = WikiDataset(train_iter, vocab)
print("Dataset数据条目数", len(train_dataset))

Dataset数据条目数 17514


In [25]:
sample_source, sample_target = train_dataset[20]
print("source:", sample_source.shape)
print("target:", sample_target.shape)

decoded_source = " ".join(vocab.lookup_tokens(sample_source.tolist()))
print("decoded source:", decoded_source)
decoded_target = " ".join(vocab.lookup_tokens(sample_target.tolist()))
print("decoded target:", decoded_target)

source: torch.Size([117])
target: torch.Size([117])
decoded source: <sos> two manga adaptations were produced , following each of the game ' s main female protagonists imca and riela . they were senjō no valkyria 3 namo naki chikai no hana ( 戦場のヴァルキュリア3 名もなき誓いの花 , lit . valkyria of the battlefield 3 the flower of the nameless oath ) , illustrated by naoyuki fujisawa and eventually released in two volumes after being serialized in dengeki maoh between 2011 and 2012 and senjō no valkyria 3 -akaki unmei no ikusa otome- ( 戦場のヴァルキュリア3 -赤き運命の戦乙女- , lit . valkyria of the battlefield 3 -the valkyrie of the crimson fate ) , illustrated by mizuki tsuge and eventually released in a single volume by kadokawa shoten in 2012 .
decoded target: two manga adaptations were produced , following each of the game ' s main female protagonists imca and riela . they were senjō no valkyria 3 namo naki chikai no hana ( 戦場のヴァルキュリア3 名もなき誓いの花 , lit . valkyria of the battlefield 3 the flower of the nameless oath ) 

In [26]:
def pad_sequence(sequences, padding_value=0, length=None):
    """
    填充序列，目的sequences token序列长度相同
    """
    max_length = max(len(seq) for seq in sequences) if length is None else length
    # 全零张量
    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)

    for i, seq in enumerate(sequences):
        end = len(seq)
        result[i, :end] = seq[:end]
    return result

def collate_fn(batch):
    """
    对batch数据进行预处理，让其src、tgt长度一致
    """
    # batch: [(src1, tgt1), (src2, tgt2), ...]
    sources, targets = zip(*batch)

    tmps = []
    tmps.extend(sources)
    tmps.extend(targets)
    max_length = max([len(s) for s in tmps])
    pad_val = vocab["<pad>"]

    sources = pad_sequence(sources, padding_value=pad_val, length=max_length)
    targets = pad_sequence(targets, padding_value=pad_val, length=max_length)
    return sources, targets

In [27]:
valid_iter = read_wikitext("test")
valid_dataset = WikiDataset(valid_iter, vocab)
print("valid_dataset", len(valid_dataset))

valid_dataset 2181


In [32]:
max_seq_len

256

In [28]:
import torch.optim as optim

batch_size = 10

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

device = "cuda" if torch.cuda.is_available() else "cpu"
# if torch.backends.mps.is_available():
#     device = "mps"
device = torch.device(device)

model = GPT(len(vocab), max_seq_len).to(device)

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
save_path = "wikitext_best.pth"
epoches = 2 # 训练x轮
min_valid_loss = float("inf")

In [29]:
for epoch in range(epoches):
    epoch_loss = 0
    # 训练模式
    for batch_idx, (source, target) in enumerate(train_dataloader):
        inputs, targets = source.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs.view(-1, len(vocab)), targets.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx + 1) % 100 == 0:
            print(f"【train】epoch: {epoch+1}, batch_idx: {batch_idx+1}, loss: {epoch_loss / (batch_idx + 1):.4f}")
        epoch_loss /= len(inputs)
    print(f"【train】epoch: {epoch+1}, loss: {epoch_loss}")

    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for source, target in valid_dataloader:
            source = source.to(device)
            target = target.to(device)
            outputs = model(source)
            loss = loss_fn(outputs.view(-1, len(vocab)), target.view(-1))
            valid_loss += loss.item()
    valid_loss /= len(valid_dataloader)
    print("【valid】Epoch: {}, Valid Loss: {:.4f}".format(epoch+1, valid_loss))
    if valid_loss < min_valid_loss:
        min_valid_loss = valid_loss
        torch.save(model.state_dict(), save_path)
    model.train()

【train】epoch: 1, batch_idx: 100, loss: 0.0350


KeyboardInterrupt: 

In [30]:
model.load_state_dict(torch.load("wikitext_best.pth", map_location=torch.device("cpu")))

<All keys matched successfully>

In [31]:
# 集束搜索
def print_candidate(candidate, prefix=""):
    special_tokens = {'<pad>', '<eos>', '<bos>', '<unk>'}
    s = " ".join([vocab.get_itos()[token] for token in candidate if vocab.get_itos()[token] not in special_tokens])
    print(prefix + f"序列: {s}")

def generate_text_beam_search(model, input_str, max_len=20, beam_width=5, debug=False):
    # model.eval()
    input_tokens = [vocab[token] for token in input_str.split()]
    # 初始化候选列表
    candidates = [(input_tokens, 0.0)]
    if debug:
        print(len(input_tokens))
    if debug:
        print_candidate(candidates[0][0], prefix="输入")
    with torch.no_grad():
        final_results = []
        for i in range(max_len): # 最多max_len个token
            new_candidates = []
            for candidate, candidate_score in candidates:
                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)
                # inputs: [1, seq_len]
                outputs = model(inputs)
                # outputs: [1, seq_len, vocab_size]
                logits = outputs[:, -1, :] # 只关心最后一步的数据
                # logits [1, vocab_size]
                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)
                # scores: [1, beam_width]
                # next_tokens: [1, beam_width]
                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):
                    new_candidate = candidate + [next_token.item()]
                    new_score = candidate_score + score.item()
                    if next_token.item() == vocab['<eos>']:
                        final_results.append((new_candidate, new_score))
                    else:
                        new_candidates.append((new_candidate, new_score))
            # print(f"第{i+1}次预测, 共有 {len(new_candidates)} 个候选 {len(final_results)}个结果集")
            # 从新生成的候选中选择最好的 beam_width 个
            candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if debug:
                print_candidate(candidates[0][0])
            # print(f"最佳候选序列的token: {[vocab.get_itos()[token] for token in best_candidate]}")
        # 将过程中的遇到<eos>的结果数据也放到候选中
        candidates.extend(final_results)
        best_candidate, _ = sorted(candidates, key=lambda x: x[1], reverse=True)[0]

        special_tokens = {'<pad>', '<eos>', '<bos>', '<unk>'}
        best_candidate_strs = [vocab.get_itos()[token] for token in best_candidate if vocab.get_itos()[token] not in special_tokens]
        
        if debug:
            print(len(best_candidate))
        return ' '.join(best_candidate_strs)

input_str = "the first"
gen_text = generate_text_beam_search(model, input_str)
print("input_str", input_str)
print("gen_text", gen_text)

input_str the first
gen_text the first down on september . the song on the song on september 11 , the song . the song on the
