In [1]:
import argparse
import math
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
from src.train_method import train,evaluate
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR

import spacy
from src.beam import beam_search_decoding, batch_beam_search_decoding
from src.model import EncoderRNN, DecoderRNN, Attention, AttnDecoderRNN, Seq2Seq

In [2]:
# utils {{{
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

def init_weights(m):
    for _, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def print_n_best(decoded_seq, itos):
    for rank, seq in enumerate(decoded_seq):
        print(f'Out: Rank-{rank+1}: {" ".join([itos[idx] for idx in seq])}')
# }}}
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

超参数

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SOS_token = '<SOS>'
EOS_token = '<EOS>'
config_dict={'batch_size': 512, 'n_epochs': 100, 
                 'enc_embd_size': 256, 'dec_embd_size': 256, 
                 'enc_h_size': 512, 'dec_h_size': 512, 
                 'beam_width': 10, 'n_best': 5, 
                 'max_dec_steps': 1000, "load_pkl":False,
                 'model_name': 's2s', 'model_path': './s2s-vanilla.pt', 
                 'skip_train': False, 'attention': False}

DATA

In [4]:
SRC = Field(tokenize=tokenize_de,
                init_token=SOS_token,
                eos_token=EOS_token,
                lower=True)
TRG = Field(tokenize=tokenize_en,
            init_token=SOS_token,
            eos_token=EOS_token,
            lower=True)
train_data, valid_data, test_data = Multi30k.splits(root=r'./',exts=('.de', '.en'), fields=(SRC, TRG))
print(f'Number of training examples: {len(train_data.examples)}')
print(f'Number of validation examples: {len(valid_data.examples)}')
print(f'Number of testing examples: {len(test_data.examples)}')

SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)
print(f'Unique tokens in source (de) vocabulary: {len(SRC.vocab)}')
print(f'Unique tokens in target (en) vocabulary: {len(TRG.vocab)}')

train_itr, valid_itr, test_itr =\
        BucketIterator.splits(
            (train_data, valid_data, test_data),
            batch_size=config_dict["batch_size"],
            device=DEVICE)
enc_v_size = len(SRC.vocab)
dec_v_size = len(TRG.vocab)



Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000
Unique tokens in source (de) vocabulary: 7853
Unique tokens in target (en) vocabulary: 5893




model

In [5]:
encoder = EncoderRNN(config_dict["enc_embd_size"], config_dict["enc_h_size"], config_dict.get("dec_h_size"), enc_v_size)
if config_dict["attention"]:
    attn = Attention(config_dict["enc_h_size"], config_dict["dec_h_size"])
    decoder = AttnDecoderRNN(config_dict["dec_embd_size"], config_dict["enc_h_size"], config_dict["dec_h_size"], dec_v_size, attn)
else:
    decoder = DecoderRNN(config_dict["dec_embd_size"], config_dict["dec_h_size"], dec_v_size)
model = Seq2Seq(encoder, decoder,DEVICE).to(DEVICE)

损失函数，优化器，预训练参数

In [6]:
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

if config_dict["load_pkl"] and config_dict["model_path"] != '':
    model.load_state_dict(torch.load(config_dict["model_path"]))
optimizer = optim.Adam(model.parameters())
scheduler = MultiStepLR(optimizer, milestones=[0,100], gamma=0.1)
loss_fn = nn.NLLLoss(ignore_index=TRG_PAD_IDX)
log_softmax = nn.LogSoftmax(dim=1)

训练过程

In [7]:
writer = SummaryWriter("runs/experiment_1")
best_valid_loss = float('inf')
for epoch in range(config_dict["n_epochs"]):
    start_time = time.time()

    train_loss = train(model, train_itr, optimizer,loss_fn,log_softmax)
    valid_loss = evaluate(model, valid_itr, loss_fn,log_softmax)
    scheduler.step()
    
    epoch_time = time.time()-start_time

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        attn_type = 'attn' if config_dict["attention"] else 'vanilla'
        model_path = f'{config_dict["model_path"]}'
        torch.save(model.state_dict(), model_path)
    writer.add_scalar('train loss', train_loss,epoch+1)
    writer.add_scalar('val loss', valid_loss,epoch+1)
    writer.add_hparams(config_dict,{
    "train_loss":train_loss,
        "val_loss":valid_loss
    },run_name="experiment_1")
    writer.add_scalar("epoch_time",epoch_time,epoch+1)
writer.close()



evl

In [None]:
model.eval()
with torch.no_grad():
    TRG_SOS_IDX = TRG.vocab.stoi[TRG.init_token]
    TRG_EOS_IDX = TRG.vocab.stoi[TRG.eos_token]
    for _, batch in enumerate(test_itr):
        src = batch.src # (T, bs)
        trg = batch.trg # (T, bs)
        print(f'In: {" ".join(SRC.vocab.itos[idx] for idx in src[:, 0])}')

        enc_outs, h = model.encoder(src) # (T, bs, H), (bs, H)
        # decoded_seqs: (bs, T)
        start_time = time.time()
        decoded_seqs = beam_search_decoding(decoder=model.decoder,
                                            enc_outs=enc_outs,
                                            enc_last_h=h,
                                            beam_width=config_dict["beam_width"],
                                            n_best=config_dict["n_best"],
                                            sos_token=TRG_SOS_IDX,
                                            eos_token=TRG_EOS_IDX,
                                            max_dec_steps=config_dict["max_dec_steps"],
                                            device=DEVICE)
        end_time = time.time()
        print(f'for loop beam search time: {end_time-start_time:.3f}')
        print_n_best(decoded_seqs[0], TRG.vocab.itos)

        start_time = time.time()
        decoded_seqs = batch_beam_search_decoding(decoder=model.decoder,
                                                    enc_outs=enc_outs,
                                                    enc_last_h=h,
                                                    beam_width=config_dict["beam_width"],
                                                    n_best=config_dict["n_best"],
                                                    sos_token=TRG_SOS_IDX,
                                                    eos_token=TRG_EOS_IDX,
                                                    max_dec_steps=config_dict["max_dec_steps"],
                                                    device=DEVICE)
        end_time = time.time()
        print(f'Batch beam search time: {end_time-start_time:.3f}')
        print_n_best(decoded_seqs[0], TRG.vocab.itos)