In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import sys
sys.path.insert(0, '/Disk1/afrouz/Projects/BiomedParse')

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch
import torch.nn.functional as F
import hydra
from hydra import compose
from hydra.core.global_hydra import GlobalHydra
from utils import process_input, process_output, slice_nms
from inference import postprocess, merge_multiclass_masks

from skimage import segmentation
from skimage.measure import label

  import pynvml  # type: ignore[import]


In [2]:
from huggingface_hub import hf_hub_download
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

GlobalHydra.instance().clear()
hydra.initialize(config_path="../../../../BiomedParse/configs/model", job_name="example_prediction", version_base=None)
cfg = compose(config_name="biomedparse_3D")
model = hydra.utils.instantiate(cfg, _convert_="object")
model.load_pretrained(hf_hub_download(
  repo_id="microsoft/BiomedParse", filename="biomedparse_v2.ckpt"))
model = model.to(device).eval()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
GPU Name: NVIDIA RTX A6000




Checkpoint loaded successfully!


In [3]:
import nibabel as nib

sample_dir = "/Disk1/afrouz/Data/Merged/BraTS20_Training_001"
sample_name = "BraTS20_Training_001"

# Load FLAIR modality
flair_path = f"{sample_dir}/{sample_name}_flair.nii"
flair_img = nib.load(flair_path)
image = flair_img.get_fdata()

# Load text prompt
text_path = f"{sample_dir}/{sample_name}_flair_text.txt"
with open(text_path, 'r') as f:
    text_prompt = f.read().strip()

print(f"Image shape: {image.shape}")
print(f"Image dtype: {image.dtype}")
print(f"Image range: [{image.min():.2f}, {image.max():.2f}]")
print(f"\nText prompt:\n{text_prompt}")

Image shape: (240, 240, 155)
Image dtype: float64
Image range: [0.00, 625.00]

Text prompt:
The lesion area is in the right frontal and parietal lobes with a mixed pattern of high and low signals with speckled high signal regions. Edema is mainly observed in the right parietal lobe, partially extending to the frontal lobe, presenting as high signal, indicating significant tissue swelling around the lesion. Necrosis is within the lesions of the right parietal and frontal lobes, appearing as mixed, with alternating high and low signal regions. Ventricular compression is seen in the lateral ventricles with significant compressive effects on the brain tissue and ventricles.


In [4]:
# Preprocess image for BiomedParse
imgs, pad_width, padded_size, valid_axis = process_input(image, 512)
imgs = imgs.to(device).int()

# Prepare input tensor
input_tensor = {
    "image": imgs.unsqueeze(0),
    "text": [text_prompt],
}

print(f"Preprocessed image shape: {imgs.shape}")
print(f"Pad width: {pad_width}")
print(f"Padded size: {padded_size}")
print(f"Valid axis: {valid_axis}")

Preprocessed image shape: torch.Size([240, 512, 512])
Pad width: [[0, 0], [0, 0], [42, 43]]
Padded size: 240
Valid axis: 0


In [5]:
# Run inference
with torch.no_grad():
    output = model(input_tensor, mode="eval", slice_batch_size=4)

# Get mask predictions
mask_preds = output["predictions"]["pred_gmasks"]
print(f"Raw mask predictions shape: {mask_preds.shape}")

# Interpolate to 512x512
mask_preds = F.interpolate(
    mask_preds, 
    size=(512, 512), 
    mode="bicubic", 
    align_corners=False, 
    antialias=True
)

print(f"Interpolated mask shape: {mask_preds.shape}")

Raw mask predictions shape: torch.Size([1, 240, 128, 128])
Interpolated mask shape: torch.Size([1, 240, 512, 512])


In [None]:
# Postprocess masks
mask_preds = postprocess(mask_preds, output["predictions"]["object_existence"])
print(f"After postprocess: {mask_preds.shape}")

# Since BraTS has single class (tumor regions), we use id 1
ids = [1]
mask_preds = merge_multiclass_masks(mask_preds, ids)
print(f"After merge: {mask_preds.shape}")

# Process output to original dimensions
final_mask = process_output(mask_preds, pad_width, padded_size, valid_axis)
print(f"Final mask shape: {final_mask.shape}")
print(f"Original image shape: {image.shape}")

In [None]:
# Visualize results - show middle slices
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Show axial, coronal, sagittal views
slice_idx_axial = image.shape[2] // 2
slice_idx_coronal = image.shape[1] // 2
slice_idx_sagittal = image.shape[0] // 2

# Original images
axes[0, 0].imshow(image[:, :, slice_idx_axial], cmap='gray')
axes[0, 0].set_title(f'FLAIR - Axial (slice {slice_idx_axial})')
axes[0, 0].axis('off')

axes[0, 1].imshow(image[:, slice_idx_coronal, :], cmap='gray')
axes[0, 1].set_title(f'FLAIR - Coronal (slice {slice_idx_coronal})')
axes[0, 1].axis('off')

axes[0, 2].imshow(image[slice_idx_sagittal, :, :], cmap='gray')
axes[0, 2].set_title(f'FLAIR - Sagittal (slice {slice_idx_sagittal})')
axes[0, 2].axis('off')

# Predictions overlaid
axes[1, 0].imshow(image[:, :, slice_idx_axial], cmap='gray')
axes[1, 0].imshow(final_mask[:, :, slice_idx_axial], cmap='jet', alpha=0.5)
axes[1, 0].set_title('Prediction - Axial')
axes[1, 0].axis('off')

axes[1, 1].imshow(image[:, slice_idx_coronal, :], cmap='gray')
axes[1, 1].imshow(final_mask[:, slice_idx_coronal, :], cmap='jet', alpha=0.5)
axes[1, 1].set_title('Prediction - Coronal')
axes[1, 1].axis('off')

axes[1, 2].imshow(image[slice_idx_sagittal, :, :], cmap='gray')
axes[1, 2].imshow(final_mask[slice_idx_sagittal, :, :], cmap='jet', alpha=0.5)
axes[1, 2].set_title('Prediction - Sagittal')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

print(f"\nPrediction statistics:")
print(f"Unique values in mask: {np.unique(final_mask)}")
print(f"Mask value counts: {np.unique(final_mask, return_counts=True)}")