In [1]:
import io
from PIL import Image
from pathlib import Path
import os
import torch
from sklearn.manifold import TSNE
import pickle
import numpy as np
import matplotlib.pyplot as plt
os.environ['HF_HOME'] = '/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/.cache' ## THIS HAS TO BE BEFORE YOU IMPORT TRANSFORMERS
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig

In [6]:
def setup_model() -> tuple:
    device = "cuda"
    dtype = torch.float16

    processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        "StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True
    ).to(device)
    generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")

    return processor, model, device, dtype, generation_config

processor, model, device, dtype, generation_config = setup_model()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacty of 44.31 GiB of which 33.00 MiB is free. Process 1667519 has 18.64 GiB memory in use. Including non-PyTorch memory, this process has 25.62 GiB memory in use. Of the allocated memory 24.84 GiB is allocated by PyTorch, and 281.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 

In [3]:
class EmbeddingCollector:

    def __init__ (self, image_folder, processor, model, generation_config):
        self.image_folder = image_folder
        self.processor = processor
        self.model = model
        self.generation_config = generation_config

        self.patch_embeddings_output = None
        self.post_layer_norm_output = None
        self.q_former_output = None
        self.language_projection_output = None
        self.embeddings_dict = {}

    def patch_embedding_hook(self, module, input, output):
        output = output.cpu().detach()  # Assuming you want to move data to CPU for analysis
        self.patch_embeddings_output = output

        
    def post_layer_norm_hook(self,module, input, output):
        """Function to be called by the hook for the post layer norm layer."""
        output = output.cpu().detach()  # Assuming you want to move data to CPU for analysis
        self.post_layer_norm_output = output

    
    def language_projection_hook(self, module, input, output):
        """Function to be called by the hook for the language projection layer."""
        input = input[0].cpu().detach()
        self.q_former_output = input

        output = output.cpu().detach()
        self.language_projection_output = output
        

    def generate_with_forward_hooks(self,image_path, prompt, processor, model, device, dtype, generation_config):
        images = Image.open(image_path).convert("RGB")

        # VINDR-SPECIFIC convert image_id to a string
        if isinstance(image_path, Path):
            image_id_string = str(image_path).split("/")[-1].split(".")[0]

        # register hooks
        patch_embeddings = model.vision_model.embeddings.patch_embedding.register_forward_hook(self.patch_embedding_hook)
        post_layer_norm = model.vision_model.post_layernorm.register_forward_hook(self.post_layer_norm_hook)
        language_projection = model.language_projection.register_forward_hook(self.language_projection_hook)
        
        # complete a forward pass 
        inputs = processor(
            images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt"
        ).to(device=device, dtype=dtype)
        output = model.generate(**inputs, generation_config=generation_config)[0]
        response = processor.tokenizer.decode(output, skip_special_tokens=True)

        self.embeddings_dict[image_id_string] = {
            'patch_embeddings': self.patch_embeddings_output.cpu().numpy(),
            'post_layer_norm': self.post_layer_norm_output.cpu().numpy(),
            'q_former': self.q_former_output.cpu().numpy(),
            'language_projection': self.language_projection_output.cpu().numpy(),
        }

        # remove hooks
        patch_embeddings.remove()
        post_layer_norm.remove()
        language_projection.remove()

        return response
    
    def save_embeddings_dict_to_pickle(self, output_path):
        with open(output_path, 'wb') as handle:
            pickle.dump(self.embeddings_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
embeddings_collector = EmbeddingCollector("images", processor, model, generation_config)

In [5]:
test_png_dset_path = '/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test/688ecdb1a4e994d42b5a50a8c4a9736f.png'
prompt = "Describe the findings in this image."
embeddings_collector.generate_with_forward_hooks(test_png_dset_path, prompt, processor, model, device, dtype, generation_config)

  [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


(tensor([[[-0.1465,  0.7617,  1.7031,  ..., -0.6055, -0.4121, -0.1396],
         [-0.1465,  0.7617,  1.7031,  ..., -0.6094, -0.4180, -0.1396],
         [-0.1523,  0.7617,  1.7109,  ..., -0.6016, -0.4082, -0.1406],
         ...,
         [-0.1543,  0.7617,  1.7109,  ..., -0.6055, -0.4082, -0.1416],
         [-0.0884,  0.8320,  1.7812,  ..., -0.5391, -0.4551, -0.2275],
         [-0.1445,  0.7578,  1.7031,  ..., -0.6055, -0.4121, -0.1377]]],
       device='cuda:0', dtype=torch.bfloat16),)


TypeError: Got unsupported ScalarType BFloat16