# Sequence-to-Sequence 介紹
- 大多數常見的 **sequence-to-sequence (seq2seq) model** 為 **encoder-decoder model**，主要由兩個部分組成，分別是 **Encoder** 和 **Decoder**，而這兩個部可以使用 **recurrent neural network (RNN)**或 **Transformer** 來實作，主要是用來解決輸入和輸出的長度不一樣的情況
- **Encoder** 是將**一連串**的輸入，如文字、影片、聲音訊號等，編碼為**單個向量**，這單個向量可以想像為是整個輸入的抽象表示，包含了整個輸入的資訊
- **Decoder** 是將 Encoder 輸出的單個向量逐步解碼，**一次輸出一個結果**，直到將最後目標輸出被產生出來為止，每次輸出會影響下一次的輸出，一般會在開頭加入 "< BOS >" 來表示開始解碼，會在結尾輸出 "< EOS >" 來表示輸出結束


![seq2seq](https://i.imgur.com/0zeDyuI.png)

# 作業介紹

# 下載和引入需要的函式庫

In [None]:
# !pip install --user torch>=1.5.0 editdistance jieba-fast sacrebleu sacremoses sentencepiece tqdm

In [None]:
# !git clone https://github.com/pytorch/fairseq.git
# !cd fairseq && git checkout 9a1c497
# !pip install --upgrade --user fairseq

In [None]:
import sys
import pdb
import pprint
import logging
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace
from fairseq import utils

import matplotlib.pyplot as plt

In [None]:
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO", # "DEBUG" "WARNING" "ERROR"
    stream=sys.stdout,
)
logger = logging.getLogger("seq2seq")

# 資料下載

# Dataset

## 英轉繁雙語言資料
* TED2020 http://opus.nlpl.eu/TED2020-v1.php, 
    - 原資料量: 404,726句
    
    - 訓練資料: 335,785句
    

### TODO: citation
@inproceedings{reimers-2020-multilingual-sentence-bert,
    title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/2004.09813",
}
### TODO: acknowledgement
J. Tiedemann, 2012, Parallel Data, Tools and Interfaces in OPUS. In Proceedings of the 8th International Conference on Language Resources and Evaluation (LREC 2012)

### 下載檔案並解壓縮

In [None]:
data_dir = 'raw'
dataset_name = 'ted2020'
url = 'https://object.pouta.csc.fi/OPUS-TED2020/v1/moses/en-zh_tw.txt.zip'
file_name = 'en-zh.zip'
prefix = os.path.join(data_dir, dataset_name)

!rm -rf {prefix}
!mkdir -p {prefix}
!wget {url} -O {prefix+'/'+file_name}
!unzip {prefix+'/'+file_name} -d {prefix}
!mv {prefix+'/'+'TED2020.en-zh_tw.en'} {prefix+'/'+'en-zh.en'}
!mv {prefix+'/'+'TED2020.en-zh_tw.zh_tw'} {prefix+'/'+'en-zh.zh'}

### 設定語言

In [None]:
src_lang = 'en'
tgt_lang = 'zh'
data_prefix = f'{prefix}/en-zh'

In [None]:
!head {data_prefix+'.'+src_lang} -n 5
!head {data_prefix+'.'+tgt_lang} -n 5

### 檔案前處理

In [None]:
import re

def strQ2B(ustring):
    """把字串全形轉半形"""
    # 參考來源:https://ithelp.ithome.com.tw/articles/10233122
    ss = []
    for s in ustring:
        rstring = ""
        for uchar in s:
            inside_code = ord(uchar)
            if inside_code == 12288:  # 全形空格直接轉換
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374):  # 全形字元（除空格）根據關係轉化
                inside_code -= 65248
            rstring += chr(inside_code)
        ss.append(rstring)
    return ''.join(ss)

def clean(input_file, output_file, lang='en'):
    with open(output_file, 'w') as out_f:
        with open(input_file, 'r') as in_f:
            for line in in_f:
                if lang == 'en':
                    line = re.sub(r"\([^()]*\)", "", line) # remove ([text])
                    line = line.replace('-', '') # remove '-'
                    line = re.sub('([.,;!?()\"])', r' \1 ', line) # keep punctuation
                elif lang == 'zh':
                    line = strQ2B(line) # Q2B
                    line = re.sub(r"\([^()]*\)", "", line) # remove ([text])
                    line = line.replace(' ', '')
                    line = line.replace('—', '')
                    line = re.sub('([。,;!?()\"~「」])', r' \1 ', line) # keep punctuation
                line = ' '.join(line.strip().split())
                print(line, file=out_f)

In [None]:
for lang in [src_lang, tgt_lang]:
    clean(f'{data_prefix}.{lang}', f'{data_prefix}.{lang}.clean', lang=lang)

In [None]:
!head {data_prefix+'.'+src_lang+'.clean'} -n 5
!head {data_prefix+'.'+tgt_lang+'.clean'} -n 5

### 切出 train/valid/test set

In [None]:
valid_ratio = 0.1
test_ratio = 0.01
train_ratio = 1 - valid_ratio - test_ratio

In [None]:
line_num = sum(1 for line in open(f'{data_prefix}.{src_lang}.clean'))
for lang in [src_lang, tgt_lang]:
    train_f = open(os.path.join(data_dir, dataset_name, f'train.{lang}.untok'), 'w')
    valid_f = open(os.path.join(data_dir, dataset_name, f'valid.{lang}.untok'), 'w')
    test_f = open(os.path.join(data_dir, dataset_name, f'test.{lang}.untok'), 'w')
    count = 0
    for line in open(f'{data_prefix}.{lang}.clean', 'r'):
        if count/line_num < train_ratio:
            train_f.write(line)
        elif count/line_num < train_ratio + valid_ratio:
            valid_f.write(line)
        else:
            test_f.write(line)
        count += 1
    train_f.close()
    valid_f.close()
    test_f.close()

### BPE encoding

In [None]:
#!pip install sentencepiece

In [None]:
import sentencepiece as spm
prefix = os.path.join(data_dir, dataset_name)
vocab_size = 30000
spm.SentencePieceTrainer.train(
    input=','.join([f'{prefix}/train.{src_lang}.untok',
                    f'{prefix}/valid.{src_lang}.untok',
                    f'{prefix}/train.{tgt_lang}.untok',
                    f'{prefix}/valid.{tgt_lang}.untok']),
    model_prefix=f'{prefix}/{vocab_size}',
    vocab_size=vocab_size,
)

In [None]:
spm_model = spm.SentencePieceProcessor(model_file=os.path.join(data_dir, dataset_name, f'{vocab_size}.model'))
for split in ['train', 'valid', 'test']:
    for lang in [src_lang, tgt_lang]:
        with open(os.path.join(data_dir, dataset_name, f'{split}.{lang}'), 'w') as out_f:
            with open(os.path.join(data_dir, dataset_name, f'{split}.{lang}.untok'), 'r') as in_f:
                for line in in_f:
                    line = line.strip()
                    tok = spm_model.encode(line, out_type=str)
                    print(' '.join(tok), file=out_f)

### 用fairseq將資料轉為binary

In [None]:
#!pip install fairseq

In [None]:
!fairseq-preprocess \
    --source-lang {src_lang}\
    --target-lang {tgt_lang}\
    --trainpref {data_dir+'/'+dataset_name+'/train'}\
    --validpref {data_dir+'/'+dataset_name+'/valid'}\
    --testpref {data_dir+'/'+dataset_name+'/test'}\
    --destdir {'./DATA/data-bin'+'/'+dataset_name}\
    --joined-dictionary\
    --workers 4

# CUDA環境

In [None]:
cuda_env = utils.CudaEnvironment()
utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
logger.info(f"using device: {device}")

# 實驗的參數設定表

In [None]:
config = Namespace(
    datadir = "./DATA/data-bin/ted2020",
    savedir = "./checkpoints-tr2",
    source_lang = "en",
    target_lang = "zh",
    
    # cpu threads when fetching & processing data.
    num_workers=8,  
    # batch size in terms of tokens. gradient accumulation increases the effective batchsize.
    max_tokens=4096,
    accum_steps=8,
    # when calculating loss, normalized by number of sentences instead of number of tokens
    sentence_average=True,
    
    # the lr s calculated from Noam lr scheduler. you can tune the maximum lr by this factor.
    lr_factor=1.0,
    lr_warmup=4000,
    
    # maximum epochs for training
    max_epoch=30,
    start_epoch=1,
    
    # beam size for beam search
    beam=5, 
    # generate sequences of maximum length ax + b, where x is the source length
    max_len_a=1.2, 
    max_len_b=10, 
    # length penalty: <1.0 favors shorter, >1.0 favors longer sentences
    lenpen=1, 
    # when decoding, post process sentence by removing sentencepiece symbols and jieba tokenization.
    post_process = "sentencepiece",
    remove_jieba = False,
    
    # checkpoints
    keep_last_epochs = 3,
    resume=None,
)

# 借用fairseq的TranslationTask
* mmap dataset非常之快
* 有現成的 dataloader iterator
* 字典 task.source_dictionary 和 task.target_dictionary 也很好用 
* 有實做 beam search

In [None]:
from fairseq.tasks.translation import TranslationConfig, TranslationTask

## setup task
task_cfg = TranslationConfig(
    data=config.datadir,
    source_lang=config.source_lang,
    target_lang=config.target_lang,
    train_subset="train",
    required_seq_len_multiple=8,
    dataset_impl="mmap",
)
task = TranslationTask.setup_task(task_cfg)

In [None]:
logger.info("loading data for epoch 1")
task.load_dataset(split="train", epoch=1)
task.load_dataset(split="valid", epoch=1)

In [None]:
sample = task.dataset("valid")[1]
pprint.pprint(sample)
pprint.pprint(
    "Source: " + \
    task.source_dictionary.string(
        sample['source'],
        config.post_process,
    )
)
pprint.pprint(
    "Target: " + \
    task.target_dictionary.string(
        sample['target'],
        config.post_process,
    )
)

## Dataset Iterator
* 將每個batch控制在N個token 讓GPU記憶體更有效被利用
* 讓training set每個epoch有不同shuffling
* 濾掉長度太長的句子
* 將每個batch內的句子pad成一樣長，好讓GPU平行運算
* 加上eos並shift一格
    - teacher forcing: 為了訓練模型根據prefix生成下個字，decoder的輸入會是輸出目標序列往右shift一格。
    - 一般是會在輸入開頭加個bos token (如下圖)
![seq2seq](https://i.imgur.com/0zeDyuI.png)
    - fairseq直接把eos挪到begining，訓練起來效果其實差不多。例如: 
    ```
    # 輸出目標 (target) 和 Decoder輸入 (prev_output_tokens): 
                   eos = 2
                target = 419,  711,  238,  888,  792,   60,  968,    8,    2
    prev_output_tokens = 2,  419,  711,  238,  888,  792,   60,  968,    8
    ```


In [None]:
def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True):
    batch_iterator = task.get_batch_iterator(
        dataset=task.dataset(split),
        max_tokens=max_tokens,
        max_sentences=None,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            max_tokens,
        ),
        ignore_invalid_inputs=True,
        seed=138,
        num_workers=num_workers,
        epoch=epoch,
        disable_iterator_cache=not cached,
        # Set this to False to speed up. However, if set to False, changing max_tokens beyond 
        # first call of this method has no effect. 
    )
    return batch_iterator

demo_epoch_obj = load_data_iterator(task, "valid", epoch=1, max_tokens=20, num_workers=1, cached=False)
demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)
sample = next(demo_iter)
sample

* 每個batch是一個字典，key是字串，value是Tensor，內容說明如下
```python
batch = {
    "id": id, # 每個example的id
    "nsentences": len(samples), # batch size 句子數
    "ntokens": ntokens, # batch size 字數
    "net_input": {
        "src_tokens": src_tokens, # 來源語言的序列
        "src_lengths": src_lengths, # 每句話沒有pad過的長度
        "prev_output_tokens": prev_output_tokens, # 上面提到右shift一格後的目標序列
    },
    "target": target, # 目標序列
}
```

# 定義模型架構
* 我們一樣繼承fairseq的encoder/decoder/model, 這樣測試階段才能直接用他寫好的beam search函式

In [None]:
from fairseq.models import (
    FairseqEncoder, 
    FairseqIncrementalDecoder,
    FairseqEncoderDecoderModel
)

## Encoder 編碼器
- seq2seq模型的編碼器為RNN或Transformer Encoder，以下說明以RNN為例，Transformer略有不同。對於每個輸入，Encoder會輸出一個向量和一個隱藏狀態(hidden state)，並將隱藏狀態用於下一個輸入。換句話說，Encoder會逐步讀取輸入序列，並在每個timestep輸出單個向量，以及在最後timestep輸出最終隱藏狀態(content vector)
- 參數:
  - *args*
      - encoder_embed_dim 是 embedding 的維度，主要將 one-hot vector 的單詞向量壓縮到指定的維度，主要是為了降維和濃縮資訊的功用
      - encoder_ffn_embed_dim 是 RNN 輸出和隱藏狀態的維度(hidden dimension)
      - encoder_layers 是 RNN 要疊多少層
      - dropout 是決定有多少的機率會將某個節點變為 0，主要是為了防止 overfitting ，一般來說是在訓練時使用，測試時則不使用
  - *dictionary*: fairseq幫我們做好的dictionary. 在此用來得到padding index，好用來得到encoder padding mask. 
  - *embed_tokens*: 事先做好的詞嵌入 (nn.Embedding)

- 輸入: 
    - *src_tokens*: 英文的整數序列 e.g. 1, 28, 29, 205, 2 
- 輸出: 
    - *outputs*: 最上層 RNN 每個timestep的輸出，後續可以用 Attention 再進行處理
    - *final_hiddens*: 每層最終timestep的隱藏狀態，將傳遞到 Decoder 進行解碼
    - *encoder_padding_mask*: 告訴我們哪些是位置的資訊不重要。


In [None]:
class RNNEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.embed_tokens = embed_tokens
        
        self.embed_dim = args.encoder_embed_dim
        self.hidden_dim = args.encoder_ffn_embed_dim
        self.num_layers = args.encoder_layers
        
        self.dropout_in_module = nn.Dropout(args.dropout)
        self.rnn = nn.GRU(
            self.embed_dim, 
            self.hidden_dim, 
            self.num_layers, 
            dropout=args.dropout, 
            batch_first=False, 
            bidirectional=True
        )
        self.dropout_out_module = nn.Dropout(args.dropout)
        
        self.padding_idx = dictionary.pad()
        
    def combine_bidir(self, outs, bsz: int):
        out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
        return out.view(self.num_layers, bsz, -1)

    def forward(self, src_tokens, **unused):
        bsz, seqlen = src_tokens.size()
        
        # get embeddings
        x = self.embed_tokens(src_tokens)
        x = self.dropout_in_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        
        # 過雙向RNN
        h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
        x, final_hiddens = self.rnn(x, h0)
        outputs = self.dropout_out_module(x)
        # outputs = [sequence len, batch size, hid dim * directions] 是最上層RNN的輸出
        # hidden =  [num_layers * directions, batch size  , hid dim]
        
        # 因為 Encoder 是雙向的RNN，所以需要將同一層兩個方向的 hidden state 接在一起
        final_hiddens = self.combine_bidir(final_hiddens, bsz)
        # hidden =  [num_layers x batch x num_directions*hidden]
        
        encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
        return tuple(
            (
                outputs,  # seq_len x batch x hidden
                final_hiddens,  # num_layers x batch x num_directions*hidden
                encoder_padding_mask,  # seq_len x batch
            )
        )
    
    def reorder_encoder_out(self, encoder_out, new_order):
        # 這個beam search時會用到，意義並不是很重要
        return tuple(
            (
                encoder_out[0].index_select(1, new_order),
                encoder_out[1].index_select(1, new_order),
                encoder_out[2].index_select(1, new_order),
            )
        )

## Attention
- 當輸入過長，或是單獨靠 “content vector” 無法取得整個輸入的意思時，用 Attention Mechanism 來提供 **Decoder** 更多的資訊
- 主要是根據現在 **Decoder hidden state** ，去計算在 **Encoder outputs** 中，那些與其有較高的關係，根據關系的數值來決定該傳給 **Decoder** 那些額外資訊 
- 常見 Attention 的實作是用 Neural Network / Dot Product 來算 **Decoder hidden state** 和 **Encoder outputs** 之間的關係，再對所有算出來的數值做 **softmax** ，最後根據過完 **softmax** 的值對 **Encoder outputs** 做 **weight sum**

- 參數:
  - *input_embed_dim*: key的維度，應是decoder要做attend時的向量的維度
  - *source_embed_dim*: query的維度，應是要被attend的向量(encoder outputs)的維度
  - *output_embed_dim*: value的維度，應是做完attention後，下一層預期的向量維度

- 輸入: 
    - *inputs*: 就是key，要attend別人的向量
    - *encoder_outputs*: 是query/value，被attend的向量
    - *encoder_padding_mask*: 告訴我們哪些是位置的資訊不重要。
- 輸出: 
    - *output*: 做完attention後的context vector
    - *attention score*: attention的分布


In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):
        super().__init__()

        self.input_proj = nn.Linear(input_embed_dim, source_embed_dim, bias=bias)
        self.output_proj = nn.Linear(
            input_embed_dim + source_embed_dim, output_embed_dim, bias=bias
        )

    def forward(self, inputs, encoder_outputs, encoder_padding_mask):
        # inputs: T, B, dim
        # encoder_outputs: S x B x dim
        # padding mask:  S x B
        
        # convert all to batch first
        inputs = inputs.transpose(1,0) # B, T, dim
        encoder_outputs = encoder_outputs.transpose(1,0) # B, S, dim
        encoder_padding_mask = encoder_padding_mask.transpose(1,0) # B, S
        
        # 投影到encoder_outputs的維度
        x = self.input_proj(inputs)

        # 計算attention
        # (B, T, dim) x (B, dim, S) = (B, T, S)
        attn_scores = torch.bmm(x, encoder_outputs.transpose(1,2))

        # 擋住padding位置的attention
        if encoder_padding_mask is not None:
            # 利用broadcast  B, S -> (B, 1, S)
            encoder_padding_mask = encoder_padding_mask.unsqueeze(1)
            attn_scores = (
                attn_scores.float()
                .masked_fill_(encoder_padding_mask, float("-inf"))
                .type_as(attn_scores)
            )  # FP16 support: cast to float and back

        # 在source對應維度softmax
        attn_scores = F.softmax(attn_scores, dim=-1)

        # 形狀 (B, T, S) x (B, S, dim) = (B, T, dim) 加權平均
        x = torch.bmm(attn_scores, encoder_outputs)

        # (B, T, dim)
        x = torch.cat((x, inputs), dim=-1)
        x = torch.tanh(self.output_proj(x)) # concat + linear + tanh
        
        # 回復形狀 (B, T, dim) -> (T, B, dim)
        return x.transpose(1,0), attn_scores

## Decoder 解碼器
* 解碼器的hidden states會用編碼器最終隱藏狀態來初始化(content vector)
* 解碼器同時也根據目前timestep的輸入(也就是前幾個timestep的output)，改變hidden states，並輸出結果 
* 如果加入attention可以使表現更好
* 我們把seq2seq步驟寫在解碼器裡，好讓等等Seq2Seq這個型別可以通用RNN和Transformer，而不用再改寫
- 參數:
  - *args*
      - decoder_embed_dim 是解碼器embedding的維度，類同encoder_embed_dim，
      - decoder_ffn_embed_dim 是解碼器RNN的隱藏維度，類同encoder_ffn_embed_dim
      - decoder_layers 解碼器RNN的層數
      - share_decoder_input_output_embed 通常decoder最後輸出的投影矩陣會和輸入embedding共用參數
  - *dictionary*: fairseq幫我們做好的dictionary.
  - *embed_tokens*: 事先做好的詞嵌入 (nn.Embedding)
- 輸入: 
    - *prev_output_tokens*: 英文的整數序列 e.g. 1, 28, 29, 205, 2 已經shift一格的target
    - *encoder_out*: 編碼器的輸出
    - *incremental_state*: 這是測試階段為了加速，所以會記錄每個timestep的hidden state 詳見forward
- 輸出: 
    - *outputs*: decoder每個timestep的logits，還沒經過softmax的分布
    - *extra*: 沒用到

In [None]:
class RNNDecoder(FairseqIncrementalDecoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.embed_tokens = embed_tokens
        
        assert args.decoder_layers == args.encoder_layers, f"""seq2seq rnn requires that encoder 
        and decoder have same layers of rnn. got: {args.encoder_layers, args.decoder_layers}"""
        assert args.decoder_ffn_embed_dim == args.encoder_ffn_embed_dim*2, f"""seq2seq-rnn requires 
        that decoder hidden to be 2*encoder hidden dim. got: {args.decoder_ffn_embed_dim, args.encoder_ffn_embed_dim*2}"""
        
        self.embed_dim = args.decoder_embed_dim
        self.hidden_dim = args.decoder_ffn_embed_dim
        self.num_layers = args.decoder_layers
        
        
        self.dropout_in_module = nn.Dropout(args.dropout)
        self.rnn = nn.GRU(
            self.embed_dim, 
            self.hidden_dim, 
            self.num_layers, 
            dropout=args.dropout, 
            batch_first=False, 
            bidirectional=False
        )
        self.attention = AttentionLayer(
            self.embed_dim, self.hidden_dim, self.embed_dim, bias=False
        ) 
        # self.attention = None
        self.dropout_out_module = nn.Dropout(args.dropout)
        
        if self.hidden_dim != self.embed_dim:
            self.project_out_dim = nn.Linear(self.hidden_dim, self.embed_dim)
        else:
            self.project_out_dim = None
        
        if args.share_decoder_input_output_embed:
            self.output_projection = nn.Linear(
                self.embed_tokens.weight.shape[1],
                self.embed_tokens.weight.shape[0],
                bias=False,
            )
            self.output_projection.weight = self.embed_tokens.weight
        else:
            self.output_projection = nn.Linear(
                self.output_embed_dim, len(dictionary), bias=False
            )
            nn.init.normal_(
                self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
            )
        
    def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **unused):
        # 取出encoder的輸出
        encoder_outputs, encoder_hiddens, encoder_padding_mask = encoder_out
        # outputs:          seq_len x batch x num_directions*hidden
        # encoder_hiddens:  num_layers x batch x num_directions*encoder_hidden
        # padding_mask:     seq_len x batch
        
        if incremental_state is not None and len(incremental_state) > 0:
            # 有上個timestep留下的資訊，讀進來就可以繼續decode，不用從bos重來
            prev_output_tokens = prev_output_tokens[:, -1:]
            cache_state = self.get_incremental_state(incremental_state, "cached_state")
            prev_hiddens = cache_state["prev_hiddens"]
        else:
            # 沒有incremental state代表這是training或者是test time時的第一步
            # 準備seq2seq: 把encoder_hiddens pass進去decoder的hidden states
            prev_hiddens = encoder_hiddens
        
        bsz, seqlen = prev_output_tokens.size()
        
        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = self.dropout_in_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
                
        # 做decoder-to-encoder attention
        if self.attention is not None:
            x, attn = self.attention(x, encoder_outputs, encoder_padding_mask)
                        
        # 過單向RNN
        x, final_hiddens = self.rnn(x, prev_hiddens)
        # outputs = [sequence len, batch size, hid dim]
        # hidden =  [num_layers * directions, batch size  , hid dim]
        x = self.dropout_out_module(x)
                
        # 投影到embedding size (如果hidden 和embed size不一樣，然後share_embedding又設成True,需要額外project一次)
        if self.project_out_dim != None:
            x = self.project_out_dim(x)
        
        # 投影到vocab size 的分佈
        x = self.output_projection(x)
        
        # T x B x C -> B x T x C
        x = x.transpose(1, 0)
        
        # 如果是Incremental, 記錄這個timestep的hidden states, 下個timestep讀回來
        cache_state = {
            "prev_hiddens": final_hiddens,
        }
        self.set_incremental_state(incremental_state, "cached_state", cache_state)
        
        return x, None
    
    def reorder_incremental_state(
        self,
        incremental_state,
        new_order,
    ):
        # 這個beam search時會用到，意義並不是很重要
        cache_state = self.get_incremental_state(incremental_state, "cached_state")
        prev_hiddens = cache_state["prev_hiddens"]
        prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]
        cache_state = {
            "prev_hiddens": torch.stack(prev_hiddens),
        }
        self.set_incremental_state(incremental_state, "cached_state", cache_state)
        return

In [None]:
class Seq2Seq(FairseqEncoderDecoderModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(encoder, decoder)
        self.args = args
    
    def forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens,
        return_all_hiddens: bool = True,
    ):
        """
        Run the forward pass for an encoder-decoder model.
        """
        encoder_out = self.encoder(
            src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
        )
        logits, extra = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
        )
        return logits, extra

In [None]:
from fairseq.models.transformer import (
    TransformerEncoder, 
    TransformerDecoder,
    TransformerModel,
)

def build_model(args, task):
    """ 按照參數設定建置模型 """
    src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

    # 詞嵌入
    encoder_embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, src_dict.pad())
    decoder_embed_tokens = nn.Embedding(len(tgt_dict), args.decoder_embed_dim, tgt_dict.pad())
    
    # 編碼器與解碼器
    encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
    decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
    
    # 序列到序列模型
    model = Seq2Seq(args, encoder, decoder)
    
    # 序列到序列模型的初始化很重要 需要特別處理
    def init_params(module):
        from fairseq.modules import MultiheadAttention
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        if isinstance(module, MultiheadAttention):
            module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
            module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
            module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.RNNBase):
            for name, param in module.named_parameters():
                if "weight" in name or "bias" in name:
                    param.data.uniform_(-0.1, 0.1)
            
    # 初始化模型
    model.apply(init_params)
    return model

# 設定模型相關參數

In [None]:
arch_args = Namespace(
    encoder_embed_dim=256,
    encoder_ffn_embed_dim=1024,
    encoder_layers=4,
    decoder_embed_dim=256,
    decoder_ffn_embed_dim=1024,
    decoder_layers=4,
    share_decoder_input_output_embed=True,
    dropout=0.2,
)

# 補上Transformer用的參數
def add_transformer_args(args):
    args.encoder_attention_heads=4
    args.encoder_normalize_before=True
    
    args.decoder_attention_heads=4
    args.decoder_normalize_before=True
    
    args.activation_fn="relu"
    args.max_source_positions=1024
    args.max_target_positions=1024
    
    # 補上我們沒有設定的Transformer預設參數
    from fairseq.models.transformer import base_architecture 
    base_architecture(arch_args)

add_transformer_args(arch_args)

In [None]:
model = build_model(arch_args, task)

In [None]:
logger.info(model)

# Label Smoothing Regularization
* 讓模型學習輸出較不集中的分佈，防止模型過度自信
* TODO: 解釋Label Smoothing

In [None]:
class LabelSmoothedCrossEntropyCriterion(nn.Module):
    def __init__(self, smoothing, ignore_index=None, reduce=True):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.reduce = reduce
    
    def forward(self, lprobs, target):
        if target.dim() == lprobs.dim() - 1:
            target = target.unsqueeze(-1)
        nll_loss = -lprobs.gather(dim=-1, index=target)
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
        if self.ignore_index is not None:
            pad_mask = target.eq(self.ignore_index)
            nll_loss.masked_fill_(pad_mask, 0.0)
            smooth_loss.masked_fill_(pad_mask, 0.0)
        else:
            nll_loss = nll_loss.squeeze(-1)
            smooth_loss = smooth_loss.squeeze(-1)
        if self.reduce:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()
        eps_i = self.smoothing / lprobs.size(-1)
        loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss
        return loss
    
criterion = LabelSmoothedCrossEntropyCriterion(
    smoothing=0.1,
    ignore_index=task.target_dictionary.pad(),
)

In [None]:
from fairseq.data import iterators
from torch.cuda.amp import GradScaler, autocast

# sentence_average = False
def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):
    itr = epoch_itr.next_epoch_itr(shuffle=True)
    itr = iterators.GroupedIterator(itr, accum_steps)
    
    stats = {"loss": []}
    # automatic mixed precision (amp)
    scaler = GradScaler()
    
    model.train()
    progress = tqdm.tqdm(itr, desc=f"train epoch {epoch_itr.epoch}", leave=False)
    for samples in progress:
        model.zero_grad()
        sample_size = 0
        for i, sample in enumerate(samples):
            if i == 1:
                # emptying the CUDA cache after the first step can reduce the chance of OOM
                torch.cuda.empty_cache()

            """gradient accumulation"""
            sample = utils.move_to_cuda(sample, device=device)
            target = sample["target"]
            sample_size_i = sample["nsentences"] if config.sentence_average else sample["ntokens"]
            sample_size += sample_size_i
            
            with autocast():
                net_output = model.forward(**sample["net_input"])
                lprobs = F.log_softmax(net_output[0], -1)            
                loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))
                
                # logging
                loss_print = loss.item()/sample_size_i
                stats["loss"].append(loss_print)
                progress.set_postfix(loss=loss_print)
                
                scaler.scale(loss).backward()                
        
        scaler.unscale_(optimizer)
        optimizer.multiply_grads(1 / (sample_size or 1.0)) # (sample_size or 1.0) handles the case of a zero gradient
        
        scaler.step(optimizer)
        scaler.update()
        
    loss_print = np.mean(stats["loss"])
    logger.info(f"training loss:\t{loss_print:.4f}")
    return stats

In [None]:

def decode(toks, dictionary, escape_unk=False, remove_jieba=False):
    s = dictionary.string(
        toks.int().cpu(),
        config.post_process,
        # The default unknown string in fairseq is `<unk>`, but
        # this is tokenized by sacrebleu as `< unk >`, inflating
        # BLEU scores. Instead, we use a somewhat more verbose
        # alternative that is unlikely to appear in the real
        # reference, but doesn't get split into multiple tokens.
        unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
    )
    # join into a sentence
    if remove_jieba:
        s = s.replace(" ", "").replace("\u2582", " ")
    return s if s else "UNKNOWNTOKENINHYP"

def inference_step(generator, sample, model):
    gen_out = generator.generate([model], sample)
    srcs = []
    hyps = []
    refs = []
    for i in range(len(gen_out)):
        """for each example, collect the source, reference 
        and most probable hypothesis (index 0) 's tokens, 
        """
        srcs.append(decode(
            utils.strip_pad(sample["net_input"]["src_tokens"][i], task.source_dictionary.pad()), 
            task.source_dictionary,
            remove_jieba=(config.source_lang == "zh") and config.remove_jieba,
        ))
        hyps.append(decode(
            gen_out[i][0]["tokens"], 
            task.target_dictionary,
            remove_jieba=(config.target_lang == "zh") and config.remove_jieba,
        ))
        refs.append(decode(
            utils.strip_pad(sample["target"][i], task.target_dictionary.pad()), 
            task.target_dictionary, 
            escape_unk=True,  # don't count <unk> as matches to the hypo
            remove_jieba=(config.target_lang == "zh") and config.remove_jieba,
        ))
    return srcs, hyps, refs

In [None]:
import shutil
import sacrebleu
from fairseq.dataclass.configs import GenerationConfig
# gen_args = GenerationConfig(beam=5, max_len_a=1.2, max_len_b=10)
sequence_generator = task.build_generator([model], config) #gen_args)

def validate(model, task, criterion):
    logger.info('begin validation')
    itr = load_data_iterator(task, "valid", 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)
    
    stats = {"loss":[], "bleu": 0, "srcs":[], "hyps":[], "refs":[]}
    srcs = []
    hyps = []
    refs = []
    
    model.eval()
    progress = tqdm.tqdm(itr, desc=f"validation", leave=False)
    with torch.no_grad():
        for i, sample in enumerate(progress):
            # validation loss
            sample = utils.move_to_cuda(sample, device=device)
            net_output = model.forward(**sample["net_input"])

            lprobs = F.log_softmax(net_output[0], -1)
            target = sample["target"]
            sample_size = sample["nsentences"] if config.sentence_average else sample["ntokens"]
            loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1)) / sample_size
            progress.set_postfix(valid_loss=loss.item())
            stats["loss"].append(loss)
            
            # compute bleu
            s, h, r = inference_step(sequence_generator, sample, model)
            srcs.extend(s)
            hyps.extend(h)
            refs.extend(r)
            
    stats["loss"] = torch.stack(stats["loss"]).mean().item()
    stats["bleu"] = sacrebleu.corpus_bleu(hyps, [refs], tokenize='zh')
    stats["srcs"] = srcs
    stats["hyps"] = hyps
    stats["refs"] = refs
    
    showid = np.random.randint(len(hyps))
    logger.info("example source: " + srcs[showid])
    logger.info("example hypothesis: " + hyps[showid])
    logger.info("example reference: " + refs[showid])
    
    # show bleu results
    logger.info(f"validation loss:\t{stats['loss']:.4f}")
    logger.info(stats["bleu"].format())
    return stats


def validate_and_save(model, task, criterion, epoch, save=True):
    
    stats = validate(model, task, criterion)
    bleu = stats['bleu']
    loss = stats['loss']
    
    if save:
        # save epoch checkpoints
        savedir = Path(config.savedir).absolute()
        savedir.mkdir(parents=True, exist_ok=True)
        
        check = {
            "model": model.state_dict(),
            "stats": {"bleu": bleu.score, "loss": loss}
        }
        torch.save(check, savedir/f"checkpoint{epoch}.pt")
        shutil.copy(savedir/f"checkpoint{epoch}.pt", savedir/f"checkpoint_last.pt")
        logger.info(f"saved epoch checkpoint: {savedir}/checkpoint{epoch}.pt")
    
        # save epoch samples
        with open(savedir/f"samples{epoch}.{config.source_lang}-{config.target_lang}.txt", "w") as f:
            for s, h in zip(stats["srcs"], stats["hyps"]):
                f.write(f"{s}\t{h}\n")

        # get best valid bleu    
        if getattr(validate_and_save, "best_bleu", 0) < bleu.score:
            validate_and_save.best_bleu = bleu.score
            torch.save(check, savedir/f"checkpoint_best.pt")
            
        del_file = savedir / f"checkpoint{epoch - config.keep_last_epochs}.pt"
        if del_file.exists():
            del_file.unlink()
    
    return stats

def try_load_checkpoint(model, name=None):
    name = name if name else "checkpoint_last.pt"
    checkpath = Path(config.savedir)/name
    if checkpath.exists():
        check = torch.load(checkpath)
        model.load_state_dict(check["model"])
        stats = check["stats"]
        logger.info(f"loaded checkpoint {checkpath}: loss={stats['loss']} bleu={stats['bleu']}")
    else:
        logger.info(f"no checkpoints found at {checkpath}!")

In [None]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
    
    @property
    def param_groups(self):
        return self.optimizer.param_groups
        
    def multiply_grads(self, c):
        """Multiplies grads by a constant *c*."""                
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.data.mul_(c)
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

In [None]:
optimizer = NoamOpt(
    model_size=arch_args.encoder_embed_dim, 
    factor=config.lr_factor, 
    warmup=config.lr_warmup, 
    optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))
plt.plot(np.arange(1, 20000), [optimizer.rate(i) for i in range(1, 20000)])
plt.legend([f"{optimizer.model_size}:{optimizer.warmup}"])
None

In [None]:
model = model.to(device=device)
criterion = criterion.to(device=device)

In [None]:
!nvidia-smi

In [None]:
logger.info("task: {}".format(task.__class__.__name__))
logger.info("encoder: {}".format(model.encoder.__class__.__name__))
logger.info("decoder: {}".format(model.decoder.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info("optimizer: {}".format(optimizer.__class__.__name__))
logger.info(
    "num. model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    )
)
logger.info(f"max tokens per batch = {config.max_tokens}, accumulate steps = {config.accum_steps}")
epoch_itr = load_data_iterator(task, "train", config.start_epoch, config.max_tokens, config.num_workers)
try_load_checkpoint(model, name=config.resume)
while epoch_itr.next_epoch_idx <= config.max_epoch:
    # train for one epoch
    train_one_epoch(epoch_itr, model, task, criterion, optimizer, config.accum_steps)
    stats = validate_and_save(model, task, criterion, epoch=epoch_itr.epoch)
    logger.info("end of epoch {}".format(epoch_itr.epoch))    
    epoch_itr = load_data_iterator(task, "train", epoch_itr.next_epoch_idx, config.max_tokens, config.num_workers)

# 成對雙語言資料
* iwslt17
* UM-Corpus的子集
    - 新聞: 450,000句
    - 口語: 220,000句
    - 教育: 450,000句
    - 字幕: 300,000 (含TED)
* TED2020 http://opus.nlpl.eu/TED2020-v1.php, 
    - 原資料量: 404,726句
    
    - 訓練資料: 335,785句
* OpenSubtitles http://opus.nlpl.eu/OpenSubtitles-v2018.php, http://www.opensubtitles.org/
    - 訓練資料: 4,772,273句
    
J. Tiedemann, 2012, Parallel Data, Tools and Interfaces in OPUS. In Proceedings of the 8th International Conference on Language Resources and Evaluation (LREC 2012)

# 中文單語言資料
- newscrawl資料集 (wmt19提供):
  - 2018訓練資料：589,475句
  - 2019訓練資料：2,974,040句
    
- 語言
  - 繁體中文: news.201X.zh.shuffled.deduped.trad (使用opencc轉換 簡->繁)

In [None]:
raise

In [None]:
try_load_checkpoint(model, name="checkpoint_last.pt")
stats = validate(model, task, criterion)

In [None]:
def _extract(args, 
    ignore=['datadir', 
            'savedir', 
            'source_lang', 
            'target_lang',
           'num_workers','keep_last_epochs','resume'
    ]
):
    kv = vars(args)
    return {k:[v] for k,v in kv.items() if not (k.startswith('_') or k in ignore)}

In [None]:
out = _extract(config)
out.update(_extract(arch_args))
import pandas as pd
pd.DataFrame.from_dict(out)