# Make embeddings for Bengali language
This notebook handles the embedding process.

### Input:
    - Pre-processed training dataframe.

### Output:
    - The trained weights of the embedding layer.

## Import libraries

In [1]:
# Imports
import re
import string
import json
from datetime import datetime
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import Module
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

device = 'cuda'

import random

random.seed(26)
np.random.seed(62)
torch.manual_seed(2021)

<torch._C.Generator at 0x7f35c4ef9f90>

## Load data

In [2]:
ben_train_df = pd.read_csv('save/bengali_hatespeech_embed_train_preprocessed.csv')
# remove empty texts
ben_train_df = ben_train_df[ben_train_df.sentence.str.len() > 0]

# load mapping {word -> id} and {id -> word}
with open('save/word_to_int_dict.json') as f:
    word_to_int = json.load(f)
with open('save/int_to_word_dict.json') as f:
    int_to_word = json.load(f)
    int_to_word = {int(k) : v for k, v in int_to_word.items()}
with open('save/word_counter.json') as f:
    word_counter = json.load(f)

# get vocab_size
vocab_size = len(word_to_int)
print(f'vocab_size: {vocab_size}')

# get total occurences
total_words = sum(word_counter.values())
print(f'total word occurences: {total_words}')

# extract sentences and labels
train_sentences = [[word_to_int[w] for w in text.split()] for text in ben_train_df['sentence']]
train_labels = ben_train_df['hate'].to_numpy()

vocab_size: 55189
total word occurences: 303627


### Constants and Hyper-parameters

In [3]:
model_save_path = 'save/word2vec_neg.pt'

window_size = 5
embedding_size = 300
neg_sample_factor = 10
noise_dist_alpha = 3/4
learning_rate = 0.01
lr_decay = lambda epoch: max(0.05, 0.9**epoch)
batch_size = 256
epochs = 100

## skip-gram

In [4]:
# sampling probability of pair (center, context)
def sampling_prob(word):
    z = word_counter[word] / total_words
    p_keep = ((z/0.000001)**0.5 + 1) * (0.000001/z)
    return p_keep

In [5]:
# noise distribution
noisy_words = [iw for iw in int_to_word]
noisy_dist = np.array([(word_counter[int_to_word[iw]]/total_words)**noise_dist_alpha for iw in noisy_words])
noisy_dist = noisy_dist / noisy_dist.sum()

# noisy word generator
def get_noise_word(batch_size, neg_factor):
    noise_list = np.random.choice(noisy_words, batch_size*neg_factor, p=noisy_dist)
    noise_list = noise_list.reshape((batch_size, neg_factor))
    return torch.from_numpy(noise_list)

In [6]:
def get_target_context(sentence: list(str())):
    for i, word in enumerate(sentence):
        for j, context_word in enumerate(sentence[i-window_size:i+window_size+1]):
            if j != i and random.random() < sampling_prob(int_to_word[context_word]):
                    yield (torch.tensor(word, dtype=torch.long), 
                           torch.tensor(context_word, dtype=torch.long)
                          )

## Train word-embedding

### Model

In [7]:
class Word2Vec(Module):
    def __init__(self):
        super(Word2Vec, self).__init__()
        self.center_embed = nn.Embedding(vocab_size, embedding_size)        
        self.context_embed = nn.Embedding(vocab_size, embedding_size)
        
        init_range = (2 / (vocab_size + embedding_size)) ** 0.5
        self.center_embed.weight.data.uniform_(-init_range, init_range)
        self.context_embed.weight.data.uniform_(-init_range, init_range)
        
        self.log_sigmoid = nn.LogSigmoid()

    def forward(self, center_ids, context_ids, negative_samples):
        # center_ids, context_ids: [batch_size]
        # negatve_samples: [batch_size, neg_sample_factor]
        
        # center_embed, context_embed: [batch_size, embedding_size]
        center_embed = self.center_embed(center_ids)
        context_embed = self.context_embed(context_ids)
        
        # pos_dot: [batch_size]
        pos_dot = (center_embed * context_embed).sum(axis=1)
        
        # pos_loss: [batch_size]
        pos_loss = self.log_sigmoid(pos_dot)
        
        # negative_embed: [batch_size, neg_sample_factor, embedding_size]
        negative_embed = self.context_embed(negative_samples)
        
        # negs_dot: [batch_size, neg_sample_factor]
        negs_dot = torch.bmm(negative_embed, center_embed.unsqueeze(2)).squeeze(2) * (-1)
        
        # neg_dot: [batch_size]
        neg_dot = negs_dot.sum(axis=1)
        
        # neg_loss: [batch_size]
        neg_loss = self.log_sigmoid(neg_dot)
        
        loss = -(pos_loss + neg_loss).sum()
        return loss, -pos_loss.sum(), -neg_loss.sum()
    
    def to_embed(self, center_id):
        return self.center_embed(center_id)
    
word2vec = Word2Vec()
torch.save(word2vec.state_dict(), model_save_path)

display(word2vec.parameters)

<bound method Module.parameters of Word2Vec(
  (center_embed): Embedding(55189, 300)
  (context_embed): Embedding(55189, 300)
  (log_sigmoid): LogSigmoid()
)>

### Optimizer and Learning-rate scheduler

In [8]:
optimizer = optim.Adam(word2vec.parameters(), lr=learning_rate)
scheduler = LambdaLR(optimizer, lr_lambda=lr_decay)

### Dataset

In [9]:
class W2VDataset(Dataset):
    def __init__(self, sentences):
        self.data = []
        for sentence in sentences:
            for data_point in get_target_context(sentence):
                self.data.append(data_point)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

### Learning parameters

In [10]:
# load initial weights
word2vec.load_state_dict(torch.load(model_save_path, map_location=torch.device(device)))
word2vec = word2vec.to(device)

early_stop = 5
history_losses = []
for epoch in range(1, epochs+1):
    train_dataset = W2VDataset(train_sentences)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    losses, pos_losses, neg_losses = 0., 0., 0.
    cnt = 0
    
    word2vec.train()
    for center_words, context_words in tqdm(train_loader):
        negative_samples = get_noise_word(len(center_words), neg_sample_factor)
        optimizer.zero_grad()
        loss, pos_loss, neg_loss = word2vec(center_words.to(device), context_words.to(device), negative_samples.to(device))
        loss.backward()
        optimizer.step()
        losses += loss
        cnt += len(center_words)
        pos_losses += pos_loss
        neg_losses += neg_loss

    scheduler.step()
    
    epoch_loss = losses / cnt
    print(f'Epoch {epoch:2}: training loss: {epoch_loss:.4f} (pos: {pos_losses/cnt:.4f}, neg: {neg_losses/(cnt*neg_sample_factor):.4f}) over {cnt} training points.')
    
    if epoch % 10 == 0:
        # save embedding
        embedding_weights = word2vec.center_embed.state_dict()
        torch.save(embedding_weights, f'save/embedding_checkpoints/{epoch}_epoch_{embedding_size}_dim_{window_size}_wsize_{neg_sample_factor}_negfac.pt')
    
    history_losses.append(epoch_loss)
    if len(history_losses) > early_stop and min(history_losses[-early_stop:]) >= min(history_losses[:-early_stop]):
        print(f'Early stopping: training loss does not decrease after {early_stop} epochs')
        break

print("Training finished")

100%|██████████| 1655/1655 [02:31<00:00, 10.95it/s]


Epoch  1: training loss: 1.4110 (pos: 1.0727, neg: 0.0338) over 423521 training points.


100%|██████████| 1657/1657 [02:28<00:00, 11.12it/s]


Epoch  2: training loss: 1.3803 (pos: 0.9942, neg: 0.0386) over 424083 training points.


100%|██████████| 1657/1657 [02:29<00:00, 11.12it/s]


Epoch  3: training loss: 1.2126 (pos: 0.8699, neg: 0.0343) over 424038 training points.


100%|██████████| 1659/1659 [02:29<00:00, 11.11it/s]


Epoch  4: training loss: 1.0393 (pos: 0.7624, neg: 0.0277) over 424571 training points.


100%|██████████| 1657/1657 [02:29<00:00, 11.07it/s]


Epoch  5: training loss: 0.9001 (pos: 0.6663, neg: 0.0234) over 424019 training points.


100%|██████████| 1655/1655 [02:29<00:00, 11.11it/s]


Epoch  6: training loss: 0.7711 (pos: 0.5760, neg: 0.0195) over 423453 training points.


100%|██████████| 1656/1656 [02:29<00:00, 11.11it/s]


Epoch  7: training loss: 0.6679 (pos: 0.5023, neg: 0.0166) over 423691 training points.


100%|██████████| 1653/1653 [02:29<00:00, 11.09it/s]


Epoch  8: training loss: 0.5844 (pos: 0.4399, neg: 0.0144) over 423090 training points.


100%|██████████| 1659/1659 [02:29<00:00, 11.11it/s]


Epoch  9: training loss: 0.5071 (pos: 0.3832, neg: 0.0124) over 424610 training points.


100%|██████████| 1657/1657 [02:30<00:00, 11.03it/s]


Epoch 10: training loss: 0.4510 (pos: 0.3418, neg: 0.0109) over 423982 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.95it/s]


Epoch 11: training loss: 0.3958 (pos: 0.2992, neg: 0.0097) over 424160 training points.


100%|██████████| 1660/1660 [02:31<00:00, 10.95it/s]


Epoch 12: training loss: 0.3537 (pos: 0.2668, neg: 0.0087) over 424767 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.95it/s]


Epoch 13: training loss: 0.3115 (pos: 0.2365, neg: 0.0075) over 423873 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.95it/s]


Epoch 14: training loss: 0.2826 (pos: 0.2154, neg: 0.0067) over 423825 training points.


100%|██████████| 1655/1655 [02:34<00:00, 10.72it/s]


Epoch 15: training loss: 0.2561 (pos: 0.1930, neg: 0.0063) over 423580 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.96it/s]


Epoch 16: training loss: 0.2315 (pos: 0.1751, neg: 0.0056) over 424074 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.95it/s]


Epoch 17: training loss: 0.2115 (pos: 0.1600, neg: 0.0052) over 424032 training points.


100%|██████████| 1659/1659 [02:31<00:00, 10.95it/s]


Epoch 18: training loss: 0.1953 (pos: 0.1462, neg: 0.0049) over 424482 training points.


100%|██████████| 1655/1655 [02:32<00:00, 10.85it/s]


Epoch 19: training loss: 0.1780 (pos: 0.1343, neg: 0.0044) over 423645 training points.


100%|██████████| 1658/1658 [02:27<00:00, 11.26it/s]


Epoch 20: training loss: 0.1664 (pos: 0.1234, neg: 0.0043) over 424431 training points.


100%|██████████| 1657/1657 [02:35<00:00, 10.68it/s]


Epoch 21: training loss: 0.1573 (pos: 0.1165, neg: 0.0041) over 424185 training points.


100%|██████████| 1657/1657 [02:29<00:00, 11.06it/s]


Epoch 22: training loss: 0.1461 (pos: 0.1084, neg: 0.0038) over 424027 training points.


100%|██████████| 1658/1658 [02:32<00:00, 10.85it/s]


Epoch 23: training loss: 0.1353 (pos: 0.0996, neg: 0.0036) over 424224 training points.


100%|██████████| 1656/1656 [02:28<00:00, 11.13it/s]


Epoch 24: training loss: 0.1294 (pos: 0.0943, neg: 0.0035) over 423735 training points.


100%|██████████| 1658/1658 [02:28<00:00, 11.18it/s]


Epoch 25: training loss: 0.1245 (pos: 0.0912, neg: 0.0033) over 424245 training points.


100%|██████████| 1655/1655 [02:28<00:00, 11.17it/s]


Epoch 26: training loss: 0.1184 (pos: 0.0872, neg: 0.0031) over 423512 training points.


100%|██████████| 1658/1658 [02:30<00:00, 11.03it/s]


Epoch 27: training loss: 0.1132 (pos: 0.0819, neg: 0.0031) over 424207 training points.


100%|██████████| 1656/1656 [02:29<00:00, 11.10it/s]


Epoch 28: training loss: 0.1125 (pos: 0.0819, neg: 0.0031) over 423903 training points.


100%|██████████| 1660/1660 [02:35<00:00, 10.71it/s]


Epoch 29: training loss: 0.1064 (pos: 0.0765, neg: 0.0030) over 424720 training points.


100%|██████████| 1657/1657 [02:32<00:00, 10.86it/s]


Epoch 30: training loss: 0.1042 (pos: 0.0754, neg: 0.0029) over 424109 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.94it/s]


Epoch 31: training loss: 0.1020 (pos: 0.0722, neg: 0.0030) over 424095 training points.


100%|██████████| 1657/1657 [02:30<00:00, 11.02it/s]


Epoch 32: training loss: 0.0990 (pos: 0.0690, neg: 0.0030) over 424041 training points.


100%|██████████| 1658/1658 [02:31<00:00, 10.97it/s]


Epoch 33: training loss: 0.0947 (pos: 0.0666, neg: 0.0028) over 424373 training points.


100%|██████████| 1655/1655 [02:32<00:00, 10.87it/s]


Epoch 34: training loss: 0.0947 (pos: 0.0662, neg: 0.0028) over 423572 training points.


100%|██████████| 1656/1656 [02:34<00:00, 10.72it/s]


Epoch 35: training loss: 0.0926 (pos: 0.0651, neg: 0.0028) over 423692 training points.


100%|██████████| 1657/1657 [02:37<00:00, 10.54it/s]


Epoch 36: training loss: 0.0879 (pos: 0.0623, neg: 0.0026) over 424129 training points.


100%|██████████| 1654/1654 [02:36<00:00, 10.58it/s]


Epoch 37: training loss: 0.0855 (pos: 0.0595, neg: 0.0026) over 423276 training points.


100%|██████████| 1656/1656 [02:36<00:00, 10.57it/s]


Epoch 38: training loss: 0.0845 (pos: 0.0590, neg: 0.0025) over 423686 training points.


100%|██████████| 1655/1655 [02:34<00:00, 10.73it/s]


Epoch 39: training loss: 0.0834 (pos: 0.0573, neg: 0.0026) over 423433 training points.


100%|██████████| 1658/1658 [02:38<00:00, 10.46it/s]


Epoch 40: training loss: 0.0801 (pos: 0.0552, neg: 0.0025) over 424295 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.92it/s]


Epoch 41: training loss: 0.0793 (pos: 0.0549, neg: 0.0024) over 424066 training points.


100%|██████████| 1658/1658 [02:40<00:00, 10.35it/s]


Epoch 42: training loss: 0.0768 (pos: 0.0523, neg: 0.0024) over 424283 training points.


100%|██████████| 1658/1658 [02:36<00:00, 10.56it/s]


Epoch 43: training loss: 0.0768 (pos: 0.0513, neg: 0.0026) over 424205 training points.


100%|██████████| 1655/1655 [02:36<00:00, 10.57it/s]


Epoch 44: training loss: 0.0739 (pos: 0.0496, neg: 0.0024) over 423547 training points.


100%|██████████| 1657/1657 [02:32<00:00, 10.87it/s]


Epoch 45: training loss: 0.0736 (pos: 0.0492, neg: 0.0024) over 424071 training points.


100%|██████████| 1659/1659 [02:31<00:00, 10.99it/s]


Epoch 46: training loss: 0.0732 (pos: 0.0488, neg: 0.0024) over 424454 training points.


100%|██████████| 1656/1656 [02:34<00:00, 10.69it/s]


Epoch 47: training loss: 0.0722 (pos: 0.0467, neg: 0.0026) over 423743 training points.


100%|██████████| 1657/1657 [02:34<00:00, 10.69it/s]


Epoch 48: training loss: 0.0713 (pos: 0.0462, neg: 0.0025) over 424044 training points.


100%|██████████| 1654/1654 [02:34<00:00, 10.67it/s]


Epoch 49: training loss: 0.0680 (pos: 0.0451, neg: 0.0023) over 423230 training points.


100%|██████████| 1656/1656 [02:35<00:00, 10.67it/s]


Epoch 50: training loss: 0.0683 (pos: 0.0443, neg: 0.0024) over 423795 training points.


100%|██████████| 1657/1657 [02:35<00:00, 10.67it/s]


Epoch 51: training loss: 0.0659 (pos: 0.0430, neg: 0.0023) over 424104 training points.


100%|██████████| 1658/1658 [02:35<00:00, 10.68it/s]


Epoch 52: training loss: 0.0656 (pos: 0.0426, neg: 0.0023) over 424401 training points.


100%|██████████| 1655/1655 [02:35<00:00, 10.67it/s]


Epoch 53: training loss: 0.0651 (pos: 0.0416, neg: 0.0024) over 423653 training points.


100%|██████████| 1656/1656 [02:35<00:00, 10.67it/s]


Epoch 54: training loss: 0.0649 (pos: 0.0412, neg: 0.0024) over 423876 training points.


100%|██████████| 1655/1655 [02:35<00:00, 10.67it/s]


Epoch 55: training loss: 0.0636 (pos: 0.0408, neg: 0.0023) over 423609 training points.


100%|██████████| 1659/1659 [02:35<00:00, 10.66it/s]


Epoch 56: training loss: 0.0633 (pos: 0.0398, neg: 0.0024) over 424686 training points.


100%|██████████| 1656/1656 [02:35<00:00, 10.67it/s]


Epoch 57: training loss: 0.0621 (pos: 0.0393, neg: 0.0023) over 423873 training points.


100%|██████████| 1656/1656 [02:35<00:00, 10.67it/s]


Epoch 58: training loss: 0.0590 (pos: 0.0385, neg: 0.0020) over 423920 training points.


100%|██████████| 1659/1659 [02:35<00:00, 10.66it/s]


Epoch 59: training loss: 0.0606 (pos: 0.0371, neg: 0.0023) over 424577 training points.


100%|██████████| 1657/1657 [02:35<00:00, 10.67it/s]


Epoch 60: training loss: 0.0576 (pos: 0.0361, neg: 0.0021) over 424027 training points.


100%|██████████| 1658/1658 [02:32<00:00, 10.89it/s]


Epoch 61: training loss: 0.0589 (pos: 0.0362, neg: 0.0023) over 424211 training points.


100%|██████████| 1656/1656 [02:32<00:00, 10.88it/s]


Epoch 62: training loss: 0.0578 (pos: 0.0356, neg: 0.0022) over 423936 training points.


100%|██████████| 1657/1657 [02:32<00:00, 10.90it/s]


Epoch 63: training loss: 0.0561 (pos: 0.0345, neg: 0.0022) over 424028 training points.


100%|██████████| 1654/1654 [02:31<00:00, 10.89it/s]


Epoch 64: training loss: 0.0571 (pos: 0.0352, neg: 0.0022) over 423243 training points.


100%|██████████| 1656/1656 [02:32<00:00, 10.89it/s]


Epoch 65: training loss: 0.0549 (pos: 0.0337, neg: 0.0021) over 423879 training points.


100%|██████████| 1658/1658 [02:32<00:00, 10.89it/s]


Epoch 66: training loss: 0.0548 (pos: 0.0325, neg: 0.0022) over 424337 training points.


100%|██████████| 1655/1655 [02:31<00:00, 10.90it/s]


Epoch 67: training loss: 0.0539 (pos: 0.0320, neg: 0.0022) over 423598 training points.


100%|██████████| 1652/1652 [02:31<00:00, 10.90it/s]


Epoch 68: training loss: 0.0532 (pos: 0.0322, neg: 0.0021) over 422750 training points.


100%|██████████| 1658/1658 [02:31<00:00, 10.91it/s]


Epoch 69: training loss: 0.0533 (pos: 0.0319, neg: 0.0021) over 424341 training points.


100%|██████████| 1658/1658 [02:32<00:00, 10.90it/s]


Epoch 70: training loss: 0.0512 (pos: 0.0304, neg: 0.0021) over 424267 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 71: training loss: 0.0526 (pos: 0.0307, neg: 0.0022) over 423928 training points.


100%|██████████| 1659/1659 [02:32<00:00, 10.91it/s]


Epoch 72: training loss: 0.0523 (pos: 0.0302, neg: 0.0022) over 424491 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.90it/s]


Epoch 73: training loss: 0.0510 (pos: 0.0295, neg: 0.0022) over 424019 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 74: training loss: 0.0501 (pos: 0.0288, neg: 0.0021) over 423890 training points.


100%|██████████| 1659/1659 [02:32<00:00, 10.90it/s]


Epoch 75: training loss: 0.0503 (pos: 0.0292, neg: 0.0021) over 424456 training points.


100%|██████████| 1654/1654 [02:31<00:00, 10.90it/s]


Epoch 76: training loss: 0.0500 (pos: 0.0285, neg: 0.0021) over 423347 training points.


100%|██████████| 1657/1657 [02:32<00:00, 10.90it/s]


Epoch 77: training loss: 0.0496 (pos: 0.0281, neg: 0.0021) over 424039 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 78: training loss: 0.0493 (pos: 0.0283, neg: 0.0021) over 423699 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 79: training loss: 0.0487 (pos: 0.0276, neg: 0.0021) over 423738 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.91it/s]


Epoch 80: training loss: 0.0479 (pos: 0.0269, neg: 0.0021) over 424143 training points.


100%|██████████| 1659/1659 [02:32<00:00, 10.91it/s]


Epoch 81: training loss: 0.0472 (pos: 0.0267, neg: 0.0020) over 424451 training points.


100%|██████████| 1660/1660 [02:32<00:00, 10.90it/s]


Epoch 82: training loss: 0.0482 (pos: 0.0271, neg: 0.0021) over 424882 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 83: training loss: 0.0471 (pos: 0.0264, neg: 0.0021) over 423804 training points.


100%|██████████| 1659/1659 [02:32<00:00, 10.90it/s]


Epoch 84: training loss: 0.0457 (pos: 0.0251, neg: 0.0021) over 424479 training points.


100%|██████████| 1658/1658 [02:32<00:00, 10.90it/s]


Epoch 85: training loss: 0.0477 (pos: 0.0259, neg: 0.0022) over 424321 training points.


100%|██████████| 1655/1655 [02:31<00:00, 10.90it/s]


Epoch 86: training loss: 0.0444 (pos: 0.0249, neg: 0.0020) over 423488 training points.


100%|██████████| 1656/1656 [02:31<00:00, 10.90it/s]


Epoch 87: training loss: 0.0459 (pos: 0.0253, neg: 0.0021) over 423805 training points.


100%|██████████| 1657/1657 [02:32<00:00, 10.90it/s]


Epoch 88: training loss: 0.0465 (pos: 0.0252, neg: 0.0021) over 424023 training points.


100%|██████████| 1657/1657 [02:31<00:00, 10.90it/s]


Epoch 89: training loss: 0.0461 (pos: 0.0245, neg: 0.0022) over 424169 training points.


100%|██████████| 1655/1655 [02:33<00:00, 10.77it/s]


Epoch 90: training loss: 0.0445 (pos: 0.0244, neg: 0.0020) over 423632 training points.


100%|██████████| 1656/1656 [02:40<00:00, 10.29it/s]


Epoch 91: training loss: 0.0445 (pos: 0.0242, neg: 0.0020) over 423789 training points.
Early stopping: training loss does not decrease after 5 epochs
Training finished


In [11]:
# save embedding weights
embedding_weights = word2vec.center_embed.state_dict()
torch.save(embedding_weights, f'save/big_embedding_weights_{window_size}_wsize_{neg_sample_factor}_negfac.pt')