#### Relevant import statements

In [14]:
import datetime
import math
import os
import pickle
import random

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

#### Defining the data set:

In [15]:
# from utils import generate_training_set()

def read_in_data(text_file):

    data_output = []
    with open(text_file, "r") as file1:
        data_list = file1.readlines()
    for line in data_list:
        token_list = line.split()
        data_output.append(token_list)
    
    return data_output

#### Creating the vocabulary:

In [16]:
training_data = read_in_data("../data/english-french_small/dev.en")

In [17]:
def create_vocabulary(training_set):

    vocabulary = []
    for sentence in training_set:
        for word in sentence:
            if word not in vocabulary:
                vocabulary.append(word)
                
    w2i = dict()
    i2w = dict()
    for idx, word in enumerate(vocabulary):
        i2w[idx] = word
        w2i[word] = idx
    
    return w2i, i2w

In [18]:
w2i, i2w = create_vocabulary(training_data)

In [19]:
print(len(w2i.keys()))

322


#### Generating the skipgram data for a single sentence:

In [20]:
def generate_skipgram(sentence, context_window_size):

    skipgram_array = []
    for idx, word in enumerate(sentence):
        window_size = random.randint(1, context_window_size)
        for index in range(max(idx - window_size, 0), min(len(sentence), idx + window_size + 1)):
            if index != idx:
                skipgram_array.append([w2i[word], w2i[sentence[index]]])

    return skipgram_array

#### Generating the data for the whole training corpus:

In [21]:
def generate_corpus_skipgrams(training_set, window_size):

    skipgrams = []
    for sentence in training_set:
        for skipgram in generate_skipgram(sentence, window_size):
            skipgrams.append(skipgram)
    random.shuffle(skipgrams)

    return np.array(skipgrams)

#### Defining the network which will generate the embeddings:

In [54]:
class Skipgram_Net(nn.Module):

    def __init__(self, embedding_dimension, vocabulary_size):
        super(Skipgram_Net, self).__init__()
        self.fc1 = nn.Linear(vocabulary_size, embedding_dimension)
        self.fc2 = nn.Linear(embedding_dimension, vocabulary_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 0)

In [None]:
y = np.array([[1, 2, 3],
             [4, 5, 6],
             [7, 8, 9]])
y = torch.tensor(y)

In [55]:
def make_batches(skipgram_training_data, batch_size):

    new_data = []
    num_samples = skipgram_training_data.shape[0]
    for idx in range(num_samples // batch_size):
        batch = skipgram_training_data[(idx)*batch_size : (idx+1)*batch_size]
        new_data.append(batch)
    
    return np.array(new_data)

In [56]:
BATCH_SIZE = 10
skipgram_training_data = generate_corpus_skipgrams(training_data, 5)
skipgram_training_data = make_batches(skipgram_training_data, BATCH_SIZE)

#### Training the network:

In [None]:
EMBEDDING_DIMENSION = 20
epochs = 1000

vocab_size = len(w2i.keys())
model = Skipgram_Net(EMBEDDING_DIMENSION, vocab_size)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss_fn = torch.nn.NLLLoss()

print(len(skipgram_training_data))

for epoch in range(epochs):
    print("EPOCH NUMBER:", epoch)
    i = 0
    
    for data_point in skipgram_training_data:
        x_values = data_point[:, 0]
        y_values = data_point[:, 1]
        input_to_network = torch.zeros(BATCH_SIZE, vocab_size)

        for idx in range(BATCH_SIZE):
            input_to_network[idx, x_values[idx]] = 1.0
        target = torch.tensor(y_values, dtype=torch.long)
        output_of_network = model(input_to_network)
        
        loss = loss_fn(output_of_network, target)
        if loss.item() < 0.01:
            break
        
        if i == 0:
            print("LOSS at step {} was {}".format(i, loss.item()))
        i +=1
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

351
EPOCH NUMBER: 0
LOSS at step 0 was 2.298279285430908
EPOCH NUMBER: 1
LOSS at step 0 was 2.287540912628174
EPOCH NUMBER: 2
LOSS at step 0 was 2.2682347297668457
EPOCH NUMBER: 3
LOSS at step 0 was 2.2258684635162354
EPOCH NUMBER: 4
LOSS at step 0 was 2.1514077186584473
EPOCH NUMBER: 5
LOSS at step 0 was 2.0500528812408447
EPOCH NUMBER: 6
LOSS at step 0 was 1.9372835159301758
EPOCH NUMBER: 7
LOSS at step 0 was 1.8262275457382202
EPOCH NUMBER: 8
LOSS at step 0 was 1.7230656147003174
EPOCH NUMBER: 9
LOSS at step 0 was 1.6293684244155884
EPOCH NUMBER: 10
LOSS at step 0 was 1.5448157787322998
EPOCH NUMBER: 11
LOSS at step 0 was 1.4684841632843018
EPOCH NUMBER: 12
LOSS at step 0 was 1.3993176221847534
EPOCH NUMBER: 13
LOSS at step 0 was 1.3363301753997803
EPOCH NUMBER: 14
LOSS at step 0 was 1.2786991596221924
EPOCH NUMBER: 15
LOSS at step 0 was 1.2257839441299438
EPOCH NUMBER: 16
LOSS at step 0 was 1.1770950555801392
EPOCH NUMBER: 17
LOSS at step 0 was 1.1322457790374756
EPOCH NUMBER: 18
L

EPOCH NUMBER: 148
LOSS at step 0 was 0.45528125762939453
EPOCH NUMBER: 149
LOSS at step 0 was 0.455468088388443
EPOCH NUMBER: 150
LOSS at step 0 was 0.4551655650138855
EPOCH NUMBER: 151
LOSS at step 0 was 0.4553641378879547
EPOCH NUMBER: 152
LOSS at step 0 was 0.4550662934780121
EPOCH NUMBER: 153
LOSS at step 0 was 0.4552761912345886
EPOCH NUMBER: 154
LOSS at step 0 was 0.4549824297428131
EPOCH NUMBER: 155
LOSS at step 0 was 0.4552031457424164
EPOCH NUMBER: 156
LOSS at step 0 was 0.4549132287502289
EPOCH NUMBER: 157
LOSS at step 0 was 0.4551447927951813
EPOCH NUMBER: 158
LOSS at step 0 was 0.4548572897911072
EPOCH NUMBER: 159
LOSS at step 0 was 0.4550991654396057
EPOCH NUMBER: 160
LOSS at step 0 was 0.4548136591911316
EPOCH NUMBER: 161
LOSS at step 0 was 0.45506566762924194
EPOCH NUMBER: 162
LOSS at step 0 was 0.454781711101532
EPOCH NUMBER: 163
LOSS at step 0 was 0.45504388213157654
EPOCH NUMBER: 164
LOSS at step 0 was 0.45476093888282776
EPOCH NUMBER: 165
LOSS at step 0 was 0.4550325

EPOCH NUMBER: 295
LOSS at step 0 was 0.45587649941444397
EPOCH NUMBER: 296
LOSS at step 0 was 0.4552501142024994
EPOCH NUMBER: 297
LOSS at step 0 was 0.455859512090683
EPOCH NUMBER: 298
LOSS at step 0 was 0.45522984862327576
EPOCH NUMBER: 299
LOSS at step 0 was 0.45584186911582947
EPOCH NUMBER: 300
LOSS at step 0 was 0.45520901679992676
EPOCH NUMBER: 301
LOSS at step 0 was 0.455824077129364
EPOCH NUMBER: 302
LOSS at step 0 was 0.4551883339881897
EPOCH NUMBER: 303
LOSS at step 0 was 0.455806165933609
EPOCH NUMBER: 304
LOSS at step 0 was 0.4551669955253601
EPOCH NUMBER: 305
LOSS at step 0 was 0.4557882249355316
EPOCH NUMBER: 306
LOSS at step 0 was 0.45514655113220215
EPOCH NUMBER: 307
LOSS at step 0 was 0.455770343542099
EPOCH NUMBER: 308
LOSS at step 0 was 0.45512619614601135
EPOCH NUMBER: 309
LOSS at step 0 was 0.4557517468929291
EPOCH NUMBER: 310
LOSS at step 0 was 0.4551040530204773
EPOCH NUMBER: 311
LOSS at step 0 was 0.45573315024375916
EPOCH NUMBER: 312
LOSS at step 0 was 0.455082

In [26]:
# Save trained model

# st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
# create directory if it does not exist
if not os.path.exists('../models'):
    os.makedirs('../models')

with open('../models/skipgram_{}-{}.model'.format(str(epochs), str(EMBEDDING_DIMENSION)), 'wb') as f:
    pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)

#### Getting the word embeddings from the trained model:

In [42]:
def create_embeddings(trained_model):

    embeddings = dict()
    params = list(trained_model.parameters())
    learned_weights = trained_model.fc1.weight.data
    
    for word in w2i.keys():
        word_idx = w2i[word]
        embeddings[word] = learned_weights[:, word_idx].numpy()
    
    return embeddings

In [43]:
embeddings_filepath = '../models/embeddings.pickle'

# check if embeddings file exists
if os.path.exists(embeddings_filepath):
    with open(embeddings_filepath, 'rb') as file:
        embeddings_dict = pickle.load(file)
else:
    embeddings_dict = create_embeddings(model)
    with open(embeddings_filepath, 'wb') as file:
        pickle.dump(embeddings_dict, file, protocol=pickle.HIGHEST_PROTOCOL)

In [47]:
print(embeddings_dict["the"])
print(embeddings_dict["as"])
print(embeddings_dict["big"])

[-3.03332098e-02  4.79692519e-02  4.85931449e-02 -3.48969921e-02
  3.10392724e-03  4.31550257e-02  3.60021368e-02 -4.11281399e-02
  4.39530499e-02  1.57726705e-02 -4.27361578e-02  2.90390607e-02
 -5.16125597e-02  1.16651533e-02  4.88539506e-03 -5.40946051e-03
  4.61067408e-02  2.15863455e-02  8.72466262e-05 -4.32019755e-02
  4.78134304e-02 -3.20397392e-02 -4.47081923e-02  4.25932184e-02
 -2.26148162e-02  1.47703011e-02  5.23346290e-02 -3.06650903e-02
  5.37822954e-02 -1.88278444e-02 -5.73800579e-02 -7.77467713e-03
  2.38426141e-02 -5.25103211e-02  5.50821330e-03 -2.64428928e-02
 -7.21240183e-03  4.44953404e-02 -4.04703878e-02 -3.83299552e-02
  3.72918099e-02  4.39799950e-02 -3.07066049e-02  3.52076325e-03
  3.60439047e-02  4.35653999e-02 -3.20944302e-02  4.92330752e-02
  1.48759745e-02  3.86610553e-02  2.55891364e-02 -9.20342980e-04
 -2.46214960e-02  5.10316230e-02  5.32553671e-03 -3.01607940e-02
 -2.99250670e-02  3.55398245e-02  1.68395299e-03 -1.42269256e-02
  4.00430784e-02  4.90308

KeyError: 'big'