### Proof of Concept

Outline

- import data / create dataloader classes
- import text and vision model / create multimodal medvqa model that inherits from hf
- import hf trainer and create model trainer that inherits from hf
- train on base model / data, then fine tune model from checkpoint using anything (radiology in practice, slate for demo)

Data:
https://huggingface.co/datasets/xmcmic/PMC-VQA/tree/main

### TODO:
- bert tokenization of text in dataset wrapper
- data augmentation python script

- peft + clip models

In [47]:
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dataclasses import dataclass, field
import transformers
from transformers import AutoModel,BertConfig,AutoTokenizer,AutoProcessor,LlamaTokenizer, CLIPVisionModel, CLIPVisionConfig
from torchvision import transforms
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
from peft import get_peft_model, LoraConfig, TaskType

### Dataset

In [93]:
# VQA Dataset

class VQA_Dataset(Dataset):
    def __init__(self, csv_path, img_root_dir, image_res, pretrained_tokenizer, is_train=True):
        self.is_train = is_train
        self.dataset = pd.read_csv(csv_path)
        self.img_root_dir = img_root_dir
        self.img_path_list = np.asarray(self.dataset['Figure_path'])
        self.question_list = np.asarray(self.dataset['Question'])
        self.choice_list = np.asarray(self.dataset.iloc[:,-5:-1])
        self.answer_list = np.asarray(self.dataset['Answer'])
        self.answer_label_list = np.asarray(self.dataset['Answer_label'])
        
        # initialize bert tokenizer + special tokens dict from pretrained llama model
        self.tokenizer = LlamaTokenizer.from_pretrained(pretrained_tokenizer) # explore diff models - can fine-tune on medical but prob not too relevant
        special_tokens_dict = {'mask_token': "</s>",
                               'eos_token': "</s>",
                               'bos_token': "<s>",
                               'unk_token': "<unk>"}
        self.tokenizer.add_special_tokens(special_tokens_dict)
        self.tokenizer.pad_token_id=0
    
        # initialize, augment, and normalize torch img transformer
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose([                        
                transforms.RandomResizedCrop([image_res,image_res],scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.RandomHorizontalFlip(),
                #RandomAugment(2,7,isPIL=True,augs=['Identity','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
                transforms.ToTensor(),
                normalize,
            ])

    def __len__(self):
        return len(self.img_path_list)
    
    def encode_text(self, question_text, question_text_with_answer, mask_token= '</s>', pad_token='<unk>', eos_token = '</s>'):
        # encode text using bert / llama tokenizer
        return

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_root_dir, self.img_path_list[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        question = self.question_list[idx]
        answer = self.answer_list[idx]
        answer_label = self.answer_label_list[idx]
        choice_A = str(self.choice_list[idx][0])
        choice_B = str(self.choice_list[idx][1])
        choice_C = str(self.choice_list[idx][2])
        choice_D = str(self.choice_list[idx][3])

        return question, answer, image # should be dictionary of all values

In [94]:
# model config
@dataclass
class ModelArguments:
    embed_dim: Optional[int] = field(default=768)
    pretrained_tokenizer:  Optional[str] = field(default="TencentARC/LLaMA-Pro-8B") # TBD
    pretrained_model: Optional[str] = field(default="meta-llama/Llama-2-7b")
    image_encoder: Optional[str] = field(default="CLIP")
    # pmcclip_pretrained: Optional[str] = field(default="./models/pmc_clip/checkpoint.pt")
    clip_pretrained: Optional[str] = field(default="openai/clip-vit-base-patch32")
    # ckp: Optional[str] = field(default="./Results/VQA_lora_noclip/vqa/checkpoint-6500")

# data config
@dataclass
class DataArguments:
    image_res: Optional[int] = field(default=512)
    img_root_dir: str = field(default='images/')
    Train_path: str = field(default='data/train.csv')
    Test_path: str= field(default='data/test.csv')
    
# optimization, training data, etc.
@dataclass
class TrainingArguments(transformers.TrainingArguments):
    output_dir: Optional[str] = field(default="results")
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    logging_dir: Optional[str] = field(default="logs")
    logging_steps: Optional[int] = field(default=50)

In [95]:
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
parser.add_argument("-f", "--file", required=False)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()[:3]

Train_dataset = VQA_Dataset(data_args.Train_path, data_args.img_root_dir, data_args.image_res, model_args.pretrained_tokenizer, is_train=True)
Test_dataset = VQA_Dataset(data_args.Test_path, data_args.img_root_dir, data_args.image_res, model_args.pretrained_tokenizer, is_train=True)

In [96]:
# single training example
Train_dataset[0]

('What is the uptake pattern in the breast? ',
 'Focal uptake pattern',
 tensor([[[-0.6623, -0.6281, -0.6452,  ...,  2.2318,  2.2318,  2.2318],
          [-0.6452, -0.6281, -0.6452,  ...,  2.2318,  2.2318,  2.2318],
          [-0.6281, -0.6452, -0.6623,  ...,  2.2318,  2.2318,  2.2318],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489]],
 
         [[-0.5476, -0.5126, -0.5301,  ...,  2.4111,  2.4111,  2.4111],
          [-0.5301, -0.5126, -0.5301,  ...,  2.4111,  2.4111,  2.4111],
          [-0.5126, -0.5301, -0.5476,  ...,  2.4111,  2.4111,  2.4111],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286]],
 
         [[-0.3230, -0.2881,

### Model

In [97]:
# bert hyperparams
@dataclass
class CLIPconfig:
    bert_model: str = field(default='dmis-lab/biobert-v1.1')
    MOMENTUM: float = 0.5  # 0.99
    
# peft for fine tuning + lora
@dataclass
class PEFTconfig:
    use_pretrained_peft: bool = field(default=True)
    model: str = field(default='lora')
    lora_rank: int = field(default=8) # dimension of low-rank matrices
    lora_alpha: int = field(default=32) # scaling factor for the low-rank matrices
    lora_dropout: int = field(default=0.1) # dropout probability of the LoRA layers
    # num_virtual_tokens: int = field(default=32)
    # mapping_hidden_dim: int = field(default=1024)

In [98]:
def get_peft_config(peft_args: PEFTconfig):
    if peft_args.use_pretrained_peft:
        peft_config = LoraConfig(
            task_type = TaskType.CAUSAL_LM, inference_mode=False,
            r = peft_args.lora_rank,
            lora_alpha = peft_args.lora_alpha, 
            lora_dropout = peft_args.lora_dropout
        )
    else:
        # TODO: lora fine tuning on text encoder for medical domain
        # perf_config = ...
        pass
    return peft_config

In [99]:
parser = transformers.HfArgumentParser((CLIPconfig, PEFTconfig))
parser.add_argument("-f", "--file", required=False)
clip_args, peft_args = parser.parse_args_into_dataclasses()[:2]

In [104]:
class VQA_Model(nn.Module):
    def __init__(self, config):
        super(VQA_Model, self).__init__()
        embed_dim = config.embed_dim
        self.tokenizer = LlamaTokenizer.from_pretrained(config.pretrained_tokenizer)
        self.llama_model = transformers.LlamaModel.from_pretrained(config.pretrained_model)
        # peft_config = get_peft_config(PEFTconfig)
        self.text_encoder = AutoTokenizer.from_pretrained(config.pretrained_tokenizer) # currently just biobert/tencent; TODO: change to peft model
        self.text_embed = nn.Sequential(nn.Linear(4096, embed_dim))
        
        # clip vision encoder
        if config.image_encoder == "CLIP":
            self.image_encoder_name = "CLIP"
            clip_config = CLIPVisionConfig(image_size=512)
            self.image_encoder = CLIPVisionModel(clip_config).from_pretrained(config.clip_pretrained) # "openai/clip-vit-base-patch32"
            self.image_encoder.vision_model.embeddings = transformers.models.clip.modeling_clip.CLIPVisionEmbeddings(clip_config)
        else:
            self.image_encoder_name = "PMC_CLIP"
            # TODO: pmc clip handling...
        # other hidden players

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

In [105]:
model = VQA_Model(model_args)

OSError: You are trying to access a gated repo.
Make sure to request access at https://huggingface.co/meta-llama/Llama-2-7b and pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`.