In [2]:
import os
from pathlib import Path

import pandas as pd
import numpy as np
from pickle import dump, load
import cv2
import matplotlib.pyplot as plt

from tqdm import tqdm

In [3]:
import spacy
from collections import Counter
import gensim.downloader

glove_emb = gensim.downloader.load('word2vec-google-news-300')
nlp = spacy.load('en_core_web_sm', disable = ['ner', 'parser'])



In [18]:
LABELS_PATH = '../data/MSVD_label_final.csv'
label_final_df = pd.read_csv(LABELS_PATH)


In [19]:
all_sent = label_final_df['caption'].tolist()#.astype('unicode')

In [20]:
wc = Counter()
try:
    for doc in nlp.pipe(all_sent):
        for word in doc:
            #print(word)
            wc[str(word)] += 1
except Exception as e:
    print(e)
    print(doc,'\nword:', word)

In [21]:
len(wc)

4624

num unique words: 6157<br>
only >5 occurence words: 1259

In [22]:
top_N = 1500

In [23]:
# initialization
EMBEDDING_SIZE = 300
embedding = np.zeros((len(wc.most_common(top_N))+4, 300)) # +4 for start, end, unk, padding
word2idx = {}
idx2word = {}

word2idx['<PAD>'] = 0
idx2word[0] = '<PAD>'
embedding[0] = np.random.normal(0, 0.2, 300)

word2idx['<START>'] = 1
idx2word[1] = '<START>'
embedding[1] = np.random.normal(0, 0.2, 300)

word2idx['<END>'] = 2
idx2word[2] = '<END>'
embedding[2] = np.random.normal(0, 0.2, 300)

word2idx['<UNK>'] = 3
idx2word[3] = '<UNK>'
embedding[3] = np.random.normal(0, 0.2, 300)

In [24]:
count = 0
for word, _ in wc.most_common(top_N):
    wid = len(word2idx)
    word2idx[word] = wid
    idx2word[wid] = word
    if word in glove_emb:
        embedding[wid] = glove_emb.get_vector(word)
    else:
        embedding[wid] = np.random.normal(0, 0.1, 300) # random initialisation (-1, 1)
        count += 1

In [25]:
print(f'{count} words are not in google news word2vec')

13 words are not in google news word2vec


In [26]:
# Save embeddings matrix
np.save('../model/MSVD_embedding.npy', embedding)
import pickle
with open("../model/MSVD_word2idx.pkl","wb") as f:
    pickle.dump(word2idx, f)
with open("../model/MSVD_idx2word.pkl","wb") as f:
    pickle.dump(idx2word, f)