# CC3M Similarity
Before starting, please download the training and validation splits of CC3M from https://ai.google.com/research/ConceptualCaptions/download

In [1]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import pandas as pd
import numpy as np

import datasets
from datasets import load_dataset
import datasets
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
import torch

import os
import requests
from tqdm import tqdm
import pickle 
from PIL import Image
from PIL import ImageFile                                                      
ImageFile.LOAD_TRUNCATED_IMAGES = True

from glob import glob
import time

In [2]:
val_df = pd.read_csv('Validation_GCC-1.1.0-Validation.tsv', sep='\t', header=None, names=['caption', 'image_url'])
trn_df = pd.read_csv('Train_GCC-training.tsv', sep='\t', header=None, names=['caption', 'image_url'])

In [3]:
val_df.shape, trn_df.shape

((15840, 2), (3318333, 2))

### Test Parallel Fetch

In [4]:
def invalid_images_as_none(batch):
    images = []
    for image_url in batch["image_url"]:
        try:
            image = Image.open(requests.get(image_url, stream=True, timeout=5).raw)
        except Exception:
            image = None
        images.append(image)
    batch["image"] = images
    return batch

In [4]:
dset = datasets.Dataset.from_pandas(val_df)
dset = dset.with_transform(invalid_images_as_none)

In [None]:
%%time
seq_times = []
nd = 256

start_time = time.time() 
for i, batch in enumerate(dset):
    if i == nd:
        break
    end_time = time.time() 
    seq_times.append(end_time - start_time)
    start_time = time.time() 

In [6]:
%%time
bs16_times = []
bs = 16
nd = 1024

loader = DataLoader(dset, batch_size=bs, num_workers=bs, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})

item_count = 0
start_time = time.time() 
for batch in loader:
    item_count += len(batch['caption'])
    
    end_time = time.time() 
    bs16_times.append(end_time - start_time)
    start_time = time.time() 
    
    if item_count >= nd:
        break

CPU times: user 1.92 s, sys: 2.94 s, total: 4.86 s
Wall time: 2min 30s


In [7]:
%%time
bs32_times = []
bs = 32
nd = 1024

loader = DataLoader(dset, batch_size=bs, num_workers=bs, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})

item_count = 0
start_time = time.time() 
for batch in loader:
    item_count += len(batch['caption'])
    
    end_time = time.time() 
    bs32_times.append(end_time - start_time)
    start_time = time.time() 
    
    if item_count >= nd:
        break

CPU times: user 2.37 s, sys: 5.07 s, total: 7.44 s
Wall time: 2min 27s


In [8]:
%%time
bs64_times = []
bs = 64
nd = 1024

loader = DataLoader(dset, batch_size=bs, num_workers=bs, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})

item_count = 0
start_time = time.time() 
for batch in loader:
    item_count += len(batch['caption'])
    
    end_time = time.time() 
    bs64_times.append(end_time - start_time)
    start_time = time.time() 
    
    if item_count >= nd:
        break



CPU times: user 5.57 s, sys: 16.4 s, total: 22 s
Wall time: 3min 58s


In [9]:
%%time
bs128_times = []
bs = 128
nd = 1024

loader = DataLoader(dset, batch_size=bs, num_workers=bs, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})

item_count = 0
start_time = time.time() 
for batch in loader:
    item_count += len(batch['caption'])
    
    end_time = time.time() 
    bs128_times.append(end_time - start_time)
    start_time = time.time() 
    
    if item_count >= nd:
        break

Traceback (most recent call last):
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/Image.py", line 712, in __getstate__
    im_data = self.tobytes()  # load image first
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/Image.py", line 755, in tobytes
    self.load()
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/ImageFile.py", line 288, in load
    raise OSError(msg)
OSError: image file is truncated (26 bytes not processed)


CPU times: user 17 s, sys: 47.7 s, total: 1min 4s
Wall time: 5min 9s


In [10]:
%%time
bs = 128
nb = 64

loader = DataLoader(dset, batch_size=bs, num_workers=bs, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})
image_exist = []

item_count = 0
for batch in loader:
    item_count += len(batch['caption'])
    image_exist += map(lambda x: x is not None, batch['image'])
    if item_count == bs * nb:
        break
sum(image_exist) / len(image_exist)

Traceback (most recent call last):
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/Image.py", line 712, in __getstate__
    im_data = self.tobytes()  # load image first
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/Image.py", line 755, in tobytes
    self.load()
  File "/home/scahyawijaya/anaconda3/envs/env_indot0/lib/python3.10/site-packages/PIL/ImageFile.py", line 288, in load
    raise OSError(msg)
OSError: image file is truncated (26 bytes not processed)


CPU times: user 24.6 s, sys: 1min 16s, total: 1min 41s
Wall time: 6min 52s


0.69384765625

In [34]:
sum(seq_times) * 4

1697.1393938064575

In [28]:
sum(bs16_times), sum(bs16_times) / len(bs16_times), sum(bs16_times[1:])

(125.81549263000488, 1.9658670723438263, 97.05540823936462)

In [29]:
sum(bs32_times), sum(bs32_times) / len(bs32_times), sum(bs32_times[1:])

(87.29814529418945, 2.7280670404434204, 47.26372528076172)

In [30]:
sum(bs64_times), sum(bs64_times) / len(bs64_times), sum(bs64_times[1:])

(148.8310990333557, 9.301943689584732, 77.15501260757446)

In [31]:
sum(bs128_times), sum(bs128_times) / len(bs128_times), sum(bs128_times[1:])

(275.9560577869415, 34.49450722336769, 58.863117694854736)

# Load Model

In [4]:
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32").to('cuda')

  return self.fget.__get__(instance, owner)()


# Load SEA-VQA & CVQA

In [4]:
sea_vqa_dataset = load_dataset('wit543/sea-vqa')
sea_vqa_dataset['indonesia'][0]

{'question': 'What is the primary activity depicted in the image?',
 'choice_a': 'Gathering fruits',
 'choice_b': 'Fishing',
 'choice_c': 'Picking herbs',
 'choice_d': 'Planting trees',
 'correct_answer': 'c',
 'image_path': 'https://ich.unesco.org/img/photo/thumb/16712-HUG.jpg',
 'image_page': 'https://ich.unesco.org/en/photo-pop-up-00973?photoID=16712',
 'copyright': 'Photograph: Ganesh Ahsha Dalila© Dwi Ranny Pertiwi Zarman, Indonesia, 2022'}

In [5]:
cvqa_dataset = load_dataset('afaji/cvqa')

cvqa_sea_subsets = [
    "('Indonesian', 'Indonesia')",
    "('Malay', 'Malaysia')",
    "('Javanese', 'Indonesia')",
    "('Minangkabau', 'Indonesia')",
    "('Sundanese', 'Indonesia')",
    "('Chinese', 'Singapore')"
]
cvqa_dataset_filt = cvqa_dataset['test'].filter(lambda x: x['Subset'] in cvqa_sea_subsets, num_proc=32)
cvqa_dataset_filt[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=154x215>,
 'ID': '5865939224275596310_1',
 'Subset': "('Sundanese', 'Indonesia')",
 'Question': 'Naon prestasina inohong dina gambar?',
 'Translated Question': 'What is the achievement of the figure in the picture?',
 'Options': ['Gubernur Jawa Barat',
  'Wali kota Bandung',
  'Gubernur DKI Jakarta',
  'Wali kota Tasik'],
 'Translated Options': ['Governor of West Java',
  'Mayor of Bandung',
  'Governor of the Special Capital Region of Jakarta',
  'Mayor of Tasik'],
 'Label': -1,
 'Category': 'Public Figure and pop culture',
 'Image Type': 'External',
 'Image Source': 'https://upload.wikimedia.org/wikipedia/commons/0/01/Mayor_of_Bandung_Dada_Rosada.jpg',
 'License': 'Public domain'}

### Extract SEA-VQA

In [7]:
%%time
# Extract Text & Image Features from SEA-VQA

sea_vqa_images_filt = []
sea_vqa_images_embed = []
sea_vqa_caption = []
sea_vqa_culture = []
for key in sea_vqa_dataset.keys():
    for row in tqdm(sea_vqa_dataset[key]):
        try:
            img_opened = Image.open(requests.get(row['image_path'], stream=True).raw)
            sea_vqa_images_embed.append(model.encode(img_opened))
            sea_vqa_images_filt.append(img_opened)
            if row['correct_answer'] in ['a', 'b', 'c', 'd']:
                sea_vqa_caption.append(row['question'] + " " + row['choice_' + row['correct_answer']])
            else:
                sea_vqa_caption.append(row['question'])
            sea_vqa_culture.append(key)
        except:
            print(row)
pickle.dump((sea_vqa_images_filt, sea_vqa_images_embed, sea_vqa_caption, sea_vqa_culture), open('sea_vqa.pkl', 'wb'))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 304/304 [16:08<00:00,  3.18s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 752/752 [30:48<00:00,  2.46s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [02:14<00:00,  1.87s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [05:58<00:00,  1.90s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 153/153 

CPU times: user 25min 8s, sys: 55 s, total: 26min 3s
Wall time: 1h 21min 50s


In [2]:
(sea_vqa_images_filt, sea_vqa_images_embed, sea_vqa_caption, sea_vqa_culture) = pickle.load(open('sea_vqa.pkl', 'rb'))

In [3]:
len(sea_vqa_images_filt), len(sea_vqa_images_embed), len(sea_vqa_caption), len(sea_vqa_culture)

(1999, 1999, 1999, 1999)

### Extract CVQA

In [7]:
%%time
# Extract Text & Image Features from CVQA

cvqa_images_filt = []
cvqa_images_embed = []
cvqa_caption = []
cvqa_culture = []
for row in tqdm(cvqa_dataset_filt):
    try:
        cvqa_images_embed.append(model.encode(row['image']))
        cvqa_images_filt.append(row['image'])
        cvqa_caption.append(row['Translated Question'] + " " + ', '.join(row['Translated Options']))
        cvqa_culture.append(eval(cvqa_dataset_filt[0]['Subset'])[0])
    except:
        print(row)
pickle.dump((cvqa_images_filt, cvqa_images_embed, cvqa_caption, cvqa_culture), open('cvqa.pkl', 'wb'))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1687/1687 [11:04<00:00,  2.54it/s]


CPU times: user 23min 58s, sys: 3min 19s, total: 27min 18s
Wall time: 13min 41s


In [8]:
(cvqa_images_filt, cvqa_images_embed, cvqa_caption, cvqa_culture) = pickle.load(open('cvqa.pkl', 'rb'))

 ### Extract CC3M Validation

In [5]:
%%time
# Extract Text & Image Features from CC3M
bs = 64
cc3m_val_images_embed = []
cc3m_val_images_filt = []
cc3m_val_caption = []

def invalid_images_as_none(batch):
    images = []
    for image_url in batch["image_url"]:
        try:
            image = Image.open(requests.get(image_url, stream=True, timeout=8).raw).convert('RGB')
        except Exception:
            image = None
        images.append(image)
    batch["image"] = images
    return batch

dset = datasets.Dataset.from_pandas(val_df)
dset = dset.with_transform(invalid_images_as_none)

loader = DataLoader(dset, batch_size=bs, num_workers=bs, prefetch_factor=8, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})
for i, batch in tqdm(enumerate(loader)):
    imgs = []
    for i, img in enumerate(batch['image']):
        if img is not None:
            if img.size[0] < 50 or img.size[1] < 50:
                continue
            imgs.append(img)
            cc3m_val_images_filt.append(batch['image_url'][i])
            cc3m_val_caption.append(batch['caption'][i])
    
    img_embeds = model.encode(imgs, batch_size=bs)
    for img_emb in img_embeds:
        cc3m_val_images_embed.append(img_emb)

    if i == len(loader) - 1:
        break

248it [04:40,  1.13s/it]

CPU times: user 14min 36s, sys: 1min 41s, total: 16min 17s
Wall time: 4min 44s





In [6]:
print(len(cc3m_val_images_embed), len(cc3m_val_images_filt), len(cc3m_val_caption), flush=True)
pickle.dump((cc3m_val_images_filt, cc3m_val_images_embed, cc3m_val_caption, []), open('./cc3m_val.pkl', 'wb'))

10978 10978 10978


 ### Extract CC3M Training

In [5]:
%%time
# Extract Text & Image Features from CC3M
bs = 64
cc3m_trn_images_embed = []
cc3m_trn_images_filt = []
cc3m_trn_caption = []

def invalid_images_as_none(batch):
    images = []
    for image_url in batch["image_url"]:
        try:
            image = Image.open(requests.get(image_url, stream=True, timeout=8).raw).convert('RGB')
        except Exception:
            image = None
        images.append(image)
    batch["image"] = images
    return batch

dset = datasets.Dataset.from_pandas(trn_df)
dset = dset.with_transform(invalid_images_as_none)

loader = DataLoader(dset, batch_size=bs, num_workers=bs, prefetch_factor=8, collate_fn=lambda x: {k: [row[k] for row in x] for k in x[0]})
for i, batch in tqdm(enumerate(loader)):
    imgs = []
    for i, img in enumerate(batch['image']):
        if img is not None:
            if img.size[0] < 50 or img.size[1] < 50:
                continue
            imgs.append(img)
            cc3m_trn_images_filt.append(batch['image_url'][i])
            cc3m_trn_caption.append(batch['caption'][i])
    
    img_embeds = model.encode(imgs, batch_size=bs)
    for img_emb in img_embeds:
        cc3m_trn_images_embed.append(img_emb)
        
    if i == len(loader) - 1:
        break

3559it [1:01:44,  1.65it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

9400it [2:28:23,  1.07it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

51849it [13:34:47,  1.06it/s]

CPU times: user 2d 3h 9min 50s, sys: 4h 29min 22s, total: 2d 7h 39min 12s
Wall time: 13h 34min 56s





In [6]:
print(len(cc3m_trn_images_embed), len(cc3m_trn_images_filt), len(cc3m_trn_caption), flush=True)
pickle.dump((cc3m_trn_images_filt, cc3m_trn_images_embed, cc3m_trn_caption, []), open('./cc3m_trn.pkl', 'wb'))

2292954 2292954 2292954
