In [10]:
import os
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import gc


In [11]:
model_id = "microsoft/Phi-3.5-vision-instruct"
cases_dir = '/media/RLAB-Disk01/(final)merged_images_with_labels_order_and_folders_mask_normalized'
cache_dir = '/media/RLAB-Disk01/Large-Language-Models-Weights'
offload_folder = '/media/RLAB-Disk01/Large-Language-Models-Weights'

In [12]:
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
#torch.cuda.set_per_process_memory_fraction(0.8)

In [13]:
# import os

# # Make sure cases_dir is properly defined
# # cases_dir = "your/cases/directory/path"

# deleted_count = 0
# for case in os.listdir(cases_dir):
#     case_dir = os.path.join(cases_dir, case)
#     if os.path.isdir(case_dir):
#         response_path = os.path.join(case_dir, 'phi-3.5-vision-instruct-response.txt')
#         if os.path.exists(response_path):
#             try:
#                 os.remove(response_path)
#                 deleted_count += 1
#                 print(f"Deleted: {response_path}")
#             except Exception as e:
#                 print(f"Error deleting {response_path}: {str(e)}")

# print(f"\nDeleted {deleted_count} response files.")
# print(f"Remaining cases without response file: {len(os.listdir(cases_dir)) - deleted_count}")

In [14]:
cases_to_process = []
for case in os.listdir(cases_dir):
    case_dir = os.path.join(cases_dir, case)
    if os.path.isdir(case_dir):
        response_path = os.path.join(case_dir, 'phi-3.5-vision-instruct-response.txt')
        if not os.path.exists(response_path):
            cases_to_process.append(case)
print(f"Found {len(cases_to_process)} cases to process.")

Found 40 cases to process.


In [15]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    _attn_implementation='flash_attention_2',
    cache_dir=cache_dir,
    offload_folder=offload_folder
)

processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True,
    num_crops=32
)


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


In [16]:
def load_image(image_path):
    try:
        img = Image.open(image_path)
        return img
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

In [17]:
failed_cases = []

for case in cases_to_process:
    case_dir = os.path.join(cases_dir, case)
    print(f"\nProcessing case: {case}")
    image_files = [
        f for f in os.listdir(case_dir)
        if f.lower().endswith('.png')
    ]

    
    try:
        # Load clinical information
        clinical_info_path = os.path.join(case_dir, 'diagnostic_prompt.txt')
        if not os.path.exists(clinical_info_path):
            raise FileNotFoundError(f"Missing clinical info: {clinical_info_path}")
            
        with open(clinical_info_path, 'r', encoding='utf-8') as f:
            clinical_info = f.read()

        # Prepare prompts
        system_prompt = """Consider that you are a professional radiologist with several years of experience and you are now treating a patient. Write a fully detailed diagnosis report for this case, avoiding any potential hallucination and paying close attention to all of the batch images attached to this message.

Use the following structure for the report:

## Radiologist's Report

### Patient Information:
- *Age:* 65
- *Sex:* Male
- *Days from earliest imaging to surgery:* 1
- *Histopathological Subtype:* Glioblastoma
- *WHO Grade:* 4
- *IDH Status:* Mutant
- *Preoperative KPS:* 80
- *Preoperative Contrast-Enhancing Tumor Volume (cm³):* 103.21
- *Preoperative T2/FLAIR Abnormality (cm³):* 36.29
- *Extent of Resection (EOR %):* 100.0
- *EOR Type:* Gross Total Resection (GTR)
- *Adjuvant Therapy:* Radiotherapy (RT) + Temozolomide (TMZ)
- *Progression-Free Survival (PFS) Days:* 649
- *Overall Survival (OS) Days:* 736

#### Tumor Characteristics:

#### Segmentation Analysis:

#### Surgical Considerations:

### Clinical Summary:

### Recommendations:

### Prognostic Considerations:

### Follow-Up Plan:

### Additional Notes*(if any)*:

Ensure all findings from all of the images and clinical data provided. Please mention at the end of the report how many images were reviewed."""

        user_prompt = f"""You will be given batches of images, which are different sequences of MRI scans. 
        The images are for patients who are likely to have a brain tumor. Each image will contain up to 10 slices for 5 different sequences and the segmentation masks for the tumor at the bottom row of the image. 
        Additional clinical data about the patient is: 
        {clinical_info}."""
        # Create image placeholders

        for image_file in image_files:
            image_path = os.path.join(case_dir, image_file)
            img = load_image(image_path)
   
        images=[]
        images.append(img)
        
        image_placeholders = "\n".join([f"<|image_{i+1}|>" for i in range(len(images))])

        


        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"{image_placeholders}\n{user_prompt}"},
        ]

        # messages = [
        #     {"role": "user", "content": f"{image_placeholders}\n{system_prompt}\n{user_prompt}"},
        # ]

        # Process inputs
        prompt = processor.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        inputs = processor(
            text=prompt,
            images=images,
            return_tensors="pt"
        ).to(model.device)

        # Generate response
        generation_args = {
            "max_new_tokens": 4096,
            "temperature": 0.7,
            "do_sample": False,
        }
        
        generate_ids = model.generate(
            **inputs,
            eos_token_id=processor.tokenizer.eos_token_id,
            **generation_args
        )

        # Decode response
        response = processor.batch_decode(
            generate_ids[:, inputs['input_ids'].shape[1]:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0]

        # Save results
        with open(os.path.join(case_dir, 'phi-3.5-vision-instruct-response.txt'), 
                 'w', encoding='utf-8') as f:
            f.write(response)

        # Cleanup
        del inputs, generate_ids, response
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Completed case {case}")

    except Exception as e:
        print(f"Error processing {case}: {str(e)}")
        failed_cases.append(case)
        continue

# Save failed cases log
if failed_cases:
    with open(os.path.join(cases_dir, 'phi3_failed_cases.txt'), 'w') as f:
        f.write("\n".join(failed_cases))
    print(f"Logged {len(failed_cases)} failed cases")

print("Processing complete.")


Processing case: RHUH-0019
Completed case RHUH-0019

Processing case: RHUH-0001
Completed case RHUH-0001

Processing case: RHUH-0002
Completed case RHUH-0002

Processing case: RHUH-0003
Completed case RHUH-0003

Processing case: RHUH-0004
Completed case RHUH-0004

Processing case: RHUH-0005
Completed case RHUH-0005

Processing case: RHUH-0006
Completed case RHUH-0006

Processing case: RHUH-0007
Completed case RHUH-0007

Processing case: RHUH-0008
Completed case RHUH-0008

Processing case: RHUH-0009
Completed case RHUH-0009

Processing case: RHUH-0010
Completed case RHUH-0010

Processing case: RHUH-0011
Completed case RHUH-0011

Processing case: RHUH-0012
Completed case RHUH-0012

Processing case: RHUH-0013
Completed case RHUH-0013

Processing case: RHUH-0014
Completed case RHUH-0014

Processing case: RHUH-0015
Completed case RHUH-0015

Processing case: RHUH-0016
Completed case RHUH-0016

Processing case: RHUH-0017
Completed case RHUH-0017

Processing case: RHUH-0018
Completed case RHU