In [None]:
import os
import numpy as np 
import matplotlib.pyplot as plt 
from skimage import morphology as mp

MASK_DIR = '/data/server2/BodyCompMRL/numpy2D/masks/'
IMG_DIR = '/data/server2/BodyCompMRL/numpy2D/slices/'

def create_bowel_mask(mask_sm, mask_scfat, mask_body):
    """
    Create bowel mask by subtracting dilated muscle and fat masks from body mask.
    
    Args:
        mask_sm: Skeletal muscle mask
        mask_scfat: Subcutaneous fat mask  
        mask_body: Body mask
        
    Returns:
        Binary bowel mask
    """
    # Dilate skeletal muscle mask
    mask_sm_dilated = mp.binary_dilation(mask_sm, mp.disk(2))
    mask_sm_dilated = np.where(mask_sm_dilated > 0, 1, 0)
    
    # Dilate subcutaneous fat mask
    mask_scfat_dilated = mp.binary_dilation(mask_scfat, mp.disk(5))
    mask_scfat_dilated = np.where(mask_scfat_dilated > 0, 1, 0)
    
    # Create bowel mask by subtraction
    mask_bowel = mask_body - mask_sm_dilated - mask_scfat_dilated
    mask_bowel = np.where(mask_bowel > 0, 1, 0)
    
    return mask_bowel

def visualize_masks(img, mask, mask_names):
    """
    Display image and overlay masks for visualization.
    
    Args:
        img: Original image
        mask: 4D mask array (H, W, channels)
        mask_names: List of mask names for titles
    """
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    for ax in axes:
        ax.axis('off')
        ax.imshow(img, cmap='gray')
    
    # Display original image
    axes[0].set_title('Original Image')
    
    # Overlay masks
    for i, (ax, name) in enumerate(zip(axes[1:], mask_names)):
        ax.imshow(mask[:, :, i], cmap='Reds', alpha=0.7)
        ax.set_title(name)
    
    plt.tight_layout()
    plt.show()


In [None]:
# Process mask files
mask_names = ['Background', 'Skeletal Muscle', 'Subcutaneous Fat', 'Bowel']

for mask_filename in os.listdir(MASK_DIR):
    print(f'Processing: {mask_filename}')
    
    # Load data
    mask_path = os.path.join(MASK_DIR, mask_filename)
    img_path = os.path.join(IMG_DIR, mask_filename)
    
    img = np.load(img_path)
    mask = np.load(mask_path)
    
    # Skip if already processed (has 4 channels)
    if mask.shape[0] == 4:
        print(f'Skipping {mask_filename} - already has bowel mask')
        continue
    
    # Extract individual masks
    mask_scfat = mask[2, :, :]
    mask_sm = mask[1, :, :]
    
    # Create body mask from subcutaneous fat convex hull
    mask_body = mp.convex_hull_image(mask_scfat)
    
    # Update background mask
    mask[0, :, :] = np.ones_like(mask[0, :, :]) - mask_body
    
    # Erode body mask for bowel calculation
    mask_body_eroded = mp.binary_erosion(mask_body, mp.disk(10))
    
    # Generate bowel mask
    mask_bowel = create_bowel_mask(mask_sm, mask_scfat, mask_body_eroded)
    
    # Add bowel mask to existing masks
    mask_bowel = np.expand_dims(mask_bowel, axis=0)
    mask = np.concatenate((mask, mask_bowel), axis=0)
    
    # Reshape to (H, W, channels) format
    mask = np.moveaxis(mask, 0, -1)
    print(f'Final mask shape: {mask.shape}')
    
    # Visualize results
    visualize_masks(img, mask, mask_names)
    
    # Save processed mask (uncomment to enable)
    # output_path = mask_path.replace('masks', 'masks_with_bowel')
    # np.save(output_path, mask)
    # print(f'Saved to: {output_path}')

print('Processing complete.')