https://github.com/SenticNet/context2vec

In [1]:
import argparse
import os
import sys

sys.path.append("..")
from pycorrector.deepcontext.infer import Inference
from pycorrector.deepcontext.preprocess import parse_xml_file, save_corpus_data, get_data_file

In [2]:
parser = argparse.ArgumentParser()

# required parameters

parser.add_argument("--raw_train_path",
                    default="../pycorrector/data/cn/sighan_2015/train.tsv", type=str,
                    help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
                    )

parser.add_argument("--dataset", default="sighan", type=str,
                    help="Dataset name. selected in the list:" + ", ".join(["sighan", "cged"])
                    )
parser.add_argument("--no_segment", action="store_true", default=True, help="Whether not to segment train data in preprocess")
parser.add_argument("--do_train", action="store_true", default=True,help="Whether not to train")
parser.add_argument("--do_predict", action="store_true", default=True,help="Whether not to predict")
parser.add_argument("--segment_type", default="char", type=str,
                    help="Segment data type, selected in list: " + ", ".join(["char", "word"]))
parser.add_argument("--model_dir", default="output/models/", type=str, help="Dir for model save.")
parser.add_argument("--train_path", default="output/train.txt", type=str, help="Train file after preprocess.")
parser.add_argument("--vocab_path", default="output/vocab.txt", type=str, help="Vocab file for train data.")

# Other parameters
parser.add_argument("--batch_size", default=8, type=int, help="Batch size.")
parser.add_argument("--embed_size", default=128, type=int, help="Embedding size.")
parser.add_argument("--hidden_size", default=128, type=int, help="Hidden size.")
parser.add_argument("--learning_rate", default=1e-3, type=float, help="Learning rate.")
parser.add_argument("--n_layers", default=2, type=int, help="Num layers.")
parser.add_argument("--min_freq", default=1, type=int, help="Mini word frequency.")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=20, type=int, help="Epoch num.")

args = parser.parse_args([])

In [3]:
import os
import sys

from xml.dom import minidom

sys.path.append('../..')
from pycorrector.utils.tokenizer import segment
from pycorrector.deepcontext import config


def parse_xml_file(path, use_segment, segment_type):
    print('Parse data from %s' % path)
    word_arr = []
    dom_tree = minidom.parse(path)
    docs = dom_tree.documentElement.getElementsByTagName('DOC')
    for doc in docs:
        # Input the text
        text = doc.getElementsByTagName('CORRECTION')[0]. \
            childNodes[0].data.strip()
        # Segment
        word_seq = ' '.join(segment(text.strip(), cut_type=segment_type)) if use_segment else text.strip()
        word_arr.append(word_seq)
    return word_arr


def get_data_file(path, use_segment, segment_type):
    data_list = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith("#"):
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            target = ' '.join(segment(parts[1].strip(), cut_type=segment_type)) if use_segment else parts[1].strip()
            data_list.append(target)
    return data_list


def save_corpus_data(data_list, data_path):
    dirname = os.path.dirname(data_path)
    os.makedirs(dirname, exist_ok=True)
    with open(data_path, 'w', encoding='utf-8') as f:
        count = 0
        for line in data_list:
            f.write(line + '\n')
            count += 1
        print("save line size:%d to %s" % (count, data_path))

In [13]:
args.no_segment = True
args.segment_type = "char"
args.use_segment = False if args.no_segment else True
data = get_data_file("../pycorrector/data/RNA/train", args.use_segment, args.segment_type)

In [14]:
print(data[0])

ASP ALA ILE ALA ASP ALA SER LYS ARG PHE SER ASP ALA THR TYR PRO ILE ALA GLU LYS PHE ASP TRP GLY GLY SER SER ALA ILE ALA LYS TYR ILE ALA ASP ALA SER ALA GLY ASN PRO ARG GLN ALA ALA LEU ALA VAL GLU LYS LEU LEU GLU VAL GLY LEU THR MET ASP PRO LYS LEU VAL ARG ALA ALA VAL GLU ALA HIS SER LYS ALA LEU ASP SER ALA LYS LYS ASN ALA LYS LEU MET ALA SER LYS GLU ASP PHE ALA ALA VAL ASN GLU ALA LEU ALA ARG MET ILE ALA SER ALA ASP LYS GLN LYS PHE ALA ALA LEU ARG THR ALA PHE PRO GLU SER ARG GLU LEU GLN GLY LYS LEU PHE ALA GLY ASN ASN ALA PHE GLU ALA GLU LYS ALA TYR ASP SER PHE LYS ALA LEU THR SER ALA VAL ARG ASP ALA SER ILE ASN GLY ALA LYS ALA PRO VAL ILE ALA GLU ALA ALA ARG ALA GLU ARG TYR VAL GLY ASP GLY PRO VAL GLY ARG ALA ALA LYS LYS PHE SER GLU ALA THR TYR PRO ILE MET ASP LYS LEU ASP TRP GLY LYS SER PRO GLU ILE SER LYS TYR ILE GLU THR ALA SER ALA LYS ASN PRO LYS MET MET ALA ASP GLY ILE ASP LYS THR LEU GLU VAL ALA LEU THR MET ASN GLN ASN ALA ILE ASN ASP ALA VAL PHE ALA HIS VAL ARG ALA ILE LYS GLY 

In [15]:
import time

import numpy as np
import torch
from torch import optim
from pycorrector.deepcontext import config
from pycorrector.deepcontext.data_reader import write_config
from pycorrector.deepcontext.model import Context2vec
from pycorrector.deepcontext.dataset import Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train(train_path,
          model_dir,
          vocab_path,
          batch_size=64,
          epochs=3,
          word_embed_size=200,
          hidden_size=200,
          learning_rate=0.0001,
          n_layers=1,
          min_freq=1,
          dropout=0.0):
    print("device: {}".format(device))
    if not os.path.isfile(train_path):
        raise FileNotFoundError

    print('Loading input file')
    dataset = Dataset(train_path,
                      batch_size,
                      min_freq,
                      device,
                      vocab_path)
    counter = np.array([dataset.word_freqs[word] for word in dataset.vocab_2_ids])
    model = Context2vec(vocab_size=len(dataset.vocab_2_ids),
                        counter=counter,
                        word_embed_size=word_embed_size,
                        hidden_size=hidden_size,
                        n_layers=n_layers,
                        use_mlp=True,
                        dropout=dropout,
                        pad_index=dataset.pad_index,
                        device=device,
                        is_inference=False).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print('batch_size:', batch_size, 'epochs:', epochs, 'word_embed_size:', word_embed_size, 'hidden_size:',
          hidden_size, 'device:', device)
    print('model:', model)

    # save model config
    output_config_file = os.path.join(model_dir, 'config.json')
    write_config(output_config_file,
                 vocab_size=len(dataset.vocab_2_ids),
                 word_embed_size=word_embed_size,
                 hidden_size=hidden_size,
                 n_layers=n_layers,
                 use_mlp=True,
                 dropout=dropout,
                 pad_index=dataset.pad_index,
                 pad_token=dataset.pad_token,
                 unk_token=dataset.unk_token,
                 sos_token=dataset.sos_token,
                 eos_token=dataset.eos_token,
                 learning_rate=learning_rate
                 )

    interval = 1e5
    best_loss = 1e3
    print("train start...")
    for epoch in range(epochs):
        begin_time = time.time()
        cur_at = begin_time
        total_loss = 0.0
        word_count = 0
        next_count = interval
        last_accum_loss = 0.0
        last_word_count = 0
        cur_loss = 0
        for it, (mb_x, mb_x_len) in enumerate(dataset.train_data):
            sentence = torch.from_numpy(mb_x).to(device).long()

            target = sentence[:, 1:-1]
            if target.size(0) == 0:
                continue
            optimizer.zero_grad()
            loss = model(sentence, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.data.mean()

            minibatch_size, sentence_length = target.size()
            word_count += minibatch_size * sentence_length
            accum_mean_loss = float(total_loss) / word_count if total_loss > 0.0 else 0.0
            cur_mean_loss = (float(total_loss) - last_accum_loss) / (word_count - last_word_count)
            cur_loss = cur_mean_loss
            if word_count >= next_count:
                now = time.time()
                duration = now - cur_at
                throuput = float((word_count - last_word_count)) / (now - cur_at)
                print('{} words, {:.2f} sec, {:.2f} words/sec, {:.4f} accum_loss/word, {:.4f} cur_loss/word'
                      .format(word_count, duration, throuput, accum_mean_loss, cur_mean_loss))
                next_count += interval
                cur_at = now
                last_accum_loss = float(total_loss)
                last_word_count = word_count

        # find best model
        is_best = cur_loss < best_loss
        best_loss = min(cur_loss, best_loss)
        print('epoch:[{}/{}], total_loss:[{}], best_cur_loss:[{}]'
              .format(epoch + 1, epochs, total_loss.item(), best_loss))
        if is_best:
            torch.save(model.state_dict(), os.path.join(model_dir, 'model.pth'))
            torch.save(optimizer.state_dict(), os.path.join(model_dir, 'model_optimizer.pth'))
            print('epoch:{}, save new bert model:{}'.format(epoch + 1, model_dir))

In [23]:
def main():
    # Preprocess
    os.makedirs(args.model_dir, exist_ok=True)

    # Train
    if args.do_train:
        # Preprocess
        args.use_segment = False if args.no_segment else True
        data_list = []
        '''
        if args.dataset == 'sighan':
            data_list.extend(get_data_file(args.raw_train_path, args.use_segment, args.segment_type))
        else:
            data_list.extend(parse_xml_file(args.raw_train_path, args.use_segment, args.segment_type))
        '''
        data_list.extend(get_data_file("../pycorrector/data/RNA/train", args.use_segment, args.segment_type)[:1000])
        save_corpus_data(data_list, args.train_path)
        

        # Train model with train data file
        train(args.train_path,
              args.model_dir,
              args.vocab_path,
              batch_size=args.batch_size,
              epochs=args.epochs,
              word_embed_size=args.embed_size,
              hidden_size=args.hidden_size,
              learning_rate=args.learning_rate,
              n_layers=args.n_layers,
              min_freq=args.min_freq,
              dropout=args.dropout
              )

    # Predict
    if args.do_predict:
        inference = Inference(args.model_dir, args.vocab_path)
        inputs = [
            'ASP ALA ILE ALA ASP ALA SER LYS ARG PHE SER ASP ALA THR TYR PRO ILE ALA GLU LYS PHE ASP TRP GLY GLY SER SER ALA ILE ALA LYS TYR ILE ALA ASP ALA SER ALA GLY ASN PRO ARG GLN ALA ALA LEU ALA VAL GLU LYS LEU LEU GLU VAL GLY LEU THR MET ASP PRO LYS LEU VAL ARG ALA ALA VAL GLU ALA HIS SER LYS ALA LEU ASP SER ALA LYS LYS ASN ALA LYS LEU MET ALA SER LYS GLU ASP PHE ALA ALA VAL ASN GLU ALA LEU ALA ARG MET ILE ALA SER ALA ASP LYS GLN LYS PHE ALA ALA LEU ARG THR ALA PHE PRO GLU SER ARG GLU LEU GLN GLY LYS LEU PHE ALA GLY ASN ASN ALA PHE GLU ALA GLU LYS ALA TYR ASP SER PHE LYS ALA LEU THR SER ALA VAL ARG ASP ALA SER ILE ASN GLY ALA LYS ALA PRO VAL ILE ALA GLU ALA ALA ARG ALA GLU ARG TYR VAL GLY ASP GLY PRO VAL GLY ARG ALA ALA LYS LYS PHE SER GLU ALA THR TYR PRO ILE MET ASP LYS LEU ASP TRP GLY LYS SER PRO GLU ILE SER LYS TYR ILE GLU THR ALA SER ALA LYS ASN PRO LYS MET MET ALA ASP GLY ILE ASP LYS THR LEU GLU VAL ALA LEU THR MET ASN GLN ASN ALA ILE ASN ASP ALA VAL PHE ALA HIS VAL ARG ALA ILE LYS GLY ALA LEU ASN THR PRO GLY LEU VAL ALA GLU ARG ASP ASP PHE ALA ARG VAL ASN LEU ALA LEU ALA LYS MET ILE ALA THR ALA ASP PRO ALA LYS PHE LYS ALA LEU LEU THR ALA PHE PRO GLY ASN ALA ASP LEU GLN MET ALA LEU PHE ALA ALA ASN ASN PRO GLU GLN ALA LYS ALA ALA TYR GLU THR PHE VAL ALA LEU THR SER ALA VAL ALA SER SER THR'
        ]
        for i in inputs:
            output = inference.predict(i)
            print('input  :', i)
            print('predict:', output)
            print()

main()

save line size:1000 to output/train.txt
device: cpu
Loading input file
batch_size: 8 epochs: 20 word_embed_size: 128 hidden_size: 128 device: cpu
model: Context2vec(
  (drop): Dropout(p=0.0, inplace=False)
  (l2r_emb): Embedding(23, 128, padding_idx=0)
  (l2r_rnn): LSTM(128, 128, num_layers=2, batch_first=True)
  (r2l_emb): Embedding(23, 128, padding_idx=0)
  (r2l_rnn): LSTM(128, 128, num_layers=2, batch_first=True)
  (criterion): NegativeSampling(
    (W): Embedding(23, 128, padding_idx=0)
    (logsigmoid): LogSigmoid()
  )
  (MLP): MLP(
    (drop): Dropout(p=0.0, inplace=False)
    (MLP): ModuleList(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (activation_function): ReLU()
  )
)
train start...
102080 words, 17.85 sec, 5720.01 words/sec, 3.7169 accum_loss/word, 3.7169 cur_loss/word
203064 words, 34.31 sec, 2943.09 words/sec, 2.2126 accum_loss/word, 0.6920 cur_loss/word
301376 words, 15.98 sec

203064 words, 15.66 sec, 6446.53 words/sec, 0.6921 accum_loss/word, 0.6920 cur_loss/word
301376 words, 15.22 sec, 6458.01 words/sec, 0.6920 accum_loss/word, 0.6920 cur_loss/word
405312 words, 16.18 sec, 6424.97 words/sec, 0.6920 accum_loss/word, 0.6921 cur_loss/word
500296 words, 15.06 sec, 6306.11 words/sec, 0.6921 accum_loss/word, 0.6921 cur_loss/word
epoch:[15/20], total_loss:[378379.3125], best_cur_loss:[0.6917896113119835]
102080 words, 15.95 sec, 6398.20 words/sec, 0.6921 accum_loss/word, 0.6921 cur_loss/word
203064 words, 15.63 sec, 6462.03 words/sec, 0.6921 accum_loss/word, 0.6920 cur_loss/word
301376 words, 15.22 sec, 6457.57 words/sec, 0.6920 accum_loss/word, 0.6920 cur_loss/word
405312 words, 16.14 sec, 6438.97 words/sec, 0.6920 accum_loss/word, 0.6921 cur_loss/word
500296 words, 14.66 sec, 6478.72 words/sec, 0.6921 accum_loss/word, 0.6921 cur_loss/word
epoch:[16/20], total_loss:[378379.3125], best_cur_loss:[0.6917896113119835]
102080 words, 15.99 sec, 6385.24 words/sec, 0.6

2022-08-30 17:46:04.355 | DEBUG    | pycorrector.deepcontext.infer:__init__:31 - device: cpu
2022-08-30 17:46:04.397 | DEBUG    | pycorrector.deepcontext.infer:__init__:43 - Loaded deep context model: output/models/, spend: 0.042 s.


epoch:[20/20], total_loss:[378379.3125], best_cur_loss:[0.6917896113119835]
input  : ASP ALA ILE ALA ASP ALA SER LYS ARG PHE SER ASP ALA THR TYR PRO ILE ALA GLU LYS PHE ASP TRP GLY GLY SER SER ALA ILE ALA LYS TYR ILE ALA ASP ALA SER ALA GLY ASN PRO ARG GLN ALA ALA LEU ALA VAL GLU LYS LEU LEU GLU VAL GLY LEU THR MET ASP PRO LYS LEU VAL ARG ALA ALA VAL GLU ALA HIS SER LYS ALA LEU ASP SER ALA LYS LYS ASN ALA LYS LEU MET ALA SER LYS GLU ASP PHE ALA ALA VAL ASN GLU ALA LEU ALA ARG MET ILE ALA SER ALA ASP LYS GLN LYS PHE ALA ALA LEU ARG THR ALA PHE PRO GLU SER ARG GLU LEU GLN GLY LYS LEU PHE ALA GLY ASN ASN ALA PHE GLU ALA GLU LYS ALA TYR ASP SER PHE LYS ALA LEU THR SER ALA VAL ARG ASP ALA SER ILE ASN GLY ALA LYS ALA PRO VAL ILE ALA GLU ALA ALA ARG ALA GLU ARG TYR VAL GLY ASP GLY PRO VAL GLY ARG ALA ALA LYS LYS PHE SER GLU ALA THR TYR PRO ILE MET ASP LYS LEU ASP TRP GLY LYS SER PRO GLU ILE SER LYS TYR ILE GLU THR ALA SER ALA LYS ASN PRO LYS MET MET ALA ASP GLY ILE ASP LYS THR LEU GLU VAL ALA

In [25]:
inference = Inference(args.model_dir, args.vocab_path)
inputs = [
    'ALA ALA ALA ALA ASP ALA SER LYS ARG PHE ALA ASP ALA THR TYR PRO ILE ALA GLU LYS PHE ASP TRP GLY GLY SER SER ALA ILE ALA LYS TYR ILE ALA ASP ALA SER ALA GLY ASN PRO ARG GLN ALA ALA LEU ALA VAL GLU LYS LEU LEU GLU VAL GLY LEU THR MET ASP PRO LYS LEU VAL ARG ALA ALA VAL GLU ALA HIS SER LYS ALA LEU ASP SER ALA LYS LYS ASN ALA LYS LEU MET ALA SER LYS GLU ASP PHE ALA ALA VAL ASN GLU ALA LEU ALA ARG MET ILE ALA SER ALA ASP LYS GLN LYS PHE ALA ALA LEU ARG THR ALA PHE PRO GLU SER ARG GLU LEU GLN GLY LYS LEU PHE ALA GLY ASN ASN ALA PHE GLU ALA GLU LYS ALA TYR ASP SER PHE LYS ALA LEU THR SER ALA VAL ARG ASP ALA SER ILE ASN GLY ALA LYS ALA PRO VAL ILE ALA GLU ALA ALA ARG ALA GLU ARG TYR VAL GLY ASP GLY PRO VAL GLY ARG ALA ALA LYS LYS PHE SER GLU ALA THR TYR PRO ILE MET ASP LYS LEU ASP TRP GLY LYS SER PRO GLU ILE SER LYS TYR ILE GLU THR ALA SER ALA LYS ASN PRO LYS MET MET ALA ASP GLY ILE ASP LYS THR LEU GLU VAL ALA LEU THR MET ASN GLN ASN ALA ILE ASN ASP ALA VAL PHE ALA HIS VAL ARG ALA ILE LYS GLY ALA LEU ASN THR PRO GLY LEU VAL ALA GLU ARG ASP ASP PHE ALA ARG VAL ASN LEU ALA LEU ALA LYS MET ILE ALA THR ALA ASP PRO ALA LYS PHE LYS ALA LEU LEU THR ALA PHE PRO GLY ASN ALA ASP LEU GLN MET ALA LEU PHE ALA ALA ASN ASN PRO GLU GLN ALA LYS ALA ALA TYR GLU THR PHE VAL ALA LEU THR SER ALA VAL ALA SER SER THR'
]
for i in inputs:
    output = inference.predict(i)
    print('input  :', i)
    print('predict:', output)
    print()

2022-08-30 17:48:12.242 | DEBUG    | pycorrector.deepcontext.infer:__init__:31 - device: cpu
2022-08-30 17:48:12.272 | DEBUG    | pycorrector.deepcontext.infer:__init__:43 - Loaded deep context model: output/models/, spend: 0.030 s.


input  : ALA ALA ALA ALA ASP ALA SER LYS ARG PHE ALA ASP ALA THR TYR PRO ILE ALA GLU LYS PHE ASP TRP GLY GLY SER SER ALA ILE ALA LYS TYR ILE ALA ASP ALA SER ALA GLY ASN PRO ARG GLN ALA ALA LEU ALA VAL GLU LYS LEU LEU GLU VAL GLY LEU THR MET ASP PRO LYS LEU VAL ARG ALA ALA VAL GLU ALA HIS SER LYS ALA LEU ASP SER ALA LYS LYS ASN ALA LYS LEU MET ALA SER LYS GLU ASP PHE ALA ALA VAL ASN GLU ALA LEU ALA ARG MET ILE ALA SER ALA ASP LYS GLN LYS PHE ALA ALA LEU ARG THR ALA PHE PRO GLU SER ARG GLU LEU GLN GLY LYS LEU PHE ALA GLY ASN ASN ALA PHE GLU ALA GLU LYS ALA TYR ASP SER PHE LYS ALA LEU THR SER ALA VAL ARG ASP ALA SER ILE ASN GLY ALA LYS ALA PRO VAL ILE ALA GLU ALA ALA ARG ALA GLU ARG TYR VAL GLY ASP GLY PRO VAL GLY ARG ALA ALA LYS LYS PHE SER GLU ALA THR TYR PRO ILE MET ASP LYS LEU ASP TRP GLY LYS SER PRO GLU ILE SER LYS TYR ILE GLU THR ALA SER ALA LYS ASN PRO LYS MET MET ALA ASP GLY ILE ASP LYS THR LEU GLU VAL ALA LEU THR MET ASN GLN ASN ALA ILE ASN ASP ALA VAL PHE ALA HIS VAL ARG ALA ILE

In [24]:
def compare_two_strings(a,b):
    res = {}
    for i, (ca, cb) in enumerate(zip(a,b)):
        if ca!=cb:
            res[i] = (ca,cb)
    return res
compare_two_strings("a","b")

{0: ('a', 'b')}