In [1]:
from sklearn.neighbors import NearestNeighbors
import numpy as np
import json
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from sklearn.decomposition import TruncatedSVD

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = 'naver/splade-cocondenser-ensembledistil'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

In [3]:


def get_sparse_vec(text):
    tokens = tokenizer(text, return_tensors='pt')
    output = model(**tokens)
    vec = torch.max(
    torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),dim=1
    )[0].squeeze()

    # print(vec.shape)
    return vec



In [4]:
text = 'Cell Culture MDCK, MDCK-pTR GFP-RasV12, and MDCK-pTR GFP-cSrcY527F cells were cultured as previously described 10 19 .'

get_sparse_vec(text)

tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SqueezeBackward0>)

In [5]:
# Load your data from a JSON file
with open("/home/julio/repos/event_finder/data/pubmed_70s/cg/events_graph.json", "r") as read_file:
    json_data = json.load(read_file)


In [6]:

# Generate the text data by concatenating the node names
data = {}
for graph in json_data:
    id = graph['id']
    text = ' '.join(node['name'] for node in graph['nodes'])
    data[id] = text

In [None]:
# Generate sparse vectors for all texts
# This line takes a lot of RAM and CPUS 
sparse_vectors = {id: get_sparse_vec(text) for id, text in data.items()}


In [7]:
len(data)

222009

In [None]:

# Create a list of vectors and a corresponding list of IDs
vectors = list(sparse_vectors.values())
ids = list(sparse_vectors.keys())


In [None]:
# Use TruncatedSVD to reduce the dimensionality of the vectors
svd = TruncatedSVD(n_components=128, random_state=42)
vectors_svd = svd.fit_transform(vectors)

# Create a NearestNeighbors instance
# nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(vectors)


In [None]:

# Now, suppose you have a new piece of text, and you want to find the
# most similar text in your original data:
query = "query text"

# Generate a sparse vector for the query
query_vector = get_sparse_vec(query)

# Query the model to get the index of the most similar vector
distances, indices = nbrs.kneighbors([query_vector])

# Get the ID of the most similar text
most_similar_id = ids[indices[0][0]]

print("The most similar text to the query is: ", data[most_similar_id])