In [1]:
import os
import cv2
import logging
import random
import pandas as pd
import numpy as np
from glob import glob
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from skimage import morphology
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torchvision.models.segmentation as models
from typing import Dict, List
import torch.multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from scipy.ndimage import binary_fill_holes
from skimage.morphology import remove_small_objects

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = False

random.seed(10)

  check_for_updates()


In [1]:
import pandas as pd 

In [2]:
import os 
os.chdir('../')
%pwd

'/home/akshar/Omdena/TB Detection/github/Omdena-PuneIndiaChapter-EarlyDetectionTuberculosis'

In [3]:
from src.constants import CLF_DATA_DIR, PROC_DATA_DIR

In [6]:
df = pd.read_csv('data/processed/processed_dataset.csv')[['segmented_lung_path', 'label']]
df.head()

Unnamed: 0,segmented_lung_path,label
0,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
1,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
2,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
3,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
4,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive


In [11]:
df['segmented_lung_path'] = df['segmented_lung_path'].str.replace('/home/akshar/Omdena/TB Detection/github/Omdena-PuneIndiaChapter-EarlyDetectionTuberculosis/data/processed/', '', regex=False)

df.head()

Unnamed: 0,segmented_lung_path,label
0,segmented_lung/tb_positive/a0e8ff5bff05d0c5644...,tb_positive
1,segmented_lung/tb_positive/90a02093a5d50b28fb1...,tb_positive
2,segmented_lung/tb_positive/83cf58db5b1ef9d68f0...,tb_positive
3,segmented_lung/tb_positive/9d2cd99f0580bfc04d5...,tb_positive
4,segmented_lung/tb_positive/71f64b8dbec4588b953...,tb_positive


In [12]:
df['segmented_lung_path'] = df['segmented_lung_path'].str.replace("segmented_lung/","", regex=False)
df.head()

Unnamed: 0,segmented_lung_path,label
0,tb_positive/a0e8ff5bff05d0c564450e92346d3fc9_A...,tb_positive
1,tb_positive/90a02093a5d50b28fb1f21fa371c849f_A...,tb_positive
2,tb_positive/83cf58db5b1ef9d68f0b7dd8db4c624f_A...,tb_positive
3,tb_positive/9d2cd99f0580bfc04d57998257a141e4_A...,tb_positive
4,tb_positive/71f64b8dbec4588b9533ed0c02f1f2d8_A...,tb_positive


In [None]:
/home/akshar/Omdena/TB Detection/github/Omdena-PuneIndiaChapter-EarlyDetectionTuberculosis/data/processed/segmented_lung

In [10]:
df.head()

Unnamed: 0,segmented_lung_path,label
0,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
1,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
2,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
3,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive
4,/home/akshar/Omdena/TB Detection/github/Omdena...,tb_positive


In [4]:
# setting up device 
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [5]:
def grayscale_truncation(image, **kwargs):
    """
    This function truncates a grayscale image by calculating the minimum and maximum values 
    from the central region of the image and then applying these values to clip the entire image.

    Parameters:
    - image (numpy.ndarray): A 2D numpy array representing the grayscale image. The shape of the array is (height, width).

    Returns:
    - numpy.ndarray: A 2D numpy array representing the truncated grayscale image. The shape of the array is the same as the input image.
    """
    height, width = image.shape 
    
    # Define central region 
    central_region = image[height//4:3*height//4, width//4:3*width//4]

    # Calculate min and max of central region 
    min_val, max_val = np.min(central_region) , np.max(central_region)

    # Truncate the image  
    truncated_image = np.clip(image, min_val, max_val)

    return truncated_image

In [6]:
def convert_to_rgb(image, **kwargs):
    """
    Converts a grayscale image to RGB format.
    """
    return np.stack([image] * 3, axis=-1)

In [7]:
def grayscale_inversion(image, **kwargs):
    """
    Inverts a grayscale image by flipping pixel intensities.
    
    Args:
        image (numpy.ndarray): Input grayscale image as a 2D numpy array (shape: HxW) 
                               or a 3-channel grayscale image (shape: HxW or HxWx1).
    
    Returns:
        numpy.ndarray: Inverted grayscale image.
    """
    # Ensure the image is single-channel grayscale (2D array)
    if len(image.shape) == 3 and image.shape[2] == 1:
        image = image.squeeze(-1)
    
    # Invert pixel intensities: black (0) becomes white (255), white becomes black
    inverted_image = 255 - image
    
    return inverted_image

In [8]:
def load_model(state_dict_path : Path,
               device : torch.device):
    """
    Prepare the DeepLabV3_MobileNetV3_Large model for lung segmentation.
    
    Args:
        state_dict_path (Path): The path to the state dictionary file.
        device (torch.device): The device to move the model to (e.g., 'cuda' or 'cpu').
        
    Returns:
        torch.nn.Module: The prepared model.
        
    Raises:
        FileNotFoundError: If the state dictionary file does not exist.
        RuntimeError: If there is an issue loading the state dictionary.
    """
    # Check if the state dictionary path exists
    if not state_dict_path.is_file():
        raise FileNotFoundError(f"State dictionary file not found: {state_dict_path}")
    # initialize the model withoutt pre-trained weights
    model = models.deeplabv3_mobilenet_v3_large(weights=None)
    # Freeze all layers in the backbone
    for param in model.backbone.parameters():
        param.requires_grad = False
    # Fine-tune the segmentation layer
    model.classifier[4] = nn.Conv2d(
        in_channels=256,
        out_channels=2,  # Adjust based on segmentation task
        kernel_size=(1, 1)
    )
    try:
        state_dict = torch.load(state_dict_path, map_location=device)
        # Remove the '_orig_mod.' prefix from the keys in the state_dict
        state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
        state_dict = {k:v for k, v in state_dict.items() if not k.startswith("aux_classifier.")}
        model.load_state_dict(state_dict)
    except RuntimeError as e:
        raise RuntimeError(f"Error loading state dictionary: {e}")   
    # Move the model to the specified device
    model = model.to(device)
    model.eval()
    # Compile the model 
    # model = torch.compile(model)

    return model


In [9]:
def create_segmented_lung(image, mask):
    """Creates a segmented lung image with black background."""
    segmented = image.copy()
    segmented[mask == 0] = 0
    return segmented

In [10]:
model = load_model(SEG_MODEL_PATH,device)

_______________________________________________

**TESTING**

#### **Pre-Processing for Segmentation**

In [23]:
def preprocess_image(image_path):
    image= cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
    image_resized = cv2.resize(image,(520, 520),interpolation=cv2.INTER_LINEAR)
    transform = A.Compose([
        A.Lambda(image=grayscale_truncation, p=1.0),
        A.Lambda(image=grayscale_inversion, p=1.0),
        A.Lambda(image=convert_to_rgb, p=1.0),
        A.CLAHE(clip_limit=5, tile_grid_size=(8, 8), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    transformed = transform(image=image_resized)   
    return transformed['image'].unsqueeze(0)

##### **Post-Processing for Segmented Masks** 

In [22]:
def apply_gaussian_blur(mask: np.ndarray, kernel_size: tuple) -> np.ndarray:
    """Apply Gaussian blur to the mask."""
    return cv2.GaussianBlur(mask, kernel_size, 0)

In [21]:
def apply_threshold(mask: np.ndarray, threshold_value: float) -> np.ndarray:
    """Convert the mask to a binary mask using a threshold."""
    max_val = np.max(mask)
    _, binary_mask = cv2.threshold(mask, threshold_value * max_val, max_val, cv2.THRESH_BINARY)
    return binary_mask

In [20]:


def post_process_mask(predicted_mask:np.ndarray,
                     config : Dict) -> np.ndarray:
    """
    Post-process the predicted lung mask to enhance quality, incorporating erosion and dilation.

    Parameters:
    - predicted_mask: numpy array, the predicted lung mask output from the model.
    - min_object_size: int, the minimum size of objects to keep in the mask (default is 1000).
    - opening_size: int, the size of the structuring element used for morphological opening (default is 5).
    - erosion_size: int, the size of the structuring element used for erosion (default is 3).
    - dilation_size: int, the size of the structuring element used for dilation (default is 3).

    Returns:
    - processed_mask: numpy array, the post-processed lung mask.
    """
    gaussian_kernel = config.get('gaussian_kernel', (5,5))
    threshold_value = config.get('threshold_value', 0.5)
    min_object_size = config.get('min_object_size', 1000)
    opening_size = config.get('opening_size', 5)
    erosion_size = config.get('erosion_size', 3)
    dilation_size = config.get('dilation_size', 3)
    # Check if the predicted mask is empty
    if predicted_mask.size == 0:
        logging.warning("Predicted mask is empty, returning a zero mask.")
        return np.zeros_like(predicted_mask)
    # Step 1: Apply Gaussian Blur
    blurred_mask = apply_gaussian_blur(predicted_mask, gaussian_kernel)

   # Step 2: Apply Thresholding
    binary_mask = apply_threshold(blurred_mask, threshold_value)

    # Step 3: Fill holes in the binary mask
    filled_mask = binary_fill_holes(binary_mask).astype(np.uint8)

    # Step 4: Remove small background objects
    cleaned_mask = remove_small_objects(filled_mask.astype(bool), min_size=min_object_size)

    # Step 5: Apply morphological opening to remove small objects and smooth boundaries
    opening_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (opening_size, opening_size))
    opened_mask = cv2.morphologyEx(cleaned_mask.astype(np.uint8), cv2.MORPH_OPEN, opening_kernel)

    # Step 6: Apply erosion to shrink the mask and potentially separate connected regions
    erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erosion_size, erosion_size))
    eroded_mask = cv2.erode(opened_mask, erosion_kernel, iterations=1)

    eroded_mask = (eroded_mask > 0).astype(np.uint8)  #  UPDATES Convert to binary (0 and 1)
    # Check if eroded mask is empty before proceeding
    if np.sum(eroded_mask) == 0:                                           # DEBUG
        print("Eroded mask is empty; skipping this mask.")
        return None
    
    # Step 7: Disconnect spuriously connected regions
    num_labels, labels = cv2.connectedComponents(eroded_mask)
    largest_components_mask = np.zeros_like(eroded_mask)
    for i in range(1, num_labels):  # Skip the background component
        component_mask = (labels == i).astype(np.uint8)
        if np.sum(component_mask) >= min_object_size:
            largest_components_mask += component_mask

    # Step 8: Apply dilation to expand the mask and recover lost lung area
    dilation_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size))
    dilated_mask = cv2.dilate(largest_components_mask.astype(np.uint8), dilation_kernel, iterations=1)

    # Step 9: Final cleaning - remove small objects once more
    final_mask = remove_small_objects(dilated_mask.astype(bool), min_size=min_object_size)
    
    # Step 10: Apply Grayscale Truncation
    truncated_final_mask = grayscale_truncation(final_mask.astype(np.float32))
    
    return truncated_final_mask.astype(np.uint8)

#### **Visualize Sample results**

In [14]:
p_config = dict(
    gaussian_kernel = (5,5),
    threshold_value = 0.5,
    min_object_size = 1500,
    opening_size = 5,
    erosion_size = 3,
    dilation_size = 3
)

### **Processed Image dataset creation**

In [24]:
def process_directory(
    model: torch.nn.Module,
    input_dir: Path,
    output_dir: Path,
    device: torch.device,
    config: dict = None
):
    """Process all images in the specified directory sequentially and save the results."""

    # Ensure input and output directories exist
    if not input_dir.exists():
        print(f"Input directory {input_dir} does not exist.")
        return

    output_dir.mkdir(parents=True, exist_ok=True)

    # Create subdirectories for segmented lungs, predicted masks, and post-processed masks
    for class_name in ['tb_positive', 'tb_negative']:
        for img_type in ['segmented_lung', 'predicted_mask', 'post_processed_mask']:
            (output_dir / img_type / class_name).mkdir(parents=True, exist_ok=True)

    results = []

    # Set model to evaluation mode
    model.eval()

    # Iterate over both classes 'tb_positive' and 'tb_negative'
    for class_name in ['tb_positive', 'tb_negative']:
        class_dir = input_dir / class_name

        # Get all image filenames
        image_filenames = [f for ext in ('*.jpg', '*.png', '*.jpeg') for f in class_dir.glob(ext)]

        # Process each image sequentially
        for filename in tqdm(image_filenames, desc=f'Processing {class_name} images'):
            image_path = class_dir / filename
            
            # Read original image from input directory
            original_image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
            if original_image is None:
                print(f"Failed to read image {image_path}. Skipping.")
                continue

            # Preprocess, predict mask, post-process, and create isolated lung dataset
            input_tensor = preprocess_image(str(image_path)).to(device)

            # Predict mask for the image
            with torch.inference_mode():
                output = model(input_tensor)['out']
                pred = torch.sigmoid(output[:, 1, :, :]) if output.shape[1] == 2 else torch.sigmoid(output.squeeze(1))

            mask = pred.cpu().numpy()
            mask = np.squeeze(mask)

            # Ensure valid mask shape
            if mask.ndim != 2 or mask.size == 0:
                print(f"Invalid mask shape at {image_path}. Skipping.")
                continue

            # Post-process the mask
            processed_mask = post_process_mask(mask, config=config)
            if processed_mask is None or processed_mask.size == 0:
                print(f"Processed mask is invalid for {image_path}. Skipping.")
                continue 

            # Resize mask to match original image size
            processed_mask_resized = cv2.resize(
                processed_mask,
                (original_image.shape[1], original_image.shape[0]),
                interpolation=cv2.INTER_NEAREST
            )

            thresholded_mask = (processed_mask_resized > 0.5).astype(np.float32)
            isolated_lungs = create_segmented_lung(original_image, thresholded_mask)

            # Construct base filename with class name
            base_filename = f"{filename.stem}_{class_name}{filename.suffix}"

            # Save predicted mask
            predicted_mask_path = output_dir / 'predicted_mask' / class_name / base_filename
            cv2.imwrite(str(predicted_mask_path), (mask * 255).astype(np.uint8))

            # Save post-processed mask
            post_processed_mask_path = output_dir / 'post_processed_mask' / class_name / base_filename
            cv2.imwrite(str(post_processed_mask_path), (processed_mask_resized * 255).astype(np.uint8))

            # Save segmented lung
            segmented_lung_path = output_dir / 'segmented_lung' / class_name / base_filename
            cv2.imwrite(str(segmented_lung_path), isolated_lungs)

            results.append({
                'input_image_path': str(image_path),
                'predicted_mask_path': str(predicted_mask_path),
                'post_processed_mask_path': str(post_processed_mask_path),
                'segmented_lung_path': str(segmented_lung_path),
                'label': class_name
            })

    # Save results to CSV file
    df = pd.DataFrame(results)
    df.to_csv(output_dir / 'processed_dataset.csv', index=False)

    print(f"Processed {len(results)} images.")

In [25]:
process_directory(model, 
                  input_dir=CLF_DATA_DIR, 
                  output_dir=PROC_DATA_DIR,
                  config=p_config,
                  device=device)

Processing tb_positive images:   0%|          | 0/6222 [00:00<?, ?it/s]

Processing tb_positive images: 100%|██████████| 6222/6222 [09:26<00:00, 10.99it/s]
Processing tb_negative images: 100%|██████████| 6881/6881 [10:32<00:00, 10.88it/s]


Processed 13103 images.
