In [None]:
import torch
from Bio import Medline
import os
import json
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification, AutoModelForQuestionAnswering
from opensearchpy import OpenSearch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


In [None]:
if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"
device

In [None]:
host = 'localhost'
port = 9200
auth = ('admin', 'qaOllama2')

# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = auth,
    use_ssl = True,
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False,
)

#index_name = 'pub_med_index'
#client.indices.delete(index=index_name)


In [None]:
pubmed_data_path = "/home/paperspace/pubmed_data"
pubmed_data_preprocessed_path = "/home/paperspace/pubmed_data_preprocessed.json"


In [None]:
records=[]
missed=0
num=0
with open(pubmed_data_path) as stream:
    for article in Medline.parse(stream):
        if not "PMID" in article:
            missed += 1
            continue

        if not "TI" in article:
            missed += 1
            continue

        if not "FAU" in article:
            missed += 1
            continue

        if not "DP" in article:
            missed += 1
            continue

        if not "AB" in article:
            missed += 1
            continue
        num+=1
        records.append(article)

with open(pubmed_data_preprocessed_path, 'w') as f:
    f.write(json.dumps(records))
print(num)

## Embeddings

In [None]:
class PubMedDataset(Dataset):
    def __init__(self, path):
        with open(path, 'r') as f:
          self.data = json.loads(f.read())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]["AB"]
        return sample

In [None]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

In [None]:
dataset = PubMedDataset(pubmed_data_preprocessed_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
embeddings = []
with torch.no_grad():
    for i, sample in enumerate(dataloader):
        inputs = tokenizer(sample, return_tensors="pt", padding=True, truncation=True).to(device)
        out = model(**inputs)
        pooled = mean_pooling(out.last_hidden_state, inputs["attention_mask"]).to(device)
        embeddings.extend(pooled)
embeddings_stacked = torch.stack(embeddings)

In [None]:
torch.save(embeddings_stacked, '/home/paperspace/pubmed_data_embeddings.bin')

In [None]:
#embeddings_stacked=torch.load('/home/paperspace/pubmed_data_embeddings.bin')

## KNN

In [None]:
#index_name = 'pub_med_index'
#client.indices.delete(index=index_name)

In [None]:
# Define an index mapping with a custom analyzer
index_mapping = {
  "settings": {
    "index.knn": True
  },
  "mappings": {
    "properties": {
      "title": {
        "type": "text",
        "analyzer": "standard"
      },
      "TI": {
        "type": "text",
        "analyzer": "standard"
      },
      "AB": {
        "type": "text",
        "analyzer": "standard"
      },
      "vector": {
        "type": "knn_vector",
        "dimension": len(embeddings_stacked[0])  
      }
    }
  }
}


# Create the index with the custom mapping
index_name = "pub_med_index"
client.indices.create(index=index_name, body=index_mapping)

In [None]:
batch_size = 1000
for i in range(0, len(records), batch_size):
    batch = records[i:i + batch_size]
    actions = [ 
    ({"index": {"_index": "pub_med_index",
                "_id":doc["PMID"]}},
     {"TI":doc["TI"],
      "AB":doc["AB"],
      "vector":embeddings_stacked[num].tolist()
    }
    )
    for num,doc in enumerate(batch)]
    request = '\n'.join([f'{json.dumps(item, indent=None, separators=(",", ":"))}' for tpl in actions for item in tpl])
    try:
        response = client.bulk(body=request, refresh=True)
        print("Bulk request successful.")
    except Exception as e:
        print(f"Failed to perform bulk request. Error: {e}")
    

In [None]:
import torch
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
from Bio import Medline
import ollama
from fastapi import FastAPI
import uvicorn

from opensearchpy import OpenSearch
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification, AutoModelForQuestionAnswering



if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"
device

records = {}
missed = 0

with open(pubmed_data_path) as stream:    
    for article in Medline.parse(stream):

        if not "PMID" in article:
            missed += 1
            continue

        if not "TI" in article:
            missed += 1
            continue

        if not "FAU" in article:
            missed += 1
            continue

        if not "DP" in article:
            missed += 1
            continue

        if not "AB" in article:
            missed += 1
            continue
        
        records[article["PMID"]] = article

#model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

# why not take cls token?
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# https://opensearch.org/docs/latest/clients/python-low-level/

host = 'localhost'
port = 9200
auth = ('admin', 'qaOllama2')

# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = auth,
    use_ssl = True,
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False,
)

index_name = 'pub_med_index'

def retrieve_documents(question):
    
    #inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
    #query_outputs = mean_pooling(model(**inputs).last_hidden_state, inputs["attention_mask"]).to(device)
    
    #print(query_outputs)
    
    model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    vector=model.encode([question])[0].tolist()

    # Define the KNN search query
    knn_query = {
        "size": 5,
        "_source": ["TI", "AB"],
        "query": {
            "knn": {
                "vector": {
                    "vector": vector,
                    "k": 5
                }
            }
        }
    }

    # Perform the KNN search
    response = client.search(index=index_name, body=knn_query)
    
    print(response)
    
    return [(res['_id'], res['_score']) for res in response['hits']['hits'][:]]
    

def generate(question: str):
    documents = retrieve_documents(question)
    return {"answer": [f"DOCUMENT-ID: {records[id]['PMID']}\n FULL-AUTHOR: {records[id]['FAU']}\n PUBLICATION-DATE: {records[id]['DP']}\n TEXT: {records[id]['AB']}\n SCORE: {round(score,2)} \n DOCUMENT-TITLE: {records[id]['TI']}" for id,score in documents]}


answer = generate("Why is alcohol bad?")

In [None]:
answer