# install all necessary libraries

In [None]:
!pip install tensorboardX

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install loguru

In [None]:
pip install -r requirements.txt

In [None]:
!pip install pypinyin

In [None]:
!pip install jieba

In [None]:
pip install tqdm

In [None]:
!pip install torch

In [None]:
!pip install torchvision

# SET UP

In [1]:
import argparse
import os
import sys
import numpy as np
from loguru import logger
import torch
import torch.nn as nn
import torch.nn.functional as F
import operator
from sklearn.model_selection import train_test_split
sys.path.append('../')
sys.path.append('../..')
#from pycorrector.seq2seq.data_reader import *
#from pycorrector.seq2seq.train import *
#from pycorrector.seq2seq.infer import *

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--raw_train_path",
                    default="../pycorrector/data/cn/sighan_2015/train.tsv", type=str,
                    help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
                    )
parser.add_argument("--dataset", default="sighan", type=str,
                    help="Dataset name. selected in the list:" + ", ".join(["sighan", "cged"])
                    )
parser.add_argument("--use_segment", action="store_true", help="Whether not to segment train data")
parser.add_argument("--do_preprocess", action="store_true", default="True",help="Whether not to preprocess train data")
parser.add_argument("--segment_type", default="char", type=str,
                        help="Segment data type, selected in list: " + ", ".join(["char", "word"]))
parser.add_argument("--model_name_or_path",
                    default="bert-base-chinese", type=str,
                    help="Path to pretrained model or model identifier from huggingface.co/models",
                    )
parser.add_argument("--model_dir", default="output/RNA/", type=str, help="Dir for model save.")
parser.add_argument("--arch", default="seq2seq", type=str,
                    help="The name of the task to train selected in the list: " + ", ".join(
                        ['seq2seq', 'convseq2seq', 'bertseq2seq']),
                    )
parser.add_argument("--train_path", default="output/train.txt", type=str, help="Train file after preprocess.")
parser.add_argument("--test_path", default="output/test.txt", type=str, help="Test file after preprocess.")
parser.add_argument("--max_length", default=500, type=int,
                    help="The maximum total input sequence length after tokenization. \n"
                            "Sequences longer than this will be truncated, sequences shorter padded.",
                    )
parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
parser.add_argument("--embed_size", default=128, type=int, help="Embedding size.")
parser.add_argument("--hidden_size", default=128, type=int, help="Hidden size.")
parser.add_argument("--dropout", default=0.25, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=100, type=int, help="Epoch num.")

args = parser.parse_args([])

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
data_list = []
data_list.extend(get_data_file("../pycorrector/data/RNA/train", args.use_segment, args.segment_type))

NameError: name 'get_data_file' is not defined

# preprocess.py

In [7]:
# get data_list from path
def get_data_file(path, use_segment, segment_type):
    '''
        params: (str,bool, str) -> list(list(str,str))
        pupose: get data_list from path
    '''
    data_list = []
    if not os.path.exists(path):
        print('%s not exists' % path)
        return data_list
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith("#"):
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            source = ' '.join(segment(parts[0].strip(), cut_type=segment_type)) if use_segment else parts[0].strip()
            target = ' '.join(segment(parts[1].strip(), cut_type=segment_type)) if use_segment else parts[1].strip()

            pair = [source, target]
            if pair not in data_list:
                data_list.append(pair)
    return data_list


def _save_data(data_list, data_path):
    '''
        params: (list(list(str)), str) -> empty
        purpose: save data_list to data_path
    '''
    dirname = os.path.dirname(data_path)
    os.makedirs(dirname, exist_ok=True)
    with open(data_path, 'w', encoding='utf-8') as f:
        count = 0
        for src, dst in data_list:
            f.write(src + '\t' + dst + '\n')
            count += 1
        print("save line size:%d to %s" % (count, data_path))


def save_corpus_data(data_list, train_data_path, test_data_path):
    '''
        params: (list(list(str)), str, str) -> empty
        purpose: split the data_list to train and test and then save to train_data_path and test_data_path
    '''
    train_lst, test_lst = train_test_split(data_list, test_size=0.1)
    _save_data(train_lst, train_data_path)
    _save_data(test_lst, test_data_path)


# data_reader.py

In [5]:
# Define constants associated with the usual special tokens.
SOS_TOKEN = '<sos>' # represent the start of a sequence
EOS_TOKEN = '<eos>' # represent the end of a sequence
UNK_TOKEN = '<unk>' # unknown token -  is used to replace the rare words that did not fit in your vocabulary.
PAD_TOKEN = '<pad>' # all the sentence in a batch should have the same length, it will be used to pad the sequence to fit the length
class CscDataset(object):
    '''
        purpose: a dataset class that load the data from a json file and return the data_list
    '''
    def __init__(self, file_path):
        self.data = json.load(open(file_path, 'r', encoding='utf-8'))

    def load(self):
        '''
            params: empty -> list(str)
        '''
        data_list = []
        for item in self.data:
            data_list.append(item['original_text'] + '\t' + item['correct_text'])
        return data_list

def create_dataset(path, num_examples=None, split_on_space=False):
    '''
        params: str, int, bool -> list(list(list(str),list(str)))
        purpose: generate datalist from path to list of tokens
    '''
    if path.endswith('.json'):
        d = CscDataset(path)
        lines = d.load()
    else:
        lines = open(path, 'r', encoding='utf-8').read().strip().split('\n')
    word_pairs = [[preprocess_sentence(s, split_on_space) for s in l.split('\t')] for l in lines[:num_examples]]
    return zip(*word_pairs)


def preprocess_sentence(sentence, split_on_space=False):
    '''
        purpose: give a sentence, slipt the string into tokens 
    '''
    # adding a start and an end token to the sentence
    # so that the model know when to start and stop predicting.
    return [SOS_TOKEN] + (sentence.split() if split_on_space else list(sentence)) + [EOS_TOKEN]


In [6]:
from collections import Counter
'''
class Counter(object):
    def __init__(self):
        self.counter = {}
    def update(self, token):
'''    


def save_word_dict(dict_data, save_path):
    '''
        purpose: save the dict_data to save_path
    '''
    with open(save_path, 'w', encoding='utf-8') as f:
        for k, v in dict_data.items():
            f.write("%s\t%d\n" % (k, v))


def load_word_dict(save_path):
    '''
        purpose: load the word_dict from save_path and return
    '''
    dict_data = dict()
    num = 0
    with open(save_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip('\n')
            items = line.split('\t')
            num += 1
            try:
                dict_data[items[0]] = int(items[1])
            except IndexError:
                logger.error('IndexError, index:%s, line:%s' % (num, line))
    return dict_data
def read_vocab(input_texts, max_size=None, min_count=0):
    '''
        purpose: count the number of each vocabs and return the vocab2id dict
    '''
    token_counts = Counter()
    special_tokens = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]

    for texts in input_texts:
        #for token in texts:
        token_counts.update(texts)
    del token_counts[SOS_TOKEN]
    del token_counts[EOS_TOKEN]
    # Sort word count by value
    count_pairs = token_counts.most_common()
    vocab = [k for k, v in count_pairs if v >= min_count]
    # Insert the special tokens to the beginning
    vocab[0:0] = special_tokens
    full_token_id = list(zip(vocab, range(len(vocab))))[:max_size]
    vocab2id = dict(full_token_id)
    return vocab2id

In [9]:
src_vocab_path = os.path.join(args.model_dir, 'vocab_source.txt')

source_texts, target_texts = create_dataset(args.train_path, split_on_space=True)
src_2_ids = read_vocab(source_texts)
save_word_dict(src_2_ids, src_vocab_path)

src_2_ids = load_word_dict(src_vocab_path)
src_2_ids

ValueError: not enough values to unpack (expected 2, got 1)

In [8]:
def prepare_data(seqs, max_length=None):
    if max_length:
        seqs = [seq[:max_length] for seq in seqs]
    lengths = [len(seq) for seq in seqs]
    n_samples = len(seqs)
    max_len = np.max(lengths)

    x = np.zeros((n_samples, max_len)).astype('int32')
    x_lengths = np.array(lengths).astype("int32")
    for idx, seq in enumerate(seqs):
        x[idx, :lengths[idx]] = seq
    return x, x_lengths  # x_mask
def get_minibatches(n, minibatch_size, shuffle=True):
    idx_list = np.arange(0, n, minibatch_size)  # [0, 1, ..., n-1]
    if shuffle:
        np.random.shuffle(idx_list)
    minibatches = []
    for idx in idx_list:
        minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
    return minibatches
def gen_examples(src_sentences, trg_sentences, batch_size, max_length=None):
    minibatches = get_minibatches(len(src_sentences), batch_size)
    examples = []
    for minibatch in minibatches:
        mb_src_sentences = [src_sentences[t] for t in minibatch]
        mb_trg_sentences = [trg_sentences[t] for t in minibatch]
        mb_x, mb_x_len = prepare_data(mb_src_sentences, max_length)
        mb_y, mb_y_len = prepare_data(mb_trg_sentences, max_length)
        examples.append((mb_x, mb_x_len, mb_y, mb_y_len))
    return examples
def one_hot(src_sentences, trg_sentences, src_dict, trg_dict, sort_by_len=True):
    """vector the sequences.
    """

    out_src_sentences = [[src_dict.get(w, 0) for w in sent] for sent in src_sentences]
    out_trg_sentences = [[trg_dict.get(w, 0) for w in sent] for sent in trg_sentences]

    # sort sentences by english lengths
    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    # sort length
    if sort_by_len:
        sorted_index = len_argsort(out_src_sentences)
        out_src_sentences = [out_src_sentences[i] for i in sorted_index]
        out_trg_sentences = [out_trg_sentences[i] for i in sorted_index]

    return out_src_sentences, out_trg_sentences

# seq2se2.py

In [9]:
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description: 
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        '''
            torch.nn.Embedding(numembeddings, embeddingdim)
                * numembeddings代表一共有多少个词
                * embedding_dim代表每个词创建一个多少维的向量来表示他
        '''
        self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
        '''
            torch.nn.GRU(input_size, hidden_size, num_layers, bias,batch_first,dropout,bidirectional)
                * input_size: the number of expected features in the input x
                * hidden_size: the number of features in the hidden state h
                * batch_first: if True, then (batch, seq, feature), else (seq, batch, feature)
                * bidirectional: if True, becomes a bidirectional GRU. Default: False
            
        '''
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)

    def forward(self, x, lengths):
        # 将x根据长度来排序
        sorted_len, sorted_idx = lengths.sort(0, descending=True)
        x_sorted = x[sorted_idx.long()]
        embedded = self.dropout(self.embed(x_sorted))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(),
                                                            batch_first=True)
        '''
            https://zhuanlan.zhihu.com/p/34418001
            torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False,enforce_sorted=True)
            当我们进行batch个训练数据一起计算的时候，我们会遇到多个训练样例长度不同时的情况，这样我们就会很自然的进行padding，
            将短句子padding为跟最长的句子一样
            
            pytorch中RNN处理变长padding主要用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行。
            
            
            
        '''
        packed_out, hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()

        hid = torch.cat([hid[-2], hid[-1]], dim=1)
        hid = torch.tanh(self.fc(hid)).unsqueeze(0)

        return out, hid


class Attention(nn.Module):
    """
    Luong Attention,根据context vectors和当前的输出hidden states，计算输出
    """

    def __init__(self, enc_hidden_size, dec_hidden_size):
        super(Attention, self).__init__()

        self.enc_hidden_size = enc_hidden_size
        self.dec_hidden_size = dec_hidden_size

        self.linear_in = nn.Linear(enc_hidden_size * 2, dec_hidden_size, bias=False)
        self.linear_out = nn.Linear(enc_hidden_size * 2 + dec_hidden_size, dec_hidden_size)

    def forward(self, output, context, mask):
        # output: batch_size, output_len, dec_hidden_size
        # context: batch_size, context_len, 2*enc_hidden_size

        batch_size = output.size(0)
        output_len = output.size(1)
        input_len = context.size(1)

        context_in = self.linear_in(context.view(batch_size * input_len, -1)).view(
            batch_size, input_len, -1)  # batch_size, context_len, dec_hidden_size

        # context_in.transpose(1,2): batch_size, dec_hidden_size, context_len
        # output: batch_size, output_len, dec_hidden_size
        attn = torch.bmm(output, context_in.transpose(1, 2))
        # batch_size, output_len, context_len

        attn.data.masked_fill(mask, -1e6)

        attn = F.softmax(attn, dim=2)
        # batch_size, output_len, context_len

        context = torch.bmm(attn, context)
        # batch_size, output_len, enc_hidden_size

        output = torch.cat((context, output), dim=2)  # batch_size, output_len, hidden_size*2

        output = output.view(batch_size * output_len, -1)
        output = torch.tanh(self.linear_out(output))
        output = output.view(batch_size, output_len, -1)
        return output, attn


class Decoder(nn.Module):
    """
    decoder会根据已经翻译的句子内容，和context vectors，来决定下一个输出的单词
    """

    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(enc_hidden_size, dec_hidden_size)
        self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True)
        self.out = nn.Linear(dec_hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_mask(self, x_len, y_len):
        # a mask of shape x_len * y_len
        max_x_len = x_len.max()
        max_y_len = y_len.max()
        x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
        y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
        mask = ~ x_mask[:, :, None] * y_mask[:, None, :]
        return mask

    def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx.long()]
        hid = hid[:, sorted_idx.long()]

        y_sorted = self.dropout(self.embed(y_sorted))  # batch_size, output_length, embed_size

        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
        out, hid = self.rnn(packed_seq, hid)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        output_seq = unpacked[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()

        mask = self.create_mask(y_lengths, ctx_lengths)

        output, attn = self.attention(output_seq, ctx, mask)
        output = F.log_softmax(self.out(output), -1)

        return output, hid, attn


class Seq2Seq(nn.Module):
    """
    Seq2Seq, 最后我们构建Seq2Seq模型把encoder, attention, decoder串到一起
    """

    def __init__(self,
                 encoder_vocab_size,
                 decoder_vocab_size,
                 embed_size,
                 enc_hidden_size,
                 dec_hidden_size,
                 dropout,
                 ):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(vocab_size=encoder_vocab_size,
                               embed_size=embed_size,
                               enc_hidden_size=enc_hidden_size,
                               dec_hidden_size=dec_hidden_size,
                               dropout=dropout)
        self.decoder = Decoder(vocab_size=decoder_vocab_size,  # len(trg_2_ids),
                               embed_size=embed_size,
                               enc_hidden_size=enc_hidden_size,
                               dec_hidden_size=dec_hidden_size,
                               dropout=dropout)

    def forward(self, x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid, attn = self.decoder(ctx=encoder_out,
                                         ctx_lengths=x_lengths,
                                         y=y,
                                         y_lengths=y_lengths,
                                         hid=hid)
        return output, attn

    def translate(self, x, x_lengths, y, max_length=128):
        print(len(x))
        encoder_out, hid = self.encoder(x, x_lengths)
        preds = []
        batch_size = x.shape[0]
        attns = []
        for i in range(max_length):
            output, hid, attn = self.decoder(ctx=encoder_out,
                                             ctx_lengths=x_lengths,
                                             y=y,
                                             y_lengths=torch.ones(batch_size).long().to(y.device),
                                             hid=hid)
            
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
            
            attns.append(attn)
        return torch.cat(preds, 1), torch.cat(attns, 1)


class LanguageModelCriterion(nn.Module):
    """
    masked cross entropy loss
    """

    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def forward(self, input, target, mask):
        # input: (batch_size * seq_len) * vocab_size
        input = input.contiguous().view(-1, input.size(2))
        # target: batch_size * 1
        target = target.contiguous().view(-1, 1)
        mask = mask.contiguous().view(-1, 1)
        output = -input.gather(1, target) * mask
        output = torch.sum(output) / torch.sum(mask)

        return output


# infer.py

In [33]:
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤', '\t', '֍', '玕', '', '《', '》']


def get_errors(corrected_text, origin_text):
    
    corrected_lst = corrected_text.split()
    origin_lst = origin_text.split()
    sub_details = []
    for i, ori_token in enumerate(origin_lst):
        if i >= len(corrected_lst):
            continue
        if ori_token != corrected_lst[i]:
            sub_details.append((ori_token, corrected_lst[i],i,i+1))
    return corrected_text, sub_details
    
    '''
    print(corrected_text)
    print(origin_text)
    sub_details = []
    for i, ori_char in enumerate(origin_text):
        if i >= len(corrected_text):
            continue
        if ori_char in unk_tokens:
            # deal with unk word
            corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
            continue
        if ori_char != corrected_text[i]:
            sub_details.append((ori_char, corrected_text[i], i, i + 1))
    sub_details = sorted(sub_details, key=operator.itemgetter(2))
    return corrected_text, sub_details
    '''

class Inference(object):
    def __init__(self, model_dir, arch='convseq2seq',
                 embed_size=128, hidden_size=128, dropout=0.25, max_length=500):
        logger.debug("Device: {}".format(device))
        logger.debug(f'Use {arch} model.')
        if arch in ['seq2seq', 'convseq2seq']:
            src_vocab_path = os.path.join(model_dir, 'vocab_source.txt')
            trg_vocab_path = os.path.join(model_dir, 'vocab_target.txt')
            self.src_2_ids = load_word_dict(src_vocab_path)
            self.trg_2_ids = load_word_dict(trg_vocab_path)
            self.id_2_trgs = {v: k for k, v in self.trg_2_ids.items()}
            if arch == 'seq2seq':
                self.model = Seq2Seq(encoder_vocab_size=len(self.src_2_ids),
                                     decoder_vocab_size=len(self.trg_2_ids),
                                     embed_size=embed_size,
                                     enc_hidden_size=hidden_size,
                                     dec_hidden_size=hidden_size,
                                     dropout=dropout).to(device)
                model_path = os.path.join(model_dir, 'seq2seq.pth')
                logger.debug('Load model from {}'.format(model_path))
                self.model.load_state_dict(torch.load(model_path, map_location=device))
                self.model.eval()
            else:
                trg_pad_idx = self.trg_2_ids[PAD_TOKEN]
                self.model = ConvSeq2Seq(encoder_vocab_size=len(self.src_2_ids),
                                         decoder_vocab_size=len(self.trg_2_ids),
                                         embed_size=embed_size,
                                         enc_hidden_size=hidden_size,
                                         dec_hidden_size=hidden_size,
                                         dropout=dropout,
                                         trg_pad_idx=trg_pad_idx,
                                         device=device,
                                         max_length=max_length).to(device)
                model_path = os.path.join(model_dir, 'convseq2seq.pth')
                self.model.load_state_dict(torch.load(model_path, map_location=device))
                logger.debug('Load model from {}'.format(model_path))
                self.model.eval()
        elif arch == 'bertseq2seq':
            # Bert Seq2seq model
            use_cuda = True if torch.cuda.is_available() else False

            # encoder_type=None, encoder_name=None, decoder_name=None
            self.model = Seq2SeqModel("bert", "{}/encoder".format(model_dir),
                                      "{}/decoder".format(model_dir), use_cuda=use_cuda)
        else:
            logger.error('error arch: {}'.format(arch))
            raise ValueError("Model arch choose error. Must use one of seq2seq model.")
        self.arch = arch
        self.max_length = max_length

    def predict(self, sentence_list):
        result = []
        if self.arch in ['seq2seq', 'convseq2seq']:
            for query in sentence_list:
                out = []
                tokens = query.split()
                tokens = [SOS_TOKEN] + tokens + [EOS_TOKEN]
                src_ids = [self.src_2_ids[i] for i in tokens if i in self.src_2_ids]

                sos_idx = self.trg_2_ids[SOS_TOKEN]
                if self.arch == 'seq2seq':
                    src_tensor = torch.from_numpy(np.array(src_ids).reshape(1, -1)).long().to(device)
                    src_tensor_len = torch.from_numpy(np.array([len(src_ids)])).long().to(device)
                    sos_tensor = torch.Tensor([[self.trg_2_ids[SOS_TOKEN]]]).long().to(device)
                    translation, attn = self.model.translate(src_tensor, src_tensor_len, sos_tensor, self.max_length)
                    translation = [self.id_2_trgs[i] for i in translation.data.cpu().numpy().reshape(-1) if
                                   i in self.id_2_trgs]
                else:
                    src_tensor = torch.from_numpy(np.array(src_ids).reshape(1, -1)).long().to(device)
                    translation, attn = self.model.translate(src_tensor, sos_idx)
                    translation = [self.id_2_trgs[i] for i in translation if i in self.id_2_trgs]
                for word in translation:
                    if word != EOS_TOKEN:
                        out.append(word)
                    else:
                        break
                corrected_text = ' '.join(out) # 已修改
                corrected_text, sub_details = get_errors(corrected_text, query)
                result.append([corrected_text, sub_details])
        else:
            corrected_sents = self.model.predict(sentence_list)
            corrected_sents = [i.replace(' ', '') for i in corrected_sents]
            for c, s in zip(corrected_sents, sentence_list):
                c = c.replace(' ', '')
                c, sub_details = get_errors(c, s)
                result.append([c, sub_details])
        return result


# train.py

In [29]:
def evaluate_seq2seq_model(model, data, device, loss_fn):
    model.eval()
    total_num_words = 0.
    total_loss = 0.
    with torch.no_grad():
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device).long()
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len - 1).to(device).long()
            mb_y_len[mb_y_len <= 0] = 1

            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
            mb_out_mask = mb_out_mask.float()

            loss = loss_fn(mb_pred, mb_output, mb_out_mask)

            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words
    loss = total_loss / total_num_words
    return loss


def train_seq2seq_model(model, train_data, device, loss_fn, optimizer, model_dir, epochs=20):
    best_loss = 1e3
    train_data, dev_data = train_test_split(train_data, test_size=0.1, shuffle=True)

    for epoch in range(epochs):
        model.train()
        total_num_words = 0.
        total_loss = 0.
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(train_data):
            mb_x = torch.from_numpy(mb_x).to(device).long()
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len - 1).to(device).long()
            mb_y_len[mb_y_len <= 0] = 1

            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
            mb_out_mask = mb_out_mask.float()

            loss = loss_fn(mb_pred, mb_output, mb_out_mask)

            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words

            # update optimizer
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()

            if it % 100 == 0:
                logger.info("Epoch :{}/{}, iteration :{}/{} loss:{:.4f}".format(epoch, epochs, it, len(train_data),
                                                                                loss.item()))
        cur_loss = total_loss / total_num_words
        logger.info("Epoch :{}/{}, training loss:{:.4f}".format(epoch, epochs, cur_loss))
        if epoch % 1 == 0:
            if dev_data:
                eval_loss = evaluate_seq2seq_model(model, dev_data, device, loss_fn)
                logger.info('Epoch:{}, dev loss:{:.4f}'.format(epoch, eval_loss))
                cur_loss = eval_loss
            # find best model
            is_best = cur_loss < best_loss
            best_loss = min(cur_loss, best_loss)
            if is_best:
                model_path = os.path.join(model_dir, 'seq2seq.pth')
                torch.save(model.state_dict(), model_path)
                logger.info('Epoch:{}, save new bert model:{}'.format(epoch, model_path))


In [30]:
def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs,
          model_dir, max_length, use_segment, model_name_or_path):
    logger.info("device: {}".format(device))
    arch = arch.lower()
    logger.debug(f'use {arch} model.')
    os.makedirs(model_dir, exist_ok=True)
    if arch in ['seq2seq', 'convseq2seq']:
        src_vocab_path = os.path.join(model_dir, 'vocab_source.txt')
        trg_vocab_path = os.path.join(model_dir, 'vocab_target.txt')

        source_texts, target_texts = create_dataset(train_path, split_on_space=True)
        logger.debug("source_texts:",source_texts)
        src_2_ids = read_vocab(source_texts)
        trg_2_ids = read_vocab(target_texts)
        save_word_dict(src_2_ids, src_vocab_path)
        save_word_dict(trg_2_ids, trg_vocab_path)

        src_2_ids = load_word_dict(src_vocab_path)
        trg_2_ids = load_word_dict(trg_vocab_path)
        id_2_srcs = {v: k for k, v in src_2_ids.items()}
        id_2_trgs = {v: k for k, v in trg_2_ids.items()}
        train_src, train_trg = one_hot(source_texts, target_texts, src_2_ids, trg_2_ids, sort_by_len=True)

        logger.debug(f'src: {[id_2_srcs[i] for i in train_src[0]]}')
        logger.debug(f'trg: {[id_2_trgs[i] for i in train_trg[0]]}')

        train_data = gen_examples(train_src, train_trg, batch_size, max_length)

        if arch == 'seq2seq':
            # Normal seq2seq
            model = Seq2Seq(encoder_vocab_size=len(src_2_ids),
                            decoder_vocab_size=len(trg_2_ids),
                            embed_size=embed_size,
                            enc_hidden_size=hidden_size,
                            dec_hidden_size=hidden_size,
                            dropout=dropout).to(device)
            logger.info(model)
            loss_fn = LanguageModelCriterion().to(device)
            optimizer = torch.optim.Adam(model.parameters())

            train_seq2seq_model(model, train_data, device, loss_fn, optimizer, model_dir, epochs=epochs)
        else:
            # Conv seq2seq model
            trg_pad_idx = trg_2_ids[PAD_TOKEN]
            model = ConvSeq2Seq(encoder_vocab_size=len(src_2_ids),
                                decoder_vocab_size=len(trg_2_ids),
                                embed_size=embed_size,
                                enc_hidden_size=hidden_size,
                                dec_hidden_size=hidden_size,
                                dropout=dropout,
                                trg_pad_idx=trg_pad_idx,
                                device=device,
                                max_length=max_length).to(device)
            logger.info(model)
            loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
            optimizer = torch.optim.Adam(model.parameters())

            train_convseq2seq_model(model, train_data, device, loss_fn, optimizer, model_dir, epochs=epochs)
    elif arch == 'bertseq2seq':
        # Bert Seq2seq model
        model_args = {
            "reprocess_input_data": True,
            "overwrite_output_dir": True,
            "max_seq_length": max_length if max_length else 128,
            "train_batch_size": batch_size if batch_size else 8,
            "num_train_epochs": epochs if epochs else 10,
            "save_eval_checkpoints": False,
            "save_model_every_epoch": False,
            "silent": False,
            "evaluate_generated_text": True,
            "evaluate_during_training": True,
            "evaluate_during_training_verbose": True,
            "best_model_dir": os.path.join(model_dir, 'best_model'),
            "use_multiprocessing": False,
            "save_best_model": True,
            "max_length": max_length if max_length else 128,  # The maximum length of the sequence to be generated.
            "output_dir": model_dir if model_dir else "./output/bertseq2seq/",
        }

        use_cuda = True if torch.cuda.is_available() else False
        # encoder_type=None, encoder_name=None, decoder_name=None
        # encoder_name="bert-base-chinese"
        model = Seq2SeqModel("bert", model_name_or_path, model_name_or_path, args=model_args, use_cuda=use_cuda)

        logger.info('start train bertseq2seq ...')
        data = load_bert_data(train_path, use_segment)
        logger.info(f'load data done, data size: {len(data)}')
        logger.debug(f'data samples: {data[:10]}')
        train_data, dev_data = train_test_split(data, test_size=0.1, shuffle=False)

        train_df = pd.DataFrame(train_data, columns=['input_text', 'target_text'])
        dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text'])

        def count_matches(labels, preds):
            logger.debug(f"labels: {labels[:10]}")
            logger.debug(f"preds: {preds[:10]}")
            match = sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
            logger.debug(f"match: {match}")
            return match

        model.train_model(train_df, eval_data=dev_df, matches=count_matches)
    else:
        logger.error('error arch: {}'.format(arch))
        raise ValueError("Model arch choose error. Must use one of seq2seq model.")

In [36]:
if args.do_preprocess:
    # Preprocess
    data_list = []
    os.makedirs(args.model_dir, exist_ok=True)
    '''
    if args.dataset == 'sighan':
        data_list.extend(get_data_file(args.raw_train_path, args.use_segment, args.segment_type))
    else:
        data_list.extend(parse_xml_file(args.raw_train_path, args.use_segment, args.segment_type))
    '''
    # get the data list
    data_list.extend(get_data_file("../pycorrector/data/RNA/train", args.use_segment, args.segment_type))
    if data_list:
        # save the datalist to a file
        save_corpus_data(data_list, args.train_path, args.test_path)
# Train model with train data file
train(args.arch, 
        args.train_path, # specify which model architecture to use
        args.batch_size,
        args.embed_size,
        args.hidden_size,
        args.dropout,
        args.epochs,
        args.model_dir,
        args.max_length,
        args.use_segment,
        args.model_name_or_path,
        )

2022-09-19 14:47:55.988 | INFO     | __main__:train:3 - device: cuda
2022-09-19 14:47:55.989 | DEBUG    | __main__:train:5 - use seq2seq model.


save line size:41487 to output/train.txt
save line size:4610 to output/test.txt


2022-09-19 14:47:57.928 | DEBUG    | __main__:train:12 - source_texts:
2022-09-19 14:48:01.305 | DEBUG    | __main__:train:24 - src: ['<sos>', 'ASN', 'GLN', '<eos>']
2022-09-19 14:48:01.305 | DEBUG    | __main__:train:25 - trg: ['<sos>', 'ASN', 'GLN', '<eos>']
2022-09-19 14:48:02.077 | INFO     | __main__:train:37 - Seq2Seq(
  (encoder): Encoder(
    (embed): Embedding(24, 128)
    (rnn): GRU(128, 128, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.25, inplace=False)
    (fc): Linear(in_features=256, out_features=128, bias=True)
  )
  (decoder): Decoder(
    (embed): Embedding(24, 128)
    (attention): Attention(
      (linear_in): Linear(in_features=256, out_features=128, bias=False)
      (linear_out): Linear(in_features=384, out_features=128, bias=True)
    )
    (rnn): GRU(128, 128, batch_first=True)
    (out): Linear(in_features=128, out_features=24, bias=True)
    (dropout): Dropout(p=0.25, inplace=False)
  )
)
2022-09-19 14:48:02.135 | INFO     | __main__:train

2022-09-19 14:50:37.400 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :400/1167 loss:0.0256
2022-09-19 14:50:40.313 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :500/1167 loss:0.0089
2022-09-19 14:50:43.572 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :600/1167 loss:0.0044
2022-09-19 14:50:46.703 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :700/1167 loss:0.0548
2022-09-19 14:50:49.761 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :800/1167 loss:0.0120
2022-09-19 14:50:52.648 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :900/1167 loss:0.0374
2022-09-19 14:50:55.401 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :1000/1167 loss:0.0109
2022-09-19 14:50:58.254 | INFO     | __main__:train_seq2seq_model:62 - Epoch :4/100, iteration :1100/1167 loss:0.0131
2022-09-19 14:51:00.226 | INFO     | __main__:train_seq2seq_mo

2022-09-19 14:53:40.045 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :500/1167 loss:0.0071
2022-09-19 14:53:43.309 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :600/1167 loss:0.0037
2022-09-19 14:53:46.451 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :700/1167 loss:0.0180
2022-09-19 14:53:49.506 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :800/1167 loss:0.0078
2022-09-19 14:53:52.402 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :900/1167 loss:0.0278
2022-09-19 14:53:55.153 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :1000/1167 loss:0.0075
2022-09-19 14:53:58.014 | INFO     | __main__:train_seq2seq_model:62 - Epoch :9/100, iteration :1100/1167 loss:0.0068
2022-09-19 14:53:59.986 | INFO     | __main__:train_seq2seq_model:65 - Epoch :9/100, training loss:0.0096
2022-09-19 14:54:01.391 | INFO     | __main__:train_seq2seq_model:69 - Ep

2022-09-19 14:56:39.858 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :500/1167 loss:0.0059
2022-09-19 14:56:43.120 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :600/1167 loss:0.0057
2022-09-19 14:56:46.266 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :700/1167 loss:0.0134
2022-09-19 14:56:49.328 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :800/1167 loss:0.0045
2022-09-19 14:56:52.234 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :900/1167 loss:0.0281
2022-09-19 14:56:54.993 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :1000/1167 loss:0.0092
2022-09-19 14:56:57.860 | INFO     | __main__:train_seq2seq_model:62 - Epoch :14/100, iteration :1100/1167 loss:0.0064
2022-09-19 14:56:59.835 | INFO     | __main__:train_seq2seq_model:65 - Epoch :14/100, training loss:0.0089
2022-09-19 14:57:01.251 | INFO     | __main__:train_seq2seq_model

2022-09-19 14:59:36.918 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :400/1167 loss:0.0158
2022-09-19 14:59:39.808 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :500/1167 loss:0.0069
2022-09-19 14:59:43.075 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :600/1167 loss:0.0075
2022-09-19 14:59:46.213 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :700/1167 loss:0.0141
2022-09-19 14:59:49.268 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :800/1167 loss:0.0085
2022-09-19 14:59:52.167 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :900/1167 loss:0.0255
2022-09-19 14:59:54.918 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :1000/1167 loss:0.0036
2022-09-19 14:59:57.779 | INFO     | __main__:train_seq2seq_model:62 - Epoch :19/100, iteration :1100/1167 loss:0.0059
2022-09-19 14:59:59.750 | INFO     | __main__:train_se

2022-09-19 15:02:39.558 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :500/1167 loss:0.0037
2022-09-19 15:02:42.815 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :600/1167 loss:0.0032
2022-09-19 15:02:45.953 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :700/1167 loss:0.0182
2022-09-19 15:02:49.006 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :800/1167 loss:0.0041
2022-09-19 15:02:51.909 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :900/1167 loss:0.0269
2022-09-19 15:02:54.659 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :1000/1167 loss:0.0167
2022-09-19 15:02:57.510 | INFO     | __main__:train_seq2seq_model:62 - Epoch :24/100, iteration :1100/1167 loss:0.0060
2022-09-19 15:02:59.483 | INFO     | __main__:train_seq2seq_model:65 - Epoch :24/100, training loss:0.0077
2022-09-19 15:03:00.884 | INFO     | __main__:train_seq2seq_model

2022-09-19 15:05:39.294 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :500/1167 loss:0.0046
2022-09-19 15:05:42.556 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :600/1167 loss:0.0035
2022-09-19 15:05:45.691 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :700/1167 loss:0.0242
2022-09-19 15:05:48.741 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :800/1167 loss:0.0080
2022-09-19 15:05:51.641 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :900/1167 loss:0.0315
2022-09-19 15:05:54.400 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :1000/1167 loss:0.0044
2022-09-19 15:05:57.247 | INFO     | __main__:train_seq2seq_model:62 - Epoch :29/100, iteration :1100/1167 loss:0.0071
2022-09-19 15:05:59.217 | INFO     | __main__:train_seq2seq_model:65 - Epoch :29/100, training loss:0.0082
2022-09-19 15:06:00.619 | INFO     | __main__:train_seq2seq_model

2022-09-19 15:08:42.312 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :600/1167 loss:0.0009
2022-09-19 15:08:45.447 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :700/1167 loss:0.0174
2022-09-19 15:08:48.505 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :800/1167 loss:0.0029
2022-09-19 15:08:51.394 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :900/1167 loss:0.0249
2022-09-19 15:08:54.151 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :1000/1167 loss:0.0044
2022-09-19 15:08:56.993 | INFO     | __main__:train_seq2seq_model:62 - Epoch :34/100, iteration :1100/1167 loss:0.0058
2022-09-19 15:08:58.966 | INFO     | __main__:train_seq2seq_model:65 - Epoch :34/100, training loss:0.0076
2022-09-19 15:09:00.419 | INFO     | __main__:train_seq2seq_model:69 - Epoch:34, dev loss:0.0072
2022-09-19 15:09:00.433 | INFO     | __main__:train_seq2seq_model:62 - Epoch :35/100, 

2022-09-19 15:11:42.063 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :600/1167 loss:0.0030
2022-09-19 15:11:45.186 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :700/1167 loss:0.0194
2022-09-19 15:11:48.251 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :800/1167 loss:0.0025
2022-09-19 15:11:51.140 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :900/1167 loss:0.0267
2022-09-19 15:11:53.899 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :1000/1167 loss:0.0040
2022-09-19 15:11:56.749 | INFO     | __main__:train_seq2seq_model:62 - Epoch :39/100, iteration :1100/1167 loss:0.0049
2022-09-19 15:11:58.724 | INFO     | __main__:train_seq2seq_model:65 - Epoch :39/100, training loss:0.0072
2022-09-19 15:12:00.126 | INFO     | __main__:train_seq2seq_model:69 - Epoch:39, dev loss:0.0072
2022-09-19 15:12:00.140 | INFO     | __main__:train_seq2seq_model:62 - Epoch :40/100, 

2022-09-19 15:14:44.996 | INFO     | __main__:train_seq2seq_model:62 - Epoch :44/100, iteration :700/1167 loss:0.0179
2022-09-19 15:14:48.058 | INFO     | __main__:train_seq2seq_model:62 - Epoch :44/100, iteration :800/1167 loss:0.0029
2022-09-19 15:14:50.940 | INFO     | __main__:train_seq2seq_model:62 - Epoch :44/100, iteration :900/1167 loss:0.0263
2022-09-19 15:14:53.699 | INFO     | __main__:train_seq2seq_model:62 - Epoch :44/100, iteration :1000/1167 loss:0.0056
2022-09-19 15:14:56.567 | INFO     | __main__:train_seq2seq_model:62 - Epoch :44/100, iteration :1100/1167 loss:0.0051
2022-09-19 15:14:58.545 | INFO     | __main__:train_seq2seq_model:65 - Epoch :44/100, training loss:0.0077
2022-09-19 15:14:59.942 | INFO     | __main__:train_seq2seq_model:69 - Epoch:44, dev loss:0.0073
2022-09-19 15:14:59.956 | INFO     | __main__:train_seq2seq_model:62 - Epoch :45/100, iteration :0/1167 loss:0.0483
2022-09-19 15:15:02.669 | INFO     | __main__:train_seq2seq_model:62 - Epoch :45/100, it

2022-09-19 15:17:49.326 | INFO     | __main__:train_seq2seq_model:62 - Epoch :49/100, iteration :800/1167 loss:0.0052
2022-09-19 15:17:52.234 | INFO     | __main__:train_seq2seq_model:62 - Epoch :49/100, iteration :900/1167 loss:0.0247
2022-09-19 15:17:54.993 | INFO     | __main__:train_seq2seq_model:62 - Epoch :49/100, iteration :1000/1167 loss:0.0058
2022-09-19 15:17:57.860 | INFO     | __main__:train_seq2seq_model:62 - Epoch :49/100, iteration :1100/1167 loss:0.0051
2022-09-19 15:17:59.833 | INFO     | __main__:train_seq2seq_model:65 - Epoch :49/100, training loss:0.0077
2022-09-19 15:18:01.229 | INFO     | __main__:train_seq2seq_model:69 - Epoch:49, dev loss:0.0089
2022-09-19 15:18:01.243 | INFO     | __main__:train_seq2seq_model:62 - Epoch :50/100, iteration :0/1167 loss:0.0433
2022-09-19 15:18:03.964 | INFO     | __main__:train_seq2seq_model:62 - Epoch :50/100, iteration :100/1167 loss:0.0077
2022-09-19 15:18:07.192 | INFO     | __main__:train_seq2seq_model:62 - Epoch :50/100, it

2022-09-19 15:20:52.389 | INFO     | __main__:train_seq2seq_model:62 - Epoch :54/100, iteration :900/1167 loss:0.0248
2022-09-19 15:20:55.150 | INFO     | __main__:train_seq2seq_model:62 - Epoch :54/100, iteration :1000/1167 loss:0.0067
2022-09-19 15:20:58.011 | INFO     | __main__:train_seq2seq_model:62 - Epoch :54/100, iteration :1100/1167 loss:0.0057
2022-09-19 15:20:59.990 | INFO     | __main__:train_seq2seq_model:65 - Epoch :54/100, training loss:0.0075
2022-09-19 15:21:01.427 | INFO     | __main__:train_seq2seq_model:69 - Epoch:54, dev loss:0.0080
2022-09-19 15:21:01.440 | INFO     | __main__:train_seq2seq_model:62 - Epoch :55/100, iteration :0/1167 loss:0.0381
2022-09-19 15:21:04.164 | INFO     | __main__:train_seq2seq_model:62 - Epoch :55/100, iteration :100/1167 loss:0.0039
2022-09-19 15:21:07.380 | INFO     | __main__:train_seq2seq_model:62 - Epoch :55/100, iteration :200/1167 loss:0.0044
2022-09-19 15:21:10.314 | INFO     | __main__:train_seq2seq_model:62 - Epoch :55/100, it

2022-09-19 15:23:55.336 | INFO     | __main__:train_seq2seq_model:62 - Epoch :59/100, iteration :1000/1167 loss:0.0032
2022-09-19 15:23:58.192 | INFO     | __main__:train_seq2seq_model:62 - Epoch :59/100, iteration :1100/1167 loss:0.0048
2022-09-19 15:24:00.174 | INFO     | __main__:train_seq2seq_model:65 - Epoch :59/100, training loss:0.0069
2022-09-19 15:24:01.576 | INFO     | __main__:train_seq2seq_model:69 - Epoch:59, dev loss:0.0079
2022-09-19 15:24:01.590 | INFO     | __main__:train_seq2seq_model:62 - Epoch :60/100, iteration :0/1167 loss:0.0341
2022-09-19 15:24:04.311 | INFO     | __main__:train_seq2seq_model:62 - Epoch :60/100, iteration :100/1167 loss:0.0039
2022-09-19 15:24:07.530 | INFO     | __main__:train_seq2seq_model:62 - Epoch :60/100, iteration :200/1167 loss:0.0032
2022-09-19 15:24:10.452 | INFO     | __main__:train_seq2seq_model:62 - Epoch :60/100, iteration :300/1167 loss:0.0023
2022-09-19 15:24:13.308 | INFO     | __main__:train_seq2seq_model:62 - Epoch :60/100, it

2022-09-19 15:26:58.200 | INFO     | __main__:train_seq2seq_model:62 - Epoch :64/100, iteration :1100/1167 loss:0.0057
2022-09-19 15:27:00.173 | INFO     | __main__:train_seq2seq_model:65 - Epoch :64/100, training loss:0.0068
2022-09-19 15:27:01.574 | INFO     | __main__:train_seq2seq_model:69 - Epoch:64, dev loss:0.0077
2022-09-19 15:27:01.588 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, iteration :0/1167 loss:0.0323
2022-09-19 15:27:04.304 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, iteration :100/1167 loss:0.0037
2022-09-19 15:27:07.522 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, iteration :200/1167 loss:0.0080
2022-09-19 15:27:10.448 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, iteration :300/1167 loss:0.0023
2022-09-19 15:27:13.307 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, iteration :400/1167 loss:0.0155
2022-09-19 15:27:16.208 | INFO     | __main__:train_seq2seq_model:62 - Epoch :65/100, ite

2022-09-19 15:30:00.255 | INFO     | __main__:train_seq2seq_model:65 - Epoch :69/100, training loss:0.0067
2022-09-19 15:30:01.659 | INFO     | __main__:train_seq2seq_model:69 - Epoch:69, dev loss:0.0075
2022-09-19 15:30:01.673 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :0/1167 loss:0.0326
2022-09-19 15:30:04.387 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :100/1167 loss:0.0069
2022-09-19 15:30:07.604 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :200/1167 loss:0.0056
2022-09-19 15:30:10.538 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :300/1167 loss:0.0020
2022-09-19 15:30:13.401 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :400/1167 loss:0.0112
2022-09-19 15:30:16.298 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iteration :500/1167 loss:0.0062
2022-09-19 15:30:19.575 | INFO     | __main__:train_seq2seq_model:62 - Epoch :70/100, iter

2022-09-19 15:33:01.674 | INFO     | __main__:train_seq2seq_model:69 - Epoch:74, dev loss:0.0080
2022-09-19 15:33:01.688 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :0/1167 loss:0.0299
2022-09-19 15:33:04.401 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :100/1167 loss:0.0041
2022-09-19 15:33:07.630 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :200/1167 loss:0.0032
2022-09-19 15:33:10.553 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :300/1167 loss:0.0026
2022-09-19 15:33:13.420 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :400/1167 loss:0.0103
2022-09-19 15:33:16.325 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :500/1167 loss:0.0031
2022-09-19 15:33:19.605 | INFO     | __main__:train_seq2seq_model:62 - Epoch :75/100, iteration :600/1167 loss:0.0019
2022-09-19 15:33:22.742 | INFO     | __main__:train_seq2seq_model:62 - Epoch :7

2022-09-19 15:36:01.885 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :0/1167 loss:0.0300
2022-09-19 15:36:04.602 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :100/1167 loss:0.0041
2022-09-19 15:36:07.829 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :200/1167 loss:0.0077
2022-09-19 15:36:10.750 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :300/1167 loss:0.0043
2022-09-19 15:36:13.619 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :400/1167 loss:0.0102
2022-09-19 15:36:16.510 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :500/1167 loss:0.0034
2022-09-19 15:36:19.779 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :600/1167 loss:0.0010
2022-09-19 15:36:22.918 | INFO     | __main__:train_seq2seq_model:62 - Epoch :80/100, iteration :700/1167 loss:0.0168
2022-09-19 15:36:25.993 | INFO     | __main__:train_seq2se

2022-09-19 15:39:04.742 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :100/1167 loss:0.0040
2022-09-19 15:39:07.973 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :200/1167 loss:0.0068
2022-09-19 15:39:10.890 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :300/1167 loss:0.0032
2022-09-19 15:39:13.760 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :400/1167 loss:0.0099
2022-09-19 15:39:16.643 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :500/1167 loss:0.0025
2022-09-19 15:39:19.923 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :600/1167 loss:0.0025
2022-09-19 15:39:23.061 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :700/1167 loss:0.0167
2022-09-19 15:39:26.134 | INFO     | __main__:train_seq2seq_model:62 - Epoch :85/100, iteration :800/1167 loss:0.0049
2022-09-19 15:39:29.033 | INFO     | __main__:train_seq2

2022-09-19 15:42:08.198 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :200/1167 loss:0.0089
2022-09-19 15:42:11.123 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :300/1167 loss:0.0025
2022-09-19 15:42:13.988 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :400/1167 loss:0.0094
2022-09-19 15:42:16.875 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :500/1167 loss:0.0062
2022-09-19 15:42:20.151 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :600/1167 loss:0.0024
2022-09-19 15:42:23.296 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :700/1167 loss:0.0182
2022-09-19 15:42:26.374 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :800/1167 loss:0.0029
2022-09-19 15:42:29.267 | INFO     | __main__:train_seq2seq_model:62 - Epoch :90/100, iteration :900/1167 loss:0.0228
2022-09-19 15:42:32.032 | INFO     | __main__:train_seq2

2022-09-19 15:45:11.369 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :300/1167 loss:0.0093
2022-09-19 15:45:14.239 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :400/1167 loss:0.0095
2022-09-19 15:45:17.125 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :500/1167 loss:0.0037
2022-09-19 15:45:20.399 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :600/1167 loss:0.0055
2022-09-19 15:45:23.528 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :700/1167 loss:0.0164
2022-09-19 15:45:26.586 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :800/1167 loss:0.0042
2022-09-19 15:45:29.490 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :900/1167 loss:0.0216
2022-09-19 15:45:32.259 | INFO     | __main__:train_seq2seq_model:62 - Epoch :95/100, iteration :1000/1167 loss:0.0057
2022-09-19 15:45:35.113 | INFO     | __main__:train_seq

In [None]:
'''
data = get_data_file("../pycorrector/data/RNA/train", args.use_segment, args.segment_type)
res = []
for i in range(len(data)-1,len(data)-100,-1):
    a,b = data[i]
    if a != b:
        res.append(i)
res
'''

In [None]:
 [data[i][1] for i in res[:2]]

In [20]:
eg5 = "MET ASN LYS SER VAL ALA PRO LEU LEU LEU ALA ALA SER ILE LEU TYR GLY GLY ALA ALA ALA GLN GLN THR VAL TRP GLY GLN CYS GLY GLY ILE GLY TRP SER GLY PRO THR ASN CYS ALA PRO GLY SER ALA CYS SER THR LEU ASN PRO TYR TYR ALA GLN CYS ILE PRO GLY ALA THR THR ILE THR THR SER THR ARG PRO PRO SER GLY PRO THR THR THR THR ARG ALA THR SER THR SER SER SER THR PRO PRO THR SER SER GLY VAL ARG PHE ALA GLY VAL ASN ILE ALA GLY PHE ASP PHE GLY CYS THR THR ASP GLY THR CYS VAL THR SER LYS VAL TYR PRO PRO LEU LYS ASN PHE THR GLY SER ASN ASN TYR PRO ASP GLY ILE GLY GLN MET GLN HIS PHE VAL ASN ASP ASP GLY MET THR ILE PHE ARG LEU PRO VAL GLY TRP GLN TYR LEU VAL ASN ASN ASN LEU GLY GLY ASN LEU ASP SER THR SER ILE SER LYS TYR ASP GLN LEU VAL GLN GLY CYS LEU SER LEU GLY ALA TYR CYS ILE VAL ASP ILE HIS ASN TYR ALA ARG TRP ASN GLY GLY ILE ILE GLY GLN GLY GLY PRO THR ASN ALA GLN PHE THR SER LEU TRP SER GLN LEU ALA SER LYS TYR ALA SER GLN SER ARG VAL TRP PHE GLY ILE MET ASN GLU PRO HIS ASP VAL ASN ILE ASN THR TRP ALA ALA THR VAL GLN GLU VAL VAL THR ALA ILE ARG ASN ALA GLY ALA THR SER GLN PHE ILE SER LEU PRO GLY ASN ASP TRP GLN SER ALA GLY ALA PHE ILE SER ASP GLY SER ALA ALA ALA LEU SER GLN VAL THR ASN PRO ASP GLY SER THR THR ASN LEU ILE PHE ASP VAL HIS LYS TYR LEU ASP SER ASP ASN SER GLY THR HIS ALA GLU CYS THR THR ASN ASN ILE ASP GLY ALA PHE SER PRO LEU ALA THR TRP LEU ARG GLN ASN ASN ARG GLN ALA ILE LEU THR GLU THR GLY GLY GLY ASN VAL GLN SER CYS ILE GLN ASP MET CYS GLN GLN ILE GLN TYR LEU ASN GLN ASN SER ASP VAL TYR LEU GLY TYR VAL GLY TRP GLY ALA GLY SER PHE ASP SER THR TYR VAL LEU THR GLU THR PRO THR GLY SER GLY ASN SER TRP THR ASP THR SER LEU VAL SER SER CYS LEU ALA ARG LYS GLY"
eg7 = "MET ALA PRO SER VAL THR LEU PRO LEU THR THR ALA ILE LEU ALA ILE ALA ARG LEU VAL ALA ALA GLN GLN PRO GLY THR SER THR PRO GLU VAL HIS PRO LYS LEU THR THR TYR LYS CYS THR LYS SER GLY GLY CYS VAL ALA GLN ASP THR SER VAL VAL LEU ASP TRP ASN TYR ARG TRP MET HIS ASP ALA ASN TYR ASN SER CYS THR VAL ASN GLY GLY VAL ASN THR THR LEU CYS PRO ASP GLU ALA THR CYS GLY LYS ASN CYS PHE ILE GLU GLY VAL ASP TYR ALA ALA SER GLY VAL THR THR SER GLY SER SER LEU THR MET ASN GLN TYR MET PRO SER SER SER GLY GLY TYR SER SER VAL SER PRO ARG LEU TYR LEU LEU ASP SER ASP GLY GLU TYR VAL MET LEU LYS LEU ASN GLY GLN GLU LEU SER PHE ASP VAL ASP LEU SER ALA LEU PRO CYS GLY GLU ASN GLY SER LEU TYR LEU SER GLN MET ASP GLU ASN GLY GLY ALA ASN GLN TYR ASN THR ALA GLY ALA ASN TYR GLY SER GLY TYR CYS ASP ALA GLN CYS PRO VAL GLN THR TRP ARG ASN GLY THR LEU ASN THR SER HIS GLN GLY PHE CYS CYS ASN GLU MET ASP ILE LEU GLU GLY ASN SER ARG ALA ASN ALA LEU THR PRO HIS SER CYS THR ALA THR ALA CYS ASP SER ALA GLY CYS GLY PHE ASN PRO TYR GLY SER GLY TYR LYS SER TYR TYR GLY PRO GLY ASP THR VAL ASP THR SER LYS THR PHE THR ILE ILE THR GLN PHE ASN THR ASP ASN GLY SER PRO SER GLY ASN LEU VAL SER ILE THR ARG LYS TYR GLN GLN ASN GLY VAL ASP ILE PRO SER ALA GLN PRO GLY GLY ASP THR ILE SER SER CYS PRO SER ALA SER ALA TYR GLY GLY LEU ALA THR MET GLY LYS ALA LEU SER SER GLY MET VAL LEU VAL PHE SER ILE TRP ASN ASP ASN SER GLN TYR MET ASN TRP LEU ASP SER GLY ASN ALA GLY PRO CYS SER SER THR GLU GLY ASN PRO SER ASN ILE LEU ALA ASN ASN PRO ASN THR HIS VAL VAL PHE SER ASN ILE ARG TRP GLY ASP ILE GLY SER THR THR ASN SER THR ALA PRO PRO PRO PRO PRO ALA SER SER THR THR PHE SER THR THR ARG ARG SER SER THR THR SER SER SER PRO SER CYS THR GLN THR HIS TRP GLY GLN CYS GLY GLY ILE GLY TYR SER GLY CYS LYS THR CYS THR SER GLY THR THR CYS GLN TYR SER ASN ASP TYR TYR SER GLN CYS LEU"

In [37]:
m = Inference(args.model_dir,
                  args.arch,
                  embed_size=args.embed_size,
                  hidden_size=args.hidden_size,
                  dropout=args.dropout,
                  max_length=args.max_length
                  )

inputs = [eg5, eg7]

#inputs = [data[i][0] for i in res[:2]]
outputs = m.predict(inputs)

for a, b in zip(inputs, outputs):
    print('input  :', a)
    print('predict:', b[0], b[1])
    print()



2022-09-19 17:53:34.112 | DEBUG    | __main__:__init__:36 - Device: cuda
2022-09-19 17:53:34.113 | DEBUG    | __main__:__init__:37 - Use seq2seq model.
2022-09-19 17:53:34.122 | DEBUG    | __main__:__init__:52 - Load model from output/RNA/seq2seq.pth


1
1
input  : MET ASN LYS SER VAL ALA PRO LEU LEU LEU ALA ALA SER ILE LEU TYR GLY GLY ALA ALA ALA GLN GLN THR VAL TRP GLY GLN CYS GLY GLY ILE GLY TRP SER GLY PRO THR ASN CYS ALA PRO GLY SER ALA CYS SER THR LEU ASN PRO TYR TYR ALA GLN CYS ILE PRO GLY ALA THR THR ILE THR THR SER THR ARG PRO PRO SER GLY PRO THR THR THR THR ARG ALA THR SER THR SER SER SER THR PRO PRO THR SER SER GLY VAL ARG PHE ALA GLY VAL ASN ILE ALA GLY PHE ASP PHE GLY CYS THR THR ASP GLY THR CYS VAL THR SER LYS VAL TYR PRO PRO LEU LYS ASN PHE THR GLY SER ASN ASN TYR PRO ASP GLY ILE GLY GLN MET GLN HIS PHE VAL ASN ASP ASP GLY MET THR ILE PHE ARG LEU PRO VAL GLY TRP GLN TYR LEU VAL ASN ASN ASN LEU GLY GLY ASN LEU ASP SER THR SER ILE SER LYS TYR ASP GLN LEU VAL GLN GLY CYS LEU SER LEU GLY ALA TYR CYS ILE VAL ASP ILE HIS ASN TYR ALA ARG TRP ASN GLY GLY ILE ILE GLY GLN GLY GLY PRO THR ASN ALA GLN PHE THR SER LEU TRP SER GLN LEU ALA SER LYS TYR ALA SER GLN SER ARG VAL TRP PHE GLY ILE MET ASN GLU PRO HIS ASP VAL ASN ILE ASN THR