In [None]:
import numpy as np 
from os import listdir
from collections import OrderedDict
import string
import torch
from torch.autograd import Variable
import torch.functional as F
import torch.nn.functional as F

In [None]:
def load_file(filename):
    file = open(filename,'r',encoding="utf8")
    text = file.read()
    file.close()
    return text

In [None]:
def clean_file(filename):
    token = filename.split()
    table = str.maketrans('', '', string.punctuation)
    token = [w.translate(table) for w in token]
    token = [word for word in token if word.isalpha()]
    token = [word for word in token if len(word) > 1]
    token = [w.lower() for w in token]
    tokens = list(OrderedDict.fromkeys(token))
    return tokens

In [None]:
def add_file_to_vocab(filename,vocab):
    doc = load_file(filename)
    tokens = clean_file(doc)
    for i in range(len(tokens)):
        vocab.append(tokens[i])
    return tokens

In [None]:
def process_docs(directory, vocab):
    for filename in listdir(directory):
        path = directory + '/' + filename
        tokens.append(add_file_to_vocab(path, vocab))

In [None]:
vocab = []
tokens = [[]]
process_docs('aclImdb_v1/aclImdb/train/neg', vocab)
process_docs('aclImdb_v1/aclImdb/train/pos', vocab)

In [None]:
clean_vocab = [] 
clean_vocab = list(OrderedDict.fromkeys(vocab))

In [None]:
def save_list(lines, filename):
    data = ' '.join(lines)
    file = open(filename, 'wb')
    file.write(data.encode("utf-8"))
    file.close()
 
save_list(clean_vocab, 'vocab.txt')

In [None]:
word2idx = {w: idx for (idx, w) in enumerate(clean_vocab)}
idx2word = {idx: w for (idx, w) in enumerate(clean_vocab)}

In [None]:
vocabulary_size = len(clean_vocab)

In [None]:
window_size = 2
idx_pairs = []
for sentence in tokens:
    indices = [word2idx[word] for word in sentence]
    for center_word_pos in range(len(indices)):
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) 

In [None]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x

In [None]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 101
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data[0]
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    if epo % 10 == 0:    
        print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')

In [None]:
print(log_softmax.shape)
# print(x)
u=W1
v=W2
print(log_softmax)

In [None]:
from sklearn.manifold import TSNE

In [None]:
labels = []
tokens = []
    
for w in vocab:
    x = Variable(get_input_layer(word2idx.get(w))).float()
    z1 = torch.matmul(W1, x)
    z2 = torch.matmul(W2, z1)
    log_softmax = F.log_softmax(z2, dim=0)
    labels.append(w)
    tokens.append(log_softmax)
    
tokens = [t.data.numpy() for t in tokens]
print(tokens[0])
print(len(labels))

In [None]:
print("Computing t-SNE embedding")
tsne = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
new_values = tsne.fit_transform(tokens)

In [None]:
x = []
y = []
for value in new_values:
    x.append(value[0])
    y.append(value[1])

plt.figure(figsize=(16, 16)) 
for i in range(len(x)):
    plt.scatter(x[i],y[i])
    plt.annotate(labels[i],xy=(x[i], y[i]),xytext=(5, 2),textcoords='offset points',ha='right',va='bottom')