In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

from bef.vector_utils import compute_kernel_bias, transform_and_normalize
from faiss_utils.faiss_indexes import create_index_flat
from sentence_transformers import SentenceTransformer
import argparse
import torch
import json
import numpy as np
import os
import faiss
import warnings
#ignoring sklearn deprecation warnings


def warn(*args, **kwargs):
    pass


warnings.warn = warn


In [9]:
class EventsDataset(torch.utils.data.Dataset):

  def __init__(self, graphs):
      'Initialization'
      self.graphs = graphs

  def __len__(self):
      'Denotes the total number of samples'
      return len(self.graphs)

  def __getitem__(self, index):
      'Generates one sample of data'
      # Select sample

      return self.graphs[index]


In [6]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--graphfile", type=str, default='/home/julio/repos/event_finder/data/pubmed_full/results/cg/graph_pubmed_2000.json',
                        help="The Json file of graphs ")
    parser.add_argument("--index", type=str, default='/home/julio/repos/event_finder/data/pubmed_full/results/cg/faiss_pubmedd_2000.index',
                        help="The name for the output embedding file")

    parser.add_argument("--sbertmodel", type=str, default='all-MiniLM-L12-v2',
                        help="https://www.sbert.net/docs/pretrained_models.html")

    return parser.parse_args('')


In [7]:
args = parse_args()

In [10]:
REDU_DIM = 128
 # DIR = args.indir
graphs = []
print('=======LOADING EVENTS JSON=======')
with open(args.graphfile) as ff:
    for g in json.load(ff):
        graphs.append(' '.join(
            [n['name'].rstrip() for n in g['nodes']]
        ))

print('=======LOADING MODEL=======')
BERT_MODEL = args.sbertmodel
model = SentenceTransformer(BERT_MODEL)

device = torch.device("cuda")
# Choose whatever GPU device number you want
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
model.to(device)
# Parameters
params = {'batch_size': 2048,
            'shuffle': False,
            'num_workers': 1}
# Generators
graphs_set = EventsDataset(graphs)
graphs_generator = torch.utils.data.DataLoader(graphs_set, **params)

embeddings = []
print('=======ENCODING..=======')
for graph_batch in graphs_generator:
    # obtain encoded tokens
    with torch.no_grad():
        outputs = model.encode(graph_batch)

    embeddings.append(outputs)
    #dbug
    break




In [16]:
embeddings[0].shape

(2048, 384)

In [13]:
embs = np.concatenate(embeddings)


In [14]:
embs.shape

(2048, 384)

In [17]:
d = embs.shape[1]
print('Original dimention:', d)
print('Reducing dimention to: ', REDU_DIM)
kernel, bias = compute_kernel_bias(embs)
kernel = kernel[:, :REDU_DIM]
embs = transform_and_normalize(embs, kernel, bias)


Original dimention: 384
Reducing dimention to:  128


In [18]:
print('++++++CREATING FAISS INDEX==============')

index = create_index_flat(embs, d)

# index_file = os.path.join('out', ONTOLOGY, ONTOLOGY+'.index')
faiss.write_index(index,  args.index)


Loading embeddings..
