In [1]:
import numpy as np
import pickle as pkl
import torch

# load the vocabulary
vocab_path = "../processed_data/remove-stopwords-punct-25000.vocab"

vocab = torch.load(vocab_path)

In [2]:
class embedding:
    def __init__(self, pretrained_embeds, vocab):
        self.embeds = pretrained_embeds
        self.vocab = vocab
        self.embed_dim = pretrained_embeds.vector_size
        self.vocab_size = len(vocab)
        self.unk_embed = self.get_unk_embed()

    def get_unk_embed(self):
        """use the mean of all word embeddings as the embedding for <unk>"""
        embed_sum = np.zeros(self.embed_dim)
        for word in self.embeds.key_to_index:
            embed_sum += self.embeds[word]
        return embed_sum / len(self.embeds.key_to_index)

    def get_embed_matrix(self):
        """return the embedding matrix"""

        # Initialize the embedding matrix
        embed_matrix = np.zeros((self.vocab_size,self.embed_dim))

        for i, token in enumerate(vocab.get_itos()):
            if token == '<PAD>':
                # use zero vector as the embedding for <PAD>
                continue
            elif token == '<UNK>':
                embed_matrix[i] = self.unk_embed
            elif token in self.embeds.key_to_index:
                embed_matrix[i] = self.embeds[token]
            else:   
                # token is not found in pre-trained embeddings
                # use a normalized random vector to represent it
                rand_vec = np.random.normal(0,1,size=self.embed_dim)
                embed_matrix[i] = rand_vec / rand_vec.max()
        return embed_matrix

In [3]:
import gensim.downloader

word2vec_vectors = gensim.downloader.load('word2vec-google-news-300')
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-300')

In [4]:
word2vec = embedding(word2vec_vectors, vocab)
word2vec_embeds = word2vec.get_embed_matrix()

glove = embedding(glove_vectors, vocab)
glove_embeds = glove.get_embed_matrix()

In [6]:
# save the results
def pkl_save(path,obj):
    with open(path,'wb') as f:
        pkl.dump(obj,f)

word2vec_name = "../processed_data/word2vec.pickle"
glove_name = "../processed_data/glove.pickle"

pkl_save(word2vec_name,word2vec_embeds)
pkl_save(glove_name,glove_embeds)