# MedVQA

Outline

- import data / create dataloader classes
- import text and vision model / create multimodal medvqa model that inherits from hf
- import hf trainer and create model trainer that inherits from hf
- train on base model / data, then fine tune model from checkpoint using anything (radiology in practice, slate for demo)

Data:
https://huggingface.co/datasets/xmcmic/PMC-VQA/tree/main

### TODO:
- bert tokenization of text in dataset wrapper
- data augmentation python script
- peft + clip models

In [2]:
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dataclasses import dataclass, field
import transformers
from transformers import AutoModel,BertConfig,AutoTokenizer,AutoProcessor,LlamaTokenizer, CLIPVisionModel, CLIPVisionConfig # model
from transformers import Trainer # training
from torchvision import transforms
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
from peft import get_peft_model, LoraConfig, TaskType
from einops import rearrange

from Transformer import Transformer

# Dataset

In [52]:
# VQA Dataset

class VQA_Dataset(Dataset):
    def __init__(self, csv_path, img_root_dir, image_res, pretrained_tokenizer, is_train=True):
        self.is_train = is_train
        self.dataset = pd.read_csv(csv_path)
        self.img_root_dir = img_root_dir
        self.img_path_list = np.asarray(self.dataset['Figure_path'])
        self.question_list = np.asarray(self.dataset['Question'])
        self.choice_list = np.asarray(self.dataset.iloc[:,-5:-1])
        self.answer_list = np.asarray(self.dataset['Answer'])
        self.answer_label_list = np.asarray(self.dataset['Answer_label'])
        
        # initialize bert tokenizer + special tokens dict from pretrained llama model
        self.tokenizer = LlamaTokenizer.from_pretrained(pretrained_tokenizer) # explore diff models - can fine-tune on medical but prob not too relevant
        special_tokens_dict = {'mask_token': "</s>",
                               'eos_token': "</s>",
                               'bos_token': "<s>",
                               'unk_token': "<unk>"}
        self.tokenizer.add_special_tokens(special_tokens_dict)
        self.tokenizer.pad_token_id=0
    
        # initialize, augment, and normalize torch img transformer
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose([                        
                transforms.RandomResizedCrop([image_res,image_res],scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.RandomHorizontalFlip(),
                #RandomAugment(2,7,isPIL=True,augs=['Identity','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
                transforms.ToTensor(),
                normalize,
            ])

    def __len__(self):
        return len(self.img_path_list)
    
    def encode_text(self, question_text, question_text_with_answer, mask_token= '</s>', pad_token='<unk>', eos_token = '</s>'):
        # encode text using bert / llama tokenizer
        def measure_text_len(text):
            text_logits = self.tokenizer.encode(text)
            return len(text_logits) - 1

        question_text_with_answer_logits = question_text_with_answer.split()
        question_text_logits = question_text.split()
        bert_input_logits = []
        output_mask = []
        bert_label_logits = []

        # create output mask by comparing len of text with text + answers
        for i, logit in enumerate(question_text_with_answer_logits):
            if i < len(question_text_logits):
                word_len = measure_text_len(logit)
                bert_input_logits += [logit]
                bert_label_logits += [pad_token] * word_len
                output_mask += [0] * word_len
            else:
                word_len = measure_text_len(logit)
                bert_input_logits += [mask_token] * word_len
                bert_label_logits += [logit]
                output_mask += [1] * word_len

        bert_input_logits += [eos_token]
        bert_label_logits += [eos_token]
        bert_input = ' '.join(bert_input_logits)
        bert_label = ' '.join(bert_label_logits)
        return bert_input, bert_label

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_root_dir, self.img_path_list[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        question = self.question_list[idx]
        answer = self.answer_list[idx]
        answer_label = self.answer_label_list[idx]
        choice_A = str(self.choice_list[idx][0])
        choice_B = str(self.choice_list[idx][1])
        choice_C = str(self.choice_list[idx][2])
        choice_D = str(self.choice_list[idx][3])
        
        # encode text logits + mlm
        question_text = 'Question: '+ question + 'The choices are: '  + choice_A + ' ' + choice_B + ' ' + choice_C + ' ' + choice_D + 'The Answer is: '
        question_text_with_answer = 'Question: '+ question + 'The choices are: '  + choice_A + ' ' + choice_B + ' ' + choice_C + ' ' + choice_D + 'The Answer is: ' + answer_label
        bert_input, bert_label = self.encode_text(question_text,question_text_with_answer)

        if self.is_train:
            encoded_input = self.tokenizer(bert_input, add_special_tokens=True, padding='max_length', truncation=True, max_length= 256)
            encoded_label = self.tokenizer(bert_label, add_special_tokens=True, padding='max_length', truncation=True, max_length= 256)
        else:
            encoded_input = self.tokenizer(bert_input, add_special_tokens=True, padding='max_length', truncation=True, max_length= 256,return_tensors="pt")
            encoded_label = self.tokenizer(bert_label, add_special_tokens=True, padding='max_length', truncation=True, max_length= 256,return_tensors="pt")

        return {
            "image": image,
            "text_logits": encoded_input['input_ids'],
            "attention_mask_logits": encoded_input['attention_mask'],
            "label": encoded_label['input_ids'],
            "image_path": img_path,
            'Choice_A': choice_A,
            'Choice_B': choice_B,
            'Choice_C': choice_C,
            'Choice_D': choice_D,
            'Answer': answer,
            'Answer_label': answer_label,
            }

In [53]:
# model config
@dataclass
class ModelArguments:
    embed_dim: Optional[int] = field(default=768)
    pretrained_tokenizer:  Optional[str] = field(default="TencentARC/LLaMA-Pro-8B") # TBD
    pretrained_model: Optional[str] = field(default="meta-llama/Llama-2-7b-hf")
    image_encoder: Optional[str] = field(default="CLIP")
    # pmcclip_pretrained: Optional[str] = field(default="./models/pmc_clip/checkpoint.pt")
    clip_pretrained: Optional[str] = field(default="openai/clip-vit-base-patch32")
    # ckp: Optional[str] = field(default="./Results/VQA_lora_noclip/vqa/checkpoint-6500")

# data config
@dataclass
class DataArguments:
    image_res: Optional[int] = field(default=512)
    img_root_dir: str = field(default='images/')
    Train_path: str = field(default='data/train.csv')
    Test_path: str= field(default='data/test.csv')
    
# optimization, training data, etc.
@dataclass
class TrainingArguments(transformers.TrainingArguments):
    output_dir: Optional[str] = field(default="results")
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    logging_dir: Optional[str] = field(default="logs")
    logging_steps: Optional[int] = field(default=50)

In [54]:
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
parser.add_argument("-f", "--file", required=False)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()[:3]

Train_dataset = VQA_Dataset(data_args.Train_path, data_args.img_root_dir, data_args.image_res, model_args.pretrained_tokenizer, is_train=True)
Test_dataset = VQA_Dataset(data_args.Test_path, data_args.img_root_dir, data_args.image_res, model_args.pretrained_tokenizer, is_train=True)

In [55]:
# single training example
Train_dataset[0].keys()

dict_keys(['image', 'text_logits', 'attention_mask_logits', 'label', 'image_path', 'Choice_A', 'Choice_B', 'Choice_C', 'Choice_D', 'Answer', 'Answer_label'])

# Model

### Architecture

1. **Text Embedding Layer (text_embedder):** This layer is responsible for embedding textual inputs. It consists of a linear transformation (fully connected layer) that maps input features (of size 4096) to an embedding space of size 768.

2. **Image Encoder (image_encoder):** This component processes visual inputs. It uses the CLIPVisionModel, which incorporates a CLIPVisionTransformer. This transformer model consists of several components:

3. **CLIPVisionEmbeddings:** Handles embeddings for visual inputs, including patch embeddings and position embeddings.

4. **CLIPEncoder:** Contains a stack of CLIP encoder layers. Each layer includes self-attention mechanisms and feedforward neural networks (MLPs).

5. **LayerNorm:** Applies layer normalization to the output of the encoder.

6. **Fusion Module (fusion_module):** This module fuses information from both textual and visual inputs. It consists of several Residual Attention Blocks, each containing a multihead attention mechanism, layer normalization, and a feedforward neural network.

7. **Query Decoder Layer (fquery_decoder_layer):** This layer is part of the query decoder. It includes a transformer decoder layer, which consists of self-attention mechanisms, multihead attention mechanisms, linear transformations, layer normalization, and dropout.

8. **Query Decoder (fquery_decoder):** This component comprises a stack of transformer decoder layers.

9. **Softmax Layer (softmax):** Applies the softmax function to the output logits along the specified dimension (-1, which is typically the last dimension) to obtain probability scores.

### Forward Step
1. **Encode Images (CLIP):**
Use CLIP image encoder to encode input images, slicing off the first token.

2. **Encode Text:**
Encode input text using specified text encoder (BERT or bioBERT). Takes *text_logits* and *attention_mask_logits* as inputs.

3. **Concatenate Features:**
Concatenate encoded image and text features, adding special tokens for images.

4. **Fusion with Transformer:**
Pass concatenated features through a transformer-based fusion module.

5. **Masked Language Modeling (MLM):**
Apply MLM to fused features, generating predictions for masked tokens.

6. **Softmax Activation:**
Apply softmax activation to output logits, producing probabilities.

In [56]:
# bert hyperparams
@dataclass
class CLIPconfig:
    bert_model: str = field(default='dmis-lab/biobert-v1.1')
    MOMENTUM: float = 0.5  # 0.99
    context_length: int = 256
    vocab_size: int = 32000
    width: int = 768
    heads: int = 8
    layers: int = 12
    fusion_layers: int = 1
    
# peft for fine tuning + lora
@dataclass
class PEFTconfig:
    use_pretrained_peft: bool = field(default=True)
    model: str = field(default='lora')
    lora_rank: int = field(default=8) # dimension of low-rank matrices
    lora_alpha: int = field(default=32) # scaling factor for the low-rank matrices
    lora_dropout: int = field(default=0.1) # dropout probability of the LoRA layers
    # num_virtual_tokens: int = field(default=32)
    # mapping_hidden_dim: int = field(default=1024)

In [57]:
def get_peft_config(peft_args: PEFTconfig):
    if peft_args.use_pretrained_peft:
        peft_config = LoraConfig(
            task_type = TaskType.CAUSAL_LM, inference_mode=False,
            r = peft_args.lora_rank,
            lora_alpha = peft_args.lora_alpha, 
            lora_dropout = peft_args.lora_dropout
        )
    else:
        # TODO: lora fine tuning on text encoder for medical domain
        # perf_config = ...
        pass
    return peft_config

In [58]:
parser = transformers.HfArgumentParser((CLIPconfig, PEFTconfig))
parser.add_argument("-f", "--file", required=False)
clip_args, peft_args = parser.parse_args_into_dataclasses()[:2]

In [59]:
class VQA_Model(nn.Module):
    def __init__(self, config):
        super(VQA_Model, self).__init__()
        embed_dim = config.embed_dim
        self.tokenizer = LlamaTokenizer.from_pretrained(config.pretrained_tokenizer)
        # self.llama_model = transformers.LlamaModel.from_pretrained(config.pretrained_model, token = "")
        # peft_config = get_peft_config(PEFTconfig)
        self.text_encoder = AutoTokenizer.from_pretrained(config.pretrained_tokenizer) # currently just biobert/tencent; TODO: change to peft model
        self.text_embedder = nn.Sequential(nn.Linear(4096, embed_dim))
        self.context_length = 256
        
        # clip vision encoder
        if config.image_encoder == "CLIP":
            self.image_encoder_name = "CLIP"
            clip_config = CLIPVisionConfig(image_size=512)
            self.image_encoder = CLIPVisionModel(clip_config).from_pretrained(config.clip_pretrained) # "openai/clip-vit-base-patch32"
            self.image_encoder.vision_model.embeddings = transformers.models.clip.modeling_clip.CLIPVisionEmbeddings(clip_config)
        else:
            self.image_encoder_name = "PMC_CLIP"
            # TODO: pmc clip handling...
        # other hidden layers
        self.fquery_input = nn.Parameter(torch.empty(32, 768))
        self.fquery_decoder_layer = nn.TransformerDecoderLayer(embed_dim, nhead=4, dim_feedforward=768, dropout=0.1, activation='relu', norm_first=True)
        self.fquery_decoder_norm = nn.LayerNorm(embed_dim)
        self.fquery_decoder = nn.TransformerDecoder(self.fquery_decoder_layer, 12, self.fquery_decoder_norm)

        text_cfg = CLIPconfig
        self.transformer_width = text_cfg.width
        self.positional_embedding = nn.Parameter(torch.empty(CLIPconfig.context_length, text_cfg.width))
        self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))
        self.mlm_projection = nn.Parameter(torch.empty(text_cfg.width, text_cfg.vocab_size))
        self.softmax = nn.LogSoftmax(dim=-1)
        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.img_special_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.fusion_module = Transformer(
            width=text_cfg.width,
            layers=text_cfg.fusion_layers,
            heads=text_cfg.heads,
        )
        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
        self.init_parameters()

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(CLIPconfig.context_length, CLIPconfig.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def init_parameters(self):
        nn.init.normal_(self.positional_embedding, std=0.01)
        nn.init.constant_(self.logit_scale, np.log(1/0.07))
        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer_width ** -0.5)
        if self.mlm_projection is not None:
            nn.init.normal_(self.mlm_projection, std=self.transformer_width ** -0.5)
    
    def forward(self, image, text_logits, attention_mask_logits):
        batch_size, _ = text_logits.shape
        if self.image_encoder_name == 'CLIP':
            # resize image to batch_size, 
            image_logits = self.image_encoder(image).last_hidden_state[:,1:,:]
        else:
            print('not supported') # TODO

        image_query_features = self.fquery_input.unsqueeze(0).expand(batch_size, -1, -1) # resize to (b . .)
        image_logits = self.fquery_decoder(image_query_features.transpose(0,1), image_logits.transpose(0,1)).transpose(0,1)
        
        question_features  = self.text_encoder(text = text_logits, attention_mask = attention_mask_logits)[0] # TODO: TRAIN TEXT ENCODER
        question_features = rearrange(question_features, 'b n d -> (b n) d') # (embed dim, 4096)
        question_features = self.text_embedder(question_features)
        x = rearrange(question_features,'(b n) d -> b n d', b = batchsize)
        
        B, _len, _dim = x.shape
        img_special_tokens = self.img_special_token.expand(B, -1, -1) # (128, 1, embed_dim)
        x = torch.cat([x, img_special_tokens, image_logits], dim=1) # combine image and text features
        x = x.permute(1, 0, 2)  # batch-first -> sequence-first
        x = self.fusion_module(x)
        x = x.permute(1, 0, 2)  # sequence-first -> batch-first
        x = x[:, :-33, :]  # Remove token (img_special_token, img)
        out = self.softmax(x @ self.mlm_projection)  # [batch_size=128, n_ctx=77, vocab_size=49409]
        return out 

# Training

In [60]:
class VQATrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        image = inputs['image']  
        label = inputs['labels'].to(dtype=torch.long) 
        input_logits = inputs['text_logits'] 
        attention_mask_logits = inputs['attention_mask_logits'] 
        outputs = model(image, input_logits, attention_mask_logits)
        loss = F.nll_loss(outputs.transpose(1, 2), label, ignore_index=0)
        return (loss, {'outputs':outputs}) if return_outputs else loss

In [61]:
model = VQA_Model(model_args)

In [62]:
trainer = VQATrainer(model=model, 
                     train_dataset = Train_dataset, 
                     eval_dataset = Test_dataset,
                     args=training_args
                    )
trainer.train()
trainer.save_state()

ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).