https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/#sentence-vectors

In [None]:
! wget http://images.cocodataset.org/zips/val2014.zip
! wget https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json
! unzip -q val2014.zip

--2022-07-02 12:27:17--  http://images.cocodataset.org/zips/val2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 54.231.226.161
Connecting to images.cocodataset.org (images.cocodataset.org)|54.231.226.161|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6645013297 (6.2G) [application/zip]
Saving to: ‘val2014.zip’


2022-07-02 12:35:28 (12.9 MB/s) - ‘val2014.zip’ saved [6645013297/6645013297]

--2022-07-02 12:35:28--  https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 29707814 (28M) [text/plain]
Saving to: ‘captions_val2014.json’


2022-07-02 12:35:32 (334 MB/s) - ‘captions_val2014.json’ saved [29707814/29707814]



In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 31.8 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 53.9 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 59.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 14.3 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling 

In [None]:
import torch
from torchvision.datasets import CocoDetection
import torchvision.transforms as transforms
import pickle as pkl

from transformers import BertTokenizer, BertModel

In [None]:
class UniformCocoCaptions(CocoDetection):
    """`Same as torchvision.datasets.CocoCaptions, but 
    always outputs exactly 5 captions (for batch formation)
    """

    def _load_target(self, id: int):
        return ["[CLS] " + ann["caption"] + " [SEP]" for ann in super()._load_target(id)[:1]]

In [None]:
from pathlib import Path
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
import gc

@torch.no_grad()
def compute_bert_coco_embeds(coco_images_dir, coco_captions_file,
                             batch_size=64, save_root=None,
                             save_all_texts=False):

    image_save_dir = Path(save_root) / 'image'
    image_save_dir.mkdir(parents=True, exist_ok=True)
    text_save_dir = Path(save_root) / 'text'
    text_save_dir.mkdir(parents=True, exist_ok=True)    

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    print(f'Using {device}')
    print('Loading BERT')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
    model.to(device)
    model.eval()
    print('Done')
    dataset = UniformCocoCaptions(root='/content/val2014',
                    annFile='captions_val2014.json',
                    transform=transforms.ToTensor())
    print('COCO dataset:\n', dataset)
    print(len(dataset))

    all_text_embeddings = []
    print('Computing embeddings')
    for i in tqdm(range(len(dataset))):
        texts = dataset[i][1][0]

        # Split the sentence into tokens.
        tokenized_text = tokenizer.tokenize(texts)
        # Map the token strings to their vocabulary indeces.
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        
        segments_ids = [1] * len(tokenized_text)

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens]).to(device)
        segments_tensors = torch.tensor([segments_ids]).to(device)

        outputs = model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]
        token_vecs = hidden_states[-2][0]
        sentence_embedding = torch.mean(token_vecs, dim=0)

        all_text_embeddings.append(sentence_embedding)

    torch.save(all_text_embeddings, text_save_dir / 'BERT.pt')
    print('Done')
    print(f'Text  embeddings: {all_text_embeddings.shape}')
    gc.collect()

In [None]:
coco_images_dir = 'val2014'
coco_captions_file = 'captions_val2014.json'
save_root = 'embeddings/coco_val2014'
batch_size = 256

compute_bert_coco_embeds(coco_images_dir, coco_captions_file,
                           batch_size, save_root)

torch.Size([40504, 768])
