In [1]:
import pickle as pkl
from transformers import DistilBertTokenizer, DistilBertModel
from sentence_transformers import SentenceTransformer
from tqdm import tqdm as tqdm
import numpy as np
import torch

### Import image, text data

In [2]:
with open("../data/img_text_comb_updated.pkl", "rb") as f:
    data_pkl = pkl.load(f)

In [3]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [9]:
model = model.to(device)

### DistilBERT base embeddings

In [10]:
model.eval()
with torch.no_grad():
    for key in tqdm(data_pkl.keys()):
        sentences = data_pkl[key]["Sentence"]
        encoded_input = tokenizer(list(sentences), return_tensors='pt', truncation=True, padding=True).to(device)
        
        # CLS token at Oth index used as sentence embedding
        embeddings = model(**encoded_input)["last_hidden_state"][:, 0, :]
        data_pkl[key]["text_emb(DistilBERT)"] = embeddings.cpu().numpy()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 342/342 [00:30<00:00, 11.25it/s]


## MiniLM-L6 Sentence transformer Embeddings

In [None]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').to(device)
for key in tqdm(data_pkl.keys()):
    sentences = data_pkl[key]["Sentence"]
    embeddings = model.encode(sentences)
    data_pkl[key]["text_emb(all-MiniLM-L6-v2)"] = embeddings

In [None]:
with open("../data/img_text_comb_updated_v3.pkl", "wb") as f:
    pkl.dump(data_pkl, f)