# Motivation

Perhaps a more practice-based approach could be more fruitful. can use textbook & images as RAG dataset.

# Dependencies

In [2]:
!pip install -U datasets pinecone-client sentence-transformers torch

Collecting datasets
  Using cached datasets-2.14.6-py3-none-any.whl.metadata (19 kB)
Collecting pinecone-client
  Using cached pinecone_client-2.2.4-py3-none-any.whl.metadata (7.8 kB)
Collecting sentence-transformers
  Using cached sentence-transformers-2.2.2.tar.gz (85 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting torch
  Using cached torch-2.1.0-cp311-cp311-manylinux1_x86_64.whl.metadata (25 kB)
Collecting loguru>=0.5.0 (from pinecone-client)
  Using cached loguru-0.7.2-py3-none-any.whl.metadata (23 kB)
Collecting dnspython>=2.0.0 (from pinecone-client)
  Using cached dnspython-2.4.2-py3-none-any.whl.metadata (4.9 kB)
Collecting scikit-learn (from sentence-transformers)
  Using cached scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting scipy (from sentence-transformers)
  Using cached scipy-1.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting nltk (from sentence-transformers

Collecting joblib (from nltk->sentence-transformers)
  Downloading joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn->sentence-transformers)
  Downloading threadpoolctl-3.2.0-py3-none-any.whl.metadata (10.0 kB)
Downloading datasets-2.14.6-py3-none-any.whl (493 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m[31m2.5 MB/s[0m eta [36m0:00:01[0m
[?25hDownloading pinecone_client-2.2.4-py3-none-any.whl (179 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.4/179.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m01[0m
[?25hDownloading dnspython-2.4.2-py3-none-any.whl (300 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m300.4/300.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hDownloading loguru-0.7.2-py3-none-any.whl (62

# Dataset

In [3]:
from datasets import load_dataset

# load the dataset from huggingface in streaming mode and shuffle it
radiology_data = load_dataset(
    path = 'Ka4on/radiology',
    split='train',
    streaming=True
).shuffle(seed=960)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# streaming mode allows us to iterate over the dataset without needing to download it
# show the contents of a single document in the dataset
next(iter(radiology_data))

{'instruction': 'Generate impression based on medical findings.',
 'input': 'Dysarthria. There is no evidence of intracranial hemorrhage, mass, or acute infarct. There are mild scattered foci of cerebral white matter T2 hyperintensity. There is diffuse cerebral volume loss, which is most pronounced in the medial temporal lobes. There is no midline shift or herniation. The major cerebral flow voids are intact. The orbits, skull, paranasal sinuses, and scalp soft tissues are grossly unremarkable.',
 'output': '1. Nonspecific mild scattered foci of cerebral white matter T2 hyperintensity may represent chronic small vessel ischemic disease. Otherwise, no evidence of acute infarction.2. Diffuse cerebral volume loss, which is most pronounced in the medial temporal lobes, which may represent Alzheimer disease in the appropriate clinical setting. '}

In [5]:
from tqdm.auto import tqdm

total_prognosis_count = 50000

counter = 0
docs = []

for d in tqdm(radiology_data, total=total_prognosis_count):
    # extract the fields we need
    doc = {
        "input": d["input"],
        "output": d["output"],
    }
    docs.append(doc)

    # stop iteration once we reach 50k
    if counter == total_prognosis_count:
        break
    counter += 1

100%|██████████████████████████████████| 50000/50000 [00:04<00:00, 10365.67it/s]


In [6]:
import pandas as pd

df = pd.DataFrame(docs)
df.head()

Unnamed: 0,input,output
0,Dysarthria. There is no evidence of intracrani...,1. Nonspecific mild scattered foci of cerebral...
1,Male 8 years old Reason: Ao root dilatation Le...,1. Status post arterial switch operation.2. No...
2,Pituitary adenoma status post TSH in 11/2013: ...,Interval evolution postoperative findings rela...
3,"History of neuroblastoma of lumbar spine, rela...",1. Postoperative findings related to laminecto...
4,"Encephalopathy: confusion, encephalopathy. Man...",Scattered chronic infarcts and probable chroni...


# Initialize Pinecone idx

In [7]:
import pinecone

# connect to pinecone environment
pinecone.init(
    api_key="05ce3e92-d0da-4ef4-9e3b-2a1ad822689b",
    environment="us-east1-gcp"  # find next to API key in console
)

In [8]:
index_name = "qa"

# check if the abstractive-question-answering index exists
if index_name not in pinecone.list_indexes():
    # create the index if it does not exist
    pinecone.create_index(
        index_name,
        dimension=768,
        metric="cosine"
    )

# connect to abstractive-question-answering index we created
index = pinecone.Index(index_name)

# Initialize Retriever

##### Retriever Tasks:

- Generate embeddings for all historical passages (context vectors/embeddings)
- Generate embeddings for our questions (query vector/embedding)

The retriever will create embeddings such that the questions and passages that hold the answers to our queries are close to one another in the vector space. 

Uses SentenceTransformer model based on Microsoft's MPNet as our retriever.

In [9]:
import torch
from sentence_transformers import SentenceTransformer

# set device to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# load the retriever model from huggingface model hub
retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
retriever

  return torch._C._cuda_getDeviceCount() > 0


cpu


Downloading (…)e933c/.gitattributes: 100%|█████| 737/737 [00:00<00:00, 4.39MB/s]
Downloading (…)_Pooling/config.json: 100%|█████| 190/190 [00:00<00:00, 1.26MB/s]
Downloading (…)cbe6ee933c/README.md: 100%|█| 9.85k/9.85k [00:00<00:00, 28.8MB/s]
Downloading (…)e6ee933c/config.json: 100%|█████| 591/591 [00:00<00:00, 3.17MB/s]
Downloading (…)ce_transformers.json: 100%|██████| 116/116 [00:00<00:00, 843kB/s]
Downloading (…)33c/data_config.json: 100%|█| 15.7k/15.7k [00:00<00:00, 77.7MB/s]
Downloading pytorch_model.bin: 100%|██████████| 438M/438M [00:04<00:00, 105MB/s]
Downloading (…)nce_bert_config.json: 100%|████| 53.0/53.0 [00:00<00:00, 294kB/s]
Downloading (…)cial_tokens_map.json: 100%|█████| 239/239 [00:00<00:00, 1.32MB/s]
Downloading (…)e933c/tokenizer.json: 100%|███| 466k/466k [00:00<00:00, 3.88MB/s]
Downloading (…)okenizer_config.json: 100%|█████| 383/383 [00:00<00:00, 2.70MB/s]
Downloading (…)933c/train_script.py: 100%|█| 13.2k/13.2k [00:00<00:00, 43.7MB/s]
Downloading (…)cbe6ee933c/vo

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

# Generate Embeddings and Upsert

In [10]:
# we will use batches of 64
batch_size = 64

for i in tqdm(range(0, len(df), batch_size)):
    # find end of batch
    i_end = min(i+batch_size, len(df))
    # extract batch
    batch = df.iloc[i:i_end]
    # generate embeddings for batch
    emb = retriever.encode(batch["input"].tolist()).tolist()
    # get metadata
    meta = batch.to_dict(orient="records")
    # create unique IDs
    ids = [f"{idx}" for idx in range(i, i_end)]
    # add all to upsert list
    to_upsert = list(zip(ids, emb, meta))
    # upsert/insert these records to pinecone
    _ = index.upsert(vectors=to_upsert)

# check that we have all vectors in index
index.describe_index_stats()

100%|█████████████████████████████████████████| 782/782 [30:06<00:00,  2.31s/it]


{'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 50001}},
 'total_vector_count': 50001}

# Initialize Generator

In [11]:
from transformers import BartTokenizer, BartForConditionalGeneration

# load bart tokenizer and model from huggingface
# TODO: can use LLM trained on textbook, but using RAG with textbooks may be more effective
# going to write down thought process and steps later in documentation
tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to(device)

Downloading (…)okenizer_config.json: 100%|████| 27.0/27.0 [00:00<00:00, 189kB/s]
Downloading (…)olve/main/vocab.json: 100%|███| 899k/899k [00:00<00:00, 14.8MB/s]
Downloading (…)olve/main/merges.txt: 100%|███| 456k/456k [00:00<00:00, 22.8MB/s]
Downloading (…)/main/tokenizer.json: 100%|█| 1.36M/1.36M [00:00<00:00, 6.53MB/s]
Downloading (…)lve/main/config.json: 100%|█| 1.32k/1.32k [00:00<00:00, 3.41MB/s]
Downloading pytorch_model.bin: 100%|███████| 1.63G/1.63G [00:19<00:00, 83.3MB/s]


In [17]:
def query_pinecone(query, top_k):
    # generate embeddings for the query
    xq = retriever.encode([query]).tolist()
    # search pinecone index for context passage with the answer
    xc = index.query(xq, top_k=top_k, include_metadata=True)
    return xc

def format_query(query, context):
    # extract passage_text from Pinecone search result and add the <P> tag
    context = [f"<P> {m['metadata']['input']}" for m in context]
    # concatinate all context passages
    context = " ".join(context)
    # contcatinate the query and context passages
    query = f"question: {query} context: {context}"
    return query

In [24]:
# example
query = "are there any image findings that are particularly indicative of a positive or negative prognosis for patients with lung cancer?"

context = query_pinecone(query, top_k=1)
print(context)

query = format_query(query, context['matches'])
print(query)

{'matches': [{'id': '39306',
              'metadata': {'input': 'Evaluate disease status for clinical '
                                    'trial. Provide 3-D measurements are all '
                                    'reference lesions and compare prior CT. '
                                    'Lung cancer status post 22 cycles of '
                                    'chemotherapy. CHEST:LUNGS AND PLEURA: '
                                    'Right upper lobe mass measures 2.3 x 1.8 '
                                    'cm (image 30; series 4), not '
                                    'significantly changed from previous '
                                    'study.MEDIASTINUM AND HILA: Small '
                                    'pericardial effusion. Stable calcified '
                                    'right hilar and subcarinal lymph '
                                    'nodes.CHEST WALL: No significant '
                                    'abnormality notedABDOMEN:LIVE

In [25]:
# TODO: would insert RAG functionality here - explore how!

# Testing / Eval

In [26]:
def generate_answer(query):
    # tokenize the query to get input_ids
    inputs = tokenizer([query], max_length=1024, return_tensors="pt").to(device)
    # use generator to predict output ids
    ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=40)
    # use tokenizer to decode the output ids
    answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return print(answer)

In [27]:
query = "are there any image findings that are particularly indicative of a positive or negative prognosis for patients with lung cancer?"
context = query_pinecone(query, top_k=5)
query = format_query(query, context["matches"])
generate_answer(query)

There are a couple of things to consider. The first is that there is a lot of variability in the image quality. The second is that there is a lot of variability in the image quality


# Appendix

### source / references
- https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/search/question-answering/abstractive-question-answering.ipynb#scrollTo=6iFLxPPvx2Tm