# Custom Medical VQA Test

### Goal: 
Using pre-trained LLM and scraped caption : image dataset, train a VQA model to accurately answer questions based on medical textbook. Explore whether an encoder or decoder model is more appropriate.

### Textbook Source:
https://drive.google.com/drive/u/2/folders/12mL45XMDRSxhkgMH_PIeQAAsAtbv-X2W

### TODO:
- question forming: {question : answering : image pairings} from scraped dictionary using coordinate classifier (potentially ask SRI experts)
- vision + text modalities: use LLM for text and ??? (resnet CV) for image modality (potentially ask SRI experts; maybe ask for online resources if they have any)
- Once modalities in place, reference MedBLIP / MedPalm papers + architecture

### Pairing + Question Forming

##### Data Loading

In [13]:
import os
import json
import regex as re
import pandas as pd
import numpy as np

In [5]:
# IMPORTANT: Run scrape notebook first
PDF_URL = "General - Mandell - Core Radiology (1e).pdf"

assert os.path.exists(f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}")
TEXT_DATA_FOLDER_URL = f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}"

In [53]:
raw_data = pd.DataFrame()

for ch_num, fjson in enumerate(os.listdir(TEXT_DATA_FOLDER_URL)):
    ftype = fjson.split('.')[-1]
    if ftype != 'json':
        continue
    fpath = TEXT_DATA_FOLDER_URL + f'/{fjson}'

    raw_file_data = pd.read_json(fpath)
    raw_file_data['ch'] = [ch_num] * len(raw_file_data)
    raw_data = pd.concat([raw_data, raw_file_data])

raw_data = raw_data.set_index(['ch'])
raw_data = raw_data.sort_index()[['header', 'body', 'images',
                                  'label_range', 'pg_range']]

In [54]:
raw_data.head(20)

##### QF

In [57]:
# how? reference SRI

### Text & Vision Modalities

### Appendix / Old Code

In [None]:
# x-y classification

def classify(row):
    # basic approach: find center of image, and get closest relevant text or N/A
    # returns new df indexed by ???? with text and relevant image
    # possibly add radius check; if text within certain radius, it gets added to df
    image_coords = np.array(row['image_coords'])
    text_coords = np.array(row['text_coords'])
    return_df = pd.DataFrame(['image_ref', 'header', 'text'])

    # potentially use a kd tree, but n*m should be good enough for 2-dimensions
    for img_ref, image_coord_ls in enumerate(image_coords):
        img_center = (image_coord_ls[0] + image_coord_ls[1])/2
        min_dist = 10000
        min_text_ref = -1
        for text_ref, text_coord_ls in enumerate(text_coords):
            text_center = text_coord_ls[0]
            dist = np.sqrt(np.sum((img_center - text_center) ** 2))
            if dist < min_dist:
                min_dist = dist
                min_text_ref = text_ref
        entry = pd.DataFrame({'image_ref': row['images'][image_ref], 
                              'header': row['header'], 
                              'text': row['body']})
        return_df = pd.concat(return_df, entry)