In [1]:
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset
from PIL import Image
import os
disable_torch_init()

  from .autonotebook import tqdm as notebook_tqdm


[2024-12-27 14:43:18,159] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [51]:

MODEL = "microsoft/llava-med-v1.5-mistral-7b"
model_name = get_model_name_from_path(MODEL)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=MODEL,
    model_base=None,
    model_name=model_name,
    device_map="auto",
    )

You are using a model of type llava_mistral to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.45it/s]
Some weights of LlavaLlamaForCausalLM were not initialized from the model checkpoint at microsoft/llava-med-v1.5-mistral-7b and are newly initialized: ['model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotar

In [52]:
context_len=4096

In [53]:
special_tokens_dict = {'additional_special_tokens': ['<image>']}
tokenizer.add_special_tokens(special_tokens_dict)

model.resize_token_embeddings(len(tokenizer))

Embedding(32001, 4096)

In [59]:
system_prompt = """
Consider that you are a professional radiologist with several years of experience and you are now treating a patient. Write a fully detailed diagnosis report for this case, avoiding any potential hallucination and paying close attention to all of the batch images attached to this message.

Use the following structure for the report:

## Radiologist's Report

### Patient Information:
- **Age:** 65
- **Sex:** Male
- **Days from earliest imaging to surgery:** 1
- **Histopathological Subtype:** Glioblastoma
- **WHO Grade:** 4
- **IDH Status:** Mutant
- **Preoperative KPS:** 80
- **Preoperative Contrast-Enhancing Tumor Volume (cm³):** 103.21
- **Preoperative T2/FLAIR Abnormality (cm³):** 36.29
- **Extent of Resection (EOR %):** 100.0
- **EOR Type:** Gross Total Resection (GTR)
- **Adjuvant Therapy:** Radiotherapy (RT) + Temozolomide (TMZ)
- **Progression-Free Survival (PFS) Days:** 649
- **Overall Survival (OS) Days:** 736

#### Tumor Characteristics:

#### Segmentation Analysis:

#### Surgical Considerations:

### Clinical Summary:

### Recommendations:

### Prognostic Considerations:

### Follow-Up Plan:

### Additional Notes*(if any)*:

Ensure all findings from all of the images and clinical data provided. Please mention at the end of the report how many images were reviewed.

"""




In [60]:
def generate_diagnosis_report(patient_folder,system_prompt):
    images = []
    # Load images and question
    for item in os.listdir(patient_folder):
        item_path = os.path.join(patient_folder, item)
        if item.lower().endswith('.png'):
            images.append(Image.open(item_path).convert("RGB"))
        elif item.lower() == 'diagnostic_prompt.txt':
            with open(item_path, 'r', encoding='utf-8') as f:
                clinical_data = f.read().strip()
         


    
    user_prompt = f"""
        USER: You will be given batches of images, which are flair sequences of MRI scans. 
        The images are for patients who has brain tumor. Each image will contain more than 50 slices of the flair modality and the segmentation masks slices for the tumor will be on the right.
        Additional clinical data about the patient is:

        {clinical_data}"""


    input_text = system_prompt + "\n\nUSER: " + user_prompt + "\nASSISTANT: "



    input_ids = tokenizer_image_token(
        input_text,
        tokenizer,
        image_token_index=tokenizer.convert_tokens_to_ids('<image>'),
        return_tensors="pt",
    ).squeeze(0)

    processed_images = image_processor(images, return_tensors="pt")['pixel_values']

    model_output = model.generate(
        input_ids=input_ids.unsqueeze(0).to(model.device),
        images=processed_images.to(model.device, dtype=model.dtype),
        max_length=context_len
    )
    
    response = tokenizer.decode(model_output[0], skip_special_tokens=True)
    print("Generated Report:\n", response)

    output_path = os.path.join(patient_folder, "llava_answer.txt")
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(response)

In [61]:
cases_dir = '/media/elboardy/RLAB-Disk01/(final)merged_images_with_labels_order_and_folders_mask_normalized/'

In [62]:
print("Model dtype:", model.dtype)


Model dtype: torch.float16


In [63]:
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    generate_diagnosis_report(case_dir,system_prompt=system_prompt)
    print("Done with case:", case)
    print("----------------------------------------------------")
    print("All cases done!")

Generated Report:
 
Consider that you are a professional radiologist with several years of experience and you are now treating a patient. Write a fully detailed diagnosis report for this case, avoiding any potential hallucination and paying close attention to all of the batch images attached to this message.

Use the following structure for the report:

## Radiologist's Report

### Patient Information:
- **Age:** 65
- **Sex:** Male
- **Days from earliest imaging to surgery:** 1
- **Histopathological Subtype:** Glioblastoma
- **WHO Grade:** 4
- **IDH Status:** Mutant
- **Preoperative KPS:** 80
- **Preoperative Contrast-Enhancing Tumor Volume (cm³):** 103.21
- **Preoperative T2/FLAIR Abnormality (cm³):** 36.29
- **Extent of Resection (EOR %):** 100.0
- **EOR Type:** Gross Total Resection (GTR)
- **Adjuvant Therapy:** Radiotherapy (RT) + Temozolomide (TMZ)
- **Progression-Free Survival (PFS) Days:** 649
- **Overall Survival (OS) Days:** 736

#### Tumor Characteristics:

#### Segmentation 