In [1]:
import numpy as np
import nibabel as nib
from tqdm import tqdm
from sklearn.cluster import KMeans

### Expectation Maximization Algorithm

In [2]:
def gaussian_pdf_1d(x, mu, var):
    return 1.0 / np.sqrt(2 * np.pi * var) * np.exp(-0.5 * ((x - mu)**2) / var)

def E_step(X, means, variances, pis):
    N = X.shape[0]
    K = len(means)
    pdfs = np.zeros((N, K))
    for k in range(K):
        pdfs[:, k] = gaussian_pdf_1d(X.flatten(), means[k], variances[k] + 1e-6)
    weighted_pdfs = pdfs * pis
    gamma = weighted_pdfs / weighted_pdfs.sum(axis=1, keepdims=True)
    return gamma

def M_step(X, gamma):
    N = X.shape[0]
    K = gamma.shape[1]
    Nk = np.sum(gamma, axis=0)
    means = np.sum(gamma * X, axis=0) / Nk
    variances = []
    for k in range(K):
        diff = X.flatten() - means[k]
        var_k = np.sum(gamma[:, k] * (diff**2)) / Nk[k]
        variances.append(var_k)
    pis = Nk / N
    return means, variances, pis

def gmm(X, K, iterations=100, tol=1e-4):
    N = X.shape[0]
    kmeans = KMeans(n_clusters=K).fit(X)
    means = kmeans.cluster_centers_.flatten()
    labels_km = kmeans.labels_
    variances = []
    pis = []
    for k in range(K):
        cluster_data = X[labels_km == k, 0]
        if len(cluster_data) == 0:
            variance = np.var(X)
        else:
            variance = np.var(cluster_data)
        variances.append(variance)
        pis.append(len(cluster_data) / float(N))
    pis = np.array(pis)
    
    print("Initial pis:", pis)
    
    log_likelihoods = []
    for i in tqdm(range(iterations), desc="GMM iterations"):
        gamma = E_step(X, means, variances, pis)
        means, variances, pis = M_step(X, gamma)
        if i % 10 == 0:
            print("Iteration", i, "pis:", pis)

        pdfs = np.zeros((N, K))
        for k in range(K):
            pdfs[:, k] = gaussian_pdf_1d(X.flatten(), means[k], variances[k] + 1e-6)
        weighted_pdfs = pdfs * pis
        ll = np.sum(np.log(np.sum(weighted_pdfs, axis=1)))
        log_likelihoods.append(ll)
        
        if i > 0 and np.abs(ll - log_likelihoods[-2]) < tol:
            break
    return means, variances, pis, gamma, log_likelihoods

### Data

In [3]:
img_file = 'data/sald_031764_img.nii'
img = nib.load(img_file)
data = img.get_fdata()
affine = img.affine
header = img.header
print("Loaded brain image with shape:", data.shape)

X = data.flatten()[:, np.newaxis]

Loaded brain image with shape: (182, 218, 182)


### GMM Segmentation

In [4]:
means, variances, pis, gamma, log_likelihoods = gmm(X, K=3, iterations=100)
labels = np.argmax(gamma, axis=1)

sorted_indices = np.argsort(means)
mapping = {sorted_indices[0]: 0, sorted_indices[1]: 1, sorted_indices[2]: 2}

gmm_labels = np.vectorize(lambda x: mapping[x])(labels)
gmm_segmentation = gmm_labels.reshape(data.shape)

prob_csf_file   = 'data/sald_031764_probmask_csf.nii'
prob_gray_file  = 'data/sald_031764_probmask_graymatter.nii'
prob_white_file = 'data/sald_031764_probmask_whitematter.nii'

Initial pis: [0.74252794 0.11966184 0.13781022]


GMM iterations:   0%|          | 0/100 [00:00<?, ?it/s]

Iteration 0 pis: [0.7286403  0.14018895 0.13117075]


GMM iterations:  10%|█         | 10/100 [00:11<01:43,  1.15s/it]

Iteration 10 pis: [0.70599127 0.18784113 0.1061676 ]


GMM iterations:  20%|██        | 20/100 [00:22<01:29,  1.12s/it]

Iteration 20 pis: [0.70599133 0.20355646 0.09045222]


GMM iterations:  30%|███       | 30/100 [00:34<01:20,  1.15s/it]

Iteration 30 pis: [0.70599142 0.22987054 0.06413805]


GMM iterations:  40%|████      | 40/100 [00:46<01:14,  1.24s/it]

Iteration 40 pis: [0.70599144 0.24599572 0.04801285]


GMM iterations:  50%|█████     | 50/100 [00:59<01:03,  1.27s/it]

Iteration 50 pis: [0.70599144 0.24668005 0.0473285 ]


GMM iterations:  59%|█████▉    | 59/100 [01:11<00:49,  1.21s/it]


### Ground Truth

In [5]:
prob_csf   = nib.load(prob_csf_file).get_fdata()
prob_gray  = nib.load(prob_gray_file).get_fdata()
prob_white = nib.load(prob_white_file).get_fdata()

probs = np.stack([prob_csf, prob_gray, prob_white], axis=-1)
gt_segmentation = np.argmax(probs, axis=-1)

In [6]:
accuracy = np.mean(gmm_segmentation == gt_segmentation) * 100
print("Pointwise accuracy of the segmentation: {:.2f}%".format(accuracy))

Pointwise accuracy of the segmentation: 88.94%


### Saving and all

In [7]:
gmm_seg_img = nib.Nifti1Image(gmm_segmentation.astype(np.int16), affine, header)
gt_seg_img  = nib.Nifti1Image(gt_segmentation.astype(np.int16), affine, header)

nib.save(gmm_seg_img, 'gmm_segmentation.nii')
nib.save(gt_seg_img, 'ground_truth_segmentation.nii')