# 3. Gaussian Mixture Model

In [45]:
import numpy as np
import nibabel as nib
from sklearn.metrics import accuracy_score
import os
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from scipy import ndimage
from skimage.morphology import remove_small_objects, closing, disk



## GMM Class

In [46]:
class GMM:
    def __init__(self, n_components):
        """
        Initialize a Gaussian Mixture Model
        
        Parameters:
        -----------
        n_components : int
            Number of Gaussian components in the mixture
        """
        self.n_components = n_components
        self.weights = None
        self.means = None
        self.covariances = None
        self.responsibilities = None
        
    def _initialize_parameters(self, X):
        """
        Initialize the model parameters
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Training data
        """
        n_samples, n_features = X.shape
        
        # Random initialization
        indices = np.random.choice(n_samples, self.n_components, replace=False)
        self.means = X[indices]
        self.covariances = np.array([np.eye(n_features) * np.var(X, axis=0).mean() 
                                        for _ in range(self.n_components)])
        self.weights = np.ones(self.n_components) / self.n_components
        
    def _multivariate_gaussian(self, X, mean, covariance):
        """
        Compute the probability density of a multivariate Gaussian distribution
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Data points
        mean : array, shape (n_features,)
            Mean of the Gaussian
        covariance : array, shape (n_features, n_features)
            Covariance matrix of the Gaussian
            
        Returns:
        --------
        pdf : array, shape (n_samples,)
            Probability density for each sample
        """
        n_features = X.shape[1]
        
        # More robust regularization
        covariance_reg = covariance + 1e-5 * np.eye(n_features)
        
        try:
            # Compute determinant and inverse of the covariance matrix
            # Using Cholesky decomposition for stability
            L = np.linalg.cholesky(covariance_reg)
            det = np.prod(np.diag(L))**2
            inv = np.linalg.inv(covariance_reg)
            
            # Compute the normalization constant
            norm_const = 1.0 / (np.power(2 * np.pi, n_features / 2) * np.sqrt(det))
            
            # Compute the exponent term
            X_centered = X - mean
            exponent = -0.5 * np.sum(X_centered @ inv * X_centered, axis=1)
            return norm_const * np.exp(exponent)
        except np.linalg.LinAlgError:
            # Fallback for numerical stability if Cholesky fails
            covariance_reg = covariance + 1e-3 * np.eye(n_features)
            det = np.linalg.det(covariance_reg)
            inv = np.linalg.inv(covariance_reg)
            norm_const = 1.0 / (np.power(2 * np.pi, n_features / 2) * np.sqrt(det))
            X_centered = X - mean
            exponent = -0.5 * np.sum(X_centered @ inv * X_centered, axis=1)
            return norm_const * np.exp(exponent)
    
    def expectation_step(self, X):
        """
        E-step: Compute responsibilities (posterior probabilities)
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Training data
            
        Returns:
        --------
        log_likelihood : float
            Log-likelihood of the data
        """
        n_samples = X.shape[0]
        
        # Compute weighted probabilities for each component
        weighted_probs = np.zeros((n_samples, self.n_components))
        
        for k in range(self.n_components):
            weighted_probs[:, k] = self.weights[k] * self._multivariate_gaussian(
                X, self.means[k], self.covariances[k]
            )
        
        # Compute total probability and responsibilities
        # Add a small epsilon to prevent division by zero
        total_probs = np.sum(weighted_probs, axis=1, keepdims=True) + 1e-10
        self.responsibilities = weighted_probs / total_probs
        
        # Compute log-likelihood
        return np.sum(np.log(total_probs))
    
    def maximization_step(self, X):
        """
        M-step: Update parameters based on computed responsibilities
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Training data
        """
        n_samples, n_features = X.shape
        
        # Compute effective number of points assigned to each component
        N_k = np.sum(self.responsibilities, axis=0) + 1e-10  # Avoid division by zero
        
        # Update weights
        self.weights = N_k / n_samples
        
        # Update means
        self.means = np.dot(self.responsibilities.T, X) / N_k.reshape(-1, 1)
        
        # Update covariances
        for k in range(self.n_components):
            X_centered = X - self.means[k]
            weighted_cov = np.dot(self.responsibilities[:, k] * X_centered.T, X_centered)
            self.covariances[k] = weighted_cov / N_k[k] + 1e-6 * np.eye(n_features)
    
    def fit(self, X, max_iter=35, tol=1e-4, verbose=False):
        """
        Fit the GMM to the data
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Training data
        max_iter : int, default=100
            Maximum number of iterations
        tol : float, default=1e-4
            Convergence threshold for log-likelihood
        verbose : bool, default=False
            Whether to print progress information
            
        Returns:
        --------
        self : object
            Returns self
        """
        # Initialize parameters
        self._initialize_parameters(X)
        
        # Run EM algorithm
        log_likelihood_old = -np.inf
        
        for iteration in range(max_iter):
            # E-step
            log_likelihood = self.expectation_step(X)
            
            # M-step
            self.maximization_step(X)
            
            # Check for convergence
            improvement = log_likelihood - log_likelihood_old
            if verbose and (iteration % 5 == 0 or iteration == max_iter - 1):
                print(f"Iteration {iteration+1}, Log-Likelihood: {log_likelihood:.4f}, Improvement: {improvement:.6f}")
                
            if abs(improvement) < tol:
                if verbose:
                    print(f"Converged after {iteration+1} iterations")
                break
                
            log_likelihood_old = log_likelihood
            
        return self
    
    def predict(self, X):
        """
        Predict the component labels for the data
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Data to predict
            
        Returns:
        --------
        labels : array, shape (n_samples,)
            Component labels
        """
        weighted_probs = np.zeros((X.shape[0], self.n_components))
        
        for k in range(self.n_components):
            weighted_probs[:, k] = self.weights[k] * self._multivariate_gaussian(
                X, self.means[k], self.covariances[k]
            )
            
        return np.argmax(weighted_probs, axis=1)
    
    def predict_proba(self, X):
        """
        Predict the component probabilities for the data
        
        Parameters:
        -----------
        X : array, shape (n_samples, n_features)
            Data to predict
            
        Returns:
        --------
        responsibilities : array, shape (n_samples, n_components)
            Component probabilities
        """
        weighted_probs = np.zeros((X.shape[0], self.n_components))
        
        for k in range(self.n_components):
            weighted_probs[:, k] = self.weights[k] * self._multivariate_gaussian(
                X, self.means[k], self.covariances[k]
            )
            
        total_probs = np.sum(weighted_probs, axis=1, keepdims=True) + 1e-10
        return weighted_probs / total_probs


## Preprocessing and post-procesing of data

In [55]:
def segment_brain_mri_improved(
    mri_path="sald_031764_img.nii", 
    csf_mask_path="sald_031764_probmask_csf.nii",
    wm_mask_path="sald_031764_probmask_whitematter.nii", 
    gm_mask_path="sald_031764_probmask_greymatter.nii",
    output_dir="segmentation_results_improved"
):
    """
    Segment brain MRI into CSF, white matter, and gray matter using GMM with improved
    brain masking to prevent misclassification of empty space as white matter.
    
    Parameters:
    -----------
    mri_path : str
        Path to the MRI .nii file
    csf_mask_path : str
        Path to the CSF probability mask .nii file
    wm_mask_path : str
        Path to the white matter probability mask .nii file
    gm_mask_path : str
        Path to the gray matter probability mask .nii file
    output_dir : str
        Directory to save the output files
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    print("Loading MRI and ground truth masks...")
    # Load MRI data
    mri_img = nib.load(mri_path)
    mri_data = mri_img.get_fdata()
    affine = mri_img.affine
    header = mri_img.header
    
    # Load ground truth masks
    csf_mask = nib.load(csf_mask_path).get_fdata()
    wm_mask = nib.load(wm_mask_path).get_fdata()
    gm_mask = nib.load(gm_mask_path).get_fdata()
    
    # Create ground truth segmentation by taking the maximum probability
    ground_truth = np.zeros_like(mri_data, dtype=np.int16)
    all_masks = np.stack([csf_mask, wm_mask, gm_mask], axis=-1)
    ground_truth = np.argmax(all_masks, axis=-1)
    
    # ------------ IMPROVED BRAIN MASK CREATION ------------
    
    # Create a robust brain mask from the probability masks
    # Sum all probability masks and threshold
    initial_brain_mask = (csf_mask + wm_mask + gm_mask) > 0.1
    
    # Use Otsu thresholding as a backup to catch any brain regions missed by probability masks
    from skimage.filters import threshold_otsu
    
    # Create a copy of the MRI data for thresholding
    mri_data_for_threshold = mri_data.copy()
    
    # Replace zeros with NaN to ignore background in thresholding
    mri_data_for_threshold[mri_data_for_threshold == 0] = np.nan
    
    # Calculate threshold using only non-NaN values
    flat_mri = mri_data_for_threshold.flatten()
    flat_mri = flat_mri[~np.isnan(flat_mri)]
    
    if len(flat_mri) > 0:
        try:
            otsu_thresh = threshold_otsu(flat_mri)
            otsu_mask = mri_data > otsu_thresh
        except:
            # Fallback if Otsu fails
            otsu_mask = mri_data > np.nanmean(flat_mri)
    else:
        otsu_mask = np.zeros_like(mri_data, dtype=bool)
    
    # Combine masks
    combined_mask = initial_brain_mask | otsu_mask
    
    # Apply morphological operations to clean up the mask
    # First, use morphological closing to fill small holes and connect nearby regions
    for i in range(combined_mask.shape[2]):
        combined_mask[:, :, i] = closing(combined_mask[:, :, i], disk(3))
    
    # Fill holes
    brain_mask = ndimage.binary_fill_holes(combined_mask)
    
    # Remove small isolated regions (noise)
    labels, num_features = ndimage.label(brain_mask)
    sizes = ndimage.sum(brain_mask, labels, range(1, num_features + 1))
    
    # Find the largest connected component
    labels, num_features = ndimage.label(brain_mask)
    if num_features > 0:
        sizes = ndimage.sum(brain_mask, labels, range(1, num_features + 1))
        largest_component_label = np.argmax(sizes) + 1 if len(sizes) > 0 else 0
        brain_mask = labels == largest_component_label
    
    # Dilate slightly to ensure we capture the brain boundary
    brain_mask = ndimage.binary_dilation(brain_mask, iterations=2)
    
    # Create a background mask (logical NOT of brain mask)
    background_mask = ~brain_mask
    
    # ------------ INTENSITY NORMALIZATION ------------
    
    # Apply intensity normalization to the MRI data
    # Only normalize within the brain mask
    mri_brain = mri_data[brain_mask]
    
    # Z-score normalization within the brain
    mean_intensity = np.mean(mri_brain)
    std_intensity = np.std(mri_brain)
    
    mri_normalized = np.zeros_like(mri_data)
    mri_normalized[brain_mask] = (mri_data[brain_mask] - mean_intensity) / std_intensity
    
    # Apply bias field correction (simplified)
    # Smooth the image to estimate bias field
    from scipy.ndimage import gaussian_filter
    smooth_brain = np.zeros_like(mri_normalized)
    smooth_brain[brain_mask] = gaussian_filter(mri_normalized[brain_mask], sigma=3)
    
    # Correct bias field
    bias_corrected = np.zeros_like(mri_data)
    bias_corrected[brain_mask] = mri_normalized[brain_mask] - 0.5 * smooth_brain[brain_mask]
    
    # ------------ FEATURE EXTRACTION ------------
    
    # Flatten the 3D volumes to 1D arrays (only for brain voxels)
    intensity_feat = bias_corrected[brain_mask].reshape(-1, 1)
    
    # Get 3D coordinates of brain voxels
    coords = np.array(np.where(brain_mask)).T
    
    # Add additional texture features
    # Local variance (a simple texture measure)
    variance_map = np.zeros_like(mri_data)
    for i in range(1, mri_data.shape[0]-1):
        for j in range(1, mri_data.shape[1]-1):
            for k in range(1, mri_data.shape[2]-1):
                if brain_mask[i, j, k]:
                    patch = mri_normalized[i-1:i+2, j-1:j+2, k-1:k+2]
                    variance_map[i, j, k] = np.var(patch)
    
    # Extract texture features for brain voxels
    texture_feat = variance_map[brain_mask].reshape(-1, 1)
    texture_feat = (texture_feat - np.mean(texture_feat)) / (np.std(texture_feat) + 1e-10)
    
    # Use prior knowledge from probability masks
    prior_csf = csf_mask[brain_mask].reshape(-1, 1)
    prior_wm = wm_mask[brain_mask].reshape(-1, 1)
    prior_gm = gm_mask[brain_mask].reshape(-1, 1)
    
    # Scale down the importance of spatial coordinates
    coords_normalized = (coords - np.mean(coords, axis=0)) / np.std(coords, axis=0) * 0.3
    
    # Create the feature matrix, weighted to emphasize the most important features
    features = np.hstack([
        intensity_feat * 1.5,          # Intensity (most important)
        # distances * 0.8,               # Distance from brain center
        coords_normalized * 0.5,       # Spatial coordinates (least important)
        texture_feat * 0.7,            # Texture features
        prior_csf * 1.0,               # Prior CSF probability
        prior_wm * 1.0,                # Prior WM probability
        prior_gm * 1.0                 # Prior GM probability
    ])
       
    print(f"Fitting GMM with 3 components to {features.shape[0]} voxels with {features.shape[1]} features...")
    
    # Fit GMM with 3 components (CSF, WM, GM) using k-means initialization
    gmm = GMM(n_components=3)
    gmm.fit(features, max_iter=40, tol=1e-3, verbose=True)
    
    # Predict segmentation
    print("Predicting segmentation...")
    labels = gmm.predict(features)
    
    # Map back to 3D volume
    segmentation = np.zeros_like(mri_data, dtype=np.int16)
    segmentation[brain_mask] = labels
    
    # Set background (outside brain) to a special value (-1)
    segmentation[background_mask] = -1
    
    # ------------ SEGMENT IDENTIFICATION ------------
    
    # Calculate mean intensity for each predicted segment
    mean_intensities = [np.mean(mri_data[brain_mask][labels == i]) for i in range(3)]
    
    # Use prior knowledge for more reliable mapping
    # Calculate overlap with ground truth for each segment
    overlap_matrix = np.zeros((3, 3))  # [predicted_class, true_class]
    
    for pred_class in range(3):
        pred_mask = labels == pred_class
        for true_class in range(3):
            true_mask = ground_truth[brain_mask] == true_class
            overlap_matrix[pred_class, true_class] = np.sum(pred_mask & true_mask)
    
    # Determine best mapping based on maximum overlap
    mapping = np.zeros(3, dtype=np.int16)
    
    # First check the most reliable mapping based on intensity
    segment_order = np.argsort(mean_intensities)
    
    # Verify with overlap matrix
    # CSF should have lowest intensity
    mapping[segment_order[0]] = 0
    
    # Check which of the remaining two has higher overlap with gray matter
    if overlap_matrix[segment_order[1], 1] > overlap_matrix[segment_order[2], 1]:
        # Mid intensity is gray matter
        mapping[segment_order[1]] = 1  # GM
        mapping[segment_order[2]] = 2  # WM
    else:
        # Mid intensity is actually white matter (unusual but possible)
        mapping[segment_order[1]] = 2  # WM
        mapping[segment_order[2]] = 1  # GM
    
    print("Mapping segments to anatomical structures based on intensity and overlap...")
    print(f"Segment mapping: {mapping}")
    
    # Apply mapping to brain voxels only
    segmentation_mapped = np.zeros_like(segmentation)
    for i in range(3):
        segmentation_mapped[segmentation == i] = mapping[i]
    
    # Background remains -1
    segmentation_mapped[segmentation == -1] = -1
    
    # ------------ POST-PROCESSING ------------

    # For visualization and saving, we'll set background to 3 (a new class)
    segmentation_display = segmentation_mapped.copy()
    segmentation_display[segmentation_mapped == -1] = 3
    
    # Get ground truth for brain voxels only
    brain_indices = np.where(brain_mask)
    ground_truth_brain = ground_truth[brain_indices]
    segmentation_brain = segmentation_mapped[brain_indices]
    
    # Calculate accuracy for brain voxels only
    print("Calculating accuracy...")
    # Filter out background voxels (-1) from accuracy calculation
    valid_indices = segmentation_brain != -1
    accuracy_overall = np.mean(segmentation_brain[valid_indices] == ground_truth_brain[valid_indices])
    
    # Calculate accuracy for each tissue type
    tissue_accuracies = {}
    for i, tissue in enumerate(['CSF', 'GM', 'WM']):
        mask_true = ground_truth_brain == i
        if np.sum(mask_true) > 0:
            pred_matches = segmentation_brain[mask_true] == i
            accuracy_i = np.mean(pred_matches)
            tissue_accuracies[tissue] = accuracy_i
    
    # Save segmentation results
    print("Saving segmentation results...")
    seg_filename = os.path.join(output_dir, "segmentation_all.nii")
    seg_img = nib.Nifti1Image(segmentation_display, affine, header)
    nib.save(seg_img, seg_filename)
    
    # Save individual tissue segmentations
    for i, tissue in enumerate(['csf', 'gm', 'wm']):
        tissue_mask = segmentation_mapped == i
        tissue_filename = os.path.join(output_dir, f"segmentation_{tissue}.nii")
        tissue_img = nib.Nifti1Image(tissue_mask.astype(np.int16), affine, header)
        nib.save(tissue_img, tissue_filename)
    
    # Save the brain mask
    brain_mask_filename = os.path.join(output_dir, "brain_mask.nii")
    brain_mask_img = nib.Nifti1Image(brain_mask.astype(np.int16), affine, header)
    nib.save(brain_mask_img, brain_mask_filename)

    # Print results
    print("\nSegmentation Results:")
    print(f"Overall Accuracy (within brain): {accuracy_overall:.4f}")
    for tissue, acc in tissue_accuracies.items():
        print(f"{tissue} Accuracy: {acc:.4f}")
    
    print(f"\nSegmentation results saved to {output_dir}")
    
    return segmentation_mapped, ground_truth, brain_mask, accuracy_overall, tissue_accuracies


## Segmentation:

In [56]:
# Execute the segmentation function
values = segment_brain_mri_improved(
    mri_path="datasets/gmm/sald_031764_img.nii",
    csf_mask_path="datasets/gmm/sald_031764_probmask_csf.nii",
    wm_mask_path="datasets/gmm/sald_031764_probmask_whitematter.nii",
    gm_mask_path="datasets/gmm/sald_031764_probmask_graymatter.nii",
    output_dir="segmentation_results"
)

Loading MRI and ground truth masks...
Fitting GMM with 3 components to 2279853 voxels with 8 features...
Iteration 1, Log-Likelihood: -11215911.6116, Improvement: inf
Iteration 6, Log-Likelihood: 20954513.6215, Improvement: 53473.170135
Iteration 11, Log-Likelihood: 22271657.8721, Improvement: 71164.441924
Iteration 16, Log-Likelihood: 23561686.2070, Improvement: 651445.230266
Iteration 21, Log-Likelihood: 23561773.5784, Improvement: 0.030753
Iteration 26, Log-Likelihood: 23561773.6126, Improvement: -0.001170
Iteration 31, Log-Likelihood: 23561773.6133, Improvement: 0.002191
Iteration 36, Log-Likelihood: 23561773.6111, Improvement: -0.002180
Iteration 40, Log-Likelihood: 23561773.6111, Improvement: -0.002179
Predicting segmentation...
Mapping segments to anatomical structures based on intensity and overlap...
Segment mapping: [2 0 1]
Calculating accuracy...
Saving segmentation results...

Segmentation Results:
Overall Accuracy (within brain): 0.7828
CSF Accuracy: 0.2593
GM Accuracy: 1.

## Visualization of results:

### Original Masks:

1. **Original CSF Mask:**  
![Original-csf.png](assets/gmm/original/csf-original.png)

2. **Original Graymatter Mask:**  
![Original-gm.png](assets/gmm/original/gm-original.png)

3. **Original Whitematter Mask:**  
![Original-wm.png](assets/gmm/original/wm-original.png)

---

- **Brain Mask created during data processing:**  
![Brain-Mask.png](assets/gmm/predicted/brain-mask-created.png)

---

### Predicted Masks/Segmentations:

1. **Predicted CSF Segmentation:**  
![predicted-csf.png](assets/gmm/predicted/csf-predicted.png)

2. **Predicted Graymatter Segmentation:**  
![predicted-gm.png](assets/gmm/predicted/gm-predicted.png)

3. **Predicted Whitematter Segmentation:**  
![predicted-wm.png](assets/gmm/predicted/wm-predicted.png)

---

## Comments:

1) Where the Highest Misclassification Occurs and Why:
- Based on typical brain MRI intensity distributions and the GMM approach used in the code, the highest misclassification would likely occur at the gray matter-white matter boundary for these reasons:

  - **Overlapping Intensity Distributions:** Gray matter and white matter have partially overlapping intensity distributions. While white matter generally has higher intensity than gray matter, there's a significant overlap region where voxels could belong to either tissue.
  - **Partial Volume Effects:** Voxels at the boundary between gray and white matter often contain both tissue types due to the limited resolution of MRI. These partial volume voxels have intermediate intensity values that can be misclassified by the GMM.
  - **Tissue Interface Complexity:** The boundary between gray and white matter is anatomically complex with many folds and curves, making it challenging for the intensity-based GMM to correctly classify all voxels.
  - **Limited Feature Set:** The GMM mainly relies on intensity and some spatial information. While this works for gross tissue separation, it struggles with subtle distinctions at tissue interfaces.
  - **Gray Matter's "Middle Position":** Gray matter's intensity distribution sits between CSF (lowest) and white matter (highest), making it susceptible to misclassification in both directions. It can be misclassified as either CSF or white matter depending on local intensity variations.

---