In [6]:
12 % 11

1

In [19]:
import torch
import numpy as np
import pandas as pd
import os.path

In [2]:
where_data = 'data'
image_name_x_image_filename : dict[str, str] = {} # key: image name; value: image filename
with open(f"{where_data}/templates.txt", 'r') as f:
    for line in f:
        # image_name describes some aspect of an image
        image_name, _, url = line.strip().split('\t')
        image_filename = url.split('/')[-1]
        image_name_x_image_filename[image_name] = image_filename
print(f"{len(image_name_x_image_filename.values())} image files!")

300 image files!


In [4]:
image_name_n_caption : list[tuple[str, str]] = [] # (image name, its caption)
with open(f"{where_data}/captions.txt", 'r') as f:
    for line in f:
        image_name, _, caption = line.strip().split('\t')
        if image_name in image_name_x_image_filename:
            image_name_n_caption.append((image_name, caption))
print(f"{len(image_name_n_caption)} (image_name, caption) pairs!")

900000 (image_name, caption) pairs!


In [5]:
image_name_n_caption

[('Y U No', 'commercial <sep> y u no same volume as show!?'),
 ('Y U No', 'Victoria <sep> y u no tell us your secret?!'),
 ('Y U No', 'KONY <sep> Y u no take justin bieber'),
 ('Y U No', 'TED <sep> y u no tell us how you met their mother'),
 ('Y U No', 'Google <sep> Y U NO LET ME FINISH TYPING?'),
 ('Y U No', 'universal remote <sep> y u no work on universe?'),
 ('Y U No', 'pink floyd <sep> y u no need no education?'),
 ('Y U No', 'INTERNET <sep> y u nO LET ME STUDY'),
 ('Y U No',
  'Girl looking for prince charming <sep> Y u no check friend zone?!'),
 ('Y U No', 'i held the door <sep> y u no say thank you'),
 ('Y U No', 'Forever alone! <sep> Y U NO FIND OTHER FOREVER ALONE?'),
 ('Y U No', 'Team rocket <sep> y u no catch a different pikachu?'),
 ('Y U No', 'ASIANs <sep> Y U KNOW?'),
 ('Y U No', 'fox news <sep> Y U no have news about foxes?'),
 ('Y U No', 'Mayans <sep> Y u no finish calendar?'),
 ('Y U No',
  'Ugly girl y u play hard to get <sep> when u already hard to want'),
 ('Y U No'

In [6]:
image_names : list[str] = []
image_filenames : list[str] = []
with open(f"{where_data}/templates.txt", 'r') as f:
    for line in f:
        # image_name describes some aspect of an image
        image_name, _, url = line.strip().split('\t')
        image_filename = url.split('/')[-1]
        image_names.append(image_name)
        image_filenames.append(image_filename)
image_name_n_image_index : dict[str, int] = {name : index \
    for name, index \
    in zip(image_names, np.arange(len(image_names)))}

In [34]:
if os.path.isfile('caption_embeddings.pt'):
    print("loading")
    image_index_TO_caption_index = np.load('image_index_TO_caption_index.npy', allow_pickle=True).item()
else:
    print("making")
    image_index_TO_caption_index = {index : [] \
        for index in np.arange(len(image_names)) #  : dict[int, NDArray]
    }
    captions : list[str] = []
    with open(f"{where_data}/captions.txt", 'r') as f:
        for caption_index, line in enumerate(f):
            image_name, _, caption = line.strip().split('\t')
            if image_name in image_name_n_image_index: # DON'T find in list
                captions.append(caption)
                image_index = image_name_n_image_index[image_name]
                image_index_TO_caption_index[image_index].append(caption_index)
    for key in image_index_TO_caption_index.keys():
        image_index_TO_caption_index[key] = np.asarray(image_index_TO_caption_index[key])
    np.save('image_index_TO_caption_index.npy', image_index_TO_caption_index)

loading


In [14]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print("MPS device not found.")

tensor([1.], device='mps:0')


In [42]:
"""All caption embeddings"""
from sentence_transformers import SentenceTransformer
# https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
if os.path.isfile('caption_embeddings.pt'):
    caption_embeddings = torch.load(f='caption_embeddings.pt')
else:
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    model.to(mps_device)
    caption_embeddings = model.encode(captions)
    print(f"{len(caption_embeddings)} embeddings, {caption_embeddings[0].shape} long ea")
    torch.save(obj=caption_embeddings, f='caption_embeddings.pt')
    pass
len(caption_embeddings)

900000

In [46]:
caption_embeddings.shape

(900000, 384)

In [38]:
"""(All) Image Embeddings"""
from PIL import Image
from torchvision import transforms
from torchvision import models
import torch.nn as nn
if os.path.isfile('image_embeddings.pt'):
    image_embeddings = torch.load(f='image_embeddings.pt')
else:
    resnet = models.resnet50(pretrained=True)
    modules = list(resnet.children())[:-1]
    resnet = nn.Sequential(*modules)
    resnet.to(mps_device)
    for p in resnet.parameters():
        p.requires_grad = False
    image_transformation = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), # rearrange PIL image to shape=(C, H, W)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    images = []
    for image_filename in image_filenames:
        image = Image.open(fp=f"{where_data}/images/{image_filename}")
        image = image_transformation(image)
        images.append(image)
    images = torch.stack(images)
    image_embeddings = resnet(images).squeeze()
    torch.save(obj=image_embeddings, f='image_embeddings.pt')

In [47]:
image_embeddings.shape

torch.Size([300, 2048])