In [1]:
import torch
import numpy as np
import pandas as pd
import os.path

In [2]:
"""load image names & filenames"""
where_data = 'data'
image_names : list[str] = []
image_filenames : list[str] = []
with open(f"{where_data}/templates.txt", 'r') as f:
    for line in f:
        # image_name describes some aspect of an image
        image_name, _, url = line.strip().split('\t')
        image_filename = url.split('/')[-1]
        image_names.append(image_name)
        image_filenames.append(image_filename)
image_name_n_image_index : dict[str, int] = {name : index \
    for name, index \
    in zip(image_names, np.arange(len(image_names)))}

In [8]:
"""Mapping image index to indices of matching captions"""
if os.path.isfile('image_index_TO_caption_index.npy'):
    print("loading")
    # need .item() to get dictionary back! (otherwise it's an array of objects)
    image_index_TO_caption_index = np.load('image_index_TO_caption_index.npy', allow_pickle=True).item()
else:
    print("making")
    # create dictionary MAP of image index TO indices of matching captions
    image_index_TO_caption_index = {index : [] \
        for index in np.arange(len(image_names)) #  : dict[int, NDArray]
    }
    captions : list[str] = []
    with open(f"{where_data}/captions.txt", 'r') as f:
        for caption_index, line in enumerate(f):
            image_name, _, caption = line.strip().split('\t')
            if image_name in image_name_n_image_index: # DON'T find in list
                captions.append(caption)
                image_index = image_name_n_image_index[image_name]
                image_index_TO_caption_index[image_index].append(caption_index)
    for key in image_index_TO_caption_index.keys():
        image_index_TO_caption_index[key] = np.asarray(image_index_TO_caption_index[key])
    np.save('image_index_TO_caption_index.npy', image_index_TO_caption_index)

loading


In [14]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print("MPS device not found.")

tensor([1.], device='mps:0')


In [5]:
"""All caption embeddings"""
from sentence_transformers import SentenceTransformer
# https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
if os.path.isfile('caption_embeddings.pt'):
    caption_embeddings = torch.load(f='caption_embeddings.pt')
else:
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    model.to(mps_device)
    caption_embeddings = model.encode(captions)
    print(f"{len(caption_embeddings)} embeddings, {caption_embeddings[0].shape} long ea")
    torch.save(obj=caption_embeddings, f='caption_embeddings.pt')
    pass
len(caption_embeddings)
print(type(caption_embeddings))
caption_embeddings.shape

<class 'torch.Tensor'>


torch.Size([900000, 384])

In [38]:
"""(All) Image Embeddings"""
from PIL import Image
from torchvision import transforms
from torchvision import models
import torch.nn as nn
if os.path.isfile('image_embeddings.pt'):
    image_embeddings = torch.load(f='image_embeddings.pt')
else:
    resnet = models.resnet50(pretrained=True)
    modules = list(resnet.children())[:-1]
    resnet = nn.Sequential(*modules)
    resnet.to(mps_device)
    for p in resnet.parameters():
        p.requires_grad = False
    image_transformation = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), # rearrange PIL image to shape=(C, H, W)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    images = []
    for image_filename in image_filenames:
        image = Image.open(fp=f"{where_data}/images/{image_filename}")
        image = image_transformation(image)
        images.append(image)
    images = torch.stack(images)
    image_embeddings = resnet(images).squeeze()
    torch.save(obj=image_embeddings, f='image_embeddings.pt')
image_embeddings.shape