In [2]:
import os
# import sys

# from src.processing import mean_pooling, mean_pooling_embedding_with_normalization

# sys.path.append('.')

import argparse
import json
import numpy as np

import faiss
import torch
# from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Embeddings Model is a mapping from chunks of max length 512 tokens -> R^768 ???
dim = 768
#normalize embeddings
norm = True
dataset = 'musique'
# unit = 'proposition'

In [4]:
model_label = 'facebook_contriever'

In [5]:
vector_path = f'data/{dataset}/{dataset}_{model_label}_proposition_vectors_norm.npy'
index_path = f'data/{dataset}/{dataset}_{model_label}_proposition_ip_norm.index'
if(os.path.isfile(vector_path)):
    vectors = np.load(vector_path)

In [6]:
vector_path

'data/musique/musique_facebook_contriever_proposition_vectors_norm.npy'

In [7]:
if dataset == 'musique':
    corpus = json.load(open('data/musique_proposition_corpus.json', 'r'))
elif dataset == '2wikimultihopqa':
    corpus = json.load(open('data/2wikimultihopqa_proposition_corpus.json', 'r'))
corpus_contents = []
for item in corpus:
    corpus_contents.append(item['title'] + '\n' + item['propositions'])
print('corpus size: {}'.format(len(corpus_contents)))

corpus size: 11656


In [8]:
total_len = 0
max_len = 0
min_len = 1000000
for line in corpus_contents:
    total_len += len(line)
    if len(line) > max_len:
        max_len = len(line)
    if len(line) < min_len:
        min_len = len(line)
print(max_len / 4)
print(min_len / 4)
print((total_len / len(corpus_contents)) / 4)    

525.25
31.25
149.56335792724778


In [None]:
#create sentence-level embeddings using mean-pooling and normalize to prepare for cosine similarity indexing
#note: UPDATE TO USE distributedDataParallel
def mean_pooling(tokenEmbeddings, paddingInfo):
    tokenEmbeddingsNoPad = tokenEmbeddings.masked_fill(~paddingInfo[...,None].bool(), 0)
    sentenceEmbeddings = tokenEmbeddingsNoPad.sum(dim = 1) / paddingInfo.sum(dim = 1)[...,None]
    return sentenceEmbeddings

def mean_pooling_embedding_with_normalization(batch_str, tokenizer, model, mps_device):
    chunks = tokenizer(batch_str, padding=True, truncation=True, return_tensors='pt').to(mps_device)
    embeddings_metadata = model(**chunks)
    sentenceEmbeddings = mean_pooling(embeddings_metadata[0], embeddings_metadata['attention_mask'])
    sentenceEmbeddingsNorm = sentenceEmbeddings.divide(torch.linalg.norm(sentenceEmbeddings,dim = 1)[...,None])
    return sentenceEmbeddingsNorm

In [14]:
if os.path.isfile(vector_path):
    print('Loading existing vectors:', vector_path)
    vectors = np.load(vector_path)
    print('Vectors loaded:', len(vectors))
else:
    # load model
    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
    model = AutoModel.from_pretrained('facebook/contriever')
    # Check if multiple GPUs are available and if so, use them all
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                "and/or you do not have an MPS-enabled device on this machine.")
    else:
        # print("device available")
        mps_device = torch.device("mps")    
        model.to(mps_device)
        model = torch.nn.DataParallel(model)
    #test batch size = 16 and batch size = 32 
    batch_size = 16
    vectors = np.zeros((len(corpus_contents), dim))
    #get batch_size number of entries from corpus_contents, tokenize and embed them in 768 dimensional space
    for idx in range(0, len(corpus_contents), batch_size):
        end_idx = min(idx + batch_size, len(corpus_contents))
        seqs = corpus_contents[idx:end_idx]
        try:
            batch_embeddings = mean_pooling_embedding_with_normalization(seqs, tokenizer, model, mps_device)
        except Exception as e:
            batch_embeddings = torch.zeros((len(seqs), dim))
            print(f'Error at {idx}:', e)
        #what happens if I don't use detach?
        vectors[idx:end_idx] = batch_embeddings.to('cpu').numpy()
    print("Type of vectors is {}".format(type(vectors)))
    np.save(vector_path, vectors)
    print('vectors saved to {}'.format(vector_path))

    #using FAISS on CPU (GPU support unavailable for mac)
    if os.path.isfile(index_path):
            print('index file already exists:', index_path)
            print('index size: {}'.format(faiss.read_index(index_path).ntotal))
    else:
        print('Building index...')
        index = faiss.IndexFlatIP(dim)
        vectors = vectors.astype('float32')
        index.add(vectors)

        # save faiss index to file
        faiss.write_index(index, index_path)
        print('index saved to {}'.format(index_path))
        print('index size: {}'.format(index.ntotal))