# data_loader.py

In [1]:
import os
from torchtext import data, datasets

PAD = 1
BOS = 2
EOS = 3

In [None]:
class DataLoader():

    def __init__(self, train_fn = None, 
                    valid_fn = None, 
                    exts = None,
                    batch_size = 64, 
                    device = 'cpu', 
                    max_vocab = 99999999,    
                    max_length = 255, 
                    fix_length = None, 
                    use_bos = True, 
                    use_eos = True, 
                    shuffle = True
                    ):

        super(DataLoader, self).__init__()

        """
        * sequential : 데이터의 유형이 연속형 데이터인지(False면 토큰화가 적용x)
        * use_vocab : Vocab 사용 여부(False면 필드의 데이터는 이미 숫자여야함)
        * batch_first : 배치 수가 먼저 텐서를 생성할지 여부
        * include_lengths : 패딩 된 미니 배치의 튜플과 각 예제의 길이를 포함하는 목록 
                            또는 패딩 된 미니 배치를 반환할지 여부(default: False)
        * fix_length : 모든 문장이 채워지는 고정 길이, 유연한 sequence의 경우 None
        * init_token : 모든 문장 앞에 추가되는 토큰
        * eos_token : 모든 문장 뒤에 추가되는 토큰
        """
        
        
        self.src = data.Field(sequential = True,
                                use_vocab = True, 
                                batch_first = True, 
                                include_lengths = True, 
                                fix_length = fix_length, 
                                init_token = None, 
                                eos_token = None
                                )
        super(DataLoader, self).__init__()

        self.tgt = data.Field(sequential = True, 
                                use_vocab = True, 
                                batch_first = True, 
                                include_lengths = True, 
                                fix_length = fix_length, 
                                init_token = '<BOS>' if use_bos else None, 
                                eos_token = '<EOS>' if use_eos else None
                                )
        
        
        
#         if train_fn is not None and valid_fn is not None and exts is not None:
            
        train = TranslationDataset(path = train_fn, exts = exts,
                                        fields = [('src', self.src), ('tgt', self.tgt)], 
                                        max_length = max_length
                                        )
        valid = TranslationDataset(path = valid_fn, exts = exts,
                                        fields = [('src', self.src), ('tgt', self.tgt)], 
                                        max_length = max_length
                                        )

        self.train_iter = data.BucketIterator(train, 
                                                batch_size = batch_size, 
                                                shuffle = shuffle, 
                                                sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)), 
                                                sort_within_batch = True
                                                )

        self.valid_iter = data.BucketIterator(valid, 
                                                batch_size = batch_size, 
                                                shuffle = False, 
                                                sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)), 
                                                sort_within_batch = True
                                                )

        self.src.build_vocab(train, max_size = max_vocab)
        self.tgt.build_vocab(train, max_size = max_vocab)

    def load_vocab(self, src_vocab, tgt_vocab):
        self.src.vocab = src_vocab
        self.tgt.vocab = tgt_vocab
        
        
        

## Defines a dataset

In [74]:
import sys
import pandas as pd

class TranslationDataset(data.Dataset):

    def sort_key(ex):  # 음수와 양수 모두 가능
        return data.interleave_keys(len(ex.src), len(ex.trg))

    def __init__(self, path, exts, fields, max_length=None, **kwargs):
        
        """
        * path : 두 언어의 데이터 파일 경로
        * exts : 각 언어의 경로 확장을 포함하는 튜플
        * fields : 각 언어의 데이터에 사용될 필드를 포함하는 튜플
        * **kwargs : 생성자에 전달 
        """
        if not isinstance(fields[0], (tuple, list)):
            fields = [('src', fields[0]), ('trg', fields[1])]

#         if not path.endswith('.'):
#             path += '.'

        src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
        
        examples = []
        with open(src_path) as src_file, open(trg_path) as trg_file:
            for src_line, trg_line in zip(src_file, trg_file):
                src_line, trg_line = src_line.strip(), trg_line.strip()
                if max_length and max_length < max(len(src_line.split()), len(trg_line.split())):
                    continue
                if src_line != '' and trg_line != '':
                    examples.append(data.Example.fromlist(
                        [src_line, trg_line], fields))

        super(TranslationDataset, self).__init__(examples, fields, **kwargs)
        
if __name__ == '__main__':

    """
    argv1,2 : train.csv와 test.csv파일이 있는 공통 경로
    (argv3, argv4) : 확장자를 포함한 각 파일 이름
    
    """
    loader = DataLoader('C:/Users/USER/Capstone/','C:/Users/USER/Capstone/' , ('train.csv','test.csv'),
                        shuffle = False, 
                        batch_size = 8
                        )
    
    
    print(len(loader.src.vocab))
    print(len(loader.tgt.vocab))

    for batch_index, batch in enumerate(loader.train_iter):
        print(batch.src)
        print(batch.tgt)

        if batch_index > 1:
            break

1110293
1111033
(tensor([[  12942,  197180,    1658,     792,  734851,     573,     614,   17296],
        [   3581,    3582,  223993,  211283,     469,      28,     756,    2873],
        [     60,     239,    9267,   19978,      15,      30,    1791,   18329],
        [  53413,  138005,    5114,  158971,      31,       8,  477815,       1],
        [  59446,  286629,  363804, 1084290,   18051,       8,  228006,       1],
        [   2164,     684,   18431,  913551,     102,      25,  108882,       1],
        [    141,  730159,   16248,      15,       7,  496405,       1,       1],
        [ 481901,       1,       1,       1,       1,       1,       1,       1]]), tensor([8, 8, 8, 7, 7, 7, 6, 1]))
(tensor([[      2,    1550,     774,   62223,   81476,    6186,    1266,  202147,
           21772,      25,      40,    5455,    1008,     615,      18,       4,
           54104,       3,       1,       1,       1,       1,       1,       1,
               1,       1,       1,       1,   

# Train.py

In [1]:
import argparse

import torch
import torch.nn as nn

from data_loader import DataLoader
import data_loader
from simple_nmt.seq2seq import Seq2Seq
import simple_nmt.trainer as trainer

ModuleNotFoundError: No module named 'data_loader'

## Reference

https://books.google.co.kr/books?id=LV2nDwAAQBAJ&pg=PA145&lpg=PA145&dq=__init__(self,+train_fn+%3D+None&source=bl&ots=uhXWUwlTcz&sig=ACfU3U0AcrutNKVXJ19pieVc0xctnQOzTA&hl=ko&sa=X&ved=2ahUKEwj6u4-fuM7nAhXYP3AKHfKuDNoQ6AEwAHoECAsQAQ#v=onepage&q=__init__(self%2C%20train_fn%20%3D%20None&f=true