# Use HuggingFace model in the VILA-M3 workflow

## Download the models and VLM

In [None]:
# Download the pretrained model
from huggingface_hub import snapshot_download

local_dir = snapshot_download(
    repo_id="MONAI/VISTA3D-HF",
    local_dir=".",
)


m3_model_dir = snapshot_download(
    repo_id="MONAI/Llama3-VILA-M3-8B",
    local_dir="./vila_m3_8b",
)


## Download and cache images

In [1]:
import os
from agent_utils import ImageCache

LIVER_URL = "https://developer.download.nvidia.com/assets/Clara/monai/samples/ct_liver_0.nii.gz"

cache_dir = "../data"
cache_images = ImageCache(cache_dir)

os.makedirs(cache_dir, exist_ok=True)
cache_images.cache({"Sample 1": LIVER_URL})

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 75/75 [00:00<00:00, 87771.43it/s]


## Define the expert VISTA-3D model using the HuggingFace pipeline

In [3]:
import os
from pathlib import Path
from shutil import move
from uuid import uuid4
import tempfile
import re
import requests
from agent_utils import get_monai_transforms, get_slice_filenames, SEGMENTATION_TOKEN
from hugging_face_pipeline import HuggingFacePipelineHelper
import torch


class ExpertVista3DHF():
    """Expert model for VISTA-3D."""

    def __init__(self) -> None:
        """Initialize the VISTA-3D expert model."""
        self.model_name = "VISTA3D"
        self.pipeline = HuggingFacePipelineHelper("vista3d").init_pipeline(
            "vista3d_pretrained_model",
            device=torch.device("cuda:0"),
        )

    def _get_label_groups(self):
        """Get the label groups from the label groups path."""
        return {
            "everything": "../experts/vista3d/label_dict.json",
            "hepatic tumor": {
                "liver": 1,
                "hepatic tumor": 26
            },
            "hepatoma": {
                "liver": 1,
                "hepatic tumor": 26
            },
            "pancreatic tumor": {
                "pancreas": 4,
                "pancreatic tumor": 24
            },
            "lung tumor": {
                "lung": 20,
                "lung tumor": 23,
                "left lung upper lobe": 28,
                "left lung lower lobe": 29,
                "right lung upper lobe": 30,
                "right lung middle lobe": 31,
                "right lung lower lobe": 32
            },
            "bone lesion": {
                "bone lesion": 128
            },
            "organs": {
                "liver": 1,
                "kidney": 2,
                "spleen": 3,
                "pancreas": 4,
                "right kidney": 5,
                "right adrenal gland": 8,
                "left adrenal gland": 9,
                "gallbladder": 10,
                "left kidney": 14,
                "brain": 22,
                "lung tumor": 23,
                "pancreatic tumor": 24,
                "hepatic vessel": 25,
                "hepatic tumor": 26,
                "colon cancer primaries": 27,
                "left lung upper lobe": 28,
                "left lung lower lobe": 29,
                "right lung upper lobe": 30,
                "right lung middle lobe": 31,
                "right lung lower lobe": 32,
                "trachea": 57,
                "left kidney cyst": 116,
                "right kidney cyst": 117,
                "prostate": 118,
                "spinal cord": 121,
                "thyroid gland": 126,
                "airway": 132
            },
            "cardiovascular": {
                "aorta": 6,
                "inferior vena cava": 7,
                "portal vein and splenic vein": 17,
                "left iliac artery": 58,
                "right iliac artery": 59,
                "left iliac vena": 60,
                "right iliac vena": 61,
                "left atrial appendage": 108,
                "brachiocephalic trunk": 109,
                "left brachiocephalic vein": 110,
                "right brachiocephalic vein": 111,
                "left common carotid artery": 112,
                "right common carotid artery": 113,
                "heart": 115,
                "pulmonary vein": 119,
                "left subclavian artery": 123,
                "right subclavian artery": 124,
                "superior vena cava": 125
            },
            "gastrointestinal": {
                "esophagus": 11,
                "stomach": 12,
                "duodenum": 13,
                "bladder": 15,
                "small bowel": 19,
                "colon": 62
            },
            "skeleton": {
                "bone": 21,
                "vertebrae L5": 33,
                "vertebrae L4": 34,
                "vertebrae L3": 35,
                "vertebrae L2": 36,
                "vertebrae L1": 37,
                "vertebrae T12": 38,
                "vertebrae T11": 39,
                "vertebrae T10": 40,
                "vertebrae T9": 41,
                "vertebrae T8": 42,
                "vertebrae T7": 43,
                "vertebrae T6": 44,
                "vertebrae T5": 45,
                "vertebrae T4": 46,
                "vertebrae T3": 47,
                "vertebrae T2": 48,
                "vertebrae T1": 49,
                "vertebrae C7": 50,
                "vertebrae C6": 51,
                "vertebrae C5": 52,
                "vertebrae C4": 53,
                "vertebrae C3": 54,
                "vertebrae C2": 55,
                "vertebrae C1": 56,
                "skull": 120,
                "sternum": 122,
                "vertebrae S1": 127,
                "bone lesion": 128,
                "left rib 1": 63,
                "left rib 2": 64,
                "left rib 3": 65,
                "left rib 4": 66,
                "left rib 5": 67,
                "left rib 6": 68,
                "left rib 7": 69,
                "left rib 8": 70,
                "left rib 9": 71,
                "left rib 10": 72,
                "left rib 11": 73,
                "left rib 12": 74,
                "right rib 1": 75,
                "right rib 2": 76,
                "right rib 3": 77,
                "right rib 4": 78,
                "right rib 5": 79,
                "right rib 6": 80,
                "right rib 7": 81,
                "right rib 8": 82,
                "right rib 9": 83,
                "right rib 10": 84,
                "right rib 11": 85,
                "right rib 12": 86,
                "left humerus": 87,
                "right humerus": 88,
                "left scapula": 89,
                "right scapula": 90,
                "left clavicula": 91,
                "right clavicula": 92,
                "left femur": 93,
                "right femur": 94,
                "left hip": 95,
                "right hip": 96,
                "sacrum": 97,
                "costal cartilages": 114
            },
            "muscles": {
                "left gluteus maximus": 98,
                "right gluteus maximus": 99,
                "left gluteus medius": 100,
                "right gluteus medius": 101,
                "left gluteus minimus": 102,
                "right gluteus minimus": 103,
                "left autochthon": 104,
                "right autochthon": 105,
                "left iliopsoas": 106,
                "right iliopsoas": 107
            }
        }

    def label_id_to_name(self, label_id: int, label_dict: dict):
        """
        Get the label name from the label ID.

        Args:
            label_id: the label ID.
            label_dict: the label dictionary.
        """
        for group_dict in list(label_dict.values()):
            if isinstance(group_dict, dict):
                # this will skip str type value, such as "everything": <path>
                for label_name, label_id_ in group_dict.items():
                    if label_id == label_id_:
                        return label_name
        return None

    def segmentation_to_string(
        self,
        output_dir: Path,
        img_file: str,
        seg_file: str,
        label_groups: dict,
        modality: str = "CT",
        slice_index: int | None = None,
        axis: int = 2,
        image_filename: str = "image.jpg",
        label_filename: str = "label.jpg",
        output_prefix=None,
    ):
        """
        Format the segmentation response to a string.

        Args:
            response: the response.
            output_dir: the output directory.
            img_file: the image file path.
            modality: the modality.
            slice_index: the slice index.
            axis: the axis.
            image_filename: the image filename for the sliced image.
            label_filename: the label filename for the sliced image.
            group_label_names: the group label names to filter the label names.
            output_prefix: the output prefix.
            label_groups_path: the label groups path for VISTA-3D.
        """
        global SEGMENTATION_TOKEN
        output_dir = Path(output_dir)
        if output_prefix is None:
            output_prefix = f"The results are {SEGMENTATION_TOKEN}. The colors in this image describe "

        transforms = get_monai_transforms(
            ["image", "label"],
            output_dir,
            modality=modality,
            slice_index=slice_index,
            axis=axis,
            image_filename=image_filename,
            label_filename=label_filename,
        )
        data = transforms({"image": img_file, "label": seg_file})

        formatted_items = []

        for label_id in data["colormap"]:
            label_name = self.label_id_to_name(label_id, label_groups)
            if label_name is not None:
                color = data["colormap"][label_id]
                formatted_items.append(f"{color}: {label_name}")

        return output_prefix + ", ".join(formatted_items) + ". "

    def mentioned_by(self, input: str):
        """
        Check if the VISTA-3D model is mentioned in the input.

        Args:
            input (str): Text from the LLM, e.g. "Let me trigger <VISTA3D(arg)>."

        Returns:
            bool: True if the VISTA-3D model is mentioned, False otherwise.
        """
        matches = re.findall(r"<(.*?)>", str(input))
        if len(matches) != 1:
            return False
        return self.model_name in str(matches[0])

    def download_file(self, url: str, img_file: str):
        """
        Download the file from the URL.

        Args:
            url (str): The URL.
            img_file (str): The file path.
        """
        parent_dir = os.path.dirname(img_file)
        os.makedirs(parent_dir, exist_ok=True)
        with open(img_file, "wb") as f:
            response = requests.get(url)
            f.write(response.content)

    def run(
        self,
        img_file: str = "",
        image_url: str = "",
        input: str = "",
        output_dir: str = "",
        slice_index: int = 0,
        prompt: str = "",
        **kwargs,
    ):
        """
        Run the VISTA-3D model.

        Args:
            image_url (str): The image URL.
            input (str): The input text.
            output_dir (str): The output directory.
            img_file (str): The image file path. If not provided, download from the URL.
            slice_index (int): The slice index.
            prompt (str): The prompt text from the original request.
            **kwargs: Additional keyword arguments.
        """
        if not img_file:
            # Download from the URL
            img_file = os.path.join(output_dir, os.path.basename(image_url))
            self.download_file(image_url, img_file)

        output_dir = Path(output_dir)
        matches = re.findall(r"<(.*?)>", input)
        if len(matches) != 1:
            raise ValueError(f"Expert model {self.model_name} is not correctly enclosed in angle brackets.")

        match = matches[0]

        # Extract the arguments
        arg_matches = re.findall(r"\((.*?)\)", match[len(self.model_name) :])

        if len(arg_matches) == 0:  # <VISTA3D>
            arg_matches = ["everything"]
        if len(arg_matches) == 1 and (arg_matches[0] == "" or arg_matches[0] == None):  # <VISTA3D()>
            arg_matches = ["everything"]
        if len(arg_matches) > 1:
            raise ValueError(
                "Multiple expert model arguments are provided in the same prompt, "
                "which is not supported in this version."
            )

        vista3d_prompts = None
        label_groups = self._get_label_groups()

        if arg_matches[0] not in label_groups:
            raise ValueError(f"Label group {arg_matches[0]} is not accepted by the VISTA-3D model.")

        if arg_matches[0] != "everything":
            vista3d_prompts = [cls_idx for _, cls_idx in label_groups[arg_matches[0]].items()]

        # Trigger the VISTA-3D model
        input_dict = {"image": img_file}
        if vista3d_prompts is not None:
            input_dict["label_prompt"] = vista3d_prompts
        else:
            input_dict["label_prompt"] = [int(i) for i in range(1, 16)]

        with tempfile.TemporaryDirectory() as temp_dir:
            self.pipeline([input_dict], output_dir=temp_dir)
            seg_file = os.path.join(output_dir, "segmentation.nii.gz")
            temp_output_dir = os.path.join(temp_dir, os.listdir(temp_dir)[0])
            output_file = os.path.join(temp_output_dir, os.listdir(temp_output_dir)[0])
            if os.path.exists(seg_file):
                if os.path.isdir(seg_file):
                    from shutil import rmtree
                    rmtree(seg_file)
                else:
                    os.remove(seg_file)
            move(output_file, seg_file)
            print(f"File exists: {os.path.exists(seg_file)}")

        seg_image = f"seg_{uuid4()}.jpg"
        text_output = self.segmentation_to_string(
            output_dir,
            img_file,
            seg_file,
            label_groups,
            modality="CT",
            slice_index=slice_index,
            image_filename=get_slice_filenames(img_file, slice_index),
            label_filename=seg_image,
        )

        if "segmented" in input:
            instruction = ""  # no need to ask for instruction
        else:
            instruction = "Use this result to respond to this prompt:\n" + prompt
        return text_output, os.path.join(output_dir, seg_image), instruction



## (Optional) Test the expert VISTA-3D model

In [4]:
expert = ExpertVista3DHF()
with tempfile.TemporaryDirectory() as temp_dir:
    text_output, seg_image, instruction = expert.run(
        img_file=cache_images.get(LIVER_URL),
        input="Let me trigger <VISTA3D(everything)>.",
        prompt="Describe the image.",
        output_dir="../data",
        slice_index=0,
    )
    print("="*50 + "test run:" + "="*50)
    print(f"Image segmentation is saved to {seg_image}")
    print(f"Instruction passed to the follow-up prompt:\n{instruction}")
    print(f"Text output:\n{text_output}")
    print("="*100)

2025-03-05 09:27:12,946 INFO image_writer.py:197 - writing: /tmp/tmp7h88ipra/ct_liver_0/ct_liver_0_seg.nii.gz
File exists: True
Image segmentation is saved to ../data/seg_4a9d8168-d778-4515-a2b4-639ec1ab10dc.jpg
Instruction passed to the follow-up prompt:
Use this result to respond to this prompt:
Describe the image.
Text output:
The results are <segmentation>. The colors in this image describe red: liver, blue: spleen, yellow: pancreas, magenta: right kidney, green: aorta, indigo: inferior vena cava, darkorange: right adrenal gland, cyan: left adrenal gland, pink: gallbladder, brown: esophagus, orange: stomach, lime: duodenum, orange: bladder. 


## Use the expert VISTA-3D model in the VILA-M3 workflow

In [5]:
from agent_utils import SessionVariables, ChatHistory, M3Generator

model_path = "./vila_m3_8b"
sv = SessionVariables()
m3 = M3Generator(
        cache_images,
        source="local",
        model_path=model_path,
        conv_mode="llama_3",
        experts_classes=[ExpertVista3DHF],
    )

sv.image_url = LIVER_URL
sv.slice_index = 57

chat_history = ChatHistory()

sv, chat_history = m3.process_prompt(
    "Is there a hepatic tumor in the image",
    sv,
    chat_history
)

print("="*100)

for message in chat_history.messages:
    role = message["role"].upper()
    content = message["content"]
    print(f"{role}: {content}")

[2025-03-05 09:27:20,508] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2025-03-05 09:27:20 - root - INFO - x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -c /tmp/tmp_e1iz3l4/test.c -o /tmp/tmp_e1iz3l4/test.o
2025-03-05 09:27:20 - root - INFO - x86_64-linux-gnu-gcc /tmp/tmp_e1iz3l4/test.o -laio -o /tmp/tmp_e1iz3l4/a.out
/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
2025-03-05 09:27:21 - root - INFO - x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -c /tmp/tmpkns0prnm/test.c -o /tmp/tmpkns0prnm/test.o
2025-03-05 09:27:21 - root - INFO - x86_64-linux-gnu-gcc /tmp/tmpkns0prnm/test.o -L/usr/local/cuda -L/usr/local/cuda/lib64 -lcufile -o /tmp/tmpkns0prnm/a.out
2025-03-05 09:27:21 - root - INFO - x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wa

2025-03-05 09:27:35,569 INFO image_writer.py:197 - writing: /tmp/tmp7u6l0pf2/ct_liver_0/ct_liver_0_seg.nii.gz
File exists: True
USER: [{'type': 'text', 'text': "Here is a list of available expert models:\n<BRATS(args)> Modality: MRI, Task: segmentation, Overview: A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data, Accuracy: Tumor core (TC): 0.8559 - Whole tumor (WT): 0.9026 - Enhancing tumor (ET): 0.7905 - Average: 0.8518, Valid args are: None\n<VISTA3D(args)> Modality: CT, Task: segmentation, Overview: domain-specialized interactive foundation model developed for segmenting and annotating human anatomies with precision, Accuracy: 127 organs: 0.792 Dice on average, Valid args are: 'everything', 'hepatic tumor', 'pancreatic tumor', 'lung tumor', 'bone lesion', 'organs', 'cardiovascular', 'gastrointestinal', 'skeleton', or 'muscles'\n<VISTA2D(args)> Modality: cell imaging, Task: segmentation, Overview: model for ce

