In [1]:
import datetime
import math
import os
import pickle
import random

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [2]:
def read_in_data(text_file):

    data_output = []
    with open(text_file, "r") as file1:
        data_list = file1.readlines()
    for line in data_list:
        token_list = line.split()
        data_output.append(token_list)
    
    return data_output

In [3]:
training_data = read_in_data("../data/english-french_small/dev.en")

training_data = training_data[:5]

In [None]:
# TODO: Preprocess data

MAX_VOCABULARY_SIZE = 1000


In [4]:
def create_vocabulary(training_set):

    vocabulary = []
    for sentence in training_set:
        for word in sentence:
            if word not in vocabulary:
                vocabulary.append(word)
                
    w2i = dict()
    i2w = dict()
    for idx, word in enumerate(vocabulary):
        i2w[idx] = word
        w2i[word] = idx
    
    return w2i, i2w

In [5]:
w2i, i2w = create_vocabulary(training_data)
V = len(w2i.keys())

In [7]:
def generate_skipgram(sentence, context_window_size):

    skipgram_array = []
    for idx, word in enumerate(sentence):
        context_set = []
        window_size = context_window_size
        for index in range(max(idx - window_size, 0), min(len(sentence), idx + window_size + 1)):
            if index != idx:
                context_set.append([word, sentence[index]])
        skipgram_array.append(context_set)

    return skipgram_array

In [8]:
WINDOW_SIZE = 2
# Training data :
context_data = [generate_skipgram(sentence, WINDOW_SIZE) for sentence in training_data]

In [10]:
def make_input_batches(training_set, batch_size):

    sentences = training_set
    random.shuffle(sentences)
    
    new_data = []
    num_samples = len(sentences)
    for idx in range(num_samples // batch_size):
        batch = sentences[(idx)*batch_size : (idx+1)*batch_size]
        new_data.append(batch)
    
    return np.array(new_data)

In [19]:
def onehot(word, vocab_size=V):
    one_hot = torch.zeros(vocab_size)
    one_hot[w2i[word]] = 1.0

    return one_hot

In [12]:
BATCH_SIZE = 10
input_data = make_input_batches(training_data, BATCH_SIZE)

In [None]:
def divergence_closed_form(mu, variance):
    '''
    Closed form of the KL divergence
    '''
    return -0.5 * torch.sum(1 + variance - torch.pow(mu, 2) - torch.exp(variance))

In [None]:
def ELBO():
    '''
    Evidence Lower BOund
    '''
    pass

In [13]:
class linearity(nn.Module):
    def __init__(self, embedding_dimension, vocabulary_size):
        super(linearity, self).__init__()
        self.fc1 = nn.Linear(vocabulary_size, embedding_dimension)
    
    def forward(self, x):
        return self.fc1(x)

In [15]:
class BSG_Net(nn.Module):

    def __init__(self, vocabulary_size, embedding_dimension=20):

        super(BSG_Net, self).__init__()

        self.embedding_dimension = embedding_dimension

        self.fc1 = nn.Linear(vocabulary_size, embedding_dimension)
        self.fc2 = nn.Linear(embedding_dimension * 2, embedding_dimension * 2)
        self.fc3 = nn.Linear(embedding_dimension * 2, embedding_dimension)
        self.fc4 = nn.Linear(embedding_dimension * 2, embedding_dimension)
        
        # for reparameterization
        self.re1 = nn.Linear(embedding_dimension, vocabulary_size)

    def forward(self, x):

        context_representation = torch.zeros(self.embedding_dimension * 2)

        for pair in x:
            center_word = self.fc1(onehot(pair[0]))
            context_word = self.fc1(onehot(pair[1]))

            concatenated = torch.cat([center_word, context_word], dim=0)
            concatenated = F.relu(self.fc2(concatenated))
            context_representation += concatenated

        mu = self.fc3(context_representation)
        sigma = F.softplus(self.fc4(context_representation))

        epsilon = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(self.embedding_dimension), torch.eye(self.embedding_dimension)).sample()

        z = mu + epsilon * sigma

        output = F.softmax(self.re1(z), dim=0)
    
        return output, mu, sigma


In [27]:
class prior_Net(nn.Module):

    def __init__(self, vocabulary_size, embedding_dimension=20):

        super(prior_Net, self).__init__()

        self.embedding_dimension = embedding_dimension

        self.L = nn.Linear(vocabulary_size, embedding_dimension)
        self.S = nn.Linear(vocabulary_size, embedding_dimension)

        self.fc1 = nn.Linear(embedding_dimension, vocabulary_size)

    def forward(self, x):
        one_hot_x = onehot(x)

        mean = self.L(one_hot_x)
        std = F.softplus(self.S(one_hot_x))
        z = torch.distributions.multivariate_normal.MultivariateNormal(mean,torch.mm(torch.eye(self.embedding_dimension), std)).sample()

        return F.softmax(self.fc1(z))


In [28]:
model = prior_Net(V, 20)

X_data = []

for sentence in training_data:
    for context_set in sentence:
        print(model(context_set))
        break
    break


RuntimeError: dimension out of range (expected to be in range of [-1, 0], but got 1)

In [16]:
EMBEDDING_DIMENSION = 20

model = BSG_Net(V, EMBEDDING_DIMENSION)
X_data = []

for sentence in context_data:
    for context_set in sentence:
        print(model(context_set))
        break
    break
    
    
    
#     for context_set in sentence:
#         to_add = []
#         for pair in context_set:
#             center_word = model_1(onehot(pair[0]))
#             context_word = model_1(onehot(pair[1]))

#             concatenated = torch.cat([center_word, context_word], dim=0)
#             concatenated = F.relu(model_2(concatenated))
#             to_add.append(concatenated)
            
#         X_data.append(to_add)

(tensor(1.00000e-02 *
       [ 2.5028,  0.9399,  3.0797,  1.1794,  1.5582,  0.9494,  1.3026,
         0.7748,  2.1435,  1.0236,  1.1902,  1.1210,  0.8714,  3.0826,
         1.5120,  2.0029,  1.0275,  1.4397,  1.3446,  1.4108,  1.6099,
         0.8656,  0.4063,  0.8884,  0.8358,  1.0510,  0.7438,  0.7721,
         1.1343,  3.1821,  1.6474,  1.1428,  1.0547,  0.8487,  2.2437,
         1.8391,  0.7768,  1.7111,  1.5330,  0.5382,  1.0233,  1.4370,
         2.2370,  1.3191,  1.4335,  1.5465,  0.7196,  2.0937,  1.5601,
         0.8138,  0.9269,  1.6357,  1.1205,  1.0677,  1.8069,  0.7476,
         1.8372,  1.8415,  0.8729,  1.5163,  1.1594,  1.4797,  0.7273,
         2.6367,  1.0206,  1.0036,  1.4164,  1.6321,  0.9004,  0.4773,
         1.5247,  1.2849,  0.5629,  1.3362]), tensor([ 0.0628,  0.0070, -0.0653,  0.0863,  0.0874,  0.0363, -0.0386,
        -0.1624, -0.0034,  0.3828, -0.0785,  0.3049, -0.1053, -0.0991,
        -0.0387,  0.1531, -0.1124,  0.2176, -0.1926, -0.0540]), tensor([ 0.6794,