In [30]:
import torch
import torch.utils.data as data

In [70]:
class Word2VecDataset(data.Dataset):
    """Make pair word dataset
    
    Skip Gram is default mode
    if skip_gram = False, it will use CBOW mode
    """
    def __init__(self, data: list, window_size: int, skip_gram: bool=True):
        self.data = data
        self.window_size = window_size
        self.skip_gram = skip_gram
        self.pair_word_data = []
        self.word2idx = {}
        self.idx2word = {}
        self.word_prob = []
        self.make_dict()
        self.calculate_sampling_prob()
        self.generate_pair_word()
        
    def generate_pair_word(self):
        print("Generating pair word data...")
        for sentence in self.data:
            words = sentence.split(" ")
            word_len = len(words)
            for idx in range(word_len):
                # get left and right side output word of skip word
                lower_idx = (idx - self.window_size) if idx >= self.window_size else 0
                upper_idx = (idx + self.window_size) if idx <= (word_len - 1) - self.window_size else (word_len - 1)
                BOW = []
                for offset in range((upper_idx - lower_idx) + 1):
                    nearest_idx = lower_idx + offset
                    if nearest_idx != idx:
                        BOW.append(words[nearest_idx])
                if self.skip_gram: # use skip gram mode
                    self.pair_word_data.append((words[idx], BOW))
                else: # use CBOW mode
                    self.pair_word_data.append((BOW, words[idx]))
    
    def make_dict(self):
        print("Makeing dictionary...")
        for sentence in self.data:
            words = sentence.split(" ")
            for word in words:
                if word not in self.word2idx:
                    word_idx = len(self.word2idx)
                    self.word2idx[word] = word_idx
                    self.idx2word[word_idx] = word
                    self.word_prob.append(0)
                self.word_prob[int(self.word2idx[word])] += 1
                
    def calculate_sampling_prob(self):
        self.word_prob = [prob * 0.75 for prob in self.word_prob]
        prob_sum = 0
        for prob in self.word_prob:
            prob_sum += prob
        self.word_prob = [prob / prob_sum for prob in self.word_prob]
        
    def __getitem__(self, idx: int) -> list:
        # pair_word_data: list[(list of words, word)]
        return self.pair_word_data[idx]
    
    def __len__(self) -> int:
        return len(self.pair_word_data)

In [68]:
"""
%run utils.ipynb
dataset = Word2VecDataset(train_set, 2)
print(dataset.word_prob)
print(len(dataset))
print(dataset[1])
"""

Reading data...
Makeing dictionary...
[0.002234815045974319, 0.0028192699197251993, 0.00016709610620292792, 0.0005119702614710714, 0.0002487366695075548, 0.020156063185218023, 7.095862044233925e-05, 0.04432166992338758, 0.00019761594295232115, 0.03160176496215922, 0.0003822609552861502, 6.638064492993027e-05, 2.136388572457526e-05, 0.013580564357561251, 0.00043872265327252766, 2.288987756204492e-06, 0.1537177357638314, 0.0005035773063649882, 0.023098938443778264, 2.288987756204492e-06, 0.0008179316248837385, 0.0002487366695075548, 0.020632172638508556, 2.7467853074453904e-05, 0.011209173042133397, 0.00010910841637908079, 0.0035448790384420233, 0.0010300444902920215, 5.340971431143815e-06, 0.0002617076001260469, 0.00024797367358882, 0.00010910841637908079, 0.000500525322690049, 0.0003914169063109681, 2.5178865318249412e-05, 2.365287348077975e-05, 0.00018083003274015487, 0.00035021512669928726, 0.0005691949553761837, 0.00010834542046034596, 0.00037920897161121083, 0.0035616649486541895, 