In [13]:
import torch
import torch.utils.data as data

In [23]:
class Word2VecDataset(data.Dataset):
    """Make pair word dataset
    
    Skip Gram is default mode
    if skip_gram = False, it will use CBOW mode
    """
    def __init__(self, data: list, window_size: int, skip_gram: bool=True):
        self.data = data
        self.window_size = window_size
        self.skip_gram = skip_gram
        self.pair_word_data = []
        self.generate_pair_word()
        
    def generate_pair_word(self):
        print("Generating pair word data...")
        for sentence in self.data:
            words = sentence.split(" ")
            word_len = len(words)
            for idx in range(word_len):
                # get left and right side output word of skip word
                lower_idx = (idx - self.window_size) if idx >= self.window_size else 0
                upper_idx = (idx + self.window_size) if idx <= (word_len - 1) - self.window_size else (word_len - 1)
                BOW = []
                for offset in range((upper_idx - lower_idx) + 1):
                    nearest_idx = lower_idx + offset
                    if nearest_idx != idx:
                        BOW.append(words[nearest_idx])
                if self.skip_gram: # use skip gram mode
                    self.pair_word_data.append((words[idx], BOW))
                else: # use CBOW mode
                    self.pair_word_data.append((BOW, words[idx]))
    
    def __getitem__(self, idx: int) -> list:
        # pair_word_data: list[(list of words, word)]
        return self.pair_word_data[idx]
    
    def __len__(self) -> int:
        return len(self.pair_word_data)

In [25]:
"""
%run utils.ipynb
dataset = Word2VecDataset(train_set, 2)
len(dataset)
print(dataset[1])
"""

Reading data...
Makeing dictionary...
Generating pair word data...
('paper', ['This', 'proposes', 'three'])
