In [None]:
%pip install -U sentencepiece
%pip install -U transformers
%pip install -U bitsandbytes
%pip install -U accelerate
%pip install -U huggingface_hub
%pip install -U Biopython
%pip install -U ollama

In [None]:
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel, pipeline
from torch.utils.data import DataLoader
from Bio import Medline

In [None]:
if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"
device

## Load Data
Download pubmed data using linux command line tools:
`esearch -db pubmed -query "intelligence[tiab]" -mindate 2013 -maxdate 2023 | efetch -format medline > ./pubmed_data`

## Preprocessing

In [None]:
preprocessed_records = []
missed = 0

with open("pubmed_data") as stream:
    for article in Medline.parse(stream):

        if not "PMID" in article:
            missed += 1
            continue

        if not "TI" in article:
            missed += 1
            continue

        if not "AB" in article:
            missed += 1
            continue

        article = {
            "id": article["PMID"],
            "title": article["TI"],
            "text": article["AB"]
        }
         
        preprocessed_records.append(article)

In [None]:
len(preprocessed_records)

In [None]:
missed

In [None]:
json_object = json.dumps(preprocessed_records)
with open("pubmed_data_preprocessed.json", "w") as outfile:
    outfile.write(json_object)

In [None]:
class PubMedDataset(Dataset):
    def __init__(self, path):
        with open(path, 'r') as f:
          self.data = json.loads(f.read())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]["text"]
        return sample

## Embedding

In [None]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

In [None]:
dataset = PubMedDataset('./pubmed_data_preprocessed.json')
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
# why not take cls token?
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
embeddings = []
with torch.no_grad():
    for i, sample in enumerate(dataloader):
        inputs = tokenizer(sample, return_tensors="pt", padding=True, truncation=True).to(device)
        out = model(**inputs)
        pooled = mean_pooling(out.last_hidden_state, inputs["attention_mask"]).to("cpu")
        embeddings.extend(pooled)
embeddings_stacked = torch.stack(embeddings)

In [None]:
torch.save(embeddings, "pubmed_data_embeddings.pt")

## Question Answering
1. summarize relevant papers
2. answer question

In [None]:
pipe_qa = pipeline("question-answering", model="deepset/roberta-base-squad2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

In [None]:
question = "What is the influence of alcohol on minors?"

inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
query_outputs = mean_pooling(model(**inputs).last_hidden_state, inputs["attention_mask"]).to("cpu")

sim = torch.cosine_similarity(embeddings_stacked, query_outputs)
sorted = torch.argsort(sim, descending=True)

context = ""
for i in range(2):
    index = sorted[i]
    text = preprocessed_records[index]["text"]
    context += summarizer(text, max_length=100, min_length=50, do_sample=False)[0]["summary_text"]



In [None]:
pipe_qa({"context": context, "question": question})

## Answer Extraction
1. find sentences similar to question
2. summarize similar sentences => answer


In [None]:
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base")

In [None]:
question = "What is the influence of alcohol on minors?"

inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
query_outputs = mean_pooling(model(**inputs).last_hidden_state, inputs["attention_mask"]).to("cpu")

sim = torch.cosine_similarity(embeddings_stacked, query_outputs)
sorted = torch.argsort(sim, descending=True)

sentences = []
for i in range(3):
    index = sorted[i]
    text = preprocessed_records[index]["text"]
    sentences.extend(text.split(". "))

sentences_tokens = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)

out = model(**sentences_tokens)
embeddings = mean_pooling(out.last_hidden_state, sentences_tokens["attention_mask"]).to("cpu")

sim = torch.cosine_similarity(embeddings, query_outputs)
sorted = torch.argsort(sim, descending=True)

relevant_text = ""
for i in range(2):
    index = sorted[i]
    relevant_text += sentences[index] + ". "

summary = summarizer(relevant_text + " " + question, max_length=50, min_length=10, do_sample=False)[0]["summary_text"]
summary

In [None]:
relevant_text

## Natural Question Answering (llama)


In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf")

In [None]:
question = "How can human intelligence be defined?"

inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
query_outputs = mean_pooling(model(**inputs).last_hidden_state, inputs["attention_mask"]).to("cpu")

sim = torch.cosine_similarity(embeddings_stacked, query_outputs)
sorted = torch.argsort(sim, descending=True)

context = ""
for i in range(4):
    index = sorted[i]
    text = preprocessed_records[index]["text"]
    context += f"""DOCUMENT-ID: {preprocessed_records[index]["id"]}
               DOCUMENT-TEXT: {summarizer(text, max_length=100, min_length=50, do_sample=False)[0]["summary_text"]}
"""

prompt = f"""ANSWER the following QUESTION soley based on the CONTEXT given. Cite the DOCUMENT-ID in your ANSWER where appropriate otherwise dont explicitly mention that you answer based on the context.

CONTEXT:
{context}

QUESTION:
{question}

ANSWER:
"""

answer = pipe(prompt)
answer[0]["generated_text"][len(prompt):]