In [1]:
#!python -m pip install --upgrade pip
#!python -m pip install torchtext==0.6.0
#!python -m pip install einops

In [3]:
#!python -m pip install spacy

In [4]:
#!python -m spacy download en_core_web_sm
#!python -m spacy download zh_core_web_sm
#!python -m spacy download de_core_news_sm

#### Load dataset from the file

In [5]:
from torchtext.data import Field
from torch.utils.data import Dataset, random_split
import json
from einops import rearrange

max_dataset_size = 440000
train_set_size = 400000
valid_set_size = 40000

max_input_length = 128
max_target_length = 128
EPOCH = 20 

SRC = Field(tokenize = "spacy",
            tokenizer_language="zh_core_web_sm",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True,
            batch_first=True)

TRG = Field(tokenize = "spacy",
            tokenizer_language="en_core_web_sm",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True,
            batch_first=True)


zh_words_set, en_words_set = [[]], [[]]

class TRANS(Dataset):
    def __init__(self, data_file):
        self.en_data, self.zh_data = self.load_data(data_file)
    
    def load_data(self, data_file):
        en_data, zh_data = [], []
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f): 
                if idx >= max_dataset_size: #We limit the dataset we use
                    break
                sample = json.loads(line.strip()) #Sample: dict()

                en_data.append(TRG.tokenize(sample['english']))

                zh_data.append(SRC.tokenize(sample['chinese']))

        return en_data, zh_data #A dict(idx): EN, ZH
    
    def __len__(self):
        return len(self.en_data)

    def __getitem__(self, idx):
        return self.en_data[idx], self.zh_data[idx]

data = TRANS('data/translation2019zh_train.json')
train_data, valid_data = random_split(data, [train_set_size, valid_set_size])
test_data = TRANS('data/translation2019zh_valid.json')

SRC.build_vocab(data.zh_data, min_freq=20)
TRG.build_vocab(data.en_data, min_freq=20)

### Demo the vocab results

In [6]:
# Printing a list of tokens mapping integer to strings
print(SRC.vocab.itos[:50])
# Printing a dict mapping tokens to indices
#print(TRG.vocab.stoi)
# Printing the index of an actual word
print(SRC.vocab.stoi['游戏'])

['<unk>', '<pad>', '<sos>', '<eos>', '的', '，', '。', '在', '了', '是', '和', '、', '一', '我', '中', '他', '对', '你', '有', '一个', '“', '不', '”', '我们', '与', '：', '上', '为', '会', '这', '将', '也', '（', '）', '他们', '人', '可以', '说', '就', '个', '都', '被', '到', '而', '研究', '能', '进行', '并', '它', '从']
638


In [9]:
print(len(SRC.vocab))
print(len(TRG.vocab))

22602
21797


In [7]:
# Checking index of special tokens
import torch.nn as nn
PAD_IDX = TRG.vocab.stoi['<pad>']
SOS_IDX = TRG.vocab.stoi['<sos>']
EOS_IDX = TRG.vocab.stoi['<eos>']
UNK_IDX = TRG.vocab.stoi['<unk>']
print('pad index:', PAD_IDX)
print('sos index:', SOS_IDX)
print('eos index:', EOS_IDX)
print('unk index:', UNK_IDX)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

pad index: 1
sos index: 2
eos index: 3
unk index: 0


### Prepare dataloader

In [8]:
import torch
from torch.utils.data import DataLoader
#torch.cuda.empty_cache() #清空缓存


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

def collote_fn(batch_samples):
    en_batch, zh_batch = zip(*batch_samples)
    zh_batch = SRC.process(zh_batch) #Pad, then convert to tensor
    en_batch = TRG.process(en_batch)
    
    #Then we need to transpose: [B, seq] -> [seq, B]
    return en_batch.transpose(0,1), zh_batch.transpose(0,1)

train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=collote_fn)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collote_fn)

Using cuda device


### Consturcting the model 

In [10]:
import torch.nn as nn
import torch.optim as optim
from transformers import get_scheduler
from torch.optim import AdamW
import math

###足够大的值，使其不同次数运行的模型大小相同
INPUT_DIM = 23500
OUTPUT_DIM = 23000

ENC_EMB_DIM = 16
ATTN_DIM = 4
DROPOUT = 0.2

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(INPUT_DIM, ENC_EMB_DIM)
        self.tgt_embedding = nn.Embedding(INPUT_DIM, ENC_EMB_DIM)
        self.transformer = nn.Transformer(nhead=ATTN_DIM, num_encoder_layers=2, d_model=ENC_EMB_DIM, dropout=DROPOUT)
        self.linear = nn.Linear(ENC_EMB_DIM, OUTPUT_DIM)
        pos_dropout = 0.1
        max_seq_length = 128
        self.pos_enc = PositionalEncoding(ENC_EMB_DIM, pos_dropout, max_seq_length)
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.7, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=None):
        # TODO: Investigate masks, positional encoding, understand Rearrange(), debug model output (output has negative numbers for some reason)
        # Original src shape: (sentence length=24?, batch_size=128)
        # Original tgt shape: (sentence length=24?, batch_size=128)
        # Transformer expects: (sentence length=24, batch_size=128, embedding_size=128)

        src_emb = self.pos_enc(self.src_embedding(src) * math.sqrt(ENC_EMB_DIM))
        tgt_emb = self.pos_enc(self.tgt_embedding(tgt) * math.sqrt(ENC_EMB_DIM))
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        out = self.linear(out)
        return out

model = TransformerModel().to(device)
for p in model.parameters(): #Initialize the parameters
    if p.dim() > 1:
        nn.init.xavier_normal_(p)


optimizer = AdamW(model.parameters(), lr=1e-3) 
scheduler = get_scheduler( 
    "linear",
    optimizer=optimizer,
    num_warmup_steps=len(train_dataloader)//10,
    num_training_steps=EPOCH*len(train_dataloader),
)


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')

  from .autonotebook import tqdm as notebook_tqdm


The model has 1,699,800 trainable parameters


## Training and Evaluation


In [11]:
def gen_nopeek_mask(length):
    mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h') #[Seq_length-1, Seq_length-1], generate triangle mask
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) #Fill we certain value
    return mask


def indices_to_string(LANGUAGE, batch): #string -> idx
    words_list = []
    for sentence in batch.transpose(1, 0):
        sentence_list = sentence.tolist()
        words = []
        for index in sentence_list:
            word = LANGUAGE.vocab.itos[index]
            words.append(word)
        words_list.append(words)
    return words_list

def string_to_indices(LANGUAGE, sentence): #idx -> string
    if (LANGUAGE==TRG): #English
        words = sentence.split()
    else: #Chinese
        words = LANGUAGE.tokenize(sentence)
    indices = []
    for word in words:
        if word in LANGUAGE.vocab.stoi:
            index = LANGUAGE.vocab.stoi[word]
            indices.append(index)
        else:
            index = LANGUAGE.vocab.stoi['<unk>']
            indices.append(index)
    result = torch.tensor(indices)
    return result

def save_json(predicted_words,tgt_words,src_words): #将结果输出为json
    print('saving predicted results...')
    results = []
    for source, pred, label in zip(src_words, predicted_words, tgt_words):
        results.append({
            "sentence": source, 
            "prediction": pred, 
            "translation": label[0]
        })
    with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
        for exapmle_result in results:
            f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')

def inference(model, example_sentence_src):
    # Translate example sentence
    example_tensor_src = string_to_indices(SRC, example_sentence_src).view(-1, 1) #[Sentence_length, Batch]
    example_sentence_tgt = '<sos>' #Start index
    example_tensor_tgt = string_to_indices(TRG, example_sentence_tgt).view(-1, 1) #Also convert to [Sentence_length, Batch]
    src = example_tensor_src.to(device)
    tgt = example_tensor_tgt.to(device)

    for i in range(128): #Max length of generated sentece: 128
        ###Below is similar to the Transforemer model
        src_key_padding_mask = src == PAD_IDX
        tgt_key_padding_mask = tgt == PAD_IDX
        memory_key_padding_mask = src_key_padding_mask.clone()
        src_key_padding_mask = rearrange(src_key_padding_mask, 'n s -> s n')
        tgt_key_padding_mask = rearrange(tgt_key_padding_mask, 'n s -> s n')
        memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
        tgt_mask = gen_nopeek_mask(tgt.shape[0]).to(device)

        output = model(src, tgt, 0, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_mask=tgt_mask) #turn off teacher forcing

        output_index = torch.argmax(output, dim=2)[-1].item() #得到输出后，在vocab维度上求argmax（找到可能性最大的单词idx)
        if (output_index==4 and example_sentence_tgt[-2]==','): #多个逗号
            break
        output_word = TRG.vocab.itos[output_index] #idx -> string
        example_sentence_tgt = example_sentence_tgt + ' ' + output_word #Concatenate
        example_tensor_tgt = string_to_indices(TRG, example_sentence_tgt).view(-1, 1) #重新转换为#[Sentence_length, Batch]的tensor 作为target
        tgt = example_tensor_tgt.to(device)
        if output_word == '<eos>':
            break
    return example_sentence_tgt


### Defining training and testing functions

In [12]:
import math
from tqdm.auto import tqdm
import evaluate
bleu = evaluate.load("bleu")

train_loss_list, val_loss_list = [],[]

def train(model: nn.Module,
          dataloader: DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0
    
    progress_bar = tqdm(range(len(dataloader))) #进度条
    progress_bar.set_description(f'loss: {0:>7f}')

    for batch_idx, (tgt, src) in enumerate(dataloader):

        src = src.to(device) #ZH
        tgt = tgt.to(device) #EN

        optimizer.zero_grad()

        # Original src shape: (sentence length=24, batch_size=16)
        # Transformer expects: (sentence length=24, batch_size=16, embedding_size=128)
        src_key_padding_mask = src == PAD_IDX
        tgt_key_padding_mask = tgt == PAD_IDX
        memory_key_padding_mask = src_key_padding_mask.clone()
        src_key_padding_mask = rearrange(src_key_padding_mask, 'n s -> s n') #[batch_size, seq_length]
        t = rearrange(tgt_key_padding_mask, 'n s -> s n')
        memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
        tgt_sentence_len = tgt.shape[0] - torch.sum(tgt_key_padding_mask, axis=1)
        tgt_inp, tgt_out = tgt[:-1, :], tgt[1:, :]
        tgt_key_padding_mask = tgt_key_padding_mask[:, :-1]
        tgt_mask = gen_nopeek_mask(tgt_inp.shape[0]).to(device)

        output = model(src, tgt_inp, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_mask=tgt_mask)

        #from_one_hot = torch.argmax(output, dim=2) 
        output = output.view(-1, output.shape[-1]) #[seq_length*batch_size, vocab_size]
        tgt_out = tgt_out.contiguous().view(-1) #[seq_length*batch_size]

        loss = criterion(output, tgt_out)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.item()

        progress_bar.set_description(f'[TRAIN] loss: {epoch_loss/(batch_idx+1):>7f}')
        progress_bar.update(1) #更新进度条

    train_loss_list.append(epoch_loss / len(dataloader)) #记录当前epoch的loss
    return epoch_loss / len(dataloader)


def test(model: nn.Module,
             dataloader: DataLoader,
             criterion: nn.Module):

    model.eval()
    epoch_loss = 0

    progress_bar = tqdm(range(len(dataloader))) #进度条
    progress_bar.set_description(f'loss: {0:>7f}')

    predicted_words, tgt_words, src_words = [], [], []
    with torch.no_grad():
        for batch_idx, (tgt, src) in enumerate(dataloader):

            src = src.to(device) #ZH
            tgt = tgt.to(device) #EN

            src_key_padding_mask = src == PAD_IDX
            tgt_key_padding_mask = tgt == PAD_IDX
            memory_key_padding_mask = src_key_padding_mask.clone()
            src_key_padding_mask = rearrange(src_key_padding_mask, 'n s -> s n')
            tgt_key_padding_mask = rearrange(tgt_key_padding_mask, 'n s -> s n')
            memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
            tgt_mask = gen_nopeek_mask(tgt.shape[0]).to(device)
            output = model(src, tgt, 0, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_mask=tgt_mask) #turn off teacher forcing
            #output: [seq_length, batch_size, vocab_size]
            from_one_hot = torch.argmax(output, dim=2) #[seq_length, batch_size]

            output = output[1:].view(-1, output.shape[-1])

            tgt = tgt[1:].contiguous().view(-1)
            loss = criterion(output, tgt)
            epoch_loss += loss.item()

            progress_bar.set_description(f'[VAL] loss: {epoch_loss/(batch_idx+1):>7f}')
            progress_bar.update(1) #更新进度条
    val_loss_list.append(epoch_loss / len(dataloader)) #记录当前epoch的loss
    return epoch_loss / len(dataloader)

=====Epoch 0=====


[TRAIN] loss: 5.809503: 100%|██████████| 12500/12500 [07:58<00:00, 26.14it/s]
[VAL] loss: 7.901631: 100%|██████████| 1250/1250 [00:20<00:00, 61.34it/s]


	Train Loss: 5.810 | Train PPL: 333.453
	 Val. Loss: 7.902 |  Val. PPL: 2701.685
saving new weights...

<sos> The <unk> is a <unk> of the <unk> . <eos>
<sos> You ? <eos>
=====Epoch 1=====


[TRAIN] loss: 5.212811: 100%|██████████| 12500/12500 [08:11<00:00, 25.44it/s]
[VAL] loss: 7.749830: 100%|██████████| 1250/1250 [00:20<00:00, 60.54it/s]


	Train Loss: 5.213 | Train PPL: 183.610
	 Val. Loss: 7.750 |  Val. PPL: 2321.179
saving new weights...

<sos> The <unk> is not not not to be used to be used . <eos>
<sos> <unk> : I ? <eos>
=====Epoch 2=====


[TRAIN] loss: 5.109494: 100%|██████████| 12500/12500 [08:16<00:00, 25.18it/s]
[VAL] loss: 7.755970: 100%|██████████| 1250/1250 [00:20<00:00, 60.43it/s]


	Train Loss: 5.109 | Train PPL: 165.586
	 Val. Loss: 7.756 |  Val. PPL: 2335.473
saving new weights...

<sos> The <unk> is not a good to be used . <eos>
<sos> <unk> : You ? How ? <eos>
=====Epoch 3=====


[TRAIN] loss: 5.046818: 100%|██████████| 12500/12500 [08:15<00:00, 25.21it/s]
[VAL] loss: 7.705946: 100%|██████████| 1250/1250 [00:20<00:00, 60.38it/s]


	Train Loss: 5.047 | Train PPL: 155.527
	 Val. Loss: 7.706 |  Val. PPL: 2221.518
saving new weights...

<sos> <unk> , the <unk> can be not not not not not not not not not not not not . <eos>
<sos> What is you ? <eos>
=====Epoch 4=====


[TRAIN] loss: 5.004351: 100%|██████████| 12500/12500 [08:13<00:00, 25.35it/s]
[VAL] loss: 7.688020: 100%|██████████| 1250/1250 [00:20<00:00, 60.21it/s]


	Train Loss: 5.004 | Train PPL: 149.060
	 Val. Loss: 7.688 |  Val. PPL: 2182.050
saving new weights...

<sos> The <unk> is not easy to be used to be not to be not to be not to be <unk> . <eos>
<sos> What is you ? Are you ? <eos>
=====Epoch 5=====


[TRAIN] loss: 4.974838: 100%|██████████| 12500/12500 [08:17<00:00, 25.10it/s]
[VAL] loss: 7.693235: 100%|██████████| 1250/1250 [00:20<00:00, 60.08it/s]


	Train Loss: 4.975 | Train PPL: 144.725
	 Val. Loss: 7.693 |  Val. PPL: 2193.458
saving new weights...

<sos> So , the <unk> can be used to be used to be used . <eos>
<sos> Yes , I love ? <eos>
=====Epoch 6=====


[TRAIN] loss: 4.952798: 100%|██████████| 12500/12500 [08:13<00:00, 25.33it/s]
[VAL] loss: 7.645921: 100%|██████████| 1250/1250 [00:20<00:00, 60.06it/s]


	Train Loss: 4.953 | Train PPL: 141.571
	 Val. Loss: 7.646 |  Val. PPL: 2092.095
saving new weights...

<sos> The <unk> is not easy to be used to be used . <eos>
<sos> What is I ? <eos>
=====Epoch 7=====


[TRAIN] loss: 4.935265: 100%|██████████| 12500/12500 [08:13<00:00, 25.33it/s]
[VAL] loss: 7.660564: 100%|██████████| 1250/1250 [00:20<00:00, 60.42it/s]


	Train Loss: 4.935 | Train PPL: 139.110
	 Val. Loss: 7.661 |  Val. PPL: 2122.954
saving new weights...

<sos> The <unk> is not easy to be used to be <unk> . <eos>
<sos> Yes ? I ? I ? I ? I ? <eos>
=====Epoch 8=====


[TRAIN] loss: 4.919694: 100%|██████████| 12500/12500 [08:14<00:00, 25.27it/s]
[VAL] loss: 7.690053: 100%|██████████| 1250/1250 [00:20<00:00, 60.10it/s]


	Train Loss: 4.920 | Train PPL: 136.961
	 Val. Loss: 7.690 |  Val. PPL: 2186.489
saving new weights...

<sos> The <unk> is not a <unk> . <eos>
<sos> <unk> : Yes ? <unk> ? <eos>
=====Epoch 9=====


[TRAIN] loss: 4.905723: 100%|██████████| 12500/12500 [08:24<00:00, 24.80it/s]
[VAL] loss: 7.642820: 100%|██████████| 1250/1250 [00:20<00:00, 59.78it/s]


	Train Loss: 4.906 | Train PPL: 135.060
	 Val. Loss: 7.643 |  Val. PPL: 2085.616
saving new weights...

<sos> The <unk> is not a <unk> . <eos>
<sos> <unk> : Yes ? <unk> ? <eos>
=====Epoch 10=====


[TRAIN] loss: 4.893079: 100%|██████████| 12500/12500 [08:15<00:00, 25.25it/s]
[VAL] loss: 7.694813: 100%|██████████| 1250/1250 [00:20<00:00, 60.08it/s]


	Train Loss: 4.893 | Train PPL: 133.364
	 Val. Loss: 7.695 |  Val. PPL: 2196.923
saving new weights...

<sos> The <unk> is not easy to be <unk> . <eos>
<sos> Yes , I love ? Yes ? What ? <eos>
=====Epoch 11=====


[TRAIN] loss: 4.881420: 100%|██████████| 12500/12500 [08:22<00:00, 24.88it/s]
[VAL] loss: 7.733724: 100%|██████████| 1250/1250 [00:20<00:00, 59.92it/s]


	Train Loss: 4.881 | Train PPL: 131.818
	 Val. Loss: 7.734 |  Val. PPL: 2284.092
saving new weights...

<sos> <unk> , the <unk> can be <unk> . <eos>
<sos> Yes , I ? <unk> ? <eos>
=====Epoch 12=====


[TRAIN] loss: 4.843423:   2%|▏         | 215/12500 [00:08<08:04, 25.35it/s]

KeyboardInterrupt: 

### Train and validate

In [None]:

CLIP = 0.5

best_valid_loss = float('inf') #以loss 作为更新的指标

for epoch in range(EPOCH):

    print(f"=====Epoch {epoch}=====")
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    val_loss = test(model, valid_dataloader, criterion)

    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {val_loss:.3f} |  Val. PPL: {math.exp(val_loss):7.3f}')

    if val_loss<best_valid_loss: #储存更好的模型
        best_valid_loss = val_loss
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{epoch+1}_valid_loss_{best_valid_loss:0.2f}_model_weights.bin')


### Inference

In [None]:
model_file = "epoch_7_valid_loss_7.64_model_weights.bin" #Mutable
model.load_state_dict(torch.load(model_file)) #Loading storage parameter
inference(model, "你好嘛？") #Generate sequence