## VEGETABLES, VISHNU 

In [1]:
from collections import defaultdict, Counter
import json
import numpy as np
import time
import mygrad as mg
import gensim
from gensim.models.keyedvectors import KeyedVectors
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt
%matplotlib inline

import re, string
punc_regex = re.compile('[{}]'.format(re.escape(string.punctuation)))

In [12]:
path = r"./glove.6B.50d.txt.w2v"
glove = KeyedVectors.load_word2vec_format(path, binary=False)

In [10]:
class Model:
    def __init__(self, glove):
        with open(r'./captions_train2014.json') as json_file:
            self.data = json.load(json_file)
            
        self.caption_ids_to_captions = {cap['id']: cap['caption'] for cap in self.data['annotations']}
        self.image_id_to_caption_id = self.generate_image_to_caption_id()
        
        self.img_ids = tuple(cap['image_id'] for cap in self.data['annotations'])
        self.caption_ids = tuple(cap['id'] for cap in self.data['annotations'])
        self.captions = tuple(cap['caption'] for cap in self.data['annotations'])
        
        self.glove = glove
    
    def generate_image_to_caption_id(self):
        image_id_to_caption_ids_dict = {}
        
        for cap_dict in self.data['annotations']:
            if cap_dict['image_id'] in image_id_to_caption_ids_dict:
                image_id_to_caption_ids_dict[cap_dict['image_id']].append(cap_dict['id'])
            else:
                image_id_to_caption_ids_dict[cap_dict['image_id']] = [cap_dict['id']]
                
        return image_id_to_caption_ids_dict
        
    def tokenize(self, corpus):
        return punc_regex.sub('', corpus).lower().split()

    def to_df(self, captions):
        
        counter = Counter()
        for caption in captions:
            counter.update(set(self.tokenize(caption)))
        return dict(counter)
    
    
    def to_idf(self, captions):
        """ 
        Given the vocabulary, and the word-counts for each document, computes
        the inverse document frequency (IDF) for each term in the vocabulary.

        Parameters
        ----------
        vocab : Sequence[str]
            Ordered list of words that we care about.

        counters : Iterable[collections.Counter]
            The word -> count mapping for each document.

        Returns
        -------
        numpy.ndarray
            An array whose entries correspond to those in `vocab`, storing
            the IDF for each term `t`: 
                               log10(N / nt)
            Where `N` is the number of documents, and `nt` is the number of 
            documents in which the term `t` occurs.
        """
        vishnu = self.to_df(captions)
        return {word : np.log10(len(captions)/cnt + 1) for word, cnt in vishnu.items()}
    
    def normalize(self, array):
        #sqrroot(sum(vectorsquared))
        return (sum(array ** 2)) ** 0.5
        
        
    def caption_to_emb(self, caption, idfs):
        vishnu = sum(self.glove[word] * idfs[word] for word in self.tokenize(caption) if word in self.glove)
        return vishnu/self.normalize(vishnu)
    
    def caption_id_to_emb(self):
        idfs = self.to_idf(self.captions)
        return {caption_id : self.caption_to_emb(self.caption_ids_to_captions[caption_id], idfs) for caption_id in self.caption_ids}

In [13]:
model = Model(glove)

In [14]:
model.caption_id_to_emb()

{48: array([ 7.00062215e-02,  6.71695322e-02, -5.78375161e-02, -1.11350305e-01,
         1.45407930e-01, -1.96619909e-02, -8.34553465e-02, -7.27506578e-02,
        -2.84980033e-02, -3.28689776e-02, -6.53046817e-02, -1.04298443e-03,
         3.22337565e-03,  1.20396294e-01,  4.33437452e-02,  1.53016737e-02,
         5.50576523e-02,  1.08734317e-01,  4.45547303e-05, -2.40555212e-01,
         1.15891956e-02,  7.13599548e-02,  2.60164607e-02, -4.67556454e-02,
        -3.34816473e-03, -2.67075300e-01, -9.44786742e-02,  2.17079595e-01,
         1.39012009e-01, -5.40609844e-02,  7.30890334e-01,  2.14696117e-02,
         1.22305192e-02, -5.78963347e-02,  1.15302810e-02,  2.07695395e-01,
         8.94648060e-02,  2.01999307e-01, -9.50818975e-03, -5.20092063e-02,
         7.56237982e-03,  1.39723793e-02,  4.96433675e-02,  7.05822110e-02,
         4.99773547e-02,  4.36063558e-02, -1.11477792e-01, -1.56357884e-01,
         4.22956757e-02, -6.35025725e-02], dtype=float32),
 67: array([ 0.11150842, 