In [22]:
import torch
import torch.nn as nn
import os
from tqdm import tqdm
import string
import pickle
import torchvision.datasets as dset
import nltk

In [23]:
def get_word_list_from_file(filename):
    '''
    Function for loading words from a file, called from the get_word_list_from_dir function.
    Args: filename - file to load words from.
    '''
    with open(filename, 'r', encoding='utf-8') as f:
        # return the split results, which is all the words in the file.
        tokens = nltk.word_tokenize(f.read().lower())  
        return tokens

def get_word_list_from_dir(directory_path):
    '''
    Function for loading words from all *.txt files in a directory.
    Args: directory_path - directory where the *.txt files are stored.
    '''
    directory = os.fsencode(directory_path)
    text_all = []
    for path, directories, files in tqdm(os.walk(directory), position=1):
        for file in tqdm(files, position=0):
            filename = os.fsdecode(file)
            if filename.endswith(".txt"):
                text_all.extend(get_word_list_from_file(os.path.join(path,filename)))
            else:
                continue
    return text_all

In [24]:
DATA_DIR = "example_corpus_data"
CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
EMDEDDING_DIM = 100

raw_text = []

path_name = os.path.join(DATA_DIR, 'data.pickle')
if os.path.exists(path_name):
    with open(path_name, 'rb') as f:
        raw_text = pickle.load(f)
else:
    raw_text = get_word_list_from_dir(DATA_DIR)
    with open(path_name, 'wb') as f:
        pickle.dump(raw_text, f)

raw_text = raw_text
print(len(raw_text))
print(raw_text[:20])

23987393
['ciastko', 'sakiewka', 'z', 'orzech', 'składnik:', '100gram', 'masło', '150gram', 'margaryna', '350gram', 'mąka', '5', 'żółtek,', '1', 'łyżka', 'śmietana', '18%,', '150gram', 'orzech', 'włoski']


In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
from gensim.models import Word2Vec

In [27]:
model = Word2Vec(sentences=[raw_text], epochs=50, vector_size=100, window=5, min_count=1, workers=4, sg=0)

In [30]:
model.wv['żółtek']

array([-0.5241254 ,  0.42368373,  0.06257143, -0.17591274, -0.16976461,
       -1.0694007 ,  0.60494655,  0.47909427, -0.43066508, -0.3187518 ,
       -0.08316223, -0.6879639 , -0.01231351, -0.10204956,  0.36072493,
       -0.5379919 ,  0.32504278, -0.6980526 ,  0.36025333, -1.1260055 ,
        0.5076809 , -0.22917123,  0.78024805, -0.07851519, -0.22857414,
        0.41238773, -0.04523865,  0.02208736, -0.36956808, -0.06972183,
        0.5230254 ,  0.27348298,  0.20937549, -0.30427164, -0.34388313,
        0.14014593, -0.06998207,  0.29179287,  0.05764451, -0.29479212,
        0.32900602, -0.14803606,  0.30954534,  0.1886388 ,  0.03757096,
       -0.22584778, -0.46406853,  0.15664963,  0.48777917,  0.20064838,
        0.06301374, -0.5083499 ,  0.07579345,  0.00242908, -0.26891017,
        0.21862704,  0.2888668 ,  0.3996825 , -0.34850234,  0.14324716,
        0.07275969, -0.18616554,  0.30536485,  0.34978428,  0.25909764,
        0.8929001 , -0.13722323,  0.26157075, -0.99892277, -0.23