In [1]:
import os
import argparse
import logging
import sys
import json
import numpy as np

import torch
from torch.optim.lr_scheduler import StepLR
import torchtext

os.chdir(os.path.dirname(os.path.abspath(os.path.dirname('__file__'))))

from models.seq2seq import Seq2seq
from loss.loss import Perplexity
from evaluator.evaluator import Evaluator
from dataset import fields

import matplotlib.pyplot as plt

In [2]:
log_level = 'info'
LOG_FORMAT = '%(asctime)s %(levelname)-6s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, log_level.upper()))

In [3]:
f1_score_lists = []
data_name = "copy"
dir_name = "separator_Ctype4_60"
data_sort = ""
rnn = "lstm"
iterator = list(range(1,11,1))

data_path = "data/"+data_name+"_rand/correction_"+dir_name
train_path = data_path+"/data_train.txt"
config_path = "models/config.json"

In [4]:
print("RNN is %s" % rnn)
print("data path: %s" % data_path)
f1_score_list = []

# Prepare dataset
max_len = 65
src = fields.SourceField()
srcp = fields.SourceField()
tgt = fields.TargetField()
tgtp = fields.TargetField()
def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len
train = torchtext.data.TabularDataset(
    path=train_path, format='tsv',
    fields=[('src', src), ('tgt', tgt)],
    filter_pred=len_filter
)
src.build_vocab(train)
tgt.build_vocab(train)
input_vocab = src.vocab
output_vocab = tgt.vocab

print("src vocab size = %d" % (len(src.vocab)))
print("tat vacab size = %d" % (len(tgt.vocab)))

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    loss.cuda()

# Model
evaluator = Evaluator(loss=loss, batch_size=32)

optimizer = "Adam"
seq2seq = None
config_json = open(config_path).read()
config = json.loads(config_json)
config["max_len"] = max_len
config["hidden_size"] = 100
config["rnn_cell"] = rnn
config["embedding_size"] = 20
config["use_attention"] = True
config["position_embedding"] = "length"
config["use_memory"] = "queue"
#config["pos_add"] = "cat"


for i in iterator:
    save_path = (data_name + "_rand_" + dir_name
                    + ("_att" if config["use_attention"] else "")
                    + ("_with_pos_" + config["position_embedding"] if config["position_embedding"] is not None else "")
                    + ("_cat" if config["pos_add"] == "cat" else "")
                    + ("_use_stack" if config["use_memory"] == "stack" else "")
                    + ("_use_queue" if config["use_memory"] == "queue" else "")
                    + "_emb" + str(config["embedding_size"])
                    + "_hidden" + str(config["hidden_size"])
                    + "_"+rnn+"_"+str(i))

    print(json.dumps(config, indent=4))
    seq2seq = Seq2seq(config, len(src.vocab), len(tgt.vocab), tgt.sos_id, tgt.eos_id)

    if torch.cuda.is_available():
        seq2seq.cuda()

    for param in seq2seq.parameters():
        param.data.uniform_(-0.08, 0.08)

    log_path = "log/pth/"+save_path+"_model_save.pth"
    seq2seq.load_state_dict(torch.load(log_path))
    for var_name in seq2seq.state_dict():
        print(var_name)

    encoder_pos_weight = seq2seq.state_dict()['encoder.pos_embedding.weight'].cpu().numpy()
    decoder_pos_weight = seq2seq.state_dict()['decoder.pos_embedding.weight'].cpu().numpy()
    
    save_path = "pretrained_weights/" + save_path
    if not os.path.isdir(save_path):
        os.mkdir(save_path)

    np.save(save_path+"/encoder_pos_weight.npy", encoder_pos_weight)
    np.save(save_path+"/decoder_pos_weight.npy", decoder_pos_weight)

RNN is lstm
data path: data/copy_rand/correction_separator_Ctype4_60
src vocab size = 7
tat vacab size = 10




{
    "max_len": 65,
    "embedding_size": 20,
    "hidden_size": 100,
    "input_dropout_p": 0,
    "dropout_p": 0,
    "n_layers": 1,
    "bidirectional": false,
    "rnn_cell": "lstm",
    "variable_lengths": true,
    "embedding": null,
    "update_embedding": true,
    "get_context_vector": false,
    "use_attention": true,
    "attn_layers": 1,
    "hard_attn": false,
    "position_embedding": "length",
    "pos_add": "add",
    "use_memory": "queue",
    "memory_dim": 5
}
pos_embedding.weight
encoder.embedding.weight
encoder.rnn.weight_ih_l0
encoder.rnn.weight_hh_l0
encoder.rnn.bias_ih_l0
encoder.rnn.bias_hh_l0
encoder.pos_embedding.weight
encoder.W_n.weight
encoder.W_n.bias
encoder.W_a.weight
encoder.W_a.bias
encoder.W_sh.weight
encoder.W_sh.bias
decoder.embedding.weight
decoder.pos_embedding.weight
decoder.rnn.weight_ih_l0
decoder.rnn.weight_hh_l0
decoder.rnn.bias_ih_l0
decoder.rnn.bias_hh_l0
decoder.attention1.linear_out.weight
decoder.attention1.linear_out.bias
decoder.out.w