# Set Parameter
- Attention = Luong
- Teacher Forcing Ratio = 0.5
- Layer = 1
- Batch size = 32
- Learning rate = 0.001
- Hidden unit = 200
- Epochs = 100
- N = 100
- Data Length = 100K
- Data = [single_Ctype4, last_separator_Ctype4, single_Ctype2_concat, separator_Ctype4]
- Deduplication
- Random split

# Import packages

import useful packages for experiments

In [1]:
import os
import argparse
import logging
import sys

import torch
from torch.optim.lr_scheduler import StepLR
import torchtext

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname('__file__'))))))))

from trainer.supervised_trainer_unmatching import SupervisedTrainer_unmatching
from models.encoderRNN_gru import EncoderRNN_gru
from models.decoderRNN_gru import DecoderRNN_gru
from models.seq2seq import Seq2seq
from loss.loss import Perplexity
from optim.optim import Optimizer
from dataset import fields

import matplotlib.pyplot as plt



# Log format

In [2]:
log_level = 'info'
LOG_FORMAT = '%(asctime)s %(levelname)-6s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, log_level.upper()))

In [3]:
data_name = ["copy_rand_correction_single_Ctype4"]
data_path = ["single_Ctype4"]
character_accuracy = []
sentence_accuracy = []
f1_score = []

# Prepare dataset

In [None]:
for i, j in zip(data_path, data_name):
    print("data : %s" % i)
    train_path = "../../../data/copy_rand/correction_" + i + "/data_train.txt"
    dev_path = "../../../data/copy_rand/correction_" + i + "/data_test.txt"

    src = fields.SourceField()
    tgt = fields.TargetField()
    max_len = 104
    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len
    train = torchtext.data.TabularDataset(
        path=train_path, format='tsv',
        fields=[('src', src), ('tgt', tgt)],
        filter_pred=len_filter
    )
    dev = torchtext.data.TabularDataset(
        path=dev_path, format='tsv',
        fields=[('src', src), ('tgt', tgt)],
        filter_pred=len_filter
    )
    src.build_vocab(train)
    tgt.build_vocab(train)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    
    optimizer = "Adam"
    hidden_size = 200
    bidirectional = False

    seq2seq = None
    encoder = EncoderRNN_gru(len(src.vocab), max_len, hidden_size,
                         bidirectional=bidirectional, variable_lengths=True)
    decoder = DecoderRNN_gru(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,
                         dropout_p=0.2, use_attention="Luong", bidirectional=bidirectional,
                         eos_id=tgt.eos_id, sos_id=tgt.sos_id)
    seq2seq = Seq2seq(encoder, decoder)
    if torch.cuda.is_available():
        seq2seq.cuda()

    for param in seq2seq.parameters():
        param.data.uniform_(-0.08, 0.08)

    # train
    t = SupervisedTrainer_unmatching(loss=loss, batch_size=32,
                          checkpoint_every=50,
                          print_every=100,
                          hidden_size=hidden_size,
                          path="GRU",
                          file_name=j)

    seq2seq, ave_loss, character_accuracy_list, sentence_accuracy_list, f1_score_list = t.train(seq2seq, train,
                                                                             num_epochs=100, dev_data=dev,
                                                                             optimizer=optimizer,
                                                                             teacher_forcing_ratio=0.5)

    character_accuracy.append(character_accuracy_list)
    sentence_accuracy.append(sentence_accuracy_list)
    f1_score.append(f1_score_list)

data : single_Ctype4


  "num_layers={}".format(dropout, num_layers))
2019-05-09 20:25:38,657 INFO   Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
), Scheduler: None
2019-05-09 20:27:34,900 INFO   Finished epoch 1: Train loss: 1.3513, Dev loss: 1.1910, Accuracy(character): 0.9528, Accuracy(sentence): 0.0632, F1 Score: 0.0000
2019-05-09 20:29:51,055 INFO   Finished epoch 2: Train loss: 1.2586, Dev loss: 1.1442, Accuracy(character): 0.9722, Accuracy(sentence): 0.1997, F1 Score: 0.0000
2019-05-09 20:32:04,306 INFO   Finished epoch 3: Train loss: 1.2347, Dev loss: 1.1543, Accuracy(character): 0.9605, Accuracy(sentence): 0.1032, F1 Score: 0.0000
2019-05-09 20:34:05,437 INFO   Finished epoch 4: Train loss: 1.2315, Dev loss: 1.1394, Accuracy(character): 0.9660, Accuracy(sentence): 0.1244, F1 Score: 0.0000
2019-05-09 20:36:19,448 INFO   Finished epoch 5: Train loss: 1.2178, Dev loss: 1.1977, Accuracy(character): 0.9545, Accuracy(senten

In [8]:
print(character_accuracy[0])

[0.9527795363124724, 0.9722095317656632, 0.9604966029460713, 0.9660028107548356, 0.9544650102357597, 0.9551725938350487, 0.9504566388854443, 0.9771985128235251, 0.9760569156099143, 0.955990584406402, 0.9662225369455449, 0.9619052436241333, 0.9692600883690878, 0.9767193574423001, 0.9750387784093999, 0.9477405185973202, 0.9765496679088811, 0.9684062019348958, 0.9771669679743639, 0.9673130097484461, 0.9735121164853382, 0.9787170165969417, 0.9772409352068798, 0.9722834989981791, 0.9860816511226703, 0.98105623032526, 0.9529905604758269, 0.9627080056476157, 0.9860626154378316, 0.9864982606822824, 0.9148207491140249, 0.9752753647781092, 0.987924305415054, 0.9730128376658552, 0.9831648403286756, 0.973097138555855, 0.9778419189710723, 0.9804737383691966, 0.9852163432775751, 0.9864291883401536, 0.9848628234162854, 0.9791107833347474, 0.9892078544498905, 0.980308399849455, 0.9833127747937076, 0.9764914731009456, 0.9818802035404198, 0.9685247670576053, 0.9792369627313924, 0.9382079915068213, 0.986

In [9]:
print(sentence_accuracy[0])

[0.0632, 0.19974, 0.10322, 0.1244, 0.05952, 0.07744, 0.10058, 0.2587, 0.2695, 0.02826, 0.16732, 0.13034, 0.28872, 0.34654, 0.34854, 0.27238, 0.40404, 0.35942, 0.4467, 0.29778, 0.41998, 0.50108, 0.47568, 0.40124, 0.62122, 0.52574, 0.29956, 0.2761, 0.61584, 0.6013, 0.28566, 0.48288, 0.65264, 0.49218, 0.5998, 0.49702, 0.51988, 0.58, 0.61664, 0.678, 0.6284, 0.5881, 0.70026, 0.5739, 0.61032, 0.59422, 0.62026, 0.50132, 0.54912, 0.37672, 0.69568, 0.72654, 0.64684, 0.68702, 0.49758, 0.7473, 0.63958, 0.67776, 0.6768, 0.63736, 0.6379, 0.7285, 0.65592, 0.67342, 0.68442, 0.72886, 0.72192, 0.62778, 0.59268, 0.69564, 0.64366, 0.63468, 0.67746, 0.68762, 0.62816, 0.66496, 0.73896, 0.71012, 0.5709, 0.64646, 0.65108, 0.7315, 0.67594, 0.67044, 0.66822, 0.40788, 0.56956, 0.7544, 0.7423, 0.64936, 0.66534, 0.79704, 0.48904, 0.75794, 0.67844, 0.77106, 0.75812, 0.68076, 0.77184, 0.74362]


In [10]:
print(f1_score[0])

[0, 0, 0, 0, 0, 5.592372004585745e-05, 0.009604174762671404, 0.03874131589701675, 0.09793695244457006, 0.10818209469735375, 0.12111933743255116, 0.19392041005196842, 0.2600651924940534, 0.28816089935007905, 0.3193967322999581, 0.3501186949853995, 0.42618855761482677, 0.3755485649928153, 0.4649650584849957, 0.4673527923063945, 0.48576332007176837, 0.5602447811996452, 0.5422123252567025, 0.5068280287380487, 0.595547153250173, 0.5794440757944407, 0.49871138570167695, 0.48902734510211143, 0.6056963227801094, 0.594809663775181, 0.41365968542350073, 0.5580100365115619, 0.6507070291048873, 0.5836551508193301, 0.6369985492088128, 0.5963221219578889, 0.5574366396284204, 0.6424560945354835, 0.6245744351593935, 0.7006461110656743, 0.6689693518222102, 0.6934294334562521, 0.693690727081138, 0.6426077222528529, 0.7187066673192058, 0.6955056179775281, 0.7084328264421265, 0.6562750453883773, 0.6250468874863415, 0.5000414683099176, 0.6519610086494898, 0.7164824974886022, 0.6806800317650681, 0.728684241