In [1]:
import os
# optimizing memory allocation to reduce fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
print("‚úÖ Memory fragmentation rules applied.")

‚úÖ Memory fragmentation rules applied.


In [1]:
import os
import torch

# 1. Help PyTorch manage fragmented memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

# 2. Clear any lingering cache
torch.cuda.empty_cache()

print(f"‚úÖ Memory settings applied. Free memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f} GB")

  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Memory settings applied. Free memory: 3.29 GB


In [3]:
import os
import sys

# Get the path to your current environment
conda_prefix = sys.prefix
lib_path = os.path.join(conda_prefix, 'lib')

# Force this path to the front of the line
os.environ['LD_LIBRARY_PATH'] = f"{lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}"

print(f"‚úÖ Forced Library Path: {lib_path}")

‚úÖ Forced Library Path: /home/fetalusr1/miniconda3/envs/fetal_project/lib


In [4]:
import torch

try:
    # Try a simple calculation on the GPU
    x = torch.tensor([1.0, 2.0]).cuda()
    y = torch.tensor([3.0, 4.0]).cuda()
    z = x * y
    print("--------------------------------------------------")
    print(f"üéâ SUCCESS: GPU Math works! Result: {z.cpu().numpy()}")
    print("--------------------------------------------------")
except RuntimeError as e:
    print("--------------------------------------------------")
    print(f"‚ùå FAILURE: {e}")
    print("--------------------------------------------------")

--------------------------------------------------
üéâ SUCCESS: GPU Math works! Result: [3. 8.]
--------------------------------------------------


In [2]:
from PIL import Image
import torch
import numpy as np
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
import matplotlib.pyplot as plt
from inference_utils.inference import interactive_infer_image
from inference_utils.output_processing import check_mask_stats
from inference_utils.processing_utils import process_intensity_image
from inference_utils.processing_utils import read_nifti
import nibabel as nib
import pandas as pd
import SimpleITK as sitk
from skimage.measure import regionprops, label
from skimage.transform import resize


out_probs = []
predicted_masks = []

  from .autonotebook import tqdm as notebook_tqdm
Authorization required, but no authorization protocol specified

Authorization required, but no authorization protocol specified

Authorization required, but no authorization protocol specified

Authorization required, but no authorization protocol specified



Deformable Transformer Encoder is not available.




## Loading the Finetuned BiomedParse model

In [3]:
# Build model config
opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
opt = init_distributed(opt)

# Load model from pretrained weights
finetuned_pth = '/home/fetalusr1/Fetal-Head-Segmentation-master/model_state_dict.pt' # Replace with the path to your finetuned checkpoint

model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained=finetuned_pth).eval().cuda()

with torch.no_grad():
    model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)

$UNUSED$ criterion.empty_weight, Ckpt Shape: torch.Size([17])


## Utilities

In [4]:
def get_segmentation_masks(original_image, segmentation_masks, texts, rotate=0):
    ''' Plot a list of segmentation mask over an image showing only the segmented region.
    '''
    original_image = original_image[:, :, :3]

    segmented_images = []

    for i, mask in enumerate(segmentation_masks):
        segmented_image = original_image.copy()
        segmented_image[mask <= 0.5] = [0, 0, 0]
        segmented_images.append(segmented_image)
        
    return segmented_images

In [5]:
def inference_nifti(file_path, text_prompts, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None, rotate=0):

    image = read_nifti(file_path, is_CT, slice_idx, site=site, HW_index=HW_index, channel_idx=channel_idx)
    
    pred_mask,out_prob = interactive_infer_image(model, Image.fromarray(image), text_prompts)
    predicted_masks.append(pred_mask)
    segmented_images = get_segmentation_masks(image, pred_mask, text_prompts, rotate=rotate)
    out_probs.append(out_prob)
    
    return image, pred_mask, segmented_images

### Post-processing Utility

In [6]:
def process_predicted_volume(volume_data, threshold_factor=0.35, output_prefix='processed'):
    """
    Process the predicted volume to filter based on ellipse measurements.
    """
    data = volume_data
    print(f"Processing volume with shape: {data.shape}")
    
    # Calculate measurements for all slices
    results = []
    z_0 = data.shape[2] // 2  # Reference slice (middle slice)
    
    print(f"Reference slice: {z_0}")
    
    for i in range(data.shape[2]):
        slice_data = data[:, :, i]
        
        # Skip empty slices
        if np.sum(slice_data) == 0:
            continue
            
        # Binarize the slice
        slice_bin = np.where(slice_data > 0, 1, 0).astype(np.uint8)
        
        # Fill holes
        slice_bin_filled = sitk.BinaryFillhole(sitk.GetImageFromArray(slice_bin))
        slice_bin_filled = sitk.GetArrayFromImage(slice_bin_filled)
        
        # Get region properties
        labeled_image = label(slice_bin_filled)
        props = regionprops(labeled_image)
        
        for prop in props:
            results.append({
                'slice_index': i,
                'major_axis_length': prop.major_axis_length,
                'minor_axis_length': prop.minor_axis_length,
                'centroid_x': prop.centroid[1],
                'centroid_y': prop.centroid[0],
                'orientation': prop.orientation,
                'area': prop.area
            })
    
    # Create DataFrame
    df_results = pd.DataFrame(results)
    print(f"Found {len(results)} regions across {len(df_results['slice_index'].unique())} slices")
    
    # Get reference slice measurements for filtering
    standard_slice_data = df_results[df_results['slice_index'] == z_0]
    
    if standard_slice_data.empty:
        print(f"Warning: No data found in reference slice {z_0}")
        # Use overall median as fallback
        major_axis_length_std = df_results['major_axis_length'].median()
        minor_axis_length_std = df_results['minor_axis_length'].median()
        centroid_x_std = df_results['centroid_x'].median()
        centroid_y_std = df_results['centroid_y'].median()
    else:
        major_axis_length_std = standard_slice_data['major_axis_length'].values[0]
        minor_axis_length_std = standard_slice_data['minor_axis_length'].values[0]
        centroid_x_std = standard_slice_data['centroid_x'].values[0]
        centroid_y_std = standard_slice_data['centroid_y'].values[0]
    
    # Define thresholds
    major_axis_length_threshold = major_axis_length_std * (1 - threshold_factor)
    minor_axis_length_threshold = minor_axis_length_std * (1 - threshold_factor)
    
    print(f"Reference measurements - Major: {major_axis_length_std:.2f}, Minor: {minor_axis_length_std:.2f}")
    print(f"Filtering thresholds - Major: {major_axis_length_threshold:.2f}, Minor: {minor_axis_length_threshold:.2f}")
    
    # Filter based on thresholds
    filtered_df = df_results[
        (df_results['major_axis_length'] >= major_axis_length_threshold) &
        (df_results['minor_axis_length'] >= minor_axis_length_threshold)
    ]
    
    print(f"After filtering: {len(filtered_df)} regions in {len(filtered_df['slice_index'].unique())} slices")
    
    # In filtered_df, in case of repeated slices, keep the one with maximum major axis length
    filtered_df = filtered_df.loc[filtered_df.groupby('slice_index')['major_axis_length'].idxmax()]
    
    # Create filtered volume
    filtered_slices = filtered_df['slice_index'].unique()
    filtered_volume = np.zeros_like(data)
    
    for slice_idx in range(data.shape[2]):
        if slice_idx in filtered_slices:
            filtered_volume[:, :, slice_idx] = data[:, :, slice_idx]
    
    return filtered_volume, filtered_df

### Interpolation Utility

In [7]:
def interpolate_blank_slices(image_path, processed_volume, blank_slices, predicted_masks, delta=1):
    """
    Interpolate blank slices in the processed volume using the previous slice.
    """
    vol_data = nib.load(image_path).get_fdata()
    central_slice = vol_data.shape[2] // 2
    
    for slice_idx in blank_slices:
        # Ensure we have a valid previous slice
        prev_slice_idx = slice_idx - delta
        if prev_slice_idx < 0 or prev_slice_idx >= len(predicted_masks):
            continue
            
        # Get the previous mask
        prev_mask = predicted_masks[prev_slice_idx][0]  # Get first mask from the list
        
        #update predicted_masks
        predicted_masks[slice_idx] = [prev_mask.copy()]  # Store the previous mask
        # Ensure the previous mask is not empty
        if np.sum(prev_mask) == 0:
            print(f"Warning: Previous mask for slice {prev_slice_idx} is empty. Skipping interpolation for slice {slice_idx}.")
            continue
        # Scale the mask based on position relative to center
        if slice_idx < central_slice: 
            # Increase the mask size by 0.5%
            new_mask = prev_mask * 1.005
        else:
            # Decrease the mask size by 0.5%
            new_mask = prev_mask * 0.995
        
        # Read the original image for this slice
        image = read_nifti(image_path, is_CT=False, slice_idx=slice_idx, site=None, HW_index=(0, 1), channel_idx=None)
        
        # Get the segmented image
        new_segmented_image = get_segmentation_masks(image, [new_mask], ['fetal head'], rotate=0)[0]
        
        # Convert RGB segmentation to grayscale if needed
        if len(new_segmented_image.shape) == 3:
            gray_mask = np.mean(new_segmented_image, axis=2)
        else:
            gray_mask = new_segmented_image
        
        # Resize to match volume dimensions and store
        from skimage.transform import resize
        processed_volume[:, :, slice_idx] = resize(gray_mask, (vol_data.shape[0], vol_data.shape[1]), preserve_range=True)
    
    return processed_volume

## Working

In [8]:
image_path = '/home/fetalusr1/Fetal-Head-Segmentation-master/28GW.nii'
text_prompt = ['fetal head']
vol = nib.load(image_path)
vol_data = vol.get_fdata()
vol_data.shape

(160, 160, 160)

In [13]:
from torch.cuda.amp import autocast
import torch
import numpy as np
import nibabel as nib
from skimage.transform import resize
from PIL import Image

# --- 1. DEFINE A ROBUST CUSTOM READER ---
def custom_read_nifti(file_path, slice_idx):
    """
    Safely reads a slice from a NIfTI file, handling 3D/4D shapes automatically.
    Replaces the buggy read_nifti from the library.
    """
    # Load volume
    vol = nib.load(file_path).get_fdata()
    
    # Handle dimensions (3D vs 4D)
    if vol.ndim == 4:
        # If 4D, take the first channel (standard for medical data)
        slice_data = vol[:, :, slice_idx, 0]
    else:
        # If 3D, just take the slice
        slice_data = vol[:, :, slice_idx]

    # Normalize Intensity (Robust Percentile Scaling)
    # This prevents black/white outliers from ruining the image
    p_low = np.percentile(slice_data, 0.5)
    p_high = np.percentile(slice_data, 99.5)
    slice_data = np.clip(slice_data, p_low, p_high)
    
    # Scale to 0-255 (uint8) for the model
    if slice_data.max() > slice_data.min():
        slice_data = (slice_data - slice_data.min()) / (slice_data.max() - slice_data.min())
    slice_data = (slice_data * 255).astype(np.uint8)
    
    # Convert grayscale to RGB (H, W, 3) because the model expects color input
    image_rgb = np.stack([slice_data, slice_data, slice_data], axis=-1)
    
    return image_rgb

# --- 2. RUN INFERENCE USING THE CUSTOM READER ---

# Initialize volume
vol_data = nib.load(image_path).get_fdata() # Just to get the shape
pred_volume = np.zeros((vol_data.shape[0], vol_data.shape[1], vol_data.shape[2]))

print(f"üöÄ Starting GPU Inference on {vol_data.shape[2]} slices...")
print(f"   (Using Custom Reader + Autocast)")

for slice_idx in range(vol_data.shape[2]):
    
    # Clean GPU memory
    torch.cuda.empty_cache()

    # A. Use our NEW reader (Bypassing the library error)
    image = custom_read_nifti(image_path, slice_idx)
    
    # B. Run Model (With mixed precision to fit in memory)
    with autocast():
        # Note: We call interactive_infer_image directly since we already read the image
        pred_mask, out_prob = interactive_infer_image(model, Image.fromarray(image), text_prompt)
    
    # C. Process Mask
    # The model output is often [1, H, W], we need [H, W]
    if len(pred_mask.shape) == 3:
        mask_2d = pred_mask[0]
    else:
        mask_2d = pred_mask

    # D. Save to volume
    pred_volume[:, :, slice_idx] = resize(mask_2d, (vol_data.shape[0], vol_data.shape[1]), preserve_range=True)

    # Progress Print
    if slice_idx % 10 == 0:
        print(f"   ‚úÖ Processed slice {slice_idx}/{vol_data.shape[2]}")

print("üéâ Inference Complete! Starting post-processing...")

# --- 3. RUN POST-PROCESSING ---
processed_volume, filtered_measurements = process_predicted_volume(
    pred_volume, 
    threshold_factor=0.4, 
    output_prefix='3_2'
)

print(f"Original volume had {np.sum(pred_volume > 0)} non-zero voxels")
print(f"Processed volume has {np.sum(processed_volume > 0)} non-zero voxels")

üöÄ Starting GPU Inference on 160 slices...
   (Using Custom Reader + Autocast)
   ‚úÖ Processed slice 0/160
   ‚úÖ Processed slice 10/160
   ‚úÖ Processed slice 20/160
   ‚úÖ Processed slice 30/160
   ‚úÖ Processed slice 40/160
   ‚úÖ Processed slice 50/160
   ‚úÖ Processed slice 60/160
   ‚úÖ Processed slice 70/160
   ‚úÖ Processed slice 80/160
   ‚úÖ Processed slice 90/160
   ‚úÖ Processed slice 100/160
   ‚úÖ Processed slice 110/160
   ‚úÖ Processed slice 120/160
   ‚úÖ Processed slice 130/160
   ‚úÖ Processed slice 140/160
   ‚úÖ Processed slice 150/160
üéâ Inference Complete! Starting post-processing...
Processing volume with shape: (160, 160, 160)
Reference slice: 80
Found 1496 regions across 160 slices
Reference measurements - Major: 0.00, Minor: 0.00
Filtering thresholds - Major: 0.00, Minor: 0.00
After filtering: 1496 regions in 160 slices
Original volume had 1798198 non-zero voxels
Processed volume has 1798198 non-zero voxels


In [14]:
#Get the first slice that survived filtering
first_filtered_slice = min(filtered_measurements['slice_index'].unique())
last_filtered_slice = max(filtered_measurements['slice_index'].unique())
print(f"First filtered slice: {first_filtered_slice}")
print(f"Last filtered slice: {last_filtered_slice}")
#from the filtered slice to the center slice, get all the slices which are blank
blank_slices = []
for slice_idx in range(first_filtered_slice, last_filtered_slice + 1):
    if np.sum(processed_volume[:, :, slice_idx]) == 0:
        blank_slices.append(slice_idx)
# Print the blank slices
print(f"Blank slices from {first_filtered_slice} to {vol_data.shape[2]-1}: {blank_slices}")

First filtered slice: 0
Last filtered slice: 159
Blank slices from 0 to 159: []


## Save Results

In [15]:
import os

# Create results directories if they don't exist
os.makedirs('./results', exist_ok=True)
os.makedirs('./FilteredRes', exist_ok=True)

# Load original NIfTI for header info
original_nii = nib.load(image_path)

# Save raw prediction
pred_nii = nib.Nifti1Image(pred_volume, original_nii.affine, original_nii.header)
raw_filename = f'./results/segmentation_result_3_2_raw.nii.gz'
nib.save(pred_nii, raw_filename)
print(f"Raw prediction saved to {raw_filename}")

# Save processed prediction
processed_nii = nib.Nifti1Image(processed_volume, original_nii.affine, original_nii.header)
processed_filename = f'./FilteredRes/segmentation_result_3_2_filtered.nii.gz'
nib.save(processed_nii, processed_filename)
print(f"Processed prediction saved to {processed_filename}")

interpolated_volume = interpolate_blank_slices(image_path, processed_volume, blank_slices, predicted_masks, delta=1)
# Save interpolated prediction
interpolated_nii = nib.Nifti1Image(interpolated_volume, original_nii.affine, original_nii.header)
interpolated_filename = f'./FilteredRes/segmentation_result_3_2_interpolated.nii.gz'
nib.save(interpolated_nii, interpolated_filename)
print(f"Interpolated prediction saved to {interpolated_filename}")


Raw prediction saved to ./results/segmentation_result_3_2_raw.nii.gz
Processed prediction saved to ./FilteredRes/segmentation_result_3_2_filtered.nii.gz
Interpolated prediction saved to ./FilteredRes/segmentation_result_3_2_interpolated.nii.gz
