In [1]:
# from transformers import AutoModel, AutoTokenizer

# model_path = "/vast/work/public/ml-datasets/llama-2/Llama-2-7b-hf"  # Replace with the full path if needed
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# model = AutoModel.from_pretrained(model_path)


In [2]:
import torch 

In [3]:
from huggingface_hub import hf_hub_download 
import torch 
import os 
from open_flamingo import create_model_and_transforms 
from accelerate import Accelerator 
from einops import repeat
from PIL import Image 
import matplotlib.pyplot as plt
import sys

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from src.utils import FlamingoProcessor

In [5]:
def clean_generation(response):
    """
    for some reason, the open-flamingo based model slightly changes the input prompt (e.g. prepends <unk>, an adds some spaces)
    """
    return response.replace('<unk> ', '').strip()

In [6]:
image_paths = [
    'img/synpic50962.jpg',
    'img/synpic52767.jpg',
    'img/synpic30324.jpg',
    'img/synpic21044.jpg',
    'img/synpic54802.jpg',
    'img/synpic57813.jpg',
    'img/synpic47964.jpg'
]

In [7]:
def main():
    accelerator = Accelerator() #when using cpu: cpu=True

    device = accelerator.device

    print('Loading model..')

    # >>> add your local path to Llama-7B (v1) model here:
    llama_path = "/vast/work/public/ml-datasets/llama-2/Llama-2-7b-hf"
    if not os.path.exists(llama_path):
        raise ValueError('Llama model not yet set up, please check README for instructions!')

    model, image_processor, tokenizer = create_model_and_transforms(
        clip_vision_encoder_path="ViT-L-14",
        clip_vision_encoder_pretrained="openai",
        lang_encoder_path=llama_path,
        tokenizer_path=llama_path,
        cross_attn_every_n_layers=4
    )
    # load med-flamingo checkpoint:
    checkpoint_path = hf_hub_download("med-flamingo/med-flamingo", "model.pt")
    print(f'Downloaded Med-Flamingo checkpoint to {checkpoint_path}')
    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
    processor = FlamingoProcessor(tokenizer, image_processor)

    # go into eval model and prepare:
    model = accelerator.prepare(model)
    is_main_process = accelerator.is_main_process
    model.eval()

    """
    Step 1: Load images
    """
    demo_images = [Image.open(path) for path in image_paths]

    """
    Step 2: Define multimodal few-shot prompt 
    """

    # example few-shot prompt:
    prompt = "You are a helpful medical assistant. You are being provided with images, a question about the image and an answer. Follow the examples and answer the last question. <image>Question: What is/are the structure near/in the middle of the brain? Answer: pons.<|endofchunk|><image>Question: Is there evidence of a right apical pneumothorax on this chest x-ray? Answer: yes.<|endofchunk|><image>Question: Is/Are there air in the patient's peritoneal cavity? Answer: no.<|endofchunk|><image>Question: Does the heart appear enlarged? Answer: yes.<|endofchunk|><image>Question: What side are the infarcts located? Answer: bilateral.<|endofchunk|><image>Question: Which image modality is this? Answer: mr flair.<|endofchunk|><image>Question: Where is the largest mass located in the cerebellum? Answer:"

    """
    Step 3: Preprocess data 
    """
    print('Preprocess data')
    pixels = processor.preprocess_images(demo_images)
    pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)
    tokenized_data = processor.encode_text(prompt)

    """
    Step 4: Generate response 
    """

    # actually run few-shot prompt through model:
    print('Generate from multimodal few-shot prompt')
    generated_text = model.generate(
    vision_x=pixels.to(device),
    lang_x=tokenized_data["input_ids"].to(device),
    attention_mask=tokenized_data["attention_mask"].to(device),
    max_new_tokens=10,
    )

    #free memory
    torch.cuda.empty_cache() 
    response = processor.tokenizer.decode(generated_text[0])
    response = clean_generation(response)

    print(f'{response=}')

In [8]:
if __name__ == "__main__":
    main()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Loading model..


Loading checkpoint shards: 100%|██████████| 2/2 [01:24<00:00, 42.31s/it]


Flamingo model initialized with 1309919248 trainable parameters
Downloaded Med-Flamingo checkpoint to /home/sd4175/.cache/huggingface/hub/models--med-flamingo--med-flamingo/snapshots/7243cd83bd426ceade9c4de9844cc5e5f3ff75e0/model.pt
Preprocess data
Generate from multimodal few-shot prompt




response="<s> You are a helpful medical assistant. You are being provided with images, a question about the image and an answer. Follow the examples and answer the last question. <image> Question: What is/are the structure near/in the middle of the brain? Answer: pons.<|endofchunk|><image> Question: Is there evidence of a right apical pneumothorax on this chest x-ray? Answer: yes.<|endofchunk|><image> Question: Is/Are there air in the patient's peritoneal cavity? Answer: no.<|endofchunk|><image> Question: Does the heart appear enlarged? Answer: yes.<|endofchunk|><image> Question: What side are the infarcts located? Answer: bilateral.<|endofchunk|><image> Question: Which image modality is this? Answer: mr flair.<|endofchunk|><image> Question: Where is the largest mass located in the cerebellum? Answer:agi Gewficficficficficibm Gewibm"


In [None]:
response