In [6]:
import numpy as np
import io
import requests
from PIL import Image

In [7]:
__all__ = ["download_image"]


def download_image(img_url):
    """ Fetches an image from the web.

    Parameters
    ----------
    img_url : string
        The url of the image to fetch.

    Returns
    -------
    PIL.Image
        The image."""

    response = requests.get(img_url)
    return Image.open(io.BytesIO(response.content))

In [8]:
def display_images(image_ids, mappings_object):
    for each in image_ids:
        download_image(mappings_object.image_ids_to_urls[each])
    return "Displaying top {} images".format(len(image_ids))

In [9]:
def get_image_embeddings(img_ids, resnet18_features):
    """
    Parameters
    ----------
    Sequence[int]
        N image IDs 
    resnet18_features : Dict[int, np.ndarray]
        img-ID -> shape-(512,) resnet vector
    
    Returns
    -------
    shape-(N, 512)
        An array of the corresponding resnet vectors
    """
    vectors = np.zeros((len(img_ids), 512), dtype=np.float32)
    for n, _id in enumerate(img_ids):
        vectors[n] = resnet18_features[_id]
    return vectors

In [11]:
class ImageSemantics():
    
    def __init__(self):
        self.database = {}
        
    def __repr__(self):
        return "Database of Image Semantics"
    
    def create_database(self, LinearEncoder, image_ids):
        for image_id in image_ids:
            self.database[image_id] = LinearEncoder(get_image_embeddings((img_id,), resnet))
            
    def query_database(caption, num_outs, mappings_object):
        caption_emb = mappings_object.caption_to_emb(caption)
        overlaps = np.matmul(caption_emb, self.database.values())
        overlaps = np.argsort(overlaps)
        k_imgs = overlaps[-1*num_outs:]
        display_images(k_imgs, mappings_object)
            
    def save_database(self, file_path):
        with open(file_path, mode='wb') as opened_file:
            pickle.dump(self.database, opened_file)
        return 'Database Saved'
            
    def load_database(self, file_path):
        with open (file_path, mode='rb') as opened_file:
            self.database = pickle.load(opened_file)
        return 'Database Loaded'