In [1]:
import torch

from seq2seq.data import Dataset
from seq2seq.model import Seq2seqModel
from seq2seq.train import Trainer, Evaluator

In [2]:
GPU_ID = 1
device = torch.device("cuda:{}".format(GPU_ID)) if torch.cuda.is_available() else torch.device("cpu")

# 1. Preprocess Data

In [3]:
SRC_FILE_PATH = 'smart_mobile/smart_src_khaiii.txt'
TGT_FILE_PATH = 'smart_mobile/smart_tgt_khaiii.txt'

EXPR_PATH = 'smart_mobile/'

BATCH_SIZE = 32
MAX_SRC_LEN = 10
MAX_TGT_LEN = 10

In [4]:
%%time
dataset = Dataset(
    src_file_path=SRC_FILE_PATH,
    tgt_file_path=TGT_FILE_PATH,
    max_src_len=MAX_SRC_LEN,
    max_tgt_len=MAX_TGT_LEN
)

Reading lines...
Read 90729 sentence pairs

Trim data to 20074 sentence pairs
Avg length of src :  7.272591411776427
Avg length of tgt :  7.307063863704294

Source vocab : 5062 (0 reduced)
Target vocab : 5060 (0 reduced)

Success to preprocess data!

CPU times: user 1.09 s, sys: 129 ms, total: 1.21 s
Wall time: 1.21 s


# 2. Define Model

In [5]:
NUM_LAYERS = 1
INPUT_SIZE = dataset.src_vocab_size
EMBED_SIZE = 64
HIDDEN_SIZE = 64
OUTPUT_SIZE = dataset.tgt_vocab_size

In [6]:
model = Seq2seqModel(
    n_layers=NUM_LAYERS,
    input_size=INPUT_SIZE,
    emb_size=EMBED_SIZE,
    hidden_size=HIDDEN_SIZE,
    output_size=OUTPUT_SIZE,
    max_tgt_len=MAX_TGT_LEN,
    dropout_p=0.0,
    bi_encoder=True,
    device=device
)

In [11]:
next(model.parameters()).get_device()

RuntimeError: get_device is not implemented for tensors with CPU backend

# 3. Train Model

In [8]:
trainer = Trainer(
    model=model,
    dataset=dataset,
    device=device,
    print_interval=1,
    plot_interval=-1,
    checkpoint_interval=10,
    expr_path=EXPR_PATH
)

In [9]:
trainer.train(num_epoch=10, batch_size=BATCH_SIZE)

Start to train


RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'

# 4. Evaluate Model

In [None]:
evaluator = Evaluator(dataset, model, device=device)
evaluator.loadModel(EXPR_PATH+'ep10.model')

In [None]:
pairs, attn_list = evaluator.evalModel(num=10, beam_size=5)
for p, attn in zip(pairs, attn_list):
    print('Input : ' + ' '.join(p[0]))
    print('Gen   : ' + ' '.join(p[1]))
    print()