In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import spacy
import pandas as pd
import numpy as np
from spacy.symbols import ORTH
from torch import nn

In [4]:
class atisDataProcessor:
    def __init__(self, data_file, type_file, glove_path, type_emb_dim=50, dtype_emb_dim=50):
        # Tags Loading
        # self.variable_names = []
        self.var2idx = {}
        self.idx2var = {}
        self.var2idx["-"] = 0 # Add additional variable "-" as other type
        self.idx2var[0] = "-"
        self.var_idx = 1
        self.var2dtype = {}

        self.nlp = spacy.load('en_core_web_sm')
        # print(f"{self.idx2var}")
        # print(f"{self.var2idx}")
        self.name2idx = {} # tags mapping
        self.idx2name = {}
        self.name_idx = 2
        self.name2idx["-"] = 0
        self.idx2name[0] = "-"
        self.name2idx["PAD"] = 1
        self.idx2name[1] = "PAD"
        
        self.word2idx = {} # word mapping
        self.idx2word = {}
        self.word_idx = 1
        self.word2idx["PAD"] = 0
        self.idx2word[0] = "PAD"
        self.word2idx["UNK"] = 1
        self.idx2word[1] = "UNK"
        
        self.template2idx = {} # sql template mapping
        self.idx2template = {}
        self.template_idx = 0
        self.var2dtype = {} # variable & datatype mapping
        
        self.train_data = [] # Training Dataset
        self.dev_data = [] # Dev Dataset
        self.test_data = [] # Testing Dataset

        with open(type_file, 'r', encoding='utf-8') as tf:
            print(f"Loading datatype of variables for additional information on learning...")
            next(tf)
            for line in tf:
                parts = line.replace(",", "").strip().split()
                self.var2dtype[parts[1].lower()] = parts[-1].lower()
                # print(f"Loading new datatype {parts[1].lower()} : {parts[-1].lower()}")
            type_set = sorted(set(self.var2dtype.values()))
            self.dtype2idx = {t:i for i, t in enumerate(type_set)}
            self.idx2dtype = {i:t for t, i in self.dtype2idx.items()}
        # print(self.dtype2idx)
        # print(self.var2dtype)
        with open(data_file, 'r', encoding='utf-8') as df:
            print(f"Loading all data in json...")
            dataset = json.load(df)
            print(f"Loading sql template...")
            for obj in dataset:
                template = min(obj['sql'], key=len)
                template_with_default = []
                template_with_default.append(template)
                for var, value in obj['sentences'][0]['variables'].items():
                    template_with_default.append({var: value})
                if template not in self.template2idx:
                    self.template2idx[template] = self.template_idx
                    self.idx2template[self.template_idx] = template_with_default
                    self.template_idx += 1
                    print(f"add a new template: {self.template_idx}")
                    # print(f"{template_with_default}")
            print(len(self.template2idx))
            self.template_classes = len(self.template2idx)
            var_type = {}
            print(f"processing samples...")
            for obj in dataset:
                # split = obj['query-split'] # split method for query split
                for v in obj['variables']:
                    var_type[v['name']] = v['type'].lower()
                    if v['type'] not in self.var2idx:
                        self.var2idx[v['type']] = self.var_idx
                        self.idx2var[self.var_idx] = v['type']
                        self.var_idx += 1
                    if v['name'] not in self.name2idx:
                        self.name2idx[v['name']] = self.name_idx
                        self.idx2name[self.name_idx] = v['name']
                        self.name_idx += 1
                
                for sentence in obj['sentences']:
                    split = sentence['question-split'] # split method for question split
                    for var in sentence['variables'].keys():
                        self.nlp.tokenizer.add_special_case(var, [{ORTH: var}]) # add variable to special case preventing tokensisation 
                    text = sentence['text']
                    doc = self.nlp(text)
                    tokens = [tok.text.lower() for tok in doc]
                    labels = [self.name2idx['-']] * len(tokens)
                    types = [self.var2idx['-']] * len(tokens)
                    dtypes = [self.dtype2idx[self.var2dtype['-']]] * len(tokens)
                    for i, tok in enumerate(tokens):
                        if tok in var_type and var_type[tok] in self.var2idx:
                            labels[i] = self.name2idx[tok]
                            dtypes[i] = self.dtype2idx[self.var2dtype[var_type[tok]]]
                            types[i] = self.var2idx[var_type[tok]]
                        tokens_sp = [sentence['variables'].get(tok, tok) for tok in tokens]
                        template_id = self.template2idx[min(obj['sql'], key=len)]
                        sample = {'tokens': tokens_sp, 'vars': labels, 'type':types, 'dtype': dtypes, 'template': template_id, 'split': split}
                        # structure of samples:
                        # tokens: texts with tokenisation(SpaCy) and word embedding(GloVe)
                        # vars: tags of each word(default: '-') with name2idx mapping
                        # types: type of each word(default: '-') with var2idx mapping
                        # dtypes: datatype of each word(default: '-') with dtype2idx mapping for additional information support
                        # template_id: SQL template of each text, as there is probably more than one template for a text, I store the (question, sql) template with full connection
                        # split: reference by query-split/question split for dividing samples to diff datasets
                        # print(f"Add a new sample with {split}: {sample}")
                        if split == 'train':
                            self.train_data.append(sample)
                        elif split == 'dev':
                            self.dev_data.append(sample)
                        elif split == 'test':
                            self.test_data.append(sample)
                        else:
                            print(f"this sample not belongs to any dataset, adding it to training dataset..")
                            self.train_data.append(sample)
            print(f"length of training set: {len(self.train_data)}")
            print(f"length of training set: {len(self.dev_data)}")
            print(f"length of training set: {len(self.test_data)}")
        
        
        self.wordmapping()
        self.glovemapping()


    def wordmapping(self):
        # traverse all samples to construct vocabulary graph and mapping to index
        for sample in self.train_data:
            for token in sample['tokens']:
                if token not in self.word2idx:
                    self.word2idx[token] = self.word_idx
                    self.idx2word[self.word_idx] = token
                    self.word_idx += 1
                    # print(f"add a new word: {token}")

    def glovemapping(self):
        # using GloVe for embedding word vectors
        glove_dict = {}
        with open(glove_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i == 0: dims = len(line.split()) - 1
                parts = line.strip().split()
                word = parts[0]
                vec = torch.tensor([float(x) for x in parts[1:]], dtype=torch.float)
                glove_dict[word] = vec
        vocab_size = len(self.word2idx)

        self.embedding_matrix = torch.randn(vocab_size, dims) * 0.1
        self.embedding_matrix[0] = torch.zeros(dims)
        for word, idx in self.word2idx.items():
            if word in glove_dict:
                self.embedding_matrix[idx] = glove_dict[word]
        del glove_dict

    def getDataLoader(self, split="train", batch_size=32, shuffle=True):
        # return specific dataloader
        if split == "train":
            dataset = TextDataset(self.train_data)
        elif split == "dev":
            dataset = TextDataset(self.dev_data)
        elif split == "test":
            dataset = TextDataset(self.test_data)
        else:
            raise ValueError("Unknown split: {}".format(split))
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        batch_size = len(batch)
        max_len = (max(len(sample["tokens"]) for sample in batch))
        word_idx = torch.zeros(batch_size, max_len, dtype=torch.long)  # word2idx[0] = PAD
        label_idx = torch.full((batch_size, max_len), fill_value=-100, dtype=torch.long) # labels of each word
        type_idx = torch.zeros(batch_size, max_len, dtype=torch.long)   # type of labels
        dtype_idx = torch.zeros(batch_size, max_len, dtype=torch.long)  # datatype of types
        class_labels = torch.zeros(batch_size, dtype=torch.long)        # SQL template of each sample
        for i, sample in enumerate(batch):
            seq_len = len(sample["tokens"])
            for j, token in enumerate(sample["tokens"]):
                word_idx[i, j] = self.word2idx.get(token, self.word2idx['UNK'])
            label_idx[i, :seq_len] = torch.tensor(sample["vars"], dtype=torch.long)
            type_idx[i, :seq_len] = torch.tensor(sample["type"], dtype=torch.long)
            dtype_idx[i, :seq_len] = torch.tensor(sample["dtype"], dtype=torch.long)
            class_labels[i] = sample["template"]
        return word_idx, label_idx, type_idx, dtype_idx, class_labels

In [5]:
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

In [6]:
class ClassificationModels(nn.Module):
    """
    models for classification task:
    Linear:
    FFN:
    LSTM:
    Transformer:
    """
    def __init__(self, embedding_matrix, type_vocab_size, dtype_vocab_size, type_emb_dim=50, dtype_emb_dim=50, 
                 model_type="linear", hidden_dim=128, template_classes=0, tag_classes=0, num_layers=1, nhead=4):
        super(ClassificationModels, self).__init__()
        self.model_type = model_type
        vocab_size, word_emb_dim = embedding_matrix.size()
        self.word_emb = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, padding_idx=0)
        self.type_emb = nn.Embedding(type_vocab_size, type_emb_dim, padding_idx=0)
        self.dtype_emb = nn.Embedding(dtype_vocab_size, dtype_emb_dim, padding_idx=0)
        input_dim = word_emb_dim + type_emb_dim + dtype_emb_dim
        print(f"Initialize {model_type} model...")
        if model_type == "linear":
            self.fc_cls = nn.Linear(input_dim, template_classes)
            self.fc_tag = nn.Linear(input_dim, tag_classes)
        elif model_type == "feedforward":
            self.ff_cls = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, template_classes)
            )
            self.ff_tag = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, tag_classes)
            )
        elif model_type == "lstm":
            self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                                  batch_first=True, bidirectional=True)
            self.fc_cls = nn.Linear(hidden_dim*2, template_classes)
            self.fc_tag = nn.Linear(hidden_dim*2, tag_classes)
        elif model_type == "transformer":
            encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=nhead)
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.fc_cls = nn.Linear(input_dim, template_classes)
            self.fc_tag = nn.Linear(input_dim, tag_classes)
        else:
            raise ValueError("Incorrect model type")
        self.dropout = nn.Dropout(0.1)

    def forward(self, word_idx, type_idx, dtype_idx):
        word_emb = self.word_emb(word_idx)        # [batch, seq_len, word_emb_dim]
        type_emb = self.type_emb(type_idx)        # [batch, seq_len, type_emb_dim]
        dtype_emb = self.dtype_emb(dtype_idx)     # [batch, seq_len, dtype_emb_dim]
        x = torch.cat((word_emb, type_emb, dtype_emb), dim=2)  # [batch, seq_len, input_dim]
        
        if self.model_type == "linear":
            cls_feat = x.mean(dim=1)  # [batch, input_dim]
            class_logits = self.fc_cls(cls_feat)
            tag_logits = self.fc_tag(x)  # [batch, seq_len, tag_classes]
        elif self.model_type == "feedforward":
            cls_feat = x.mean(dim=1)
            class_logits = self.ff_cls(cls_feat)
            tag_logits = self.ff_tag(x)
        elif self.model_type == "lstm":
            lstm_out, _ = self.lstm(x)  # [batch, seq_len, 2*hidden_dim]
            cls_feat = lstm_out.mean(dim=1)
            class_logits = self.fc_cls(cls_feat)
            tag_logits = self.fc_tag(lstm_out)  # [batch, seq_len, tag_classes]
        elif self.model_type == "transformer":
            x_t = x.permute(1, 0, 2)  # [seq_len, batch, input_dim]
            trans_out = self.transformer(x_t)  # [seq_len, batch, input_dim]
            trans_out = trans_out.permute(1, 0, 2)  # [batch, seq_len, input_dim]
            cls_feat = trans_out.mean(dim=1)
            class_logits = self.fc_cls(cls_feat)
            tag_logits = self.fc_tag(trans_out)
        else:
            raise ValueError("Incorrect model type")
        return class_logits, tag_logits

In [7]:
def train_model(processor, model, epochs=10, lr=1e-3, weight_cls=1.0, weight_tag=1.0, patience=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = "cpu"
    model = model.to(device)
    criterion_cls = nn.CrossEntropyLoss()
    criterion_tag = nn.CrossEntropyLoss(ignore_index=-100)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_val_acc = 0.0
    counter = 0
    
    for epoch in range(1, epochs+1):
        model.train()
        train_corr_cls = 0
        train_corr_tag = 0
        train_total_cls = 0
        train_total_tag = 0
        train_loader = processor.getDataLoader("train", shuffle=True)
        for word_idx, label_idx, type_idx, dtype_idx, class_labels in train_loader:
            word_idx = word_idx.to(device)
            label_idx = label_idx.to(device)
            type_idx = type_idx.to(device)
            dtype_idx = dtype_idx.to(device)
            class_labels = class_labels.to(device)
            
            optimizer.zero_grad()
            class_logits, tag_logits = model(word_idx, type_idx, dtype_idx)
            # print(f"class_logit: {class_logits}")
            # print(f"class_labels: {class_labels}")
            # print("class_logits shape:", class_logits.shape)      # [B, num_classes]
            # print("class_labels max:", class_labels.max().item())

            loss_cls = criterion_cls(class_logits, class_labels)

            loss_tag = criterion_tag(tag_logits.permute(0, 2, 1), label_idx)
            loss = weight_cls * loss_cls + weight_tag * loss_tag
            loss.backward()
            torch.cuda.synchronize()
            optimizer.step()
            
            preds = class_logits.argmax(dim=1)
            train_corr_cls += (preds == class_labels).sum().item()
            train_total_cls += class_labels.size(0)
            pred_tags = tag_logits.argmax(dim=2)  # [batch, seq_len]
            mask = (label_idx != -100)
            train_corr_tag += ((pred_tags == label_idx) & mask).sum().item()
            train_total_tag += mask.sum().item()
        
        train_acc_cls = train_corr_cls / train_total_cls if train_total_cls > 0 else 0
        train_acc_tag = train_corr_tag / train_total_tag if train_total_tag > 0 else 0
        
        model.eval()
        val_corr_cls = 0
        val_corr_tag = 0
        val_total_cls = 0
        val_total_tag = 0
        with torch.no_grad():
            val_loader = processor.getDataLoader("dev", shuffle=False)
            for word_idx, label_idx, type_idx, dtype_idx, class_labels in val_loader:
                word_idx = word_idx.to(device)
                type_idx = type_idx.to(device)
                dtype_idx = dtype_idx.to(device)
                class_labels = class_labels.to(device)
                label_idx = label_idx.to(device)
                class_logits, tag_logits = model(word_idx, type_idx, dtype_idx)
                
                preds = class_logits.argmax(dim=1)
                val_corr_cls += (preds == class_labels).sum().item()
                val_total_cls += class_labels.size(0)
                
                pred_tags = tag_logits.argmax(dim=2)
                mask = (label_idx != -100)
                val_corr_tag += ((pred_tags == label_idx) & mask).sum().item()
                val_total_tag += mask.sum().item()
        val_acc_cls = val_corr_cls / val_total_cls if val_total_cls > 0 else 0
        val_acc_tag = val_corr_tag / val_total_tag if val_total_tag > 0 else 0
        
        print(f"Epoch {epoch}: Train_cls_acc={train_acc_cls:.4f}, Train_tag_acc={train_acc_tag:.4f}, " +
              f"Val_cls_acc={val_acc_cls:.4f}, Val_tag_acc={val_acc_tag:.4f}")
        
        if val_acc_cls > best_val_acc:
            best_val_acc = val_acc_cls
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("The accuracy of dev set seems not increase for 3 epoches, stop training...")
                break

In [8]:
def evaluate_model(processor, model, batch_size=32):
    device = next(model.parameters()).device
    model.eval()
    corr_cls = corr_tag = total_cls = total_tag = 0
    loader = processor.getDataLoader(split="test", batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for word_idx, label_idx, type_idx, dtype_idx, class_labels in loader:
            word_idx = word_idx.to(device)
            type_idx = type_idx.to(device)
            dtype_idx = dtype_idx.to(device)
            class_labels = class_labels.to(device)
            label_idx = label_idx.to(device)
            class_logits, tag_logits = model(word_idx, type_idx, dtype_idx)

            class_logits, tag_logits = model(word_idx, type_idx, dtype_idx)

            preds = class_logits.argmax(dim=1)
            corr_cls += (preds == class_labels).sum().item()
            total_cls += class_labels.size(0)

            pred_tags = tag_logits.argmax(dim=2)
            mask = (label_idx != -100)
            corr_tag += ((pred_tags == label_idx) & mask).sum().item()
            total_tag += mask.sum().item()

    acc_cls = corr_cls / total_cls if total_cls else 0
    acc_tag = corr_tag / total_tag if total_tag else 0
    return acc_cls, acc_tag

In [9]:
def inference(model, processor, question):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    tokens = [tok.text for tok in processor.nlp(question.strip())]
    word_idxs = torch.tensor([[ processor.word2idx.get(tok.lower(), processor.word2idx["UNK"]) 
                                 for tok in tokens ]], dtype=torch.long)
    type_idxs = torch.zeros_like(word_idxs)
    dtype_idxs = torch.zeros_like(word_idxs)
    word_idxs = word_idxs.to(device)
    type_idxs = type_idxs.to(device)
    dtype_idxs = dtype_idxs.to(device)
    with torch.no_grad():
        class_logits, tag_logits = model(word_idxs, type_idxs, dtype_idxs)
    pred_class = class_logits.argmax(dim=1).item()
    pred_tags = tag_logits.argmax(dim=2).squeeze(0).tolist()  # [seq_len]
    variables = []
    for tok, tag in zip(tokens, pred_tags):
        if tag != 0:
            type_name = processor.idx2type.get(tag, "UNK")
            variables.append((tok, type_name))
    return pred_class, variables

In [10]:
class GenerationModels(nn.Module):
    """
    models for generation task:
    LSTM:
    LSTM(With Attention):
    Transformer:
    """
    def __init__(self, *args):
        pass

    def forward(self, *args):
        pass
    

In [11]:
tags_file = "atis-fields.txt"        
data_file = "atis.json"
type_file = "atis-schema.csv"     
glove_path = "glove.6B.50d.txt"  
    
processor = atisDataProcessor(data_file, type_file, glove_path, type_emb_dim=50, dtype_emb_dim=50)

Loading datatype of variables for additional information on learning...
Loading all data in json...
Loading sql template...
add a new template: 1
add a new template: 2
add a new template: 3
add a new template: 4
add a new template: 5
add a new template: 6
add a new template: 7
add a new template: 8
add a new template: 9
add a new template: 10
add a new template: 11
add a new template: 12
add a new template: 13
add a new template: 14
add a new template: 15
add a new template: 16
add a new template: 17
add a new template: 18
add a new template: 19
add a new template: 20
add a new template: 21
add a new template: 22
add a new template: 23
add a new template: 24
add a new template: 25
add a new template: 26
add a new template: 27
add a new template: 28
add a new template: 29
add a new template: 30
add a new template: 31
add a new template: 32
add a new template: 33
add a new template: 34
add a new template: 35
add a new template: 36
add a new template: 37
add a new template: 38
add a new t

  vec = torch.tensor([float(x) for x in parts[1:]], dtype=torch.float)


In [None]:
model_set = ["linear", "feedforward", "lstm", "transformer"]
for model_type in model_set:
    model = ClassificationModels(embedding_matrix=processor.embedding_matrix,
                           type_vocab_size=len(processor.var2idx),
                           dtype_vocab_size=len(processor.dtype2idx),
                           type_emb_dim=50, dtype_emb_dim=50,
                           model_type=model_type,
                           hidden_dim=128,
                           template_classes=processor.template_classes,
                           tag_classes=len(processor.name2idx),
                           num_layers=1, nhead=5)
    print(f"===============Training Model=================")
    train_model(processor, model, epochs=20, lr=1e-3, patience=3)
    print(f"===============Testing Model=================")
    acc_cls, acc_tag = evaluate_model(processor, model)
    print(f"Test set {model_type} —  Classification Acc: {acc_cls:.4f},  Tagging Acc: {acc_tag:.4f}")

Initialize linear model...


In [None]:
question = "show me the flights arriving at MKE"
pred_id, vars_detected = inference(model, processor, question)
print(f"template ID: {pred_id}")
print(f"variables: {vars_detected}")

In [None]:
# LSTM seq2seq
class Seq2SeqLSTM(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, emb_dim=100, hidden_dim=256):
        super().__init__()
        # 输入和输出的词嵌入层
        self.encoder_embedding = nn.Embedding(input_vocab_size, emb_dim)
        self.encoder = nn.LSTM(emb_dim, hidden_dim, batch_first=True)

        self.decoder_embedding = nn.Embedding(output_vocab_size, emb_dim)
        self.decoder = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, output_vocab_size)

    def forward(self, src, tgt):
        # Encoder
        src_emb = self.encoder_embedding(src)
        _, (hidden, cell) = self.encoder(src_emb)

        # Decoder
        tgt_emb = self.decoder_embedding(tgt)
        output, _ = self.decoder(tgt_emb, (hidden, cell))
        logits = self.out(output)
        return logits


# Attention LSTM seq2seq
class Seq2SeqLSTMWithAttention(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, emb_dim=100, hidden_dim=256):
        super().__init__()
        self.encoder_embedding = nn.Embedding(input_vocab_size, emb_dim)
        self.encoder = nn.LSTM(emb_dim, hidden_dim, batch_first=True)

        self.decoder_embedding = nn.Embedding(output_vocab_size, emb_dim)
        self.decoder = nn.LSTM(emb_dim + hidden_dim, hidden_dim, batch_first=True)
        self.attn = nn.Linear(hidden_dim + emb_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_vocab_size)

    def forward(self, src, tgt):
        # Input
        src_emb = self.encoder_embedding(src)
        encoder_outputs, (hidden, cell) = self.encoder(src_emb)

        # Attention
        tgt_emb = self.decoder_embedding(tgt)
        outputs = []
        for t in range(tgt.size(1)):
            emb_t = tgt_emb[:, t:t+1, :]  # 当前时刻的输入
            # 计算注意力权重
            attn_weights = torch.bmm(emb_t, encoder_outputs.transpose(1, 2))
            attn_weights = torch.softmax(attn_weights, dim=-1)
            context = torch.bmm(attn_weights, encoder_outputs)  # 加权求和
            rnn_input = torch.cat((emb_t, context), dim=-1)  
            output, (hidden, cell) = self.decoder(rnn_input, (hidden, cell))
            logits = self.out(output)  # 输出词预测结果
            outputs.append(logits)
        return torch.cat(outputs, dim=1)


# Transformer
class Seq2SeqTransformer(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, emb_dim=256, num_heads=4, num_layers=2):
        super().__init__()
        # 词嵌入层
        self.src_embedding = nn.Embedding(input_vocab_size, emb_dim)
        self.tgt_embedding = nn.Embedding(output_vocab_size, emb_dim)

        # Transformer 
        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=512,
            dropout=0.1,
            batch_first=True
        )
        self.out = nn.Linear(emb_dim, output_vocab_size)

    def forward(self, src, tgt):
        src_emb = self.src_embedding(src)
        tgt_emb = self.tgt_embedding(tgt)
        memory = self.transformer.encoder(src_emb)  # 编码输入序列
        output = self.transformer.decoder(tgt_emb, memory)  # 解码目标序列
        return self.out(output)  # 输出词预测结果