In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
import pickle
import re
from collections import defaultdict
from tqdm import tqdm
from scipy.spatial.distance import cosine

### Preparing Data

In [12]:
# text_data = []

In [13]:
# def file_to_data(path):
#     try:
#         with open(path) as f:
#             data = [[x.rstrip().split('\t')[1], x.rstrip().split('\t')[2], x.rstrip().split('\t')[0]]  for x in f.readlines()]
#     except FileNotFoundError:
#         print("File does not exist")
#         return
    
#     formatted_data = []
#     for row in data:
#         formatted_data.append([row[0], row[1], row[2]])

#     return formatted_data

In [123]:
# year = '2016'
# data_file = 'postediting.test.tsv'
# path = f'data/sts/semeval-sts/{year}/{data_file}'
# data = file_to_data(path)

In [127]:
with open('data/Multi-30k/train.en') as f:
    data = [x.strip() for x in f.readlines()]
data[:5]

['Two young, White males are outside near many bushes.',
 'Several men in hard hats are operating a giant pulley system.',
 'A little girl climbing into a wooden playhouse.',
 'A man in a blue shirt is standing on a ladder cleaning a window.',
 'Two men are at the stove preparing food.']

In [135]:
# for dat in data[15000:20000]:
#     text_data.append(dat)

In [136]:
# text_data = list(set(text_data))
# len(text_data)

60265

In [3]:
with open('data/base_text_data.pkl', 'rb') as f:
    text_data = pickle.load(f)

### Word2Vec Training

In [4]:
class Tokenizer:
    def __init__(self):
        self.punctuations = [r'\.', r'\.{2,}',
                             r'\!+', r'\:+', r'\;+', r'\"+', r"\'+", r'\?+', r'\,+', r'\(|\)|\[|\]|\{|\}|\<|\>']

    def clean(self, line):
        for pattern in self.punctuations:
            line = re.sub(pattern, '', line)
        line = re.sub(r'[^a-z]', ' ', line.lower())
        return line

    def tokenize(self, line):
        line = self.clean(line)
        return line.split()

In [5]:
class Review:
    tokenizer = Tokenizer()
    
    def __init__(self, text):
        self.text = text
        self.tokens = Review.tokenizer.tokenize(self.text)

    def __iter__(self):
        return iter(self.tokens)

    def __str__(self):
        return self.text

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx]


In [147]:
training_data = []

for r in text_data:
    training_data.append(Review(r))

In [6]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, inputs):
        input_embeds = self.embeddings(inputs)
        embeds = torch.mean(input_embeds, dim=1)
        out = self.linear(embeds)
        return F.log_softmax(out, dim=1)

In [7]:
class Word2Vec:
    def __init__(self, reviews, context_size=2, embedding_size=50, oov_threshold=2, neg_sample_size=5, lr=0.001):

        self.reviews = reviews
        self.oov_threshold = oov_threshold
        self.oov_token = '<OOV>'
        self.context_size = context_size
        self.embedding_size = embedding_size

        self.vocabulary = {self.oov_token}
        self.vocab_idx = {self.oov_token: 0}
        self.vocab_ridx = {0: self.oov_token}

        self.freq = defaultdict(int)
        self.freq_dist = [0]
        self.total_word_count = 0

        self.build_vocabulary()

        self.BATCH_SIZE = 64
        self.neg_sample_size = neg_sample_size

        self.model = CBOW(self.N, self.embedding_size)
        self.dataset = self.build_dataset()
        self.weights = self.negative_sampling()
        self.optimizer = optim.Adam(self.model.parameters(), lr)
    
    def build_vocabulary(self):
        
        print("Building Vocabulary")
        for review in tqdm(self.reviews):
            for token in review:
                self.freq[token] += 1

        index = 1
        for token, f in self.freq.items():
            if f > self.oov_threshold:
                self.vocabulary.add(token)
                self.vocab_idx[token] = index
                self.vocab_ridx[index] = token
                self.freq_dist.append(f)
                index += 1
            else:
                self.freq_dist[0] += f

        self.total_word_count = sum(self.freq.values())        
        self.N = len(self.vocabulary)
        print(f"Total Vocabulary Size: {self.N}")

    def build_dataset(self):

        print("Building Dataset")
        dataset = []
        for review in tqdm(self.reviews):
            for i in range(self.context_size, len(review) - self.context_size):
                focus = review[i]
                if focus not in self.vocabulary:
                    focus = self.oov_token
                focus_index = self.vocab_idx[focus]
                context_indices = []
                for j in range(i - self.context_size, i + self.context_size + 1):
                    if i == j:
                        continue
                    context = review[j]
                    if context not in self.vocabulary:
                        context = self.oov_token
                    context_index = self.vocab_idx[context]
                    context_indices.append(context_index)
                dataset.append((context_indices, focus_index))

        return dataset


    def negative_sampling(self):
        print("Computing Weights")
        normalized_freq = F.normalize(torch.Tensor(self.freq_dist).pow(0.75), dim=0)
        weights = torch.ones(len(self.freq_dist))

        for _ in tqdm(range(len(self.freq_dist))):
            for _ in range(self.neg_sample_size):
                neg_index = torch.multinomial(normalized_freq, 1)[0]
                weights[neg_index] += 1
        
        return weights


    def train(self, num_epochs):
        losses = []
        loss_fn = nn.NLLLoss(weight=self.weights)

        for epoch in range(num_epochs):
            print(f"Epoch {epoch}")
            net_loss = 0
            for i in tqdm(range(0, len(self.dataset), self.BATCH_SIZE)):
                batch = self.dataset[i: i+self.BATCH_SIZE]

                context = [x[0] for x in batch]
                focus = [x[1] for x in batch]

                context_var = Variable(torch.LongTensor(context))
                focus_var = Variable(torch.LongTensor(focus))

                self.optimizer.zero_grad()
                log_probs = self.model(context_var)
                loss = loss_fn(log_probs, focus_var)
                loss.backward()
                self.optimizer.step()

                net_loss += loss.item()
            print(f"Loss: {loss.item()}")
            losses.append(net_loss)

    
    def get_embedding(self, word_idx):
        embedding_index = Variable(torch.LongTensor([word_idx]))
        return self.model.embeddings(embedding_index).data[0]
    
    def get_closest_vector(self, _word, k):
        
        word = _word.lower()

        if word not in self.vocabulary:
            word = self.oov_token

        distances = []
        focus_index = self.vocab_idx[word]
        focus_embedding = self.get_embedding(focus_index)

        for i in range(1, self.N):
            if i == focus_index:
                continue
        
            comp_embedding = self.get_embedding(i)
            comp_word = self.vocab_ridx[i]
            dist = cosine(focus_embedding, comp_embedding)
            distances.append({'Word': comp_word, 'Distance': dist})
        
        distances = sorted(distances, key=lambda x: x['Distance'])

        return [x['Word'] for x in distances[:k]]

In [153]:
encoder = Word2Vec(training_data)
encoder.train(1)

Building Vocabulary


100%|██████████| 60265/60265 [00:00<00:00, 394942.21it/s]


Total Vocabulary Size: 12907
Building Dataset


100%|██████████| 60265/60265 [00:01<00:00, 39110.01it/s]


Computing Weights


100%|██████████| 12907/12907 [00:28<00:00, 460.48it/s]


Epoch 0


100%|██████████| 7848/7848 [02:49<00:00, 46.36it/s]

Loss: 1.6111623048782349



