In [2]:
# !pip install -U sentence-transformers

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

In [180]:
sentences = ["This is an example sentence", "Each sentence is converted", "This is a demonstration"]

sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embeddings = sentence_model.encode(sentences)
# print(embeddings)

In [181]:
np.dot(embeddings[0], embeddings[1])/ (np.linalg.norm(embeddings[0])* np.linalg.norm(embeddings[1]))

0.40455922

In [182]:
np.dot(embeddings[0], embeddings[2])/ (np.linalg.norm(embeddings[0])* np.linalg.norm(embeddings[2]))

0.475783

In [183]:
def get_cosine_sim(embed1, embed2):
    assert embed1.shape == embed2.shape, \
    f"Both embeddings must be of same shape, got {embed1.shape} and {embed2.shape}"
    return np.dot(embed1, embed2)/ (np.linalg.norm(embed1)*np.linalg.norm(embed2))
    

In [184]:
import pickle
import pprint
pp = pprint.PrettyPrinter(indent=4)
from IPython.display import clear_output
import time
import os.path
from os import path
from IPython.display import Image, display
import random
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')


In [7]:
with open('meme_900k_cleaned_data.pkl', 'rb') as f:
    data = pickle.load(f)

In [20]:
for uuid in tqdm(data['uuid_caption_dic']):
    captions = data['uuid_caption_dic'][uuid]
    captions = [' '.join(caption) for caption in captions]
    captions = [caption.replace('<emp>', '') for caption in captions]
    data['uuid_caption_dic'][uuid] = captions

100%|██████████| 300/300 [00:00<00:00, 573.95it/s]


In [10]:
# meme_embeddings = {}

# for uuid in tqdm(data['uuid_caption_dic']):
#     captions = data['uuid_caption_dic'][uuid]
#     embeddings = model.encode(captions)
#     average = np.mean(embeddings,  axis=0)
#     meme_embeddings[uuid] = average

100%|██████████| 300/300 [03:48<00:00,  1.32it/s]


In [11]:
# with open("meme_embeddings.pkl", 'wb') as f:
#     pickle.dump(meme_embeddings, f)

In [8]:
with open("meme_embeddings.pkl", 'rb') as f:
    meme_embeddings = pickle.load(f)

In [198]:
all_uuids = list(data['uuid_caption_dic'].keys())

niter = 50
random_similarities = []
true_similarities = []
candidate_cap_index = [np.random.choice(len(data['uuid_caption_dic'][uuid])) for uuid in all_uuids]
candidate_captions = [data['uuid_caption_dic'][all_uuids[i]][ind] for i, ind in enumerate(candidate_cap_index)]
print(len(candidate_captions), candidate_captions[0])
for embed_uuid, caption in zip(all_uuids, candidate_captions):
    embedding = sentence_model.encode([caption])[0]
    average_sim = 0
    for i in range(niter):
        random_uuid = embed_uuid
        while random_uuid == embed_uuid:
            random_uuid = np.random.choice(all_uuids)
        random_meme_embedding  = meme_embeddings[random_uuid]
        average_sim += get_cosine_sim(embedding, random_meme_embedding)
    average_sim /= niter
    true_similarity = get_cosine_sim(embedding, meme_embeddings[embed_uuid])
    true_similarities.append(true_similarity)
    random_similarities.append(average_sim)
    

300 jets y u no get good quarterbacks?!


In [199]:
np.mean(random_similarities)

0.2462124640348251

In [200]:
np.mean(true_similarities)

0.37961653

In [19]:
# similarity score 

# now test full user captions for accuracy
import os
import regex as re
import pickle
testing_user_captions = []
dir_path = './memes900k_qa/'
for path in tqdm(os.listdir(dir_path)):
    if os.path.isfile(os.path.join(dir_path, path)):
        if not re.match(r'.*_manual.pkl', path):
            with open(os.path.join(dir_path, path), 'rb') as f:
                dic = pickle.load(f)
                for v in dic['qa'].keys(): 
                    testing_user_captions.append([v, dic['uuid']])

100%|██████████| 152/152 [00:00<00:00, 2611.53it/s]


In [25]:
import random


In [28]:
candidate_user_prompts_and_uuids = random.choices(testing_user_captions, k=500)

In [29]:
candidate_user_prompts = [candidate_tuple[0] for candidate_tuple in candidate_user_prompts_and_uuids]
candidate_uuids = [candidate_tuple[1] for candidate_tuple in candidate_user_prompts_and_uuids]

In [30]:
new_meme_embeddings = {}
for uuid in tqdm(data['uuid_caption_dic']):
    captions = data['uuid_caption_dic'][uuid][:1000]
    embeddings = model.encode(captions)
    average = np.mean(embeddings,  axis=0)
    new_meme_embeddings[uuid] = average

100%|██████████| 300/300 [32:07<00:00,  6.43s/it]


In [31]:
# with open("user_prompt_meme_embeddings.pkl", 'wb') as f:
#     pickle.dump(new_meme_embeddings, f)

In [35]:
candidate_captions = []
for uuid in candidate_uuids:
    caption = np.random.choice(data['uuid_caption_dic'][uuid][1000:3000])
    candidate_captions.append(caption)

In [37]:
user_sims = []
caption_sims = []

for i, uuid in enumerate(candidate_uuids):
    meme_embedding = new_meme_embeddings[uuid]
    caption_embedding = model.encode([candidate_captions[i]])[0]
    user_embedding = model.encode([candidate_user_prompts[i]])[0]
    caption_sims.append(get_cosine_sim(caption_embedding, meme_embedding))
    user_sims.append(get_cosine_sim(user_embedding, meme_embedding))

In [38]:
np.mean(caption_sims)

0.38320404

In [39]:
np.mean(user_sims)

0.3654713

# CLIP Baseline

For every meme caption, we try to find the best matching image through CLIP 

In [22]:
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [123]:
uuid_feature_dic = {}
all_paths = list(data['uuid_image_path_dic'].values())
all_uuids = list(data['uuid_caption_dic'].keys())
for i in tqdm(range(len(all_paths))):
    uuid = all_uuids[i]
    img_path = data['uuid_image_path_dic'][uuid]
    image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
    uuid_feature_dic[uuid] = image_features.squeeze()

100%|██████████| 300/300 [00:04<00:00, 72.33it/s]


In [124]:
img_path

'./memes900k/images/bane-permission-to-die.jpg'

In [125]:
uuid_feature_dic[uuid].shape

torch.Size([512])

In [126]:
all_captions = []
all_uuid_inds = []
for i, uuid in enumerate(data['uuid_caption_dic'].keys()):
    captions = data['uuid_caption_dic'][uuid]
    all_captions.extend(captions)
    all_uuid_inds.extend([i for _ in captions])

In [136]:
list(data['uuid_caption_dic'].keys()) == list(uuid_feature_dic.keys()) # Need to match for getting top k 

True

In [137]:
batch_size = 300

In [28]:
tokenized_caps = []
num_batches = len(all_captions)// batch_size
for i in range(num_batches):
    batch = all_captions[i*batch_size: (i+1)*batch_size]
    batch_tokenized = clip.tokenize(batch).to(device)
    tokenized_caps.append(batch_tokenized)

In [29]:
batch_tokenized

tensor([[49406,   953,   649,  ...,     0,     0,     0],
        [49406,   340,   692,  ...,     0,     0,     0],
        [49406,  2543,  1563,  ...,     0,     0,     0],
        ...,
        [49406,   827,  2943,  ...,     0,     0,     0],
        [49406,   827,   592,  ...,     0,     0,     0],
        [49406,   592,   720,  ...,     0,     0,     0]], device='cuda:0',
       dtype=torch.int32)

In [30]:

text_feats = []
with torch.no_grad():
    for batch_tokenized in tqdm(tokenized_caps):
        batch_text_feats = model.encode_text(batch_tokenized)
        batch_text_feats /= batch_text_feats.norm(dim=-1, keepdim=True)
        text_feats.append(batch_text_feats)

100%|██████████| 3000/3000 [16:29<00:00,  3.03it/s]


In [129]:
image_features = torch.vstack(list(uuid_feature_dic.values())).to(device)

In [130]:
image_features /= image_features.norm(dim=-1, keepdim=True)

In [131]:
new_feats = []
for i,batch_text_feats in enumerate(text_feats):
    new_batch = []
    for j, text_feat in enumerate(batch_text_feats):
        ind = i*batch_size + j
        uuid_ind = all_uuid_inds[ind]
        image_feat = image_features[uuid_ind]
        new_batch.append(image_feat)
    new_batch = torch.vstack(new_batch).to(device)
    new_feats.append(new_batch)

In [132]:
all_inds = []
for i,batch_text_feats in enumerate(text_feats):
    for j, text_feat in enumerate(batch_text_feats):
        ind = i*batch_size + j
        all_inds.append(ind)

In [138]:
for i, batch in enumerate(new_feats):
    assert (batch.norm(dim=-1) == 1).all(), f"Error at {i}"

AssertionError: Error at 50

In [142]:
image_features.norm(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9995, 0.9995, 1.0000, 1.0000,
        1.0000, 0.9995, 1.0000, 1.0000, 1.0000, 1.0000, 0.9995, 1.0000, 1.0000,
        1.0000, 0.9995, 1.0000, 1.0000, 1.0000, 0.9995, 1.0000, 1.0000, 1.0000,
        1.0000, 0.9995, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9995,
        1.0000, 1.0000, 1.0000, 1.0000, 0.9995, 0.9995, 1.0000, 0.9995, 1.0000,
        1.0000, 1.0000, 1.0000, 0.9995, 0.9995, 0.9995, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 0.9995, 0.9995, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 0.9995, 1.0000, 0.9995, 0.9995, 0.9995, 1.0000, 1.0000, 1.0000,
        1.0000, 0.9995, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9995, 1.0000, 1.0000, 1.0000,
        0.9995, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9995,
        0.9995, 1.0000, 0.9995, 1.0000, 

In [167]:
all_values = []
all_indices = []
for batch_text_feats in text_feats:
    similarity = (100.0 * batch_text_feats @ image_features.T).softmax(dim=-1)
    values, indices = similarity.topk(100)
    values = values.cpu().numpy().tolist()
    all_values.extend(values)
    indices = indices.cpu().numpy().tolist()
    all_indices.extend(indices)

In [168]:
def top_k(top_indices, true_indices, k, debug=False):
    assert len(top_indices) and len(true_indices), \
    f"Inputs should have non zero length, got {len(top_indices)} and {len(true_indices)}"
    assert len(top_indices[0]) >= k, f"Length should be atleast {k}, got {len(top_indices[0])}"
    accuracy = 0
    if debug:
        debug_set = set()
    for i in range(len(top_indices)):
        if true_indices[i] in top_indices[i][:k]:
            accuracy += 1
        elif debug and true_indices[i] not in debug_set:
            debug_set.add(true_indices[i])
    if debug:
        print("No matches at:\n")
        print(debug_set)
    accuracy /= len(top_indices)
    return accuracy

In [169]:
len(all_uuid_inds), len(all_indices)

(900000, 900000)

In [170]:
top_k(all_indices, all_uuid_inds, 10)

0.26787666666666665

In [171]:
top_k(all_indices, all_uuid_inds, 5)

0.2111888888888889

In [172]:
top_k(all_indices, all_uuid_inds, 1)

0.09203

In [149]:
# sanity check with image features
all_sanity_values = []
all_sanity_indices = []
# assert (image_features.norm(dim=-1) == 1 ).all()
for i, batch_text_feats in enumerate(new_feats):
#     assert (batch_text_feats.norm(dim=-1) == 1).all(), f"Error at {i}"
    similarity = (100.0 * batch_text_feats @ image_features.T).softmax(dim=-1)
    values, indices = similarity.topk(100)
    values = values.cpu().numpy().tolist()
    all_sanity_values.extend(values)
    indices = indices.cpu().numpy().tolist()
    all_sanity_indices.extend(indices)

# Template Relevance Score


In [201]:
clip_similarities= []
for i in tqdm(range(len(candidate_captions))):
    embedding = sentence_model.encode([candidate_captions[i]])[0]
    ind = i*3000 + candidate_cap_index[i]
    retrieved_meme_uuid = all_uuids[all_indices[ind][0]] # take the first retrieved image
    retrieved_meme_embedding = meme_embeddings[retrieved_meme_uuid]
    sim = get_cosine_sim(embedding, retrieved_meme_embedding)
    clip_similarities.append(sim)

100%|██████████| 300/300 [00:02<00:00, 139.52it/s]


In [202]:
np.mean(clip_similarities)

0.30860755