In [None]:
import io
from PIL import Image
from pathlib import Path
import os
import torch
from sklearn.manifold import TSNE
import pickle
import numpy as np
import matplotlib.pyplot as plt
os.environ['HF_HOME'] = '/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/.cache' ## THIS HAS TO BE BEFORE YOU IMPORT TRANSFORMERS
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig

In [None]:
def setup_model() -> tuple:
    device = "cuda"
    dtype = torch.float16

    processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        "StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True
    ).to(device)
    generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")

    return processor, model, device, dtype, generation_config

processor, model, device, dtype, generation_config = setup_model()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [15]:
class EmbeddingCollector:

    def __init__ (self, image_folder, processor, model, generation_config):
        self.image_folder = image_folder
        self.processor = processor
        self.model = model
        self.generation_config = generation_config

        self.patch_embeddings_output = None
        self.post_layer_norm_output = None
        self.q_former_output = None
        self.language_projection_output = None
        self.embeddings_dict = {}

    def patch_embedding_hook(self, module, input, output):
        output = output.cpu().detach()  # Assuming you want to move data to CPU for analysis
        self.patch_embeddings_output = output

        
    def post_layer_norm_hook(self,module, input, output):
        """Function to be called by the hook for the post layer norm layer."""
        output = output.cpu().detach()  # Assuming you want to move data to CPU for analysis
        self.post_layer_norm_output = output

    
    def language_projection_hook(self, module, input, output):
        """Function to be called by the hook for the language projection layer."""
        input = input[0].cpu().detach()
        self.q_former_output = input

        output = output.cpu().detach()
        self.language_projection_output = output
        

    def generate_with_forward_hooks(self,image_path, prompt, processor, model, device, dtype, generation_config):
        images = Image.open(image_path).convert("RGB")

        # VinDR-specific convert image_id to a string
        # image_id_string = str(image_path).split("/")[-1].split(".")[0]

        # CheXpert-Train-specific 
        image_id_string = ("/").join(str(image_path).split("/")[8:]) # 8 for valid/train , 10 for test
        # register hooks
        patch_embeddings = model.vision_model.embeddings.patch_embedding.register_forward_hook(self.patch_embedding_hook)
        post_layer_norm = model.vision_model.post_layernorm.register_forward_hook(self.post_layer_norm_hook)
        language_projection = model.language_projection.register_forward_hook(self.language_projection_hook)
        
        # complete a forward pass 
        inputs = processor(
            images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt"
        ).to(device=device, dtype=dtype)
        
        output = model.generate(**inputs, generation_config=generation_config)
        response = None

        if output is not None: # output is none when disable language model
            output = output[0]
            response = processor.tokenizer.decode(output, skip_special_tokens=True)

        self.embeddings_dict[image_id_string] = {
            'patch_embeddings': self.patch_embeddings_output.cpu().numpy(),
            'post_layer_norm': self.post_layer_norm_output.cpu().numpy(),
            'q_former': self.q_former_output.cpu().numpy(),
            'language_projection': self.language_projection_output.cpu().numpy(),
        }

        # remove hooks
        patch_embeddings.remove()
        post_layer_norm.remove()
        language_projection.remove()

        return response
    
    
    def save_embeddings_dict_to_pickle(self, output_path):
        with open(output_path, 'wb') as handle:
            pickle.dump(self.embeddings_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [16]:
embeddings_collector = EmbeddingCollector("images", processor, model, generation_config)

In [17]:
# for training images
cheXpert_train_5000_csv_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/train_sample_paths/5,000.csv")
cheXpert_train_5002_10001_csv_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/train_sample_paths/10,000.csv")
cheXpert_small_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/")

chexpert_valid_csv_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/valid.csv")
chexpert_valid_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/")

chexpert_test_csv_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test.csv")
chexpert_test_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/small/")


# chexpert_

with open(cheXpert_train_5002_10001_csv_path, 'r') as f:
    lines = f.readlines()
    # start from line 5002 inclusive and end at 10001 exclusive
    for i, line in enumerate(lines[5002:]):
        if i % 1000 == 0:
            print(f"Collecting image {i}")

        image_path = line.split(",")[0]
        image_path = cheXpert_small_path / image_path
        
        prompt = "Describe the findings"
        response = embeddings_collector.generate_with_forward_hooks(image_path, prompt, processor, model, device, dtype, generation_config)
        # print(response)

embeddings_collector.save_embeddings_dict_to_pickle("chexpert_train_5002_10001_dict.pkl")

Collecting image 0
Collecting image 1000
Collecting image 2000
Collecting image 3000
Collecting image 4000


In [12]:
with open("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/CheXpert-small/embeddings_only_dict/chexpert_5000_train_dict.pkl", 'rb') as handle:
    embeddings_dict = pickle.load(handle)

In [None]:
# VINDR-SPECIFIC

# required_image_id_path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/image_text_reasoning_datasets/train_pathology_left_or_right_unaninmous_agreement_random_radiologist')
# vindr_dir_train_path = Path('/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/train/')
# with open(required_image_id_path, 'r') as f:
#     f.readline() # skip header    
#     # image_ids = f.readlines()
#     image_ids = {line.split(',')[0].strip() for line in f.readlines()}
#     for image_id in image_ids:
#         image_path = vindr_dir_train_path / f'{image_id}.png'
    
#         prompt = "Describe the findings"
#         embeddings_collector.generate_with_forward_hooks(image_path, prompt, processor, model, device, dtype, generation_config)
   

# embeddings_collector.save_embeddings_dict_to_pickle('train_pathology_unanimous_agreement_random_radiologist.pkl')