# Custom Medical VQA Test

### Goal: 
Using pre-trained LLM and scraped caption : image dataset, train a VQA model to accurately answer questions based on medical textbook. Explore whether an encoder or decoder model is more appropriate.

### Textbook Source:
https://drive.google.com/drive/u/2/folders/12mL45XMDRSxhkgMH_PIeQAAsAtbv-X2W

### TODO:
- question forming: {question : answering : image pairings} from scraped dictionary using coordinate classifier (potentially ask SRI experts)
- vision + text modalities: use LLM for text and ??? (resnet CV) for image modality (potentially ask SRI experts; maybe ask for online resources if they have any)
- Once modalities in place, reference MedBLIP / MedPalm papers + architecture

### Pairing + Question Forming

##### Data Loading

In [1]:
import os
import json
import regex as re
import pandas as pd
import numpy as np

In [2]:
# IMPORTANT: Run scrape notebook first
PDF_URL = "General - Mandell - Core Radiology (1e).pdf"

assert os.path.exists(f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}")
TEXT_DATA_FOLDER_URL = f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}"

In [3]:
raw_data = pd.DataFrame()

for ch_num, fjson in enumerate(os.listdir(TEXT_DATA_FOLDER_URL)):
    ftype = fjson.split('.')[-1]
    if ftype != 'json':
        continue
    fpath = TEXT_DATA_FOLDER_URL + f'/{fjson}'

    raw_file_data = pd.read_json(fpath)
    raw_file_data['ch'] = [ch_num] * len(raw_file_data)
    raw_data = pd.concat([raw_data, raw_file_data])

raw_data = raw_data.set_index(['ch'])
raw_data = raw_data.sort_index()[['header', 'body', 'images',
                                  'label_range', 'pg_range']]

In [4]:
raw_data.head(20)

Unnamed: 0_level_0,header,body,images,label_range,pg_range
ch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,[STARTUP],Approach to interpreting an angiogram Normal ...,[scrape_out/General - Mandell - Core Radiology...,"[0, 5]","[707, 708]"
0,Air embolism,Percutaneous transluminal angioplasty (PTA) Em...,[scrape_out/General - Mandell - Core Radiology...,"[0, 0]","[709, 712]"
0,superior vena cava (s Vc),Congenital anomalies of the superior vena cava...,[scrape_out/General - Mandell - Core Radiology...,"[1, 0]","[712, 721]"
0,Hepatic artery aneurysmMesenteric ischemiaAcut...,Role of interventional radiology in gastrointe...,[scrape_out/General - Mandell - Core Radiology...,"[0, 2]","[722, 724]"
0,Atherosclerotic renal artery stenosisFibromusc...,Fibromuscular dysplasia (medial fibroplasia su...,[scrape_out/General - Mandell - Core Radiology...,"[0, 3]","[725, 725]"
0,OncocytomaAngiomyolipoma (AML),Axial T 1-weighted mRI shows a mass in the low...,[scrape_out/General - Mandell - Core Radiology...,"[0, 6]","[726, 726]"
0,Renal traumaRenal arteriovenous fistulas and m...,grade Iv renal injury: Selective d SA angiogra...,[scrape_out/General - Mandell - Core Radiology...,"[0, 3]","[727, 728]"
0,Nutcracker syndromeMay–Thürner,may–Thürner: Axial contrast-enhanced CT (left ...,[scrape_out/General - Mandell - Core Radiology...,"[0, 1]","[729, 732]"
0,Varicocele,Retroaortic left renal vein: Initial digital s...,[scrape_out/General - Mandell - Core Radiology...,"[0, 2]","[733, 733]"
0,Percutaneous transhepatic cholangiography (Ptc),Biliary intervention overview and technique Th...,[],"[0, 2]","[734, 734]"


##### Question Generation (QG)

In [5]:
# Reference: https://huggingface.co/mrm8488/t5-base-finetuned-question-generation-ap
from transformers import AutoModelWithLMHead, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")

Downloading (…)okenizer_config.json: 100%|████| 25.0/25.0 [00:00<00:00, 131kB/s]
Downloading (…)lve/main/config.json: 100%|█| 1.23k/1.23k [00:00<00:00, 7.91MB/s]
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
Downloading (…)ve/main/spiece.model: 100%|███| 792k/792k [00:00<00:00, 3.21MB/s]
Downloading (…)cial_tokens_map.json: 100%|█| 1.79k/1.79k [00:00<00:00, 10.1MB/s]
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this wa

In [38]:
def get_question(answer, context, max_length=64):
    input_text = "answer: %s  context: %s </s>" % (answer, context)
    features = tokenizer([input_text], return_tensors='pt')

    output = model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'],
               max_length=max_length)

    return tokenizer.decode(output[0])


def form_question_set(answers, context):
    """
    returns list of (question, answer) pairs
    """
    qa_pairs = []
    
    # for single answer string
    if type(answers) == str: return [(get_question(answers, context), answers)]
    
    # for list of answers
    for answer in answers:
        qa_pairs.append((get_question(answer, context), answer))
    return qa_pairs

In [39]:
context = "Percutaneous transluminal angioplasty (PTA) Embolic materials Complications of embolization Catheter sizing High-flow catheters Selective and superselective catheters Standard floppy tip wires Case courtesy Timothy P. Killoran, MD, Brigham and Women’s Hospital.\ngiant cell arteritis: Selective angiogram of the axillary artery (catheter is not visible) shows mild narrowing\nof the axillary artery, complete occlusion of the axillary/brachial artery at the origin of the brachial artery\n(yellow arrow), and an irregular appearance of the posterior circumflex humeral artery (red arrow).\n• Aortic disease is discussed in the cardiovascular imaging section."
answers = ["Air embolism", "Timothy P. Killoran, MD, Brigham and Women’s Hospital"]

print(form_question_set(answers, context))

[('<pad> question: What is a common complication of embolization?</s>', 'Air embolism'), ('<pad> question: Who kindly provided the catheters for this procedure?</s>', 'Timothy P. Killoran, MD, Brigham and Women’s Hospital')]


### Text & Vision Modalities

In [55]:
# qa dataset formation

qa_dataset = pd.DataFrame()

for n in range(len(raw_data)):
    context, answer = raw_data.iloc[n]['body'], raw_data.iloc[n]['header']
    qa = form_question_set(answer, context)
    for q, a in qa:
        new_row = pd.DataFrame({"question":[q], "answer":[a]})
        qa_dataset = pd.concat([qa_dataset, new_row], ignore_index=True)

In [56]:
# qa_dataset.to_csv('qa_dataset.csv')
# qa_dataset = pd.read_csv('qa_dataset.csv')

In [59]:
print(len(qa_dataset))
qa_dataset.head(20)

472


Unnamed: 0,question,answer
0,<pad> question: What is the name of the angiog...,[STARTUP]
1,<pad> question: What is a common complication ...,Air embolism
2,<pad> question: What is SVC?</s>,superior vena cava (s Vc)
3,<pad> question: What is the role of interventi...,Hepatic artery aneurysmMesenteric ischemiaAcut...
4,<pad> question: What is the name of the diseas...,Atherosclerotic renal artery stenosisFibromusc...
5,<pad> question: What is AML?</s>,OncocytomaAngiomyolipoma (AML)
6,<pad> question: What was the grade Iv injury?</s>,Renal traumaRenal arteriovenous fistulas and m...
7,<pad> question: What is the name of the syndro...,Nutcracker syndromeMay–Thürner
8,<pad> question: What is the name of the vein t...,Varicocele
9,<pad> question: What is the name of the proced...,Percutaneous transhepatic cholangiography (Ptc)


### Appendix / Old Code

In [None]:
# x-y classification

def classify(row):
    # basic approach: find center of image, and get closest relevant text or N/A
    # returns new df indexed by ???? with text and relevant image
    # possibly add radius check; if text within certain radius, it gets added to df
    image_coords = np.array(row['image_coords'])
    text_coords = np.array(row['text_coords'])
    return_df = pd.DataFrame(['image_ref', 'header', 'text'])

    # potentially use a kd tree, but n*m should be good enough for 2-dimensions
    for img_ref, image_coord_ls in enumerate(image_coords):
        img_center = (image_coord_ls[0] + image_coord_ls[1])/2
        min_dist = 10000
        min_text_ref = -1
        for text_ref, text_coord_ls in enumerate(text_coords):
            text_center = text_coord_ls[0]
            dist = np.sqrt(np.sum((img_center - text_center) ** 2))
            if dist < min_dist:
                min_dist = dist
                min_text_ref = text_ref
        entry = pd.DataFrame({'image_ref': row['images'][image_ref], 
                              'header': row['header'], 
                              'text': row['body']})
        return_df = pd.concat(return_df, entry)