In [1]:
from word_embs import WordEmbsAug
# from torchtext.vocab import Vectors
import torch
# import MeCab


In [88]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, phase='train'):
        self.data_list = data_list
        japanese_vectors = Vectors(name='../../data/news/cc.ja.300.vec')
        self.transform = DataTransform(japanese_vectors)
        self.phase = phase

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        text = self.data_list[idx]
        text_transformed = self.transform(text, self.phase)
        return text_transformed


class DataTransform:
    def __init__(self, vectors):
        self.data_transform = {
            'train': Compose([
                RandomDelete(vectors, aug_p=0.1),
                RandomSwap(vectors, aug_p=0.1),
                RandomSubstitute(vectors, aug_p=0.1),
                RandomInsert(vectors, aug_p=0.1),
            ]),
            'val': Compose([

            ])
        }
    
    def __call__(self, text, phase):
        return self.data_transform[phase](text)

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, text):
        for t in self.transforms:
            text = t(text)
        return text

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomSwap:
    def __init__(self, vectors, aug_p=0.5):
        self.swap_aug = WordEmbsAug(model=vectors, action='swap', 
                  stopwords=['<cls>', '<eos>', '<sep>'], aug_p=aug_p)

    def __call__(self, text):
        text_transformed = self.swap_aug.augment(text)
        return text_transformed
    
class RandomInsert:
    def __init__(self, vectors, aug_p=0.5):
        self.swap_aug = WordEmbsAug(model=vectors, action='insert', 
                  stopwords=['<cls>', '<eos>', '<sep>'], aug_p=aug_p)

    def __call__(self, text):
        text_transformed = self.swap_aug.augment(text)
        return text_transformed

class RandomSubstitute:
    def __init__(self, vectors, aug_p=0.5):
        self.swap_aug = WordEmbsAug(model=vectors, action='substitute', 
                  stopwords=['<cls>', '<eos>', '<sep>'], aug_p=aug_p)

    def __call__(self, text):
        text_transformed = self.swap_aug.augment(text)
        return text_transformed

class RandomDelete:
    def __init__(self, vectors, aug_p=0.3):
        self.swap_aug = WordEmbsAug(model=vectors, action='substitute', 
                  stopwords=['<cls>', '<eos>', '<sep>'], aug_p=aug_p)

    def __call__(self, text):
        text_transformed = self.swap_aug.augment(text)
        return text_transformed


In [89]:
train_ds = MyDataset()

In [90]:
train_ds.__getitem__(0)

'<cls> 株価 格付 <organization> <company> 0 アンテルディ 0 格下げ <span> <company> 各社 ラフィット・ロートシルト 守ら <eos>'