In [None]:
import os
import nibabel as nib
import numpy as np

def file_to_ndarray(filepath):
    # Check the file extension
    _, file_extension = os.path.splitext(filepath)
    #print(file_extension)
    
    try:
        if file_extension in ['.nii', '.nii.gz', '.gz']:  # Handle gzipped or regular NIfTI files
            # Load the NIfTI file
            nii_img = nib.load(filepath)
            # Convert to ndarray
            data = nii_img.get_fdata()
            #print(f"Loaded NIfTI file: {filepath}")
        else:
            print("Unsupported file format.")
            return None
        
        return data
    
    except Exception as e:
        print(f"An error occurred while processing the file: {e}")
        return None

In [None]:
filepath = r'C:\Users\acer\Desktop\Project_TMJOA\Data\47-4881 L 2014.nii.gz'

voxel = file_to_ndarray(filepath)
print(voxel.shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def display_array_slice(array3d, axis=0, slice_num=0):
    """
    Display a 2D slice from a 3D numpy array.
    
    Parameters:
    -----------
    array3d : numpy.ndarray
        Input 3D array
    axis : int
        Axis along which to take the slice (0, 1, or 2)
    slice_num : int
        Index of the slice to display
        
    Returns:
    --------
    None (displays the plot)
    """
    # Input validation
    if not isinstance(array3d, np.ndarray) or array3d.ndim != 3:
        raise ValueError("Input must be a 3D numpy array")
    
    if axis not in [0, 1, 2]:
        raise ValueError("Axis must be 0, 1, or 2")
    
    # Get the maximum valid slice number for the chosen axis
    max_slice = array3d.shape[axis] - 1
    if slice_num > max_slice:
        raise ValueError(f"Slice number must be between 0 and {max_slice} for axis {axis}")
    
    # Extract the slice based on the axis
    if axis == 0:
        slice_2d = array3d[slice_num, :, :]
        title = f"Slice {slice_num} along axis 0 (front to back)"
    elif axis == 1:
        slice_2d = array3d[:, slice_num, :]
        title = f"Slice {slice_num} along axis 1 (top to bottom)"
    else:  # axis == 2
        slice_2d = array3d[:, :, slice_num]
        title = f"Slice {slice_num} along axis 2 (left to right)"
    
    # Create the plot
    plt.figure(figsize=(8, 6))
    plt.imshow(slice_2d, cmap='viridis')
    plt.colorbar(label='Value')
    plt.title(title)
    plt.xlabel('Column')
    plt.ylabel('Row')
    plt.show()

In [None]:
# Display different slices
display_array_slice(voxel, axis=0, slice_num=112)  # Show third slice along axis 0
display_array_slice(voxel, axis=1, slice_num=112)  # Show fourth slice along axis 1
display_array_slice(voxel, axis=2, slice_num=112)  # Show fifth slice along axis 2

In [None]:
def compute_histogram(ndarray):
    
    flat_array = ndarray.flatten()

    # Define the bin edges from -4000 to 4000 with a bin size of 10
    bins = np.arange(-4000, 4001, 10)  # 2001 to include the endpoint 2000 in the last bin

    # Compute histogram
    histogram_values, bin_edges = np.histogram(flat_array, bins=bins)

    # Convert histogram values to list
    #histogram_list = histogram_values.tolist()

    return histogram_values, bin_edges


In [None]:
histogram_list, bin_edges = compute_histogram(voxel)
print(histogram_list[1:])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.stats import norm

def plot_histogram_peaks_normal(arr, bin_edges, variance1=1.0, variance2=1.0, height=None, distance=1, prominence=None):
    """
    Plot histogram with local peaks and normal distributions centered at the two highest peaks
    
    Parameters:
    -----------
    arr : numpy.ndarray
        1D input array of histogram heights
    bin_edges : numpy.ndarray
        Array of bin edges (should be len(arr) + 1)
    variance1 : float
        Variance for the normal distribution at the highest peak
    variance2 : float
        Variance for the normal distribution at the second highest peak
    height : float or None
        Minimum height of peaks
    distance : int
        Minimum horizontal distance between peaks
    prominence : float or None
        Minimum prominence of peaks
    """
    if len(bin_edges) != len(arr) + 1:
        raise ValueError("bin_edges should have length equal to arr length + 1")
        
    # Find local peaks
    peaks, properties = find_peaks(arr, height=height, distance=distance, prominence=prominence)
    peak_heights = arr[peaks]
    
    # Sort peaks by height
    peak_order = np.argsort(peak_heights)[::-1]
    peaks_sorted = peaks[peak_order]
    heights_sorted = peak_heights[peak_order]
    
    # Create the plot
    plt.figure(figsize=(12, 6))
    
    # Plot histogram bars
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    plt.bar(bin_centers, arr, width=bin_edges[1]-bin_edges[0], 
            alpha=0.5, color='blue', label='Histogram')
    
    # Plot peaks
    peak_x_positions = bin_centers[peaks_sorted]
    plt.scatter(peak_x_positions, heights_sorted, 
                c='red', s=100, label='Local Peaks')
    
    # Add normal distributions for top two peaks
    colors = ['g', 'm']  # green for first peak, magenta for second
    styles = ['--', ':']  # dashed for first peak, dotted for second
    variances = [variance1, variance2]
    
    for i in range(min(2, len(peaks_sorted))):
        peak_center = bin_centers[peaks_sorted[i]]
        peak_height = heights_sorted[i]
        variance = variances[i]
        
        # Create x range centered around the peak
        x_normal = np.linspace(peak_center - 4*np.sqrt(variance), 
                             peak_center + 4*np.sqrt(variance), 
                             200)
        
        # Calculate normal distribution
        y_normal = norm.pdf(x_normal, peak_center, np.sqrt(variance))
        
        # Scale the normal distribution to match peak height
        y_normal = y_normal * (peak_height / np.max(y_normal))
        
        # Plot normal distribution
        plt.plot(x_normal, y_normal, 
                color=colors[i], 
                linestyle=styles[i],
                label=f'Normal at Peak {i+1} (σ²={variance:.1f})', 
                linewidth=2)
        
        # Add peak labels
        plt.annotate(f'Peak {i+1}', 
                    (peak_center, peak_height),
                    xytext=(5, 5),
                    textcoords='offset points')
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('Bin Center')
    plt.ylabel('Count')
    plt.title('Histogram with Local Peaks and Normal Distributions')
    plt.legend()
    plt.show()
    
    return peaks_sorted, heights_sorted, peak_x_positions


In [None]:
histogram_val, bin_edges = compute_histogram(voxel)

In [None]:
input_y = histogram_val[1:]
input_x = bin_edges[1:]

# Plot with different variances for each peak
peaks, heights, peak_centers = plot_histogram_peaks_normal(
    input_y[300:600],
    input_x[300:601],
    variance1=10000,    # Variance for highest peak
    variance2=22000,    # Variance for second peak
    height=0,      # No minimum height
    distance=20,       # Minimum 5 bins between peaks
    prominence=2000     # Minimum prominence
)

print("Peak bin centers:", peak_centers)
print("Peak heights:", heights)

In [None]:
import numpy as np
from scipy.stats import norm

def rescale_by_probability(image, target_mean, variance):
    """
    Rescale pixel values based on their probability of being the pixel of interest
    under a normal distribution.
    
    Parameters:
    -----------
    image : numpy.ndarray
        Input grayscale image array
    target_mean : float
        The target pixel value (mean of the normal distribution)
    variance : float
        Variance of the normal distribution
    
    Returns:
    --------
    numpy.ndarray
        Rescaled image where each pixel value represents the probability
        of that pixel being the pixel of interest
    """
    # Create a copy to avoid modifying the original
    rescaled = image.copy().astype(float)
    
    # Calculate probability for each pixel value
    probabilities = norm.pdf(rescaled, target_mean, np.sqrt(variance))
    
    # Normalize to [0, 1] range
    probabilities = (probabilities - probabilities.min()) / (probabilities.max() - probabilities.min())
    
    # Optional: Convert to uint8 for visualization (0-255)
    rescaled = (probabilities * 255).astype(np.uint8)
    
    return rescaled

def visualize_rescaling(original, rescaled, target_mean, variance):
    """
    Visualize original and rescaled images side by side
    """
    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original image
    im1 = ax1.imshow(original, cmap='gray')
    ax1.set_title('Original Image')
    plt.colorbar(im1, ax=ax1)
    
    # Rescaled image
    im2 = ax2.imshow(rescaled, cmap='gray')
    ax2.set_title(f'Probability Map\n(mean={target_mean}, variance={variance})')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()

In [None]:
slice_2d = voxel[112, :, :]

# Target pixel value of 100 with some variance
target_mean = 355
variance = 10000  # Wide variance to show the effect

# Rescale the image
rescaled_image_air = rescale_by_probability(slice_2d, target_mean, variance)

# Visualize results
visualize_rescaling(slice_2d, rescaled_image_air, target_mean, variance)

In [None]:
slice_2d = voxel[112, :, :]

# Target pixel value of 100 with some variance
target_mean = 935
variance = 22000  # Wide variance to show the effect

# Rescale the image
rescaled_image_bone = rescale_by_probability(slice_2d, target_mean, variance)

# Visualize results
visualize_rescaling(slice_2d, rescaled_image_bone, target_mean, variance)

In [None]:
import numpy as np

def rescale_range(array, a, b):
    """
    Rescale values in range [a,b] to [0,255] based on their position in the range.
    Values outside [a,b] become 0.
    
    Parameters:
    array: np.ndarray - Input array
    a: float - Lower bound of the range
    b: float - Upper bound of the range
    """
    # Create a copy to avoid modifying the original
    result = array.copy()
    
    # Set values outside [a,b] to 0
    result[result < a] = 0
    result[result > b] = 0
    
    # Find values within the range
    mask = (result >= a) & (result <= b)
    
    # Linear rescaling of values in range [a,b] to [0,255]
    result[mask] = ((result[mask] - a) / (b - a)) * 255
    
    return result

# Example usage:
# array = your_array
# rescaled = rescale_range(array, 100, 200)

In [None]:
slice_2d = voxel[112, :, :]
a = 355
b = 935
adjust_const = 0.25
adjust = int((b-a)*adjust_const)
print(b-a, adjust)

area_of_uncertainty = rescale_range(slice_2d, a+adjust, b-adjust)

In [None]:
visualize_rescaling(slice_2d, area_of_uncertainty, 0, 0)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming you have three 2D arrays of the same shape
# Convert them to uint8 if they aren't already
red = rescaled_image_air.astype(np.uint8)
green = area_of_uncertainty.astype(np.uint8)
blue = rescaled_image_bone.astype(np.uint8)

# Combine into RGB
rgb_image = np.stack([red, green, blue], axis=2)

# Display the image
plt.imshow(rgb_image)
plt.axis('off')
plt.show()

In [None]:
import os
import shutil
import random

def split_dataset(source_folder, destination_base, train_ratio=0.7, val_ratio=0.2):
   for split in ['train', 'val', 'test']:
       os.makedirs(os.path.join(destination_base, split), exist_ok=True)
   
   files = [f for f in os.listdir(source_folder) if f.endswith('.nii.gz')]
   random.shuffle(files)
   
   n_files = len(files)
   n_train = int(n_files * train_ratio)
   n_val = int(n_files * val_ratio)
   
   train_files = files[:n_train]
   val_files = files[n_train:n_train + n_val]
   test_files = files[n_train + n_val:]
   
   for files, split in [(train_files, 'train'), 
                       (val_files, 'val'), 
                       (test_files, 'test')]:
       for f in files:
           shutil.copy2(os.path.join(source_folder, f),
                       os.path.join(destination_base, split, f))

# Usage
source_folder = r'D:\Kananat\_Segmented_Preprocessed_expand5px'
destination_folder = r'D:\Kananat\_dataset'
split_dataset(source_folder, destination_folder)

Calculate histogram

In [None]:
import os
import numpy as np
import nibabel as nib

def process_nii_files(folder_path):
   # Initialize histogram bins
   bins = np.arange(-1500, 2001)  # -1500 to 2000 inclusive
   total_hist = np.zeros(len(bins)-1)
   
   # Process each file
   for filename in os.listdir(folder_path):
       if filename.endswith('.nii.gz'):
           # Load image
           print(filename)
           img = nib.load(os.path.join(folder_path, filename))
           data = img.get_fdata()
           
           # Calculate histogram for this image
           hist, _ = np.histogram(data, bins=bins)
           total_hist += hist
           
           # Clear memory
           del data
           img = None

   return total_hist, bins

# Usage
folder_path = r'D:\Kananat\_dataset\train'
histogram, bin_edges = process_nii_files(folder_path)

# Plot result
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.bar(bin_edges[:-1], histogram, width=1)
plt.xlabel('Voxel Value')
plt.ylabel('Frequency')
plt.title('Voxel Value Distribution')
plt.show()

In [None]:
from sklearn.mixture import GaussianMixture
import numpy as np

def fit_gmm(histogram_data, bin_edges):
   # Create dataset by repeating values according to histogram frequencies
   data = []
   for i in range(len(histogram_data)):
       count = int(histogram_data[i])
       if count > 0:
           # Use bin edges to create uniform samples within each bin
           samples = np.random.uniform(bin_edges[i], bin_edges[i+1], count)
           data.extend(samples)
   
   data = np.array(data).reshape(-1, 1)
   
   # Fit GMM
   gmm = GaussianMixture(n_components=2, random_state=0)
   gmm.fit(data)
   
   return gmm.means_.flatten(), gmm.covariances_.flatten()

In [None]:
mean, variance = fit_gmm(histogram, bin_edges)
print(mean,variance)

In [None]:
input_y = histogram
input_x = bin_edges

# Plot with different variances for each peak
peaks, heights, peak_centers = plot_histogram_peaks_normal(
    input_y,
    input_x,
    variance1=23000,    # Variance for highest peak
    variance2=34000,    # Variance for second peak
    height=0,      # No minimum height
    distance=500,       # Minimum 5 bins between peaks
    prominence=50000     # Minimum prominence
)

print("Peak bin centers:", peak_centers)
print("Peak heights:", heights)

In [None]:
filepath = r'D:\Kananat\_dataset\val\58-42016 L.nii.gz'

voxel = file_to_ndarray(filepath)

slice_2d = voxel[112, :, :]

# Target pixel value of 100 with some variance
target_mean = 313
variance = 23000  # Wide variance to show the effect

# Rescale the image
rescaled_image_air = rescale_by_probability(slice_2d, target_mean, variance)

# Target pixel value of 100 with some variance
target_mean = 875
variance = 34000  # Wide variance to show the effect

# Rescale the image
rescaled_image_bone = rescale_by_probability(slice_2d, target_mean, variance)

a = 313
b = 875
adjust_const = 0.25
adjust = int((b-a)*adjust_const)
print(b-a, adjust)

area_of_uncertainty = rescale_range(slice_2d, a+adjust, b-adjust)

import numpy as np
import matplotlib.pyplot as plt

# Assuming you have three 2D arrays of the same shape
# Convert them to uint8 if they aren't already
red = rescaled_image_air.astype(np.uint8)
green = area_of_uncertainty.astype(np.uint8)
blue = rescaled_image_bone.astype(np.uint8)

# Combine into RGB
rgb_image = np.stack([red, green, blue], axis=2)

# Display the image
plt.imshow(rgb_image)
plt.axis('off')
plt.show()

# Test

In [None]:
import os
import nibabel as nib
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture

In [None]:
def file_to_ndarray(filepath):
    # Check the file extension
    _, file_extension = os.path.splitext(filepath)
    #print(file_extension)
    
    try:
        if file_extension in ['.nii', '.nii.gz', '.gz']:  # Handle gzipped or regular NIfTI files
            # Load the NIfTI file
            nii_img = nib.load(filepath)
            # Convert to ndarray
            data = nii_img.get_fdata()
            #print(f"Loaded NIfTI file: {filepath}")
        else:
            print("Unsupported file format.")
            return None
        
        return data
    
    except Exception as e:
        print(f"An error occurred while processing the file: {e}")
        return None

def fit_gmm(histogram_data, bin_edges):
   # Create dataset by repeating values according to histogram frequencies
   data = []
   for i in range(len(histogram_data)):
       count = int(histogram_data[i])
       if count > 0:
           # Use bin edges to create uniform samples within each bin
           samples = np.random.uniform(bin_edges[i], bin_edges[i+1], count)
           data.extend(samples)
   
   data = np.array(data).reshape(-1, 1)
   
   # Fit GMM
   gmm = GaussianMixture(n_components=2, random_state=0)
   gmm.fit(data)
   
   return gmm.means_.flatten(), gmm.covariances_.flatten()

def rescale_by_probability(image, target_mean, variance):
    """
    Rescale pixel values based on their probability of being the pixel of interest
    under a normal distribution.
    
    Parameters:
    -----------
    image : numpy.ndarray
        Input grayscale image array
    target_mean : float
        The target pixel value (mean of the normal distribution)
    variance : float
        Variance of the normal distribution
    
    Returns:
    --------
    numpy.ndarray
        Rescaled image where each pixel value represents the probability
        of that pixel being the pixel of interest
    """
    # Create a copy to avoid modifying the original
    rescaled = image.copy().astype(float)
    
    # Calculate probability for each pixel value
    probabilities = norm.pdf(rescaled, target_mean, np.sqrt(variance))
    
    # Normalize to [0, 1] range
    probabilities = (probabilities - probabilities.min()) / (probabilities.max() - probabilities.min())
    
    # Optional: Convert to uint8 for visualization (0-255)
    rescaled = (probabilities * 255).astype(np.uint8)
    
    return rescaled

def rescale_range(array, a, b):
    """
    Rescale values in range [a,b] to [0,255] based on their position in the range.
    Values outside [a,b] become 0.
    
    Parameters:
    array: np.ndarray - Input array
    a: float - Lower bound of the range
    b: float - Upper bound of the range
    """
    # Create a copy to avoid modifying the original
    result = array.copy()
    
    # Set values outside [a,b] to 0
    result[result < a] = 0
    result[result > b] = 0
    
    # Find values within the range
    mask = (result >= a) & (result <= b)
    
    # Linear rescaling of values in range [a,b] to [0,255]
    result[mask] = ((result[mask] - a) / (b - a)) * 255
    
    return result

In [None]:
# Load data
filepath = r"C:\Users\acer\Desktop\Data\47-4881 L 2014.nii.gz"
voxel = file_to_ndarray(filepath)

# Compute histogram
flat_array = voxel.flatten()
bins = np.arange(-500, 1500, 1)  # -500 to 1500 size 1

histogram_values, bin_edges = np.histogram(flat_array, bins=bins)
mean, variance = fit_gmm(histogram_values, bin_edges)
print(f"mean : {mean}, variance : {variance}")

slice_2d = voxel[112, :, :]
rescaled_image_air = rescale_by_probability(slice_2d, int(mean[0]), int(variance[0]))
rescaled_image_bone = rescale_by_probability(slice_2d, int(mean[1]), int(variance[1]))

a = mean[0]
b = mean[1]
adjust_const = 0.25
adjust = int((b-a)*adjust_const)
area_of_uncertainty = rescale_range(slice_2d, a+adjust, b-adjust)

red = rescaled_image_air.astype(np.uint8)
green = area_of_uncertainty.astype(np.uint8)
blue = rescaled_image_bone.astype(np.uint8)

# Combine into RGB
rgb_image = np.stack([red, green, blue], axis=2)

# Display the image
plt.imshow(rgb_image)
plt.axis('off')
plt.show()

# Processing folder

In [None]:
import os
import nibabel as nib
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from glob import glob
from PIL import Image

In [None]:
def fit_gmm(histogram_data, bin_edges):
   # Create dataset by repeating values according to histogram frequencies
   data = []
   for i in range(len(histogram_data)):
       count = int(histogram_data[i])
       if count > 0:
            # Use bin edges to create uniform samples within each bin
            samples = np.random.uniform(bin_edges[i], bin_edges[i+1], count)
            data.extend(samples)
   
   data = np.array(data).reshape(-1, 1)
   
   # Fit GMM
   gmm = GaussianMixture(n_components=2, random_state=0)
   gmm.fit(data)

   output_mean = np.sort(gmm.means_.flatten())
   output_variances = gmm.covariances_.flatten()

   if output_mean[0] != gmm.means_.flatten()[0] :
        output_variances = np.flip(gmm.covariances_.flatten())

   return output_mean, output_variances

def rescale_uncertainty(array, a, b):
    # Create a copy to avoid modifying the original
    result = array.copy()
    
    # Set values outside [a,b] to 0
    result[result < a] = 0
    result[result > b] = 0
    
    # Find values within the range
    mask = (result >= a) & (result <= b)
    
    # Linear rescaling of values in range [a,b] to [0,255]
    result[mask] = ((result[mask] - a) / (b - a)) * 255
    
    return result

def rescale_by_probability(image, target_mean, variance):

    # Create a copy to avoid modifying the original
    rescaled = image.copy().astype(float)
    
    # Calculate probability for each pixel value
    probabilities = norm.pdf(rescaled, target_mean, np.sqrt(variance))
    
    # Normalize to [0, 1] range
    probabilities = (probabilities - probabilities.min()) / (probabilities.max() - probabilities.min())
    
    # Optional: Convert to uint8 for visualization (0-255)
    rescaled = (probabilities * 255).astype(np.uint8)
    
    return rescaled

import numpy as np

def pad_image(image, target_size=(224, 224, 3)):
    # Calculate padding amounts
    h_padding = (target_size[0] - image.shape[0]) // 2
    w_padding = (target_size[1] - image.shape[1]) // 2
    
    # Calculate extra padding if odd dimension
    h_extra = (target_size[0] - image.shape[0]) % 2
    w_extra = (target_size[1] - image.shape[1]) % 2
    
    # Pad the image
    padded_image = np.pad(
        image,
        ((h_padding, h_padding + h_extra),  # Height padding
         (w_padding, w_padding + w_extra),  # Width padding
         (0, 0)),                          # No padding for channels
        mode='constant',
        constant_values=0
    )
    
    return padded_image


def slice_extraction(input_folder, output_base_dir):
    nii_files = glob(os.path.join(input_folder, '*.nii.gz'))

    for nii_file in nii_files:
        print(f"Processing : {nii_file}")

        img = nib.load(nii_file)
        data = img.get_fdata()
        
        base_name = os.path.splitext(os.path.splitext(os.path.basename(nii_file))[0])[0]

        flat_array = data.flatten()
        bins = np.arange(-500, 1500, 1)  # -500 to 1500 size 1
        histogram_values, bin_edges = np.histogram(flat_array, bins=bins)

        mean, variance = fit_gmm(histogram_values, bin_edges)
        print(f"mean : {mean}, variance : {variance}")

        for i in range(224):

            if i%10 != 0 :
                continue

            slice_2d = data[i, 0:156 , 0:156]

            empty_count = np.sum(slice_2d < -2000)
            empty_ratio = empty_count / (slice_2d.shape[0] * slice_2d.shape[1])

            if empty_ratio > 0.9:
                continue

            print(f"Extracting slice number : {i}")

            rescaled_image_air = rescale_by_probability(slice_2d, int(mean[0]), int(variance[0]))
            rescaled_image_bone = rescale_by_probability(slice_2d, int(mean[1]), int(variance[1]))

            adjust_const = 0.25
            adjust = int((mean[1]-mean[0])*adjust_const)
            area_of_uncertainty = rescale_uncertainty(slice_2d, mean[0]+adjust, mean[1]-adjust)

            red = rescaled_image_air.astype(np.uint8)
            green = area_of_uncertainty.astype(np.uint8)
            blue = rescaled_image_bone.astype(np.uint8)

            rgb_image = np.stack([red, green, blue], axis=2)

            rgb_image = pad_image(rgb_image)

            out_file = os.path.join(output_base_dir, f"{base_name}_{i:03d}.jpg")

            img = Image.fromarray(rgb_image)
            img.save(out_file)

In [None]:
input_folder = r"D:\Kananat\_dataset\test\erosion_1"
output_base_dir = r"D:\Kananat\_dataset_2d\test\erosion_1"

slice_extraction(input_folder, output_base_dir)

In [None]:
import numpy as np
import nibabel as nib
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

def extract_cbct_layer_with_gmm(nii_file_path, layer_number, background_value=-250):
    """
    Extract a 2D layer from 3D CBCT data with GMM-based tissue classification.
    GMM is fitted on the entire 3D volume for robust tissue classification.
    
    Parameters:
    -----------
    nii_file_path : str
        Path to the .nii.gz file
    layer_number : int
        Layer index to extract (0-based indexing)
    background_value : float
        Background voxel value to exclude from calculations (default: -250)
    
    Returns:
    --------
    rgb_image : PIL.Image
        RGB image where:
        - Red channel: Soft tissue probability (0-255)
        - Green channel: Original voxel values (0-255)
        - Blue channel: Bone probability (0-255)
    """
    
    # Load the NIfTI file
    print(f"Loading NIfTI file: {nii_file_path}")
    nii_img = nib.load(nii_file_path)
    volume_data = nii_img.get_fdata()
    
    print(f"Volume shape: {volume_data.shape}")
    print(f"Volume value range: [{volume_data.min():.1f}, {volume_data.max():.1f}]")
    
    # Validate layer number
    if layer_number >= volume_data.shape[2]:
        raise ValueError(f"Layer {layer_number} exceeds volume depth {volume_data.shape[2]}")
    
    # Remove background values from ENTIRE 3D volume for GMM fitting
    print("Preparing 3D volume data for GMM fitting...")
    volume_non_background_mask = volume_data != background_value
    volume_non_background_values = volume_data[volume_non_background_mask]
    
    print(f"Total voxels in volume: {volume_data.size}")
    print(f"Non-background voxels in volume: {len(volume_non_background_values)}")
    print(f"Background voxels in volume: {volume_data.size - len(volume_non_background_values)}")
    print(f"Volume non-background range: [{volume_non_background_values.min():.1f}, {volume_non_background_values.max():.1f}]")
    
    if len(volume_non_background_values) < 10:
        raise ValueError("Insufficient non-background voxels in entire volume for GMM fitting")
    
    # For very large volumes, use a random sample for GMM fitting to speed up computation
    max_samples_for_gmm = 1000000  # 1M samples should be sufficient
    if len(volume_non_background_values) > max_samples_for_gmm:
        print(f"Volume has {len(volume_non_background_values)} non-background voxels.")
        print(f"Using random sample of {max_samples_for_gmm} voxels for GMM fitting...")
        np.random.seed(42)  # For reproducible results
        sample_indices = np.random.choice(len(volume_non_background_values), 
                                        size=max_samples_for_gmm, replace=False)
        gmm_fitting_data = volume_non_background_values[sample_indices]
    else:
        gmm_fitting_data = volume_non_background_values
    
    print(f"Using {len(gmm_fitting_data)} voxels for GMM fitting")
    
    # Extract the specified layer for final processing
    layer_2d = volume_data[layer_number, :, :]
    layer_non_background_mask = layer_2d != background_value
    print(f"Extracted layer {layer_number} with shape: {layer_2d.shape}")
    print(f"Non-background voxels in layer: {np.sum(layer_non_background_mask)}")
    
    # Fit Gaussian Mixture Model with 2 components on 3D volume data
    print("Fitting Gaussian Mixture Model on entire 3D volume...")
    gmm = GaussianMixture(n_components=2, random_state=42)
    gmm.fit(gmm_fitting_data.reshape(-1, 1))
    
    # Get GMM parameters
    means = gmm.means_.flatten()
    stds = np.sqrt(gmm.covariances_).flatten()
    weights = gmm.weights_
    
    print(f"GMM Component 1: mean={means[0]:.1f}, std={stds[0]:.1f}, weight={weights[0]:.3f}")
    print(f"GMM Component 2: mean={means[1]:.1f}, std={stds[1]:.1f}, weight={weights[1]:.3f}")
    
    # Determine which component is bone (higher mean) and which is soft tissue
    if means[0] > means[1]:
        bone_idx, soft_tissue_idx = 0, 1
    else:
        bone_idx, soft_tissue_idx = 1, 0
    
    print(f"Bone component (higher intensity): Component {bone_idx}")
    print(f"Soft tissue component (lower intensity): Component {soft_tissue_idx}")
    
    # Apply the trained GMM to the selected layer
    print(f"Applying GMM to layer {layer_number}...")
    # Calculate probabilities for all pixels in the layer
    # For background pixels, set probabilities to 0
    bone_prob = np.zeros_like(layer_2d)
    soft_tissue_prob = np.zeros_like(layer_2d)
    
    # Only calculate probabilities for non-background pixels in the layer
    if np.sum(layer_non_background_mask) > 0:
        # Get probability predictions for the entire layer
        all_probs = gmm.predict_proba(layer_2d.reshape(-1, 1))
        all_probs_reshaped = all_probs.reshape(layer_2d.shape[0], layer_2d.shape[1], 2)
        
        # Extract bone and soft tissue probabilities
        bone_prob = all_probs_reshaped[:, :, bone_idx]
        soft_tissue_prob = all_probs_reshaped[:, :, soft_tissue_idx]
        
        # Set background pixels to 0 probability
        bone_prob[~layer_non_background_mask] = 0
        soft_tissue_prob[~layer_non_background_mask] = 0
    
    # Rescale original voxel values to [0, 255] using the VOLUME range for consistency
    original_scaled = np.zeros_like(layer_2d)
    if len(volume_non_background_values) > 0:
        vol_min_val, vol_max_val = volume_non_background_values.min(), volume_non_background_values.max()
        original_scaled[layer_non_background_mask] = 255 * (layer_2d[layer_non_background_mask] - vol_min_val) / (vol_max_val - vol_min_val)
    
    # Rescale probabilities to [0, 255]
    bone_prob_scaled = (bone_prob * 255).astype(np.uint8)
    soft_tissue_prob_scaled = (soft_tissue_prob * 255).astype(np.uint8)
    original_scaled = original_scaled.astype(np.uint8)
    
    # Create RGB image
    # Red: Soft tissue probability
    # Green: Original voxel values  
    # Blue: Bone probability
    rgb_array = np.stack([soft_tissue_prob_scaled, original_scaled, bone_prob_scaled], axis=2)
    rgb_image = Image.fromarray(rgb_array, 'RGB')
    
    # Create debugging visualizations
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1: Original data and histograms
    # Original layer
    im1 = axes[0, 0].imshow(layer_2d, cmap='gray')
    axes[0, 0].set_title(f'Original Layer {layer_number}')
    axes[0, 0].axis('off')
    plt.colorbar(im1, ax=axes[0, 0])
    
    # Histogram of VOLUME non-background values (used for GMM)
    axes[0, 1].hist(gmm_fitting_data, bins=50, alpha=0.7, density=True, color='gray')
    axes[0, 1].set_title('Histogram of Volume Data (GMM Training)')
    axes[0, 1].set_xlabel('Voxel Value')
    axes[0, 1].set_ylabel('Density')
    
    # GMM overlay on histogram
    x_range = np.linspace(gmm_fitting_data.min(), gmm_fitting_data.max(), 1000)
    
    # Plot individual components
    for i in range(2):
        component_pdf = weights[i] * (1/np.sqrt(2*np.pi*gmm.covariances_[i,0,0])) * \
                       np.exp(-0.5 * ((x_range - means[i])**2) / gmm.covariances_[i,0,0])
        label = f'{"Bone" if i == bone_idx else "Soft Tissue"}'
        color = 'blue' if i == bone_idx else 'red'
        axes[0, 1].plot(x_range, component_pdf, color=color, label=label, linewidth=2)
    
    # Plot total GMM
    total_pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
    axes[0, 1].plot(x_range, total_pdf, 'k--', label='Total GMM', linewidth=2)
    axes[0, 1].legend()
    
    # GMM classification result for the layer
    classification = gmm.predict(layer_2d.reshape(-1, 1)).reshape(layer_2d.shape)
    classification_display = np.full_like(layer_2d, -1, dtype=int)  # -1 for background
    classification_display[layer_non_background_mask] = classification[layer_non_background_mask]
    
    im3 = axes[0, 2].imshow(classification_display, cmap='RdBu')
    axes[0, 2].set_title(f'GMM Classification of Layer {layer_number}\n(Red=Soft Tissue, Blue=Bone)')
    axes[0, 2].axis('off')
    
    # Row 2: Probability maps and final RGB
    # Bone probability
    im4 = axes[1, 0].imshow(bone_prob, cmap='Blues', vmin=0, vmax=1)
    axes[1, 0].set_title('Bone Probability')
    axes[1, 0].axis('off')
    plt.colorbar(im4, ax=axes[1, 0])
    
    # Soft tissue probability
    im5 = axes[1, 1].imshow(soft_tissue_prob, cmap='Reds', vmin=0, vmax=1)
    axes[1, 1].set_title('Soft Tissue Probability')
    axes[1, 1].axis('off')
    plt.colorbar(im5, ax=axes[1, 1])
    
    # Final RGB result
    axes[1, 2].imshow(rgb_array)
    axes[1, 2].set_title('Final RGB Image\n(R=Soft Tissue, G=Original, B=Bone)')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    print(f"Volume shape: {volume_data.shape}")
    print(f"Volume non-background voxels: {len(volume_non_background_values)} / {volume_data.size}")
    print(f"Layer shape: {layer_2d.shape}")
    print(f"Layer non-background pixels: {np.sum(layer_non_background_mask)} / {layer_2d.size}")
    print(f"Layer background pixels: {np.sum(~layer_non_background_mask)} / {layer_2d.size}")
    
    if np.sum(layer_non_background_mask) > 0:
        print(f"Mean bone probability (layer non-background): {bone_prob[layer_non_background_mask].mean():.3f}")
        print(f"Mean soft tissue probability (layer non-background): {soft_tissue_prob[layer_non_background_mask].mean():.3f}")
        print(f"Layer voxel range: [{layer_2d[layer_non_background_mask].min():.1f}, {layer_2d[layer_non_background_mask].max():.1f}]")
    
    return rgb_image

# Example usage:
# rgb_img = extract_cbct_layer_with_gmm('path/to/your/cbct_data.nii.gz', layer_number=50)
# rgb_img.save('cbct_layer_rgb.png')

In [None]:
# Example call (replace with your actual file path and layer number)
image_path = r"C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D\test\0\50-30909 R_adjustedBG.nii.gz"  # Replace with your NIfTI file path
rgb_img = extract_cbct_layer_with_gmm(image_path, layer_number=112)
rgb_img.save('cbct_layer_rgb.png')

In [None]:
import numpy as np
import nibabel as nib
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings('ignore')

def extract_cbct_layer_with_gmm(nii_file_path, output_image_path, layer_number, debug_folder_path, background_value=-250):
    """
    Extract a 2D layer from 3D CBCT data with GMM-based tissue classification.
    GMM is fitted on the entire 3D volume for robust tissue classification.
    
    Parameters:
    -----------
    nii_file_path : str
        Path to the .nii.gz file
    output_image_path : str
        Path to save the output RGB image (e.g., 'output.png')
    layer_number : int
        Layer index to extract (0-based indexing)
    debug_folder_path : str
        Path to folder where debug visualizations will be saved
    background_value : float
        Background voxel value to exclude from calculations (default: -250)
    
    Returns:
    --------
    rgb_image : PIL.Image
        RGB image where:
        - Red channel: Soft tissue probability (0-255)
        - Green channel: Original voxel values (0-255)
        - Blue channel: Bone probability (0-255)
    
    Files Created:
    --------------
    - output_image_path: RGB image file
    - debug_folder_path/gmm_analysis.png: Complete debugging visualization
    - debug_folder_path/layer_original.png: Original layer grayscale
    - debug_folder_path/bone_probability.png: Bone probability map
    - debug_folder_path/soft_tissue_probability.png: Soft tissue probability map
    """
    
    # Create debug folder if it doesn't exist
    os.makedirs(debug_folder_path, exist_ok=True)
    
    # Load the NIfTI file
    print(f"Loading NIfTI file: {nii_file_path}")
    nii_img = nib.load(nii_file_path)
    volume_data = nii_img.get_fdata()
    
    print(f"Volume shape: {volume_data.shape}")
    print(f"Volume value range: [{volume_data.min():.1f}, {volume_data.max():.1f}]")
    
    # Validate layer number
    if layer_number >= volume_data.shape[2]:
        raise ValueError(f"Layer {layer_number} exceeds volume depth {volume_data.shape[2]}")
    
    # Remove background values from ENTIRE 3D volume for GMM fitting
    print("Preparing 3D volume data for GMM fitting...")
    volume_non_background_mask = volume_data != background_value
    volume_non_background_values = volume_data[volume_non_background_mask]
    
    print(f"Total voxels in volume: {volume_data.size}")
    print(f"Non-background voxels in volume: {len(volume_non_background_values)}")
    print(f"Background voxels in volume: {volume_data.size - len(volume_non_background_values)}")
    print(f"Volume non-background range: [{volume_non_background_values.min():.1f}, {volume_non_background_values.max():.1f}]")
    
    if len(volume_non_background_values) < 10:
        raise ValueError("Insufficient non-background voxels in entire volume for GMM fitting")
    
    # For very large volumes, use a random sample for GMM fitting to speed up computation
    max_samples_for_gmm = 1000000  # 1M samples should be sufficient
    if len(volume_non_background_values) > max_samples_for_gmm:
        print(f"Volume has {len(volume_non_background_values)} non-background voxels.")
        print(f"Using random sample of {max_samples_for_gmm} voxels for GMM fitting...")
        np.random.seed(42)  # For reproducible results
        sample_indices = np.random.choice(len(volume_non_background_values), 
                                        size=max_samples_for_gmm, replace=False)
        gmm_fitting_data = volume_non_background_values[sample_indices]
    else:
        gmm_fitting_data = volume_non_background_values
    
    print(f"Using {len(gmm_fitting_data)} voxels for GMM fitting")
    
    # Extract the specified layer for final processing
    layer_2d = volume_data[:, :, layer_number]
    layer_non_background_mask = layer_2d != background_value
    print(f"Extracted layer {layer_number} with shape: {layer_2d.shape}")
    print(f"Non-background voxels in layer: {np.sum(layer_non_background_mask)}")
    
    # Fit Gaussian Mixture Model with 2 components on 3D volume data
    print("Fitting Gaussian Mixture Model on entire 3D volume...")
    gmm = GaussianMixture(n_components=2, random_state=42)
    gmm.fit(gmm_fitting_data.reshape(-1, 1))
    
    # Get GMM parameters
    means = gmm.means_.flatten()
    stds = np.sqrt(gmm.covariances_).flatten()
    weights = gmm.weights_
    
    print(f"GMM Component 1: mean={means[0]:.1f}, std={stds[0]:.1f}, weight={weights[0]:.3f}")
    print(f"GMM Component 2: mean={means[1]:.1f}, std={stds[1]:.1f}, weight={weights[1]:.3f}")
    
    # Determine which component is bone (higher mean) and which is soft tissue
    if means[0] > means[1]:
        bone_idx, soft_tissue_idx = 0, 1
    else:
        bone_idx, soft_tissue_idx = 1, 0
    
    print(f"Bone component (higher intensity): Component {bone_idx}")
    print(f"Soft tissue component (lower intensity): Component {soft_tissue_idx}")
    
    # Apply the trained GMM to the selected layer
    print(f"Applying GMM to layer {layer_number}...")
    # Calculate probabilities for all pixels in the layer
    # For background pixels, set probabilities to 0
    bone_prob = np.zeros_like(layer_2d)
    soft_tissue_prob = np.zeros_like(layer_2d)
    
    # Only calculate probabilities for non-background pixels in the layer
    if np.sum(layer_non_background_mask) > 0:
        # Get probability predictions for the entire layer
        all_probs = gmm.predict_proba(layer_2d.reshape(-1, 1))
        all_probs_reshaped = all_probs.reshape(layer_2d.shape[0], layer_2d.shape[1], 2)
        
        # Extract bone and soft tissue probabilities
        bone_prob = all_probs_reshaped[:, :, bone_idx]
        soft_tissue_prob = all_probs_reshaped[:, :, soft_tissue_idx]
        
        # Set background pixels to 0 probability
        bone_prob[~layer_non_background_mask] = 0
        soft_tissue_prob[~layer_non_background_mask] = 0
    
    # Rescale original voxel values to [0, 255] using the VOLUME range for consistency
    original_scaled = np.zeros_like(layer_2d)
    if len(volume_non_background_values) > 0:
        vol_min_val, vol_max_val = volume_non_background_values.min(), volume_non_background_values.max()
        original_scaled[layer_non_background_mask] = 255 * (layer_2d[layer_non_background_mask] - vol_min_val) / (vol_max_val - vol_min_val)
    
    # Rescale probabilities to [0, 255]
    bone_prob_scaled = (bone_prob * 255).astype(np.uint8)
    soft_tissue_prob_scaled = (soft_tissue_prob * 255).astype(np.uint8)
    original_scaled = original_scaled.astype(np.uint8)
    
    # Create RGB image
    # Red: Soft tissue probability
    # Green: Original voxel values  
    # Blue: Bone probability
    rgb_array = np.stack([soft_tissue_prob_scaled, original_scaled, bone_prob_scaled], axis=2)
    rgb_image = Image.fromarray(rgb_array, 'RGB')
    
    # Save the main RGB output image
    print(f"Saving RGB image to: {output_image_path}")
    rgb_image.save(output_image_path)
    
    # Save individual probability maps and original layer
    print(f"Saving debug images to: {debug_folder_path}")
    
    # Save original layer as grayscale
    original_layer_img = Image.fromarray((layer_2d - layer_2d.min()) / (layer_2d.max() - layer_2d.min()) * 255).convert('L')
    original_layer_img.save(os.path.join(debug_folder_path, 'layer_original.png'))
    
    # Save bone probability map
    bone_prob_img = Image.fromarray((bone_prob * 255).astype(np.uint8), 'L')
    bone_prob_img.save(os.path.join(debug_folder_path, 'bone_probability.png'))
    
    # Save soft tissue probability map
    soft_tissue_prob_img = Image.fromarray((soft_tissue_prob * 255).astype(np.uint8), 'L')
    soft_tissue_prob_img.save(os.path.join(debug_folder_path, 'soft_tissue_probability.png'))
    
    # Create and save comprehensive debugging visualization
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1: Original data and histograms
    # Original layer
    im1 = axes[0, 0].imshow(layer_2d, cmap='gray')
    axes[0, 0].set_title(f'Original Layer {layer_number}')
    axes[0, 0].axis('off')
    plt.colorbar(im1, ax=axes[0, 0])
    
    # Histogram of VOLUME non-background values (used for GMM)
    axes[0, 1].hist(gmm_fitting_data, bins=50, alpha=0.7, density=True, color='gray')
    axes[0, 1].set_title('Histogram of Volume Data (GMM Training)')
    axes[0, 1].set_xlabel('Voxel Value')
    axes[0, 1].set_ylabel('Density')
    
    # GMM overlay on histogram
    x_range = np.linspace(gmm_fitting_data.min(), gmm_fitting_data.max(), 1000)
    
    # Plot individual components
    for i in range(2):
        component_pdf = weights[i] * (1/np.sqrt(2*np.pi*gmm.covariances_[i,0,0])) * \
                       np.exp(-0.5 * ((x_range - means[i])**2) / gmm.covariances_[i,0,0])
        label = f'{"Bone" if i == bone_idx else "Soft Tissue"}'
        color = 'blue' if i == bone_idx else 'red'
        axes[0, 1].plot(x_range, component_pdf, color=color, label=label, linewidth=2)
    
    # Plot total GMM
    total_pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
    axes[0, 1].plot(x_range, total_pdf, 'k--', label='Total GMM', linewidth=2)
    axes[0, 1].legend()
    
    # GMM classification result for the layer
    classification = gmm.predict(layer_2d.reshape(-1, 1)).reshape(layer_2d.shape)
    classification_display = np.full_like(layer_2d, -1, dtype=int)  # -1 for background
    classification_display[layer_non_background_mask] = classification[layer_non_background_mask]
    
    im3 = axes[0, 2].imshow(classification_display, cmap='RdBu')
    axes[0, 2].set_title(f'GMM Classification of Layer {layer_number}\n(Red=Soft Tissue, Blue=Bone)')
    axes[0, 2].axis('off')
    
    # Row 2: Probability maps and final RGB
    # Bone probability
    im4 = axes[1, 0].imshow(bone_prob, cmap='Blues', vmin=0, vmax=1)
    axes[1, 0].set_title('Bone Probability')
    axes[1, 0].axis('off')
    plt.colorbar(im4, ax=axes[1, 0])
    
    # Soft tissue probability
    im5 = axes[1, 1].imshow(soft_tissue_prob, cmap='Reds', vmin=0, vmax=1)
    axes[1, 1].set_title('Soft Tissue Probability')
    axes[1, 1].axis('off')
    plt.colorbar(im5, ax=axes[1, 1])
    
    # Final RGB result
    axes[1, 2].imshow(rgb_array)
    axes[1, 2].set_title('Final RGB Image\n(R=Soft Tissue, G=Original, B=Bone)')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    
    # Save the comprehensive debug visualization
    debug_plot_path = os.path.join(debug_folder_path, 'gmm_analysis.png')
    plt.savefig(debug_plot_path, dpi=300, bbox_inches='tight')
    plt.close()  # Close the figure to free memory
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    print(f"Volume shape: {volume_data.shape}")
    print(f"Volume non-background voxels: {len(volume_non_background_values)} / {volume_data.size}")
    print(f"Layer shape: {layer_2d.shape}")
    print(f"Layer non-background pixels: {np.sum(layer_non_background_mask)} / {layer_2d.size}")
    print(f"Layer background pixels: {np.sum(~layer_non_background_mask)} / {layer_2d.size}")
    
    if np.sum(layer_non_background_mask) > 0:
        print(f"Mean bone probability (layer non-background): {bone_prob[layer_non_background_mask].mean():.3f}")
        print(f"Mean soft tissue probability (layer non-background): {soft_tissue_prob[layer_non_background_mask].mean():.3f}")
        print(f"Layer voxel range: [{layer_2d[layer_non_background_mask].min():.1f}, {layer_2d[layer_non_background_mask].max():.1f}]")
    
    print(f"\n=== Files Created ===")
    print(f"Main RGB output: {output_image_path}")
    print(f"Debug folder: {debug_folder_path}")
    print(f"  - gmm_analysis.png: Complete analysis visualization")
    print(f"  - layer_original.png: Original layer grayscale")
    print(f"  - bone_probability.png: Bone probability map")
    print(f"  - soft_tissue_probability.png: Soft tissue probability map")
    
    return rgb_image

# Example usage:
# rgb_img = extract_cbct_layer_with_gmm(
#     nii_file_path='path/to/cbct_data.nii.gz',
#     output_image_path='output/layer_50_rgb.png',
#     layer_number=50,
#     debug_folder_path='debug_output/'
# )

In [None]:
nii_file = r"C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D\test\0\50-30909 R_adjustedBG.nii.gz"
output_path = r"test.png"
debug = r"C:\Users\acer\Desktop\Project_TMJOA\Data\debug"

rgb_img = extract_cbct_layer_with_gmm(
    nii_file_path=nii_file,
    output_image_path=output_path,
    layer_number=50,
    debug_folder_path=debug
)

In [None]:
# 3D NIfTI to 2D PNG Layer Extractor - Top N Informative Layers
# Extracts the N most informative layers with minimum spacing between them

import os
import numpy as np
import nibabel as nib
from PIL import Image
from pathlib import Path
import logging

# Additional imports for GMM processing
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Set up logging for Jupyter
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# =============================================================================
# CONFIGURATION - Modify these parameters
# =============================================================================

# Dataset paths
INPUT_DATASET_DIR = r"C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D"  # Update this path
OUTPUT_DATASET_DIR = r"C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_2D_v3"  # Update this path

# Processing parameters
N_LAYERS = 10          # Number of most informative layers to extract
MIN_SPACING = 5       # Minimum spacing between selected layers (k parameter)
BACKGROUND_VALUE = -250   # Background value to exclude

# GMM processing parameters
USE_GMM_PROCESSING = True  # Set to False to use simple normalization
DEBUG_FOLDER_PATH = r"C:\Users\acer\Desktop\Project_TMJOA\Data\debug"   # Set to a path like 'debug_output/' for GMM debugging, or None to disable

print("Configuration:")
print(f"Input dataset: {INPUT_DATASET_DIR}")
print(f"Output dataset: {OUTPUT_DATASET_DIR}")
print(f"Number of layers to extract: {N_LAYERS}")
print(f"Minimum spacing between layers: {MIN_SPACING}")
print(f"Background value: {BACKGROUND_VALUE}")
print(f"Use GMM processing: {USE_GMM_PROCESSING}")
print(f"Debug folder: {DEBUG_FOLDER_PATH}")

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

# Your comprehensive GMM-based layer extraction function
def extract_cbct_layer_with_gmm(nii_file_path, output_image_path, layer_number, debug_folder_path, background_value=-250):
    """
    Extract a 2D layer from 3D CBCT data with GMM-based tissue classification.
    GMM is fitted on the entire 3D volume for robust tissue classification.
    
    Parameters:
    -----------
    nii_file_path : str
        Path to the .nii.gz file
    output_image_path : str
        Path to save the output RGB image (e.g., 'output.png')
    layer_number : int
        Layer index to extract (0-based indexing)
    debug_folder_path : str
        Path to folder where debug visualizations will be saved
    background_value : float
        Background voxel value to exclude from calculations (default: -250)
    
    Returns:
    --------
    rgb_image : PIL.Image
        RGB image where:
        - Red channel: Soft tissue probability (0-255)
        - Green channel: Original voxel values (0-255)
        - Blue channel: Bone probability (0-255)
    """
    
    # Create debug folder if it doesn't exist
    if debug_folder_path:
        os.makedirs(debug_folder_path, exist_ok=True)
    
    # Load the NIfTI file
    # print(f"    Processing NIfTI file: {nii_file_path}")
    nii_img = nib.load(nii_file_path)
    volume_data = nii_img.get_fdata()
    
    # print(f"    Volume shape: {volume_data.shape}")
    # print(f"    Volume value range: [{volume_data.min():.1f}, {volume_data.max():.1f}]")
    
    # Validate layer number
    if layer_number >= volume_data.shape[0]:  # Changed from shape[2] to shape[0] for axis-0 extraction
        raise ValueError(f"Layer {layer_number} exceeds volume depth {volume_data.shape[0]}")
    
    # Remove background values from ENTIRE 3D volume for GMM fitting
    # print(f"    Preparing 3D volume data for GMM fitting...")
    volume_non_background_mask = volume_data != background_value
    volume_non_background_values = volume_data[volume_non_background_mask]
    
    # print(f"    Total voxels in volume: {volume_data.size}")
    # print(f"    Non-background voxels in volume: {len(volume_non_background_values)}")
    # print(f"    Volume non-background range: [{volume_non_background_values.min():.1f}, {volume_non_background_values.max():.1f}]")
    
    if len(volume_non_background_values) < 10:
        raise ValueError("Insufficient non-background voxels in entire volume for GMM fitting")
    
    # For very large volumes, use a random sample for GMM fitting to speed up computation
    max_samples_for_gmm = 1000000  # 1M samples should be sufficient
    if len(volume_non_background_values) > max_samples_for_gmm:
        # print(f"    Volume has {len(volume_non_background_values)} non-background voxels.")
        # print(f"    Using random sample of {max_samples_for_gmm} voxels for GMM fitting...")
        np.random.seed(42)  # For reproducible results
        sample_indices = np.random.choice(len(volume_non_background_values), 
                                        size=max_samples_for_gmm, replace=False)
        gmm_fitting_data = volume_non_background_values[sample_indices]
    else:
        gmm_fitting_data = volume_non_background_values
    
    # print(f"    Using {len(gmm_fitting_data)} voxels for GMM fitting")
    
    # Extract the specified layer for final processing (along axis 0)
    layer_2d = volume_data[layer_number, :, :]
    layer_non_background_mask = layer_2d != background_value
    # print(f"    Extracted layer {layer_number} with shape: {layer_2d.shape}")
    # print(f"    Non-background voxels in layer: {np.sum(layer_non_background_mask)}")
    
    # Fit Gaussian Mixture Model with 2 components on 3D volume data
    # print(f"    Fitting Gaussian Mixture Model on entire 3D volume...")
    gmm = GaussianMixture(n_components=2, random_state=42)
    gmm.fit(gmm_fitting_data.reshape(-1, 1))
    
    # Get GMM parameters
    means = gmm.means_.flatten()
    stds = np.sqrt(gmm.covariances_).flatten()
    weights = gmm.weights_
    
    # print(f"    GMM Component 1: mean={means[0]:.1f}, std={stds[0]:.1f}, weight={weights[0]:.3f}")
    # print(f"    GMM Component 2: mean={means[1]:.1f}, std={stds[1]:.1f}, weight={weights[1]:.3f}")
    
    # Determine which component is bone (higher mean) and which is soft tissue
    if means[0] > means[1]:
        bone_idx, soft_tissue_idx = 0, 1
    else:
        bone_idx, soft_tissue_idx = 1, 0
    
    # print(f"    Bone component (higher intensity): Component {bone_idx}")
    # print(f"    Soft tissue component (lower intensity): Component {soft_tissue_idx}")
    
    # Apply the trained GMM to the selected layer
    # print(f"    Applying GMM to layer {layer_number}...")
    # Calculate probabilities for all pixels in the layer
    # For background pixels, set probabilities to 0
    bone_prob = np.zeros_like(layer_2d)
    soft_tissue_prob = np.zeros_like(layer_2d)
    
    # Only calculate probabilities for non-background pixels in the layer
    if np.sum(layer_non_background_mask) > 0:
        # Get probability predictions for the entire layer
        all_probs = gmm.predict_proba(layer_2d.reshape(-1, 1))
        all_probs_reshaped = all_probs.reshape(layer_2d.shape[0], layer_2d.shape[1], 2)
        
        # Extract bone and soft tissue probabilities
        bone_prob = all_probs_reshaped[:, :, bone_idx]
        soft_tissue_prob = all_probs_reshaped[:, :, soft_tissue_idx]
        
        # Set background pixels to 0 probability
        bone_prob[~layer_non_background_mask] = 0
        soft_tissue_prob[~layer_non_background_mask] = 0
    
    # Rescale original voxel values to [0, 255] using the VOLUME range for consistency
    original_scaled = np.zeros_like(layer_2d)
    if len(volume_non_background_values) > 0:
        vol_min_val, vol_max_val = volume_non_background_values.min(), volume_non_background_values.max()
        original_scaled[layer_non_background_mask] = 255 * (layer_2d[layer_non_background_mask] - vol_min_val) / (vol_max_val - vol_min_val)
    
    # Rescale probabilities to [0, 255]
    bone_prob_scaled = (bone_prob * 255).astype(np.uint8)
    soft_tissue_prob_scaled = (soft_tissue_prob * 255).astype(np.uint8)
    original_scaled = original_scaled.astype(np.uint8)
    
    # Create RGB image
    # Red: Soft tissue probability
    # Green: Original voxel values  
    # Blue: Bone probability
    rgb_array = np.stack([soft_tissue_prob_scaled, original_scaled, bone_prob_scaled], axis=2)
    rgb_image = Image.fromarray(rgb_array, 'RGB')
    
    # Save the main RGB output image
    # print(f"    Saving RGB image to: {output_image_path}")
    rgb_image.save(output_image_path)
    
    # Save debug visualizations only if debug folder is provided

    patient_id = os.path.basename(nii_file_path)  # Use the file name as patient ID
    patient_id = patient_id.split('_')[0]  # Remove file extension
    debug_plot_path = os.path.join(debug_folder_path, f'{patient_id}_gmm_analysis.png')

    if debug_folder_path and not os.path.exists(debug_plot_path):
        
        try:
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            
            # Row 1: Original data and histograms
            # Original layer
            im1 = axes[0, 0].imshow(layer_2d, cmap='gray')
            axes[0, 0].set_title(f'Original Layer {layer_number}')
            axes[0, 0].axis('off')
            plt.colorbar(im1, ax=axes[0, 0])
            
            # Histogram of VOLUME non-background values (used for GMM)
            axes[0, 1].hist(gmm_fitting_data, bins=50, alpha=0.7, density=True, color='gray')
            axes[0, 1].set_title('Histogram of Volume Data (GMM Training)')
            axes[0, 1].set_xlabel('Voxel Value')
            axes[0, 1].set_ylabel('Density')
            
            # GMM overlay on histogram
            x_range = np.linspace(gmm_fitting_data.min(), gmm_fitting_data.max(), 1000)
            
            # Plot individual components
            for i in range(2):
                component_pdf = weights[i] * (1/np.sqrt(2*np.pi*gmm.covariances_[i,0,0])) * \
                               np.exp(-0.5 * ((x_range - means[i])**2) / gmm.covariances_[i,0,0])
                label = f'{"Bone" if i == bone_idx else "Soft Tissue"}'
                color = 'blue' if i == bone_idx else 'red'
                axes[0, 1].plot(x_range, component_pdf, color=color, label=label, linewidth=2)
            
            # Plot total GMM
            total_pdf = np.exp(gmm.score_samples(x_range.reshape(-1, 1)))
            axes[0, 1].plot(x_range, total_pdf, 'k--', label='Total GMM', linewidth=2)
            axes[0, 1].legend()
            
            # GMM classification result for the layer
            classification = gmm.predict(layer_2d.reshape(-1, 1)).reshape(layer_2d.shape)
            classification_display = np.full_like(layer_2d, -1, dtype=int)  # -1 for background
            classification_display[layer_non_background_mask] = classification[layer_non_background_mask]
            
            im3 = axes[0, 2].imshow(classification_display, cmap='RdBu')
            axes[0, 2].set_title(f'GMM Classification of Layer {layer_number}\n(Red=Soft Tissue, Blue=Bone)')
            axes[0, 2].axis('off')
            
            # Row 2: Probability maps and final RGB
            # Bone probability
            im4 = axes[1, 0].imshow(bone_prob, cmap='Blues', vmin=0, vmax=1)
            axes[1, 0].set_title('Bone Probability')
            axes[1, 0].axis('off')
            plt.colorbar(im4, ax=axes[1, 0])
            
            # Soft tissue probability
            im5 = axes[1, 1].imshow(soft_tissue_prob, cmap='Reds', vmin=0, vmax=1)
            axes[1, 1].set_title('Soft Tissue Probability')
            axes[1, 1].axis('off')
            plt.colorbar(im5, ax=axes[1, 1])
            
            # Final RGB result
            axes[1, 2].imshow(rgb_array)
            axes[1, 2].set_title('Final RGB Image\n(R=Soft Tissue, G=Original, B=Bone)')
            axes[1, 2].axis('off')
            
            plt.tight_layout()
            
            # Save the comprehensive debug visualization
            plt.savefig(debug_plot_path, dpi=300, bbox_inches='tight')
            plt.close()  # Close the figure to free memory
        except Exception as e:
            print(f"    Warning: Could not create debug plot: {str(e)}")

    return rgb_image

def normalize_voxel_values(data, background_value=-250):
    """
    Normalize voxel values from [-250, max_value] to [0, 255]
    
    Args:
        data: 2D or 3D numpy array
        background_value: Background value (default: -250)
    
    Returns:
        Normalized data as uint8
    """
    # Clip values to ensure minimum is background_value
    data = np.clip(data, background_value, None)
    
    # Get min and max values
    min_val = data.min()
    max_val = data.max()
    
    if max_val == min_val:
        # Handle case where all values are the same
        return np.zeros_like(data, dtype=np.uint8)
    
    # Normalize to [0, 255]
    normalized = ((data - min_val) / (max_val - min_val)) * 255
    return normalized.astype(np.uint8)

def calculate_layer_informativeness(data, background_value=-250):
    """
    Calculate informativeness score for each layer
    Uses non-background ratio as the primary metric
    
    Args:
        data: 3D numpy array
        background_value: Background value to exclude
    
    Returns:
        List of tuples: (layer_index, informativeness_score)
    """
    layer_scores = []
    total_voxels_per_layer = data.shape[1] * data.shape[2]
    
    for i in range(data.shape[0]):
        layer = data[i, :, :]
        
        # Calculate non-background ratio
        non_background_count = np.sum(layer != background_value)
        non_background_ratio = non_background_count / total_voxels_per_layer
        
        # Additional informativeness metrics can be added here:
        # - Variance of non-background voxels
        # - Edge content
        # - Texture measures
        
        # For now, using non-background ratio as informativeness score
        informativeness_score = non_background_ratio
        
        layer_scores.append((i, informativeness_score))
    
    return layer_scores

def select_top_layers_with_spacing(layer_scores, n_layers, min_spacing):
    """
    Select top N layers ensuring minimum spacing between them
    
    Args:
        layer_scores: List of tuples (layer_index, score)
        n_layers: Number of layers to select
        min_spacing: Minimum spacing between selected layers
    
    Returns:
        List of tuples: (layer_index, score) for selected layers
    """
    # Sort by score in descending order
    sorted_layers = sorted(layer_scores, key=lambda x: x[1], reverse=True)
    
    selected_layers = []
    used_indices = set()
    
    # print(f"  Selecting {n_layers} layers with minimum spacing of {min_spacing}...")
    
    for layer_idx, score in sorted_layers:
        # Check if this layer conflicts with already selected layers
        conflict = False
        for selected_idx, _ in selected_layers:
            if abs(layer_idx - selected_idx) < min_spacing:
                conflict = True
                break
        
        if not conflict:
            selected_layers.append((layer_idx, score))
            used_indices.add(layer_idx)
            # print(f"    Selected layer {layer_idx} (score: {score:.4f})")
            
            if len(selected_layers) >= n_layers:
                break
    
    # Sort selected layers by index for consistent naming
    selected_layers.sort(key=lambda x: x[0])
    
    if len(selected_layers) < n_layers:
        print(f"    ⚠️ Only found {len(selected_layers)} layers that satisfy spacing constraint")
    
    return selected_layers

def extract_selected_layers(data, selected_layers):
    """
    Extract the selected layers
    
    Args:
        data: 3D numpy array
        selected_layers: List of tuples (layer_index, score)
    
    Returns:
        Dictionary of layer information for processing
    """
    layers = {}
    
    for i, (layer_idx, score) in enumerate(selected_layers):
        layer_key = f"layer_{layer_idx:03d}_rank_{i+1:02d}_score_{score:.3f}"
        layers[layer_key] = {
            'layer_index': layer_idx,
            'score': score,
            'rank': i + 1
        }
        # print(f"  Selected layer {layer_idx} (rank {i+1}, score: {score:.4f})")
    
    return layers

def save_layer_as_png(layer_data, output_path):
    """
    Save 2D layer as PNG with 3 identical channels (fallback method)
    
    Args:
        layer_data: 2D numpy array (normalized to 0-255)
        output_path: Output file path
    """
    # Create 3-channel image (RGB) with identical values
    rgb_image = np.stack([layer_data, layer_data, layer_data], axis=-1)
    
    # Convert to PIL Image and save
    pil_image = Image.fromarray(rgb_image, mode='RGB')
    pil_image.save(output_path)

def process_nifti_file(nifti_path, output_dir, n_layers, min_spacing, background_value=-250):
    """
    Process a single NIfTI file and extract top N informative layers with spacing
    
    Args:
        nifti_path: Path to .nii.gz file
        output_dir: Output directory for this file's layers
        n_layers: Number of layers to extract
        min_spacing: Minimum spacing between layers
        background_value: Background value
    """
    try:
        # print(f"\nProcessing: {nifti_path}")
        
        # Load NIfTI file
        nifti_img = nib.load(nifti_path)
        data = nifti_img.get_fdata()
        
        # print(f"  Data shape: {data.shape}")
        # print(f"  Data range: [{data.min():.2f}, {data.max():.2f}]")
        
        # Calculate informativeness for all layers
        layer_scores = calculate_layer_informativeness(data, background_value)
        
        # Select top N layers with spacing constraint
        selected_layers = select_top_layers_with_spacing(layer_scores, n_layers, min_spacing)
        
        if not selected_layers:
            print(f"  ⚠️ No suitable layers found")
            return 0
        
        # Get layer information
        layers = extract_selected_layers(data, selected_layers)
        
        # Process and save each layer
        filename_base = Path(nifti_path).stem.replace('.nii', '')  # Remove .nii.gz extension
        filename_base = filename_base.split('_')[0]  # Remove any file extension
        saved_count = 0
        
        # Create debug folder for this file if enabled
        file_debug_folder = None
        if DEBUG_FOLDER_PATH:
            file_debug_folder = DEBUG_FOLDER_PATH
            os.makedirs(file_debug_folder, exist_ok=True)
        
        for layer_name, layer_info in layers.items():
            layer_idx = layer_info['layer_index']
            
            # Create output filename
            output_filename = f"{filename_base}_{layer_info['rank']}.png"
            output_path = os.path.join(output_dir, output_filename)
            
            try:
                if USE_GMM_PROCESSING:
                    # Use your comprehensive GMM function
                    rgb_img = extract_cbct_layer_with_gmm(
                        nii_file_path=str(nifti_path),
                        output_image_path=output_path,
                        layer_number=layer_idx,
                        debug_folder_path=file_debug_folder,
                        background_value=background_value
                    )

                    bar_length = 10
                    filled_length = layer_info['rank']
                    
                    # Create the bar
                    bar = '█' * filled_length + '-' * (bar_length - filled_length)
                    print(f"    \r   Processing {filename_base} |{bar}|", end="")
                else:
                    # Use simple normalization as fallback
                    layer_data = data[layer_idx, :, :]
                    normalized_layer = normalize_voxel_values(layer_data, background_value)
                    save_layer_as_png(normalized_layer, output_path)
                    # print(f"    ✅ Simple processed and saved: {output_filename}")
                
                saved_count += 1
                
            except Exception as layer_error:
                print(f"    ❌ Error processing layer {layer_idx}: {str(layer_error)}")
                continue
        
        # print(f"  ✅ Saved {saved_count}/{len(layers)} layers successfully")
        return saved_count
    
    except Exception as e:
        print(f"  ❌ Error processing {nifti_path}: {str(e)}")
        return 0

def process_dataset(input_dataset_dir, output_dataset_dir, n_layers, min_spacing, background_value=-250):
    """
    Process entire dataset
    
    Args:
        input_dataset_dir: Input dataset directory
        output_dataset_dir: Output dataset directory
        n_layers: Number of layers to extract per file
        min_spacing: Minimum spacing between layers
        background_value: Background value
    
    Returns:
        Dictionary with processing statistics
    """
    input_path = Path(input_dataset_dir)
    output_path = Path(output_dataset_dir)
    
    stats = {
        'total_files': 0,
        'processed_files': 0,
        'total_layers_saved': 0,
        'splits': {}
    }
    
    # Iterate through train/val/test folders
    for split_dir in input_path.iterdir():
        if split_dir.is_dir() and split_dir.name in ['train', 'val', 'test']:
            print(f"\n{'='*50}")
            print(f"Processing split: {split_dir.name}")
            print(f"{'='*50}")
            
            stats['splits'][split_dir.name] = {'classes': {}}
            
            # Iterate through class folders (0, 1)
            for class_dir in split_dir.iterdir():
                if class_dir.is_dir() and class_dir.name in ['0', '1']:
                    print(f"\n📁 Processing class: {class_dir.name}")
                    
                    # Create output directory
                    output_class_dir = output_path / split_dir.name / class_dir.name
                    output_class_dir.mkdir(parents=True, exist_ok=True)
                    
                    # Process all .nii.gz files in this directory
                    nifti_files = list(class_dir.glob('*.nii.gz'))
                    print(f"Found {len(nifti_files)} NIfTI files")
                    
                    class_stats = {
                        'total_files': len(nifti_files),
                        'processed_files': 0,
                        'total_layers': 0
                    }
                    
                    for nifti_file in nifti_files:

                        stats['total_files'] += 1
                        layers_saved = process_nifti_file(nifti_file, output_class_dir, n_layers, min_spacing, background_value)

                        if layers_saved > 0:
                            stats['processed_files'] += 1
                            class_stats['processed_files'] += 1
                            stats['total_layers_saved'] += layers_saved
                            class_stats['total_layers'] += layers_saved
                    
                    stats['splits'][split_dir.name]['classes'][class_dir.name] = class_stats
                    print(f"Class {class_dir.name} summary: {class_stats['processed_files']}/{class_stats['total_files']} files processed, {class_stats['total_layers']} layers saved")
    
    return stats

# =============================================================================
# MAIN PROCESSING
# =============================================================================

# Validate input directory
if not os.path.exists(INPUT_DATASET_DIR):
    print(f"❌ Input dataset directory does not exist: {INPUT_DATASET_DIR}")
    print("Please update the INPUT_DATASET_DIR variable with the correct path")
else:
    print(f"✅ Input dataset found: {INPUT_DATASET_DIR}")
    
    # Validate parameters
    if N_LAYERS <= 0:
        print(f"❌ N_LAYERS must be positive, got: {N_LAYERS}")
    elif MIN_SPACING < 0:
        print(f"❌ MIN_SPACING must be non-negative, got: {MIN_SPACING}")
    else:
        print(f"✅ Starting processing to extract {N_LAYERS} layers with {MIN_SPACING} spacing...")
        
        # Process the dataset
        print(f"\n🚀 Starting dataset processing...")
        stats = process_dataset(INPUT_DATASET_DIR, OUTPUT_DATASET_DIR, N_LAYERS, MIN_SPACING, BACKGROUND_VALUE)
        
        # Print final summary
        print(f"\n{'='*60}")
        print("📊 PROCESSING SUMMARY")
        print(f"{'='*60}")
        print(f"Total files processed: {stats['processed_files']}/{stats['total_files']}")
        print(f"Total layers extracted: {stats['total_layers_saved']}")
        print(f"Average layers per file: {stats['total_layers_saved']/max(stats['processed_files'], 1):.1f}")
        print(f"Output directory: {OUTPUT_DATASET_DIR}")
        
        for split_name, split_data in stats['splits'].items():
            print(f"\n{split_name.upper()}:")
            for class_name, class_data in split_data['classes'].items():
                avg_layers = class_data['total_layers'] / max(class_data['processed_files'], 1)
                print(f"  Class {class_name}: {class_data['processed_files']}/{class_data['total_files']} files → {class_data['total_layers']} layers (avg: {avg_layers:.1f})")
        
        print(f"\n✅ Processing completed successfully!")

Configuration:
Input dataset: C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D
Output dataset: C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_2D_v3
Number of layers to extract: 10
Minimum spacing between layers: 5
Background value: -250
Use GMM processing: True
Debug folder: C:\Users\acer\Desktop\Project_TMJOA\Data\debug
✅ Input dataset found: C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D
✅ Starting processing to extract 10 layers with 5 spacing...

🚀 Starting dataset processing...

Processing split: test

📁 Processing class: 0
Found 16 NIfTI files
9 Processing 50-30909 R |██████████|                                    

KeyboardInterrupt: 