In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import unicodedata
import re
from sklearn.model_selection import train_test_split
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from collections import Counter

for module in mpl, np, pd, sklearn, torch:
    print(module.__name__, module.__version__)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(torch.cuda.get_device_name())

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)


matplotlib 3.10.0
numpy 2.0.2
pandas 2.2.3
sklearn 1.5.2
torch 2.6.0+cu126
NVIDIA GeForce GTX 1050 Ti


In [7]:
def unicode_to_ascii(string):
    return "".join(c for c in unicodedata.normalize("NFD", string) if unicodedata.category(c) != "Mn")

def preprocess_sentence(string: str):
    string = unicode_to_ascii(string.lower().strip())
    string = re.sub(r"([?.!,¿])", r" \1 ", string)
    string = re.sub(r'[" "]+', " ", string)
    string = re.sub(r"[^a-zA-Z?.!,¿]+", " ", string)
    
    return string.rstrip().strip()

In [2]:
class LangPairDataset(Dataset):
    file_path = Path(r"./data/seq2seq/data_spa_en/spa.txt")
    cache_path = Path(r"./data/seq2seq/.cache/lang_pair.npy")
    split_index = np.random.choice(a=["train", "valid"], replace=True, p=[0.9, 0.1], size=118964)
    
    def __init__(self, mode, use_cache=False):
        super().__init__()

        if use_cache or not self.cache_path.exists():
            self.cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(self.file_path, "r", encoding="utf8") as file:
                lines = file.readlines()
                lang_pair = [[preprocess_sentence(w) for w in l.split('\t')]  for l in lines]
                trg, src = zip(*lang_pair)
                trg = np.array(trg)
                src = np.array(src)
                np.save(self.cache_path, {"trg": trg, "src": src})
        else:
            lang_pair = np.load(self.cache_path, allow_pickle=True).item() #读取npy文件，allow_pickle=True允许读取字典
            trg = lang_pair["trg"]
            src = lang_pair["src"]
        self.trg = trg[self.split_index == mode]
        self.src = src[self.split_index == mode]
    
    def __getitem__(self, index):
        return self.src[index], self.trg[index]
    
    def __len__(self):
        return len(self.src)
    
train_dataset = LangPairDataset("train")
valid_dataset = LangPairDataset("valid")

In [3]:
def get_word_index_map(dataset, mode, threshold=1):
    word2index = {
        "PAD":0,
        "BOS":1,
        "UNK":2,
        "EOS":3
    }

    index2word = {index : word for word, index in word2index.items()}
    index = len(index2word)

    word_list = " ".join([pair[0 if mode=="src" else 1] for pair in dataset]).split()
    counter = Counter(word_list)

    for word, count in counter.items():
        if count >= threshold:
            word2index[word]  = index
            index2word[index] = word
            index += 1
    
    return word2index, index2word

src_word2index, src_index2word = get_word_index_map(train_dataset, "src")
trg_word2index, trg_index2word = get_word_index_map(train_dataset, "trg")

In [4]:
class Tokenizer:
    def __init__(self, word2index, index2word, max_length=500, pad_idx=0, bos_idx=1, unk_idx=2, eos_idx=3):
        self.word2index = word2index
        self.index2word = index2word
        self.max_length = max_length
        self.pad_idx = pad_idx
        self.bos_idx = bos_idx
        self.unk_idx = unk_idx
        self.eos_idx = eos_idx
    
    def encode(self, text_list, padding_first=False, add_bos=True, add_eos=True, return_mask=False):
        max_length = min(self.max_length, add_bos + max([len(text) for text in text_list]) + add_eos)
        indices_list = []
        for text in text_list:
            indices = [self.word2index.get(word, self.unk_idx) for word in text[:max_length - add_bos - add_eos]]
            
            if add_bos:
                indices = [self.bos_idx] + indices
            
            if add_eos:
                indices = indices + [self.eos_idx]
            
            if padding_first:
                indices = [self.pad_idx] * (max_length - len(indices)) + indices
            else:
                indices = indices + [self.pad_idx] * (max_length - len(indices))
            
            indices_list.append(indices)
        
        input_indices = torch.tensor(indices_list)
        mask = (input_indices == self.pad_idx).to(dtype=torch.int64)

        return input_indices if not return_mask else (input_indices, mask)
    
    def decode(self, indices_list, rm_bos=True, rm_eos=True, rm_pad=True, split=False):
        text_list = []
        for indices in indices_list:
            text = []
            for index in indices:
                word = self.index2word.get(index, "[UNK]")
                if rm_bos and word == "[BOS]":
                    continue

                if rm_eos and word == "[EOS]":
                    break

                if rm_pad and word == "[PAD]":
                    break

                text.append(word)
            text_list.append(" ".join(text) if not split else text)
        return text_list
    
src_tokenizer = Tokenizer(word2index=src_word2index, index2word=src_index2word)
trg_tokenizer = Tokenizer(word2index=trg_word2index, index2word=trg_index2word)

In [8]:
def collate_fn(batch):
    src_word = [pair[0].split() for pair in batch]
    trg_word = [pair[1].split() for pair in batch]

    # [PAD] [BOS] src [EOS]
    encoder_inputs, encoder_inputs_mask = src_tokenizer.encode(
        src_word, padding_first=True, add_bos=True, add_eos=True, return_mask=True
    )

    # [BOS] trg [EOS] [PAD]
    decoder_inputs = trg_tokenizer.encode(
        trg_word, padding_first=False, add_bos=True, add_eos=False, return_mask=False
    )

    # trg [EOS] [PAD]
    decoder_labels, decoder_labels_mask = trg_tokenizer.encode(
        trg_word, padding_first=False, add_bos=False, add_eos=True, return_mask=True
    )

    return {
        "encoder_inputs"        : encoder_inputs.to(device),
        "encoder_inputs_mask"   : encoder_inputs_mask.to(device),
        "decoder_inputs"        : decoder_inputs.to(device),
        "decoder_labels"        : decoder_labels.to(device),
        "decoder_labels_mask"   : decoder_labels_mask.to(device),
    }


In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)

    def forward(self, encoder_inputs):
        '''
            **seq_output** : [**batch_size, seq_len, hidden_dim**]\n
            **hidden**     : [**num_layers, batch_size, hidden_dim**]   
        '''
        embeds = self.embedding(encoder_inputs) # batch_size, seq_len, embedding_dim

        seq_output, hidden = self.gru(embeds)

        return seq_output, hidden

In [21]:
test_encoder = Encoder(vocab_size=100, num_layers=4)
encoder_inputs = torch.randint(0,65, (2, 10))
encoder_inputs.shape

# Inputs : (2, 10) -> (2, 10, 256) -> (2, 10, 1024)
# Hidden : ()
encoder_outputs, hidden = test_encoder(encoder_inputs)
print(encoder_outputs.shape)
print(hidden.shape)

torch.Size([2, 10, 1024])
torch.Size([4, 2, 1024])


In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim=1024):
        super().__init__()
        self.W_k = nn.Linear(hidden_dim, hidden_dim)
        self.W_q = nn.Linear(hidden_dim, hidden_dim)
        self.V   = nn.Linear(hidden_dim, 1)

    def forward(self, query, keys, values, attn_mask=None):
        """
            **context_vector**: [**batch_size, hidden_dim**]\n
            **scores**: [**batch_size, seq_len, 1**]
        """
        scores = self.V(F.tanh(self.W_k(keys) + self.W_q(query.unsqueeze(-2))))
        if attn_mask is not None:
            attn_mask = (attn_mask.unsqueeze(-1)) * -1e16
            scores += attn_mask
        scores = F.softmax(scores, dim=-2)
        context_vector = torch.mul(scores, values).sum(dim=-2)

        return context_vector, scores

W_k.weight torch.Size([1024, 1024])
W_k.bias torch.Size([1024])
W_q.weight torch.Size([1024, 1024])
W_q.bias torch.Size([1024])
V.weight torch.Size([1, 1024])
V.bias torch.Size([1])


In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim + hidden_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(0.6)
        self.attention = BahdanauAttention(hidden_dim)

    def forward(self, decoder_input, hidden, encoder_outputs, attn_mask=None):
        """
            **logits**: [**batch_size, seq_len, vocab_size**]\n
            **hidden**: [**num_layers, batch_size, hidden_dim**]\n
            **attention_score**: [**batch_size, seq_len, seq_len**]
        """
        
        context_vector, attention_score = self.attention(
            query = hidden,
            keys = encoder_outputs,
            values = encoder_outputs,
            attn_mask = attn_mask,
        )

        # decoder_input : [batch_size, seq_len]
        # embeds        : [batch_size, seq_len, embedding_dim]
        embeds = self.embedding(decoder_input)
        
        # embeds: [batch_size, seq_len, embedding_dim]
        # embeds: [batch_size, seq_len, embedding_dim + hidden_dim]
        embeds = torch.cat([context_vector.unsqueeze(-2), embeds], dim=-1)

        # seq_output: [batch_size, seq_len, hidden_dim]
        # hidden    : [num_layers, batch_size, hidden_dim]
        seq_output, hidden = self.gru(embeds)

        # logits: [batch_size, seq_len, vocab_size]
        logits = self.fc(self.dropout(seq_output))

        return logits, hidden, attention_score

In [None]:
test_decoder = Decoder(vocab_size=100, num_layers=4)



In [130]:
class Seq2Seq(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            encoder_embedding_dim=256,
            encoder_hidden_dim=1024,
            encoder_num_layers=1,
            decoder_embedding_dim=256,
            decoder_hidden_dim=1024,
            decoder_num_layers=1,
            bos_idx=1,
            eos_idx=3,
            max_length=512
        ):
        super().__init__()
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.max_length = max_length
        self.encoder = Encoder(src_vocab_size, encoder_embedding_dim, encoder_hidden_dim, encoder_num_layers)
        self.decoder = Decoder(trg_vocab_size, decoder_embedding_dim, decoder_hidden_dim, decoder_num_layers)

    def forward(self, *,encoder_inputs, decoder_inputs, attn_mask=None):
        encoder_outputs, hidden = self.encoder(encoder_inputs)

        batch_size, seq_len = decoder_inputs.shape
        logits_list = []
        scores_list = []
        for i in range(seq_len):
            logits, hidden, score = self.decoder(
                decoder_inputs[:, i:i+1],
                hidden[-1],
                encoder_outputs,
                attn_mask=attn_mask
            )

            logits_list.append(logits)
            scores_list.append(score)

        return torch.cat(logits_list, dim=-2), torch.cat(scores_list, dim=-1)
    
    @torch.no_grad()
    def inference(self, encoder_input, attn_mask=None):
        encoder_outputs, hidden = self.encoder(encoder_input)
        decoder_input = torch.Tensor([self.bos_idx]).reshape(1, 1).to(dtype=torch.int64)
        decoder_pred = None
        pred_list = []
        score_list = []
        for _ in range(self.max_length):
            logits, hidden, score = self.decoder(
                decoder_input, 
                hidden[-1], 
                encoder_outputs, 
                attn_mask=attn_mask
                )
            decoder_pred = logits.argmax(dim=-1)
            decoder_input = decoder_pred
            pred_list.append(decoder_pred.reshape(-1).item())
            score_list.append(score)

            if decoder_pred == self.eos_idx:
                break
        
        return pred_list, torch.cat(score_list, dim=-1)

In [98]:
def cross_entropy_with_padding(logits, labels, padding_mask=None):
    batch_size, seq_len, class_num = logits.shape
    loss = F.cross_entropy(logits.reshape(batch_size * seq_len, class_num), labels.reshape(-1))
    if padding_mask is None:
        loss = loss.mean()
    else:
        padding_mask = 1 - padding_mask.reshape(-1)
        loss = torch.mul(loss, padding_mask).sum() / padding_mask.sum()

    return loss

In [99]:
class EarlyStopCallback:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.best_metric = - np.inf
        self.counter = 0

    def __call__(self, metric):
        if metric >= self.best_metric + self.min_delta:
            self.best_metric = metric
            self.counter = 0
        else:
            self.counter += 1

    @property
    def early_stop(self):
        return self.counter >= self.patience


In [103]:
@torch.no_grad()
def evaluating(model, dataloader, loss_fct):
    loss_list = []
    for batch in dataloader:
        encoder_inputs = batch["encoder_inputs"]
        encoder_inputs_mask = batch["encoder_inputs_mask"]
        decoder_inputs = batch["decoder_inputs"]
        decoder_labels = batch["decoder_labels"]
        decoder_labels_mask = batch["decoder_labels_mask"]

        logits, _ = model(
            encoder_inputs=encoder_inputs,
            decoder_inputs=decoder_inputs,
            attn_mask=encoder_inputs_mask
            )
        loss = loss_fct(logits, decoder_labels, padding_mask=decoder_labels_mask)         # 验证集损失
        loss_list.append(loss.cpu().item())

    return np.mean(loss_list)


In [128]:
def training(
    model,
    train_loader,
    val_loader,
    epoch,
    loss_fct,
    optimizer,
    early_stop_callback=None,
    eval_step=500,
    ):
    record_dict = {
        "train": [],
        "valid": []
    }

    global_step = 1
    model.train()
    with tqdm(total=epoch * len(train_loader)) as pbar:
        for epoch_id in range(epoch):

            for batch in train_loader:
                encoder_inputs = batch["encoder_inputs"]
                encoder_inputs_mask = batch["encoder_inputs_mask"]
                decoder_inputs = batch["decoder_inputs"]
                decoder_labels = batch["decoder_labels"]
                decoder_labels_mask = batch["decoder_labels_mask"]

                optimizer.zero_grad()
                logits, _ = model(
                    encoder_inputs=encoder_inputs,
                    decoder_inputs=decoder_inputs,
                    attn_mask=encoder_inputs_mask
                    )
                loss = loss_fct(logits, decoder_labels, padding_mask=decoder_labels_mask)
                loss.backward()
                optimizer.step()

                loss = loss.cpu().item()

                record_dict["train"].append({
                    "loss": loss, "step": global_step
                })

                if global_step % eval_step == 0:
                    model.eval()
                    valid_loss = evaluating(model, val_loader, loss_fct)
                    record_dict["valid"].append({
                        "loss": valid_loss, "step": global_step
                    })
                    model.train()

                    if early_stop_callback is not None:
                        early_stop_callback(-valid_loss)
                        if early_stop_callback.early_stop:
                            print(f"Early stop at epoch {epoch_id} / global_step {global_step}")
                            return record_dict

                global_step += 1
                pbar.update(1)
            pbar.set_postfix({"epoch": epoch_id, "loss": loss, "valid_loss": valid_loss})

    return record_dict

In [131]:
epoch = 20
batch_size = 64

src_vocab_size = len(src_word2index)
trg_vocab_size = len(trg_word2index)

model = Seq2Seq(src_vocab_size, trg_vocab_size).to(device)

loss_function = cross_entropy_with_padding

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

early_stop_cb = EarlyStopCallback()

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

record = training(model, train_dataloader, valid_dataloader, epoch, loss_function, optimizer, early_stop_callback=early_stop_cb, eval_step=len(train_dataloader))

  0%|          | 0/33520 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
plt.plot([i["step"] for i in record["train"]], [i["loss"] for i in record["train"]], label="train")
plt.plot([i["step"] for i in record["valid"]], [i["loss"] for i in record["valid"]], label="val")
plt.grid()
plt.show()

In [None]:
class Translator:
    def __init__(self, model, src_tokenizer, trg_tokenizer):
        self.model = model
        self.model.eval()
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer

    def draw_attention_map(self, scores, src_words_list, trg_words_list):
        plt.matshow(scores.T, cmap='viridis')

        ax = plt.gca()

        for i in range(scores.shape[0]):
            for j in range(scores.shape[1]):
                ax.text(j, i, f'{scores[i, j]:.2f}',
                               ha='center', va='center', color='k')

        plt.xticks(range(scores.shape[0]), src_words_list)
        plt.yticks(range(scores.shape[1]), trg_words_list)
        plt.show()

    def __call__(self, sentence):
        sentence = preprocess_sentence(sentence)
        encoder_input, attn_mask = self.src_tokenizer.encode(
            [sentence.split()],
            padding_first=True,
            add_bos=True,
            add_eos=True,
            return_mask=True,
            )
        encoder_input = torch.Tensor(encoder_input).to(dtype=torch.int64)

        preds, scores = model.infer(encoder_input=encoder_input, attn_mask=attn_mask)

        trg_sentence = self.trg_tokenizer.decode([preds], split=True, remove_eos=False)[0]

        src_decoded = self.src_tokenizer.decode(
            encoder_input.tolist(),
            split=True,
            remove_bos=False,
            remove_eos=False
            )[0]

        self.draw_attention_map(
            scores.squeeze(0).numpy(),
            src_decoded,
            trg_sentence
            )
        return " ".join(trg_sentence[:-1])


In [None]:
translator = Translator(model.cpu(), src_tokenizer, trg_tokenizer)

In [None]:
translator(u'hace mucho frio aqui .')

In [None]:
translator(u'esta es mi vida.')