In [1]:
import datetime
import math
import os
import pickle
import random

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [2]:
def read_in_data(text_file):

    data_output = []
    with open(text_file, "r") as file1:
        data_list = file1.readlines()
    for line in data_list:
        token_list = line.split()
        data_output.append(token_list)
    
    return data_output

In [33]:
training_data = read_in_data("../data/english-french_small/dev.en")

training_data = training_data[:5]

In [34]:
def create_vocabulary(training_set):

    vocabulary = []
    for sentence in training_set:
        for word in sentence:
            if word not in vocabulary:
                vocabulary.append(word)
                
    w2i = dict()
    i2w = dict()
    for idx, word in enumerate(vocabulary):
        i2w[idx] = word
        w2i[word] = idx
    
    return w2i, i2w

In [35]:
w2i, i2w = create_vocabulary(training_data)
V = len(w2i.keys())

In [36]:
print(len(w2i.keys()))

74


In [37]:
def generate_skipgram(sentence, context_window_size):

    skipgram_array = []
    for idx, word in enumerate(sentence):
        context_set = []
        window_size = context_window_size
        for index in range(max(idx - window_size, 0), min(len(sentence), idx + window_size + 1)):
            if index != idx:
                context_set.append([word, sentence[index]])
        skipgram_array.append(context_set)

    return skipgram_array

In [38]:
WINDOW_SIZE = 2
# Training data :
context_data = [generate_skipgram(sentence, WINDOW_SIZE) for sentence in training_data]

In [39]:
# flat_list = []


# for sublist in context_data:
#     for item_1 in sublist:
#         for item_2
#         flat_list.append(item)
        
    

print(len(context_data[0]))
print(len(training_data[0]))


print(training_data[0])
print("======")
print(context_data[0])


28
28
['each', 'of', 'them', 'is', 'very', 'complex', ',', 'but', 'the', 'link', 'between', 'the', 'two', 'is', 'even', 'more', 'complex', 'which', 'makes', 'the', 'whole', 'situation', 'for', 'most', 'people', 'understandably', 'confusing', '.']
[[['each', 'of'], ['each', 'them']], [['of', 'each'], ['of', 'them'], ['of', 'is']], [['them', 'each'], ['them', 'of'], ['them', 'is'], ['them', 'very']], [['is', 'of'], ['is', 'them'], ['is', 'very'], ['is', 'complex']], [['very', 'them'], ['very', 'is'], ['very', 'complex'], ['very', ',']], [['complex', 'is'], ['complex', 'very'], ['complex', ','], ['complex', 'but']], [[',', 'very'], [',', 'complex'], [',', 'but'], [',', 'the']], [['but', 'complex'], ['but', ','], ['but', 'the'], ['but', 'link']], [['the', ','], ['the', 'but'], ['the', 'link'], ['the', 'between']], [['link', 'but'], ['link', 'the'], ['link', 'between'], ['link', 'the']], [['between', 'the'], ['between', 'link'], ['between', 'the'], ['between', 'two']], [['the', 'link'], ['t

In [10]:
def make_input_batches(training_set, batch_size):

    sentences = training_set
    random.shuffle(sentences)
    
    new_data = []
    num_samples = len(sentences)
    for idx in range(num_samples // batch_size):
        batch = sentences[(idx)*batch_size : (idx+1)*batch_size]
        new_data.append(batch)
    
    return np.array(new_data)

In [57]:
def get_one_hot_vector(word, vocab_size=V):
    one_hot = torch.zeros(vocab_size)
    one_hot[w2i[word]] = 1.0

    return one_hot

In [11]:
BATCH_SIZE = 10
input_data = make_input_batches(training_data, BATCH_SIZE)

In [46]:
class linearity(nn.Module):
    def __init__(self, embedding_dimension, vocabulary_size):
        super(linearity, self).__init__()
        self.fc1 = nn.Linear(vocabulary_size, embedding_dimension)
    
    def forward(self, x):
        return self.fc1(x)

In [73]:
epsilon1 = torch.distributions.multivariate_normal.MultivariateNormal(
torch.zeros(2),
torch.eye(2)
).sample()


epsilon2 = torch.ones(1,2)

print(epsilon1)
print(epsilon2)
# print(epsilon1.mul(epsilon2))
print(epsilon1 * epsilon2)

tensor([ 0.2424,  0.3701])
tensor([[ 1.,  1.]])
tensor([[ 0.2424,  0.3701]])


In [88]:
class BSG_Net(nn.Module):

    def __init__(self, vocabulary_size, embedding_dimension=20):

        super(BSG_Net, self).__init__()

        self.fc1 = nn.Linear(vocabulary_size, embedding_dimension)
        self.fc2 = nn.Linear(embedding_dimension * 2, embedding_dimension * 2)
        self.fc3 = nn.Linear(embedding_dimension * 2, embedding_dimension)
        self.fc4 = nn.Linear(embedding_dimension * 2, embedding_dimension)

        self.embedding_dimension = embedding_dimension

        self.re1 = nn.Linear(embedding_dimension, vocabulary_size)


    def forward(self, x):

        context_representation = torch.zeros(self.embedding_dimension * 2)

        for pair in x:
            center_word = self.fc1(get_one_hot_vector(pair[0]))
            context_word = self.fc1(get_one_hot_vector(pair[1]))

            concatenated = torch.cat([center_word, context_word], dim=0)
            concatenated = F.relu(self.fc2(concatenated))
            context_representation += concatenated

        mu = self.fc3(context_representation)
        sigma = F.softplus(self.fc4(context_representation))

        epsilon = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(self.embedding_dimension), torch.eye(self.embedding_dimension)).sample()

        z = mu + epsilon * sigma

        output = F.softmax(self.re1(z), dim=0)
    
        return output, mu, sigma


In [89]:
EMBEDDING_DIMENSION = 20

model = BSG_Net(V, EMBEDDING_DIMENSION)
X_data = []

for sentence in context_data:
    for context_set in sentence:
        print(model(context_set))
        break
    break
    
    
    
#     for context_set in sentence:
#         to_add = []
#         for pair in context_set:
#             center_word = model_1(get_one_hot_vector(pair[0]))
#             context_word = model_1(get_one_hot_vector(pair[1]))

#             concatenated = torch.cat([center_word, context_word], dim=0)
#             concatenated = F.relu(model_2(concatenated))
#             to_add.append(concatenated)
            
#         X_data.append(to_add)

(tensor(1.00000e-02 *
       [ 2.3372,  0.8322,  0.6292,  1.7487,  0.8757,  1.5447,  1.3844,
         1.7886,  0.5361,  1.6353,  3.1622,  0.6323,  1.9057,  2.5837,
         1.1526,  0.6768,  1.4584,  1.2509,  1.0355,  2.1900,  0.6910,
         0.7611,  1.2955,  1.2102,  2.6714,  2.1262,  0.6125,  0.9043,
         0.9463,  0.8904,  1.4939,  1.1081,  1.1373,  0.5659,  0.6063,
         2.2028,  1.1779,  0.9723,  1.0147,  3.9599,  2.1676,  1.7888,
         1.3008,  0.6819,  2.1627,  0.9816,  1.2672,  3.3311,  1.3914,
         0.6511,  0.7149,  1.4801,  1.9497,  1.0113,  0.9677,  0.7966,
         1.5462,  1.0199,  1.4093,  0.5986,  2.0467,  0.9312,  0.6401,
         0.5863,  1.7831,  0.8144,  0.8820,  0.5907,  1.5452,  0.8565,
         3.1911,  1.1057,  0.8524,  1.2479]), tensor([ 0.1090,  0.0958, -0.2351, -0.2278, -0.2172,  0.0290, -0.2035,
        -0.1593, -0.2088, -0.0719, -0.1707, -0.1486, -0.2703, -0.0294,
         0.0143, -0.1398, -0.1037,  0.2416,  0.0019,  0.2682]), tensor([ 0.5642,