In [30]:
import torch
import torch.utils.data as data

In [77]:
class Word2VecDataset(data.Dataset):
    """Make pair word dataset
    
    Skip Gram is default mode
    if skip_gram = False, it will use CBOW mode
    """
    def __init__(self, data: list, window_size: int, skip_gram: bool=True):
        self.data = data
        self.window_size = window_size
        self.bag_size = window_size * 2 + 1
        self.skip_gram = skip_gram
        self.pair_word_data = []
        self.word2idx = {}
        self.idx2word = {}
        self.word_prob = []
        self.make_dict()
        self.calculate_sampling_prob()
        self.generate_pair_word()
        
    def generate_pair_word(self):
        print("Generating pair word data...")
        for sentence in self.data:
            words = sentence.split(" ")
            word_len = len(words)
            for idx in range(word_len):
                # get left and right side output word of skip word
                lower_idx = (idx - self.window_size) if idx >= self.window_size else 0
                upper_idx = (idx + self.window_size) if idx <= (word_len - 1) - self.window_size else (word_len - 1)
                BOW = []
                if idx - self.window_size < 0:
                    pad_num = self.window_size - idx
                    for i in range(pad_num):
                        BOW.append(self.word2idx["<PAD>"])
                for offset in range((upper_idx - lower_idx) + 1):
                    nearest_idx = lower_idx + offset
                    if nearest_idx != idx:
                        BOW.append(self.word2idx[words[nearest_idx]])
                if idx + self.window_size > word_len - 1:
                    pad_num = self.window_size + idx - (word_len - 1)
                    for i in range(pad_num):
                        BOW.append(self.word2idx["<PAD>"])
                if self.skip_gram: # use skip gram mode
                    self.pair_word_data.append((self.word2idx[words[idx]], BOW))
                else: # use CBOW mode
                    self.pair_word_data.append((BOW, self.word2idx[words[idx]]))
    
    def make_dict(self):
        print("Makeing dictionary...")
        self.word2idx["<PAD>"] = len(self.word2idx)
        self.idx2word[len(self.word2idx)] = "<PAD>"
        self.word_prob.append(0)
        for sentence in self.data:
            words = sentence.split(" ")
            for word in words:
                if word not in self.word2idx:
                    word_idx = len(self.word2idx)
                    self.word2idx[word] = word_idx
                    self.idx2word[word_idx] = word
                    self.word_prob.append(0)
                self.word_prob[int(self.word2idx[word])] += 1
                
    def calculate_sampling_prob(self):
        self.word_prob = [prob * 0.75 for prob in self.word_prob]
        prob_sum = 0
        for prob in self.word_prob:
            prob_sum += prob
        self.word_prob = [prob / prob_sum for prob in self.word_prob]
        
    def __getitem__(self, idx: int) -> list:
        # pair_word_data: list[(list of words, word)]
        return self.pair_word_data[idx]
    
    def __len__(self) -> int:
        return len(self.pair_word_data)

In [81]:
"""
%run utils.ipynb
dataset = Word2VecDataset(train_set, 2, True)
print(len(dataset))
print(dataset[0])
"""

Makeing dictionary...
Generating pair word data...
1310623
(1, [0, 0, 2, 3])
