<a href="https://www.kaggle.com/code/navneetguglani/tb-x-ray-visualization?scriptVersionId=241991121" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

**I have taken the following approach to solve this task**

**Loading the X-ray Image**

used OpenCV's imread function to load the X-ray image in grayscale. Grayscale is appropriate for X-ray images since they're already in black and white.

**Preprocessing the Image**

Histogram Equalization: Enhances the contrast of the image, making the TB area more visible. Gaussian Blur: Reduces noise in the image that might interfere with our analysis.

**Segmenting the TB Area**

used thresholding, which determines an optimal threshold value to separate the image into foreground and background. Connected component analysis finds regions that might be TB-affected areas. We filter the components by size to focus on significant areas.

**Creating a Depth Map**

created a depth map using the intensity values of the image. TB areas are typically denser and appear brighter in X-rays, which we use to create elevation. We use a distance transform to create a smooth transition from normal tissue to TB areas. The depth map is normalized to have values between 0 and 1.

**3D Visualization**

used matplotlib's 3D plotting capabilities to create a surface plot. The X and Y coordinates correspond to the image dimensions. The Z coordinate represents height that comes from our depth map. We downsample the image for better performance since X-rays can be large. The plot is styled with a grayscale colormap to represent the original X-ray appearance.

**Rotation for 3D Effect**

added a functionality to rotate the 3D visualization, allowing to see the TB area from different angles. This helps understand the spatial characteristics of the TB affected area.

In [1]:
import kagglehub
!pip install gradio
path = kagglehub.dataset_download("tawsifurrahman/tuberculosis-tb-chest-xray-dataset")
print("Path to dataset files:", path)

Collecting gradio
  Downloading gradio-5.31.0-py3-none-any.whl.metadata (16 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.1 (from gradio)
  Downloading gradio_client-1.10.1-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.met

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import cv2
import os
import glob
from PIL import Image
import tensorflow as tf
from tensorflow import keras
import gradio as gr
import warnings
warnings.filterwarnings('ignore')

DATASET_PATH = path
TB_DIR = os.path.join(DATASET_PATH,"TB_Chest_Radiography_Database","Tuberculosis")
NORMAL_DIR = os.path.join(DATASET_PATH,"TB_Chest_Radiography_Database","Normal")

2025-05-26 15:36:25.554973: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748273785.964551      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748273786.089032      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
def find_tb_image():
    print("Looking for a clear TB X-ray image...")
    tb_images = glob.glob(os.path.join(TB_DIR, "*.png"))
    if not tb_images:
        raise ValueError(f"No TB images found in {TB_DIR}")
    print(f"Found {len(tb_images)} TB images")
    return tb_images[0]
    
def load_and_preprocess_image(image_path, size=(512, 512)):
    print(f"Loading image: {image_path}")
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image from {image_path}")
    img = cv2.resize(img, size)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    img_enhanced = clahe.apply(img)
    return img_enhanced

In [4]:
def detect_tb_regions(img, threshold=220):
    _, binary = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY)
    kernel = np.ones((5, 5), np.uint8)
    binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    heatmap = np.zeros_like(img, dtype=np.float32)
    min_area = 100  
    max_area = img.shape[0] * img.shape[1] // 4
    for contour in contours:
        area = cv2.contourArea(contour)
        if min_area < area < max_area:
            mask = np.zeros_like(img, dtype=np.uint8)
            cv2.drawContours(mask, [contour], 0, 255, -1)
            mask = cv2.GaussianBlur(mask, (15, 15), 0)
            heatmap += mask.astype(np.float32)
    if np.max(heatmap) > 0:
        heatmap = 255 * (heatmap / np.max(heatmap))
    return heatmap.astype(np.uint8)

In [5]:
def create_depth_map(img, heatmap, emphasis_factor=2.5):
    img_norm = img.astype(float) / 255.0
    heatmap_norm = heatmap.astype(float) / 255.0
    heatmap_smooth = cv2.GaussianBlur(heatmap_norm, (15, 15), 0)
    depth_map = img_norm + (heatmap_smooth * emphasis_factor)
    depth_map = (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map))
    return depth_map

def create_3d_visualization(img, depth_map, title="3D Visualization of TB Area in Chest X-ray"):
    y, x = np.mgrid[0:img.shape[0], 0:img.shape[1]]
    scale_factor = 4
    x_down = x[::scale_factor, ::scale_factor]
    y_down = y[::scale_factor, ::scale_factor]
    z_down = depth_map[::scale_factor, ::scale_factor]
    gray_colors = img[::scale_factor, ::scale_factor] / 255.0
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(x_down,y_down,z_down,facecolors=cm.viridis(gray_colors),rstride=1,cstride=1,antialiased=True,shade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Elevation')
    ax.set_title(title)
    ax.text2D(0.05, 0.95,"Elevated areas (peaks) highlight potential TB regions",transform=ax.transAxes, fontsize=12)
    flat_idx = np.argsort(z_down.flatten())[-3:]
    high_pts = np.unravel_index(flat_idx, z_down.shape)
    for i in range(len(high_pts[0])):
        y_idx,x_idx = high_pts[0][i],high_pts[1][i]
        z_val = z_down[y_idx,x_idx]
        if z_val > 0.7:
            ax.text(x_down[y_idx,x_idx],y_down[y_idx, x_idx],z_val + 0.05,"TB AREA",color='red',fontweight='bold',fontsize=8)
    ax.view_init(elev=30, azim=-60)
    plt.tight_layout()
    return fig

In [6]:
def create_rotating_visualization(img, depth_map, output_gif="tb_rotating_3d.gif"):
    print("Creating rotating 3D visualization...")
    frames = []
    angles = list(range(0, 360, 15)) 
    for idx, angle in enumerate(angles):
        print(f"Rendering frame {idx+1}/{len(angles)}: {angle}° rotation")
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        y, x = np.mgrid[0:img.shape[0],0:img.shape[1]]
        scale_factor = 4
        x_down = x[::scale_factor,::scale_factor]
        y_down = y[::scale_factor,::scale_factor]
        z_down = depth_map[::scale_factor, ::scale_factor]
        colors = img[::scale_factor,::scale_factor] / 255.0
        
        surf = ax.plot_surface(x_down,y_down,z_down,facecolors=cm.viridis(colors),rstride=1,cstride=1,antialiased=True,shade=True)
        ax.set_title('TB Areas in 3D (Rotating View)')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Elevation')
        flat_idx = np.argsort(z_down.flatten())[-3:]
        high_pts = np.unravel_index(flat_idx, z_down.shape)
        for i in range(len(high_pts[0])):
            y_idx, x_idx = high_pts[0][i], high_pts[1][i]
            z_val = z_down[y_idx, x_idx]
            if z_val > 0.7:
                ax.text(x_down[y_idx, x_idx],y_down[y_idx, x_idx],z_val + 0.05,"TB",color='red',fontweight='bold',fontsize=8)
        ax.view_init(elev=30, azim=angle)
        frame_file = f'temp_frame_{idx:03d}.png'
        plt.savefig(frame_file)
        frames.append(frame_file)
        plt.close(fig)
    try:
        images = [Image.open(frame) for frame in frames]
        images[0].save(output_gif, save_all=True, append_images=images[1:], 
                      optimize=False, duration=100, loop=0)
        print(f"Rotating visualization saved as '{output_gif}'")
        for frame in frames:
            os.remove(frame)
    except Exception as e:
        print(f"Error creating GIF: {e}")
        print("Individual frames were saved as temp_frame_*.png")
    return output_gif

In [7]:
def create_gradio_interface():
    def process_image(input_img, elev, azim):
        if len(input_img.shape) == 3:
            img = cv2.cvtColor(input_img, cv2.COLOR_RGB2GRAY)
        else:
            img = input_img
        img = cv2.resize(img, (512, 512))
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img_enhanced = clahe.apply(img)
        heatmap = detect_tb_regions(img_enhanced)
        depth_map = create_depth_map(img_enhanced, heatmap)
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        y, x = np.mgrid[0:img.shape[0], 0:img.shape[1]]
        scale_factor = 4
        x_down = x[::scale_factor, ::scale_factor]
        y_down = y[::scale_factor, ::scale_factor]
        z_down = depth_map[::scale_factor, ::scale_factor]
        gray_colors = img_enhanced[::scale_factor, ::scale_factor] / 255.0
        surf = ax.plot_surface(
            x_down, y_down, z_down,
            facecolors=cm.viridis(gray_colors),
            rstride=1, cstride=1,
            antialiased=True, shade=True)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Elevation')
        ax.set_title(f'TB Chest X-Ray 3D Visualization (Elevation: {elev}°, Azimuth: {azim}°)')
        flat_idx = np.argsort(z_down.flatten())[-3:]
        high_pts = np.unravel_index(flat_idx, z_down.shape)
        for i in range(len(high_pts[0])):
            y_idx, x_idx = high_pts[0][i], high_pts[1][i]
            z_val = z_down[y_idx, x_idx]
            if z_val > 0.7:
                ax.text(x_down[y_idx, x_idx],y_down[y_idx, x_idx], z_val + 0.05,"TB AREA",color='red',fontweight='bold',fontsize=8)
        ax.view_init(elev=elev, azim=azim)
        plt.tight_layout()
        heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        if len(img_enhanced.shape) == 2:
            img_color = cv2.cvtColor(img_enhanced, cv2.COLOR_GRAY2BGR)
        else:
            img_color = img_enhanced
        heatmap_overlay = cv2.addWeighted(img_color, 0.7, heatmap_colored, 0.3, 0)
        if len(heatmap_overlay.shape) == 3:
            heatmap_overlay = cv2.cvtColor(heatmap_overlay, cv2.COLOR_BGR2RGB)
        return img_enhanced, heatmap_overlay, fig

    iface = gr.Interface(
        fn=process_image,
        inputs=[
            gr.Image(type="numpy", label="Upload Chest X-ray"),
            gr.Slider(minimum=0, maximum=90, value=30, step=5, label="Elevation Angle"),
            gr.Slider(minimum=-180, maximum=180, value=-60, step=15, label="Azimuth Angle"),
        ],
        outputs=[
            gr.Image(type="numpy", label="Original X-ray"),
            gr.Image(type="numpy", label="TB Region Heatmap"),
            gr.Plot(label="3D Visualization")
        ],
        title="TB Chest X-ray 3D Visualization",
        description="""
        Upload a chest X-ray to generate a 3D visualization highlighting potential TB regions.
        
        Instructions:
        1. Upload a chest X-ray image
        2. Use the sliders below to adjust the viewing angle:
           - top down Angle: Cfontrols the vertical viewing angle (0-90°)
           - right left Angle: Controls the horizontal rotation (-180° to 180°)
        3. Elevated areas (peaks) in the 3D visualization indicate potential TB regions
        """
    )
    
    return iface

In [8]:
def main():
    print("Starting 3D visualization process for TB chest X-ray...")
    print(f"Dataset path: {DATASET_PATH}")
    try:
        tb_image_path = find_tb_image()
        print(f"Selected TB image: {tb_image_path}")
        chest_xray = load_and_preprocess_image(tb_image_path)
        tb_heatmap = detect_tb_regions(chest_xray)
        depth_map = create_depth_map(chest_xray, tb_heatmap)
        cv2.imwrite("tb_original.png", chest_xray)
        print("Original image saved as 'tb_original.png'")
        heatmap_colored = cv2.applyColorMap(tb_heatmap,cv2.COLORMAP_JET)
        overlay = cv2.addWeighted(cv2.cvtColor(chest_xray,cv2.COLOR_GRAY2BGR),0.7,heatmap_colored,0.3,0)
        cv2.imwrite("tb_overlay.png",overlay)
        print("TB region overlay saved as 'tb_overlay.png'")
        fig = create_3d_visualization(chest_xray, depth_map)
        plt.savefig("tb_3d_visualization.png", dpi=300)
        plt.close(fig)
        print("3D visualization saved as 'tb_3d_visualization.png'")
        gif_path = create_rotating_visualization(chest_xray, depth_map)
        print("Process completed successfully!")
        print(f"Check the following files in your current directory:")
        print("- tb_original.png: Original X-ray image")
        print("- tb_overlay.png: 2D visualization with TB areas highlighted")
        print("- tb_3d_visualization.png: Static 3D visualization")
        print(f"- {gif_path}: Rotating 3D visualization")
        print("Launching Gradio interface for interactive visualization...")
        iface = create_gradio_interface()
        iface.launch(share=True)
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()

Starting 3D visualization process for TB chest X-ray...
Dataset path: /kaggle/input/tuberculosis-tb-chest-xray-dataset
Looking for a clear TB X-ray image...
Found 700 TB images
Selected TB image: /kaggle/input/tuberculosis-tb-chest-xray-dataset/TB_Chest_Radiography_Database/Tuberculosis/Tuberculosis-173.png
Loading image: /kaggle/input/tuberculosis-tb-chest-xray-dataset/TB_Chest_Radiography_Database/Tuberculosis/Tuberculosis-173.png
Original image saved as 'tb_original.png'
TB region overlay saved as 'tb_overlay.png'
3D visualization saved as 'tb_3d_visualization.png'
Creating rotating 3D visualization...
Rendering frame 1/24: 0° rotation
Rendering frame 2/24: 15° rotation
Rendering frame 3/24: 30° rotation
Rendering frame 4/24: 45° rotation
Rendering frame 5/24: 60° rotation
Rendering frame 6/24: 75° rotation
Rendering frame 7/24: 90° rotation
Rendering frame 8/24: 105° rotation
Rendering frame 9/24: 120° rotation
Rendering frame 10/24: 135° rotation
Rendering frame 11/24: 150° rotati

**used documentation of following for references**

https://docs.opencv.org/ for cv2.threshold, cv2.morphologyEx, cv2.findContours methods

https://matplotlib.org/ For creating 3D plots and surface visualizations

https://numpy.org/doc/stable/ for array manipulation

https://radiopaedia.org/articles/tuberculosis-pulmonary-manifestations for image enhancement techniques like CLAHE (Contrast Limited Adaptive Histogram Equalization)

https://www.gradio.app/docs/ for web interface

https://pillow.readthedocs.io/en/stable/ for creating animated GIFs


