In [1]:
import nibabel as nib
import numpy as np
import pandas as pd
import os
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline


## Load CT Scans and Organ Masks

Replace `'path/to/your/data'` with the directory containing your CT scans and corresponding organ masks.

In [3]:
import os

# Define the HuggingFace repository and local path
ct_filename = "ct.nii.gz"
segmentations_dir = "segmentations"
pancreas_segmentations_filename = os.path.join(segmentations_dir, "pancreas.nii.gz")

data_folder = "./data"
healthy_pancreas_data_path = os.path.join(data_folder, "healthy-pancreas")
unhealthy_pancreas_data_path = os.path.join(data_folder, "pancreatic-tumor")

os.makedirs(healthy_pancreas_data_path, exist_ok=True)
os.makedirs(unhealthy_pancreas_data_path, exist_ok=True)

In [4]:
unhealthy_pancreas_patient_ids = [
    'BDMAP_00000087',
    'BDMAP_00000093',
    'BDMAP_00000192',
    'BDMAP_00000225',
    'BDMAP_00000243',
    'BDMAP_00000324',
    'BDMAP_00000332',
    'BDMAP_00000416',
    'BDMAP_00000541',
    'BDMAP_00000696',
    'BDMAP_00000714',
    'BDMAP_00000715',
    'BDMAP_00000855',
    'BDMAP_00000940',
    'BDMAP_00000956',
    'BDMAP_00001040',
    'BDMAP_00001067',
    'BDMAP_00001096',
    'BDMAP_00001125',
    'BDMAP_00001205',
    'BDMAP_00001331',
    'BDMAP_00001461',
    'BDMAP_00001464',
    'BDMAP_00001476',
    'BDMAP_00001523',
    'BDMAP_00001564',
    'BDMAP_00001605',
    'BDMAP_00001617',
    'BDMAP_00001646',
    'BDMAP_00001649',
    'BDMAP_00001704',
    'BDMAP_00001746',
    'BDMAP_00001754',
    'BDMAP_00001823',
    'BDMAP_00001862',
    'BDMAP_00002021',
    'BDMAP_00002278',
    'BDMAP_00002298',
    'BDMAP_00002328',
    'BDMAP_00002387',
    'BDMAP_00002402',
    'BDMAP_00002616',
    'BDMAP_00002690',
    'BDMAP_00002793',
    'BDMAP_00002944',
    'BDMAP_00002945',
    'BDMAP_00003017',
    'BDMAP_00003036',
    'BDMAP_00003133',
    'BDMAP_00003141',
    'BDMAP_00003244',
    'BDMAP_00003326',
    'BDMAP_00003347',
    'BDMAP_00003427',
    'BDMAP_00003440',
    'BDMAP_00003451',
    'BDMAP_00003502',
    'BDMAP_00003551',
    'BDMAP_00003590',
    'BDMAP_00003592',
    'BDMAP_00003612',
    'BDMAP_00003658',
    'BDMAP_00003744',
    'BDMAP_00003776',
    'BDMAP_00003781',
    'BDMAP_00003812',
    'BDMAP_00004060',
    'BDMAP_00004106',
    'BDMAP_00004128',
    'BDMAP_00004229',
    'BDMAP_00004231',
    'BDMAP_00004447',
    'BDMAP_00004494',
    'BDMAP_00004511',
    'BDMAP_00004672',
    'BDMAP_00004770',
    'BDMAP_00004804',
    'BDMAP_00004847',
    'BDMAP_00004880',
    'BDMAP_00004927',
    'BDMAP_00004964',
    'BDMAP_00004969',
    'BDMAP_00004992',
    'BDMAP_00005020',
    'BDMAP_00005022',
    'BDMAP_00005070',
    'BDMAP_00005074',
    'BDMAP_00005075',
    'BDMAP_00005185'
]


healthy_pancreas_patient_ids = [
    'BDMAP_00000002',
    'BDMAP_00000110',
    'BDMAP_00000198',
    'BDMAP_00000246',
    'BDMAP_00000351',
    'BDMAP_00000598',
    'BDMAP_00000673',
    'BDMAP_00000682',
    'BDMAP_00000764',
    'BDMAP_00000846',
    'BDMAP_00000878',
    'BDMAP_00000928',
    'BDMAP_00001002',
    'BDMAP_00001348',
    'BDMAP_00001662',
    'BDMAP_00001774',
    'BDMAP_00001820',
    'BDMAP_00001871',
    'BDMAP_00001942',
    'BDMAP_00001943',
    'BDMAP_00002059',
    'BDMAP_00002212',
    'BDMAP_00002236',
    'BDMAP_00002569',
    'BDMAP_00002650',
    'BDMAP_00002753',
    'BDMAP_00002763',
    'BDMAP_00003013',
    'BDMAP_00003033',
    'BDMAP_00003085',
    'BDMAP_00003154',
    'BDMAP_00003265',
    'BDMAP_00003577',
    'BDMAP_00003644',
    'BDMAP_00003876',
    'BDMAP_00004098',
    'BDMAP_00004142',
    'BDMAP_00004202',
    'BDMAP_00004360',
    'BDMAP_00004458',
    'BDMAP_00004480',
    'BDMAP_00005110'
]

In [5]:
healthy_patient_folders = [os.path.join(healthy_pancreas_data_path, j) for j in healthy_pancreas_patient_ids]
pancreatic_tumor_folders = [os.path.join(unhealthy_pancreas_data_path, j) for j in unhealthy_pancreas_patient_ids]

# Load 1st image in dataset
def get_patient_ct_scan_path(patient_id, patient_health_status="healthy"):
    if patient_health_status == "healthy":
        return os.path.join(healthy_pancreas_data_path, patient_id, ct_filename)
    else:
        return os.path.join(unhealthy_pancreas_data_path, patient_id, ct_filename)

def get_patient_segmentations_path(patient_id, patient_health_status="healthy"):
    if patient_health_status == "healthy":
        return os.path.join(healthy_pancreas_data_path, patient_id, segmentations_dir)
    else:
        return os.path.join(unhealthy_pancreas_data_path, patient_id, segmentations_dir)

def get_patient_pancreas_segmentation_path(patient_id, patient_health_status="healthy"):
    if patient_health_status == "healthy":
        return os.path.join(healthy_pancreas_data_path, patient_id, pancreas_segmentations_filename)
    else:
        return os.path.join(unhealthy_pancreas_data_path, patient_id, pancreas_segmentations_filename)

first_ct_scan_path = os.path.join(unhealthy_pancreas_data_path, unhealthy_pancreas_patient_ids[0], ct_filename)
first_healthy_segmentations_path = os.path.join(healthy_pancreas_data_path, healthy_pancreas_patient_ids[0], segmentations_dir)
first_pancreas_segmentation_path = os.path.join(unhealthy_pancreas_data_path, unhealthy_pancreas_patient_ids[0], pancreas_segmentations_filename)

In [None]:
from tqdm import tqdm
import nibabel as nib

def load_imgs_in_folders_list(folders_list: list):
    """
    Load CT scans from a list of folders.

    Args:
        folders_list (list): List of folders containing the CT scans. Each folder should contain a CT scan in NIfTI format.
    
    Returns:
        list: List of loaded CT scans in NIfTI format.
    """
    cts = []
    for path in tqdm(folders_list, desc="Encoding CT Scans"):
        cts.append(nib.load(os.path.join(path, ct_filename)))

    return cts

healthy_cts = load_imgs_in_folders_list(healthy_patient_folders)
pancreatic_tumor_cts = load_imgs_in_folders_list(pancreatic_tumor_folders)

In [None]:
# Replace with your data directory
data_dir = '/path/to/your/data'

# List of CT scan files
ct_files = [f for f in os.listdir(data_dir) if f.endswith('.nii.gz') and 'CT' in f]

# Lists to hold the data
ct_scans = []
organ_masks = []

for ct_file in ct_files:
    ct_path = os.path.join(data_dir, ct_file)
    organ_mask_path = ct_path.replace('CT', 'Mask')  # Adjust according to your file naming convention

    # Load the CT scan
    ct_img = nib.load(ct_path)
    ct_data = ct_img.get_fdata()
    ct_scans.append(ct_data)

    # Load the organ mask
    if os.path.exists(organ_mask_path):
        mask_img = nib.load(organ_mask_path)
        mask_data = mask_img.get_fdata()
        organ_masks.append(mask_data)
    else:
        print(f"Mask file not found for {ct_file}")
        organ_masks.append(None)


## Assign Tissue Types Based on Hounsfield Units

We use the following HU ranges to assign tissue types:
- **Air**: HU ≤ -190
- **Fat**: -190 < HU ≤ -30
- **Soft Tissue**: -30 < HU ≤ 70
- **High-Density Soft Tissue**: 70 < HU ≤ 150
- **Bone**: HU > 150

In [None]:
# Function to assign tissue types based on HU values
def assign_tissue_type(ct_data):
    tissue_types = np.zeros(ct_data.shape, dtype=np.uint8)

    tissue_types[ct_data <= -190] = 0  # Air or low-density
    tissue_types[(ct_data > -190) & (ct_data <= -30)] = 1  # Fat
    tissue_types[(ct_data > -30) & (ct_data <= 70)] = 2  # Soft Tissue
    tissue_types[(ct_data > 70) & (ct_data <= 150)] = 3  # High-Density Soft Tissue
    tissue_types[ct_data > 150] = 4  # Bone

    return tissue_types

# Map tissue types to labels
tissue_labels = {
    0: 'Air',
    1: 'Fat',
    2: 'Soft Tissue',
    3: 'High-Density Soft Tissue',
    4: 'Bone'
}

# Assign tissue types
tissue_type_maps = []

for ct_data in ct_scans:
    tissue_map = assign_tissue_type(ct_data)
    tissue_type_maps.append(tissue_map)


## Extract VQ-GAN Encodings

We extract patches from the CT scans and obtain their VQ-GAN encodings. For this demonstration, we use a placeholder function for the VQ-GAN encoding. Replace `encode_vqgan` with your actual VQ-GAN model inference code.

In [None]:
# Placeholder function for VQ-GAN encoding
def encode_vqgan(ct_patch):
    # This function should be replaced with actual VQ-GAN encoding
    # For demonstration, we'll use random values
    encoding = np.random.rand(24, 24, 24, 8)
    return encoding.flatten()

# Function to extract patches and their corresponding tissue labels
def extract_patches_and_labels(ct_data, tissue_map, patch_size=96, stride=48):
    patches = []
    labels = []
    x_max, y_max, z_max = ct_data.shape
    for x in range(0, x_max - patch_size + 1, stride):
        for y in range(0, y_max - patch_size + 1, stride):
            for z in range(0, z_max - patch_size + 1, stride):
                ct_patch = ct_data[x:x+patch_size, y:y+patch_size, z:z+patch_size]
                tissue_patch = tissue_map[x:x+patch_size, y:y+patch_size, z:z+patch_size]

                # Get the majority tissue type in the patch
                unique, counts = np.unique(tissue_patch, return_counts=True)
                majority_tissue = unique[np.argmax(counts)]

                # Encode the patch
                encoding = encode_vqgan(ct_patch)

                patches.append(encoding)
                labels.append(majority_tissue)

    return patches, labels

# Extract encodings and labels
all_encodings = []
all_labels = []

for ct_data, tissue_map in zip(ct_scans, tissue_type_maps):
    encodings, labels = extract_patches_and_labels(ct_data, tissue_map)
    all_encodings.extend(encodings)
    all_labels.extend(labels)


## Perform Clustering on Encodings

We use KMeans clustering to cluster the encodings. The number of clusters is set to the number of tissue types.

In [None]:
from sklearn.preprocessing import StandardScaler

# Convert encodings and labels to numpy arrays
X = np.array(all_encodings)
y_true = np.array(all_labels)

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Perform clustering
n_clusters = len(tissue_labels)

kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(X_scaled)
y_pred = kmeans.labels_


## Evaluate Clustering

We evaluate the clustering using Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI). We also display a confusion matrix.

In [None]:
# Evaluate clustering
ari = adjusted_rand_score(y_true, y_pred)
nmi = normalized_mutual_info_score(y_true, y_pred)

print(f"Adjusted Rand Index (ARI): {ari:.4f}")
print(f"Normalized Mutual Information (NMI): {nmi:.4f}")

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=list(tissue_labels.values()), yticklabels=list(tissue_labels.values()))
plt.xlabel('Predicted Tissue Type')
plt.ylabel('True Tissue Type')
plt.title('Confusion Matrix')
plt.show()
