In [21]:
import torch
#from pinecone import Pinecone
#from sentence_transformers import SentenceTransformer
from Bio import Medline
import os
import json
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification, AutoModelForQuestionAnswering
#import ollama
#from fastapi import FastAPI
#import uvicorn

In [75]:
from opensearchpy import OpenSearch

host = 'localhost'
port = 9200
auth = ('admin', 'admin')


# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = auth,
    use_ssl = True,
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False,
)

index_name = 'pub_med_index'
client.indices.delete(index=index_name)


{'acknowledged': True}

In [None]:
pubmed_data_path = "/home/chris/University/NLP_project/pubmed_data.json"
pubmed_data_preprocessed_path = "/home/paperspace/pubmed_data_preprocessed.json"


In [None]:
records=[]
missed=0
num=0
with open("pubmed_data") as stream:
    for article in Medline.parse(stream):
        if not "PMID" in article:
            missed += 1
            continue

        if not "TI" in article:
            missed += 1
            continue

        if not "FAU" in article:
            missed += 1
            continue

        if not "DP" in article:
            missed += 1
            continue

        if not "AB" in article:
            missed += 1
            continue
        num+=1
        records.append(article)

with open(pubmed_data_preprocessed_path, 'w') as f:
    f.write(json.dumps(records))
print(num)

## Embeddings

In [None]:
class PubMedDataset(Dataset):
    def __init__(self, path):
        with open(path, 'r') as f:
          self.data = json.loads(f.read())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]["text"]
        return sample

In [None]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

In [None]:
dataset = PubMedDataset(pubmed_data_preprocessed_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
embeddings = []
with torch.no_grad():
    for i, sample in enumerate(dataloader):
        inputs = tokenizer(sample, return_tensors="pt", padding=True, truncation=True).to(device)
        out = model(**inputs)
        pooled = mean_pooling(out.last_hidden_state, inputs["attention_mask"]).to("cpu")
        embeddings.extend(pooled)
embeddings_stacked = torch.stack(embeddings)

In [None]:
torch.save(embeddings_stacked, '/home/paperspace/pubmed_data_embeddings.bin')

In [22]:
embeddings_stacked=torch.load('/home/paperspace/pubmed_data_embeddings.bin')

## KNN

In [76]:
# Define an index mapping with a custom analyzer
index_mapping = {
  "settings": {
    "index.knn": True
  },
  "mappings": {
    "properties": {
      "title": {
        "type": "text",
        "analyzer": "standard"
      },
      "TI": {
        "type": "text",
        "analyzer": "standard"
      },
      "AB": {
        "type": "text",
        "analyzer": "standard"
      },
      "vector": {
        "type": "knn_vector",
        "dimension": len(embeddings_stacked[0])  
      }
    }
  }
}


# Create the index with the custom mapping
index_name = "pub_med_index"
client.indices.create(index=index_name, body=index_mapping)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'pub_med_index'}

In [78]:
actions = [ 
    ({"index": {"_index": "pub_med_index",
                "_id":doc["PMID"]}},
     {"TI":doc["TI"],
      "AB":doc["AB"],
      "vector":embeddings_stacked[num].tolist()
    }
    )
    for num,doc in enumerate(records[0:10])
]

"""'PMID': 
'STAT': 
'DRDT': 
'CTDT':
'PB': 
'DP': 
'TI': 
'BTI': 
'AB': 
'CI': 
'FED': 
'ED': 
'FAU': 
'AU': 
'AD': 
'LA': 
'PT': 
'PL':
'OTO': 
'OT': 
'EDAT': 
'CRDT': 
'AID':"""

"'PMID': \n'STAT': \n'DRDT': \n'CTDT':\n'PB': \n'DP': \n'TI': \n'BTI': \n'AB': \n'CI': \n'FED': \n'ED': \n'FAU': \n'AU': \n'AD': \n'LA': \n'PT': \n'PL':\n'OTO': \n'OT': \n'EDAT': \n'CRDT': \n'AID':"

In [79]:
request = '\n'.join([f'{json.dumps(item, indent=None, separators=(",", ":"))}' for tpl in actions for item in tpl])

In [80]:
try:
    response = client.bulk(body=request, refresh=True)
    print("Bulk request successful.")
except Exception as e:
    print(f"Failed to perform bulk request. Error: {e}")

Bulk request successful.


In [40]:
records['20301331']

{'PMID': '20301331',
 'STAT': 'Publisher',
 'DRDT': ['20230511'],
 'CTDT': ['19981012'],
 'PB': ['University of Washington, Seattle'],
 'DP': '1993',
 'TI': 'Achondroplasia.',
 'BTI': ['GeneReviews((R))'],
 'AB': "CLINICAL CHARACTERISTICS: Achondroplasia is the most common cause of disproportionate short stature. Affected individuals have rhizomelic shortening of the limbs, macrocephaly, and characteristic facial features with frontal bossing and midface retrusion. In infancy, hypotonia is typical, and acquisition of developmental motor milestones is often both aberrant in pattern and delayed. Intelligence and life span are usually near normal, although craniocervical junction compression increases the risk of death in infancy. Additional complications include obstructive sleep apnea, middle ear dysfunction, kyphosis, and spinal stenosis. DIAGNOSIS/TESTING: Achondroplasia can be diagnosed by characteristic clinical and radiographic findings in most affected individuals. In individuals 

In [81]:
import torch
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
from Bio import Medline
import ollama
from fastapi import FastAPI
import uvicorn

from opensearchpy import OpenSearch
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification, AutoModelForQuestionAnswering



if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"
device

records = {}
missed = 0

with open("pubmed_data") as stream:    
    for article in Medline.parse(stream):

        if not "PMID" in article:
            missed += 1
            continue

        if not "TI" in article:
            missed += 1
            continue

        if not "FAU" in article:
            missed += 1
            continue

        if not "DP" in article:
            missed += 1
            continue

        if not "AB" in article:
            missed += 1
            continue
        
        records[article["PMID"]] = article

#model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

# why not take cls token?
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# https://opensearch.org/docs/latest/clients/python-low-level/

host = 'localhost'
port = 9200
auth = ('admin', 'admin')

# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = auth,
    use_ssl = True,
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False,
)

index_name = 'pub_med_index'

def retrieve_documents(question):
    
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
    query_outputs = mean_pooling(model(**inputs).last_hidden_state, inputs["attention_mask"]).to("cpu")
    print(len(query_outputs[0].tolist()))

    # Define the KNN search query
    knn_query = {
        "size": 1,
        "_source": ["title", "text"],
        "query": {
            "knn": {
                "vector": {
                    "vector": query_outputs[0].tolist(),
                    "k": 1
                }
            }
        }
    }

    # Perform the KNN search
    response = client.search(index=index_name, body=knn_query)
    
    print(response)
    
    return [(res['_id'], res['_score']) for res in response['hits']['hits'][:]]
    

def generate(question: str):
    documents = retrieve_documents(question)
    return {"answer": [f"DOCUMENT-ID: {records[id]['PMID']}\n FULL-AUTHOR: {records[id]['FAU']}\n PUBLICATION-DATE: {records[id]['DP']}\n TEXT: {records[id]['AB']}\n SCORE: {round(score,2)} \n DOCUMENT-TITLE: {records[id]['TI']}" for id,score in documents]}


answer = generate("Why is alcohol bad?")



768
{'took': 50, 'timed_out': False, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}, 'hits': {'total': {'value': 1, 'relation': 'eq'}, 'max_score': 0.07615581, 'hits': [{'_index': 'pub_med_index', '_id': '20301451', '_score': 0.07615581, '_source': {}}]}}


In [82]:
answer

{'answer': ["DOCUMENT-ID: 20301451\n FULL-AUTHOR: ['Scarpa, Maurizio']\n PUBLICATION-DATE: 1993\n TEXT: CLINICAL CHARACTERISTICS: Mucopolysaccharidosis type II (MPS II; also known as Hunter syndrome) is an X-linked multisystem disorder characterized by glycosaminoglycan (GAG) accumulation. The vast majority of affected individuals are male; on rare occasion heterozygous females manifest findings. Age of onset, disease severity, and rate of progression vary significantly among affected males. In those with early progressive disease, CNS involvement (manifest primarily by progressive cognitive deterioration), progressive airway disease, and cardiac disease usually result in death in the first or second decade of life. In those with slowly progressive disease, the CNS is not (or is minimally) affected, although the effect of GAG accumulation on other organ systems may be early progressive to the same degree as in those who have progressive cognitive decline. Survival into the early adult 