In [1]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import PIL
from torch import nn
from helpers.utils import load_yaml, load_checkpoint
from transformers import AutoModelForImageTextToText
import torch
from typing import Any, Dict, List, Optional, Tuple, Union

data = [{'images': PIL.Image.open('../../../Databases/PULSE/mimic_v4/p1856/p18569031/s41816242/41816242-0.png').convert('RGB'),
  'question': "<image>\nWhat are the main ECG features I'm seeing here?",
  'answer': 'Upon examining this ECG, I notice that the heart rhythm is irregularly irregular, with no discernible P waves. This is consistent with atrial fibrillation (AFib). Additionally, the ventricular response is slow, which is unusual in AFib. I also observe ventricular couplets, which are pairs of ventricular beats that are closely coupled. Furthermore, the QRS complexes exhibit a Left bundle branch block (LBBB) pattern, characterized by a prolonged QRS duration and a specific morphology. These findings are indicative of an abnormal ECG.',
  'id': '18569031_41816242',
  'metadata': {'subtask': 'morphology', 'type': 'open-ended'},
  'source': 'MIMIC-IV-ECG'}]

model_id = "google/paligemma2-3b-mix-224"
processor = AutoProcessor.from_pretrained(model_id)

class BaselineVLModel(nn.Module):
    def __init__(self,model_pretrained_path):
        super().__init__()
        self.model_pretrained_path = model_pretrained_path
        self.model_pretrained = AutoModelForImageTextToText.from_pretrained(self.model_pretrained_path, attn_implementation='eager')

        ##Unshare memory weightsn
        self.model_pretrained.language_model.lm_head.weight = torch.nn.Parameter(self.model_pretrained.language_model.lm_head.weight.clone())
        self.model_pretrained.language_model.model.embed_tokens.weight = torch.nn.Parameter(self.model_pretrained.language_model.model.embed_tokens.weight.clone())
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        **kwargs
    ):
        output = self.model_pretrained(
            input_ids = input_ids,
            attention_mask = attention_mask,
            position_ids = position_ids,
            past_key_values = past_key_values,
            inputs_embeds = inputs_embeds,
            labels = labels,
            use_cache = use_cache,
            output_attentions = output_attentions,
            output_hidden_states = output_hidden_states,
            return_dict = return_dict,
            pixel_values = pixel_values,
            **kwargs
        )
        return output


    def generate(
        self,
        input_ids = None,
        max_new_tokens=50,
        **kwargs
    ):
        output = self.model_pretrained.generate(
            input_ids = input_ids,
            max_new_tokens = max_new_tokens,
            **kwargs
        )
        return output[:,input_ids.shape[1]:]

model = BaselineVLModel(model_id)
model = model.to("cuda")
inputs = processor(
            [item['images'] for item in data],
            [f"Question: {item['question']} Answer: " for item in data],
            truncation=True,
            padding=True,
            max_length=512, 
            return_tensors="pt"
        ).to("cuda")
output = model.generate(**inputs)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
processor.tokenizer.batch_decode(output)

['P-QRS complexes<eos>']