Code modified from https://github.com/tsb0601/MultiMon/blob/main/scrape.py

In [None]:
import re
from tqdm import tqdm
import os
import json
import numpy as np
from datetime import datetime
import torch
from collections import defaultdict
import torch.nn.functional as F
import torch
from PIL import Image
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
import torch
from typing import Any, Callable, Dict, List, Optional, Union
from transformers import AutoTokenizer, BertModel
import torch
import pandas as pd
import numpy as np
import csv 
from pycocotools.coco import COCO
from sentence_transformers import SentenceTransformer
import urllib.request
import zipfile

In [None]:
unique_rows = set()

In [None]:
def load_captions(annotations_path):
    coco = COCO(annotations_path)

    img_ids = coco.getImgIds()

    all_captions = []
    for img_id in img_ids:
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)
        img_captions = [ann['caption'].lower() for ann in anns]
        all_captions.extend(img_captions)

    return all_captions

def write_unique_rows(row, writer):
    
    key1 = (row[2], row[3])
    key2 = (row[0], row[1])
    key3 = (row[1], row[0])

    if (key1 not in unique_rows) and (key2 not in unique_rows) and (key3 not in unique_rows):
        unique_rows.add(key1)
        unique_rows.add(key2)
        unique_rows.add(key3)
        
        writer.writerow(row)

        return True

    return False

In [None]:
data_dir=''
url="http://images.cocodataset.org/annotations/annotations_trainval2017.zip"

os.makedirs(data_dir, exist_ok=True)
file_path = os.path.join(data_dir, 'annotations.zip')

if not os.path.exists(file_path):
    print("Downloading the annotations...")
    urllib.request.urlretrieve(url, file_path)

if not os.path.exists(os.path.join(data_dir, 'annotations')):
    print("Extracting the annotations...")
    with zipfile.ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(file_path)

In [None]:
def load_coco():

    annotations_train_path = os.path.join(data_dir, 'annotations', 'captions_train2017.json')
    annotations_val_path = os.path.join(data_dir, 'annotations', 'captions_val2017.json')

    all_captions_train = load_captions(annotations_train_path)
    all_captions_val = load_captions(annotations_val_path)

    all_captions = all_captions_train + all_captions_val

    print(f"Total number of captions (train): {len(all_captions_train)}")
    print(f"Total number of captions (val): {len(all_captions_val)}")
    print(f"Total number of captions (train + val): {len(all_captions)}")

    return all_captions

In [None]:
clip_model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14')

tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [None]:
premises=load_coco()
similarity_threshold=0.9
batch_size=1024
num_premises=20*batch_size
corpus_data='COCO'

# Compute the embeddings for each batch of premises
bert_text_embeds_prompts = []
for i in tqdm(range(0, num_premises, batch_size)): #len(premises)
    premises_batch = premises[i:i+batch_size]
    with torch.no_grad():
        text_embeds_prompts_batch = bert_model.encode(premises_batch)

    text_embeds_prompts_batch = torch.from_numpy(text_embeds_prompts_batch)
    text_embeds_prompts_batch = F.normalize(text_embeds_prompts_batch, dim=1)

    bert_text_embeds_prompts.append(text_embeds_prompts_batch)

# Concatenate the embeddings for all batches
bert_text_embeds_prompts = torch.cat(bert_text_embeds_prompts, dim=0)
torch.save(bert_text_embeds_prompts,f'{data_dir}\\bert_text_embeds_prompts.pt')

# split the premises into batches
premises_batches = [premises[i:i+batch_size] for i in range(0, num_premises, batch_size)]

# compute the embeddings for each batch of premises
text_embeds_prompts = torch.zeros(num_premises, 768)
for i, premises_batch in enumerate(tqdm(premises_batches)):
    tok = tokenizer(premises_batch, return_tensors="pt", padding=True, truncation=True)
    
    with torch.no_grad():
        text_outputs = clip_model.text_model(**tok)
    text_embeds = text_outputs[1]
    text_embeds = clip_model.text_projection(text_embeds)
    text_embeds_prompt = F.normalize(text_embeds, dim=1)
    start_idx = i * batch_size
    end_idx = min(start_idx + batch_size, num_premises)
    text_embeds_prompts[start_idx:end_idx, :] = text_embeds_prompt
torch.save(text_embeds_prompts,f'{data_dir}\\clip_text_embeds_prompts.pt')

similar_pairs = []

for i in tqdm(range(0, num_premises, batch_size)):
    batch_premises = premises[i:i+batch_size]
    batch_text_embeds_prompts = text_embeds_prompts[i:i+batch_size]
    bert_batch_text_embeds_prompts = bert_text_embeds_prompts[i:i+batch_size]
    
    similarity_matrix = torch.matmul(batch_text_embeds_prompts, text_embeds_prompts.t())
    bert_similarity_matrix = torch.matmul(bert_batch_text_embeds_prompts, bert_text_embeds_prompts.t())
    
    mask = (similarity_matrix > similarity_threshold) & (abs(similarity_matrix - bert_similarity_matrix) > 0.2)

    j_indices, k_indices = mask.nonzero(as_tuple=True)

    for j, k in zip(j_indices.tolist(), k_indices.tolist()):
        similarity_score = similarity_matrix[j, k].item()
        bert_similarity_score = bert_similarity_matrix[j, k].item()
        similar_pairs.append((batch_premises[j], premises[k],i+j,k,similarity_score,bert_similarity_score))

file_path = f'similar_from_{corpus_data}2.csv'
with open(file_path, mode='w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    csv_writer.writerow(['Sample 1','Sample 2','Sample 1 Index','Sample 2 Index','CLIP Similarity Score','BERT Similarity Score'])

    negative_keywords = ["there is no", "unable", "does not", "do not", "am not", "no image", "no picture"]

    similar_pairs.sort(key=lambda x: x[2], reverse=True)

    num_written = 0

    for pair in tqdm(similar_pairs):

            if not any(keyword in field for field in pair[:2] for keyword in negative_keywords):

                prompt1, prompt2 = pair[0], pair[1]
                
                is_unique = write_unique_rows(pair, csv_writer)
                if is_unique:
                    num_written += 1