<a href="https://colab.research.google.com/github/Soutrik-Chakraborty/cardiac-modeling/blob/main/Untitled4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install nibabel scikit-image trimesh vedo

Collecting trimesh
  Downloading trimesh-4.7.1-py3-none-any.whl.metadata (18 kB)
Collecting vedo
  Downloading vedo-2025.5.4-py3-none-any.whl.metadata (14 kB)
Collecting vtk (from vedo)
  Downloading vtk-9.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.5 kB)
Downloading trimesh-4.7.1-py3-none-any.whl (709 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m709.0/709.0 kB[0m [31m46.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading vedo-2025.5.4-py3-none-any.whl (2.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m105.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading vtk-9.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (112.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.1/112.1 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh, vtk, vedo
Successfully installed trimesh-4.7.1 vedo-2025.5.4 vtk-9.5.0


In [3]:
import os, glob, zipfile
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import nibabel as nib  # for .nii images
import numpy as np
from scipy.ndimage import zoom
from tkinter import Tk, filedialog
import shutil
import pandas as pd
from skimage import measure
import trimesh
from vedo import Plotter, Mesh
import matplotlib.pyplot as plt

In [4]:
from google.colab import files
uploaded = files.upload()  # Then select `cropped.zip` manually

Saving cropped.zip to cropped.zip
Saving hvsmr_clinical.csv to hvsmr_clinical.csv
Saving hvsmr_technical.csv to hvsmr_technical.csv


In [5]:
def unzip_local_zip():
    # Define the path to the uploaded ZIP file
    zip_path = "/content/cropped.zip"  # <-- Change this if the file name or path is different /content/hvsmr2_0.zip

    if not os.path.isfile(zip_path):
        raise FileNotFoundError(f"Zip file not found at: {zip_path}")

    # Define output directory for extracted files
    output_dir = os.path.join(os.path.dirname(zip_path), "cropped_unzipped")
    os.makedirs(output_dir, exist_ok=True)

    # Extract the ZIP file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(output_dir)

    print(f"Extracted dataset to: {output_dir}")

    # Print the contents of the extracted directory for verification
    print("\nContents of extracted directory:")
    for root, dirs, files in os.walk(output_dir):
        level = root.replace(output_dir, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print('{}{}'.format(subindent, f))

    return output_dir


In [6]:
output_dir = unzip_local_zip()

Extracted dataset to: /content/cropped_unzipped

Contents of extracted directory:
cropped_unzipped/
    cropped/
        pat45_cropped_seg_endpoints.nii.gz
        pat2_cropped_seg_endpoints.nii.gz
        pat39_cropped_seg_endpoints.nii.gz
        pat4_cropped_seg.nii.gz
        pat4_cropped.nii.gz
        pat21_cropped.nii.gz
        pat48_cropped_seg_endpoints.nii.gz
        pat42_cropped_seg_endpoints.nii.gz
        pat15_cropped_seg.nii.gz
        pat34_cropped_seg_endpoints.nii.gz
        pat15_cropped.nii.gz
        pat57_cropped_seg_endpoints.nii.gz
        pat9_cropped_seg.nii.gz
        pat39_cropped.nii.gz
        pat20_cropped_seg.nii.gz
        pat59_cropped_seg_endpoints.nii.gz
        pat48_cropped_seg.nii.gz
        pat3_cropped.nii.gz
        pat10_cropped.nii.gz
        pat0_cropped_seg_endpoints.nii.gz
        pat13_cropped_seg.nii.gz
        pat20_cropped_seg_endpoints.nii.gz
        pat38_cropped_seg_endpoints.nii.gz
        pat19_cropped_seg_endpoints.nii.gz
     

In [7]:
def convert_nii_gz_to_nii(input_path, output_path=None):

    if not input_path.endswith(".nii.gz"):
        raise ValueError("Input file must be a .nii.gz file")

    # Set default output path
    if output_path is None:
        output_path = input_path[:-3]  # remove .gz extension

    # Load the image
    img = nib.load(input_path)

    # Save it in uncompressed .nii format
    nib.save(img, output_path)

    print(f"Converted {input_path} -> {output_path}")
    os.remove(input_path)
    print(f"Deleted original: {input_path}")
    return output_path


In [8]:
def batch_convert_and_delete(folder_path):
    for filename in os.listdir(folder_path):
        if filename.endswith(".nii.gz"):
            full_path = os.path.join(folder_path, filename)
            convert_nii_gz_to_nii(full_path)

In [9]:
class HVSMR3DDataset(Dataset):
    def __init__(self, root, split="train"):
        all_files = sorted(glob.glob(os.path.join(root, "*.nii")))

        images = [f for f in all_files if "_seg" not in f]
        masks  = [f for f in all_files if "_seg" in f]

        self.pairs = []
        image_dict = {os.path.basename(img).replace(".nii", ""): img for img in images}
        mask_dict = {os.path.basename(mask).replace("_seg.nii", ""): mask for mask in masks}

        split_ratio = 0.8 # 80% for training, 20% for validation
        split_index = int(len(image_dict) * split_ratio)

        paired_keys = sorted(list(image_dict.keys() & mask_dict.keys()))

        if split == "train":
            paired_keys = paired_keys[:split_index]
        elif split == "val":
            paired_keys = paired_keys[split_index:]

        for key in paired_keys:
            self.pairs.append((image_dict[key], mask_dict[key]))

        if not self.pairs:
             raise ValueError(f"No image/mask pairs found for split '{split}' in {root}. Please check the file structure and naming convention.")


    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        img = nib.load(self.pairs[idx][0]).get_fdata().astype(np.float32)
        msk = nib.load(self.pairs[idx][1]).get_fdata().astype(np.int64)
        target_shape = (128, 128, 128)
        img = zoom(img, np.array(target_shape)/img.shape, order=1)
        msk = zoom(msk, np.array(target_shape)/msk.shape, order=0)

        if msk.shape != target_shape:
            resized_msk = np.zeros(target_shape, dtype=msk.dtype)
            copy_shape = tuple(min(s1, s2) for s1, s2 in zip(msk.shape, target_shape))
            slices = tuple(slice(0, s) for s in copy_shape)
            resized_msk[slices] = msk[slices]
            msk = resized_msk

        img = (img - img.mean()) / img.std()
        return torch.unsqueeze(torch.from_numpy(img),0), torch.from_numpy(msk)


In [10]:
class UNet3D(nn.Module):
    def __init__(self, in_ch=1, out_ch=9):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv3d(in_ch,16,3,padding=1), nn.ReLU(),
            nn.Conv3d(16,32,3,padding=1), nn.ReLU())
        self.pool = nn.MaxPool3d(2)
        self.dec = nn.Sequential(
            nn.ConvTranspose3d(32,16,2,stride=2), nn.ReLU(),
            nn.Conv3d(16,out_ch,1))
    def forward(self, x):
        x1 = self.enc(x)
        x2 = self.pool(x1)
        x3 = self.dec(x2)
        return x3


In [11]:
def train_model(root):
    train_ds = HVSMR3DDataset("/content/patients_0_to_5", "train")
    val_ds   = HVSMR3DDataset("/content/patients_0_to_5", "val")
    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)
    model = UNet3D()
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(10):
        model.train()
        for img, msk in train_loader:
            img, msk = img.to(device), msk.to(device)
            pred = model(img)
            loss = loss_fn(pred, msk)
            optim.zero_grad(); loss.backward(); optim.step()

        # Validation
        model.eval()
        dice_sum, count = 0, 0
        with torch.no_grad():
            for img, msk in val_loader:
                img, msk = img.to(device), msk.to(device)
                pred = torch.argmax(model(img), dim=1)
                dice = (2*(pred==msk).float().sum()) / (pred.numel() + msk.numel())
                dice_sum += dice.item()
                count += 1
        print(f"Epoch {epoch}  Val Dice: {dice_sum/count:.4f}")

    torch.save(model.state_dict(), "hvsmr3d_unet.pth")

def dice_score(pred, true):
    intersection = np.sum((pred == true) & (true > 0))
    return 2. * intersection / (np.sum(pred > 0) + np.sum(true > 0) + 1e-5)

def jaccard_index(pred, true):
    return jaccard_score(true.flatten(), pred.flatten(), average='macro')

def hausdorff(pred, true):
    coords1 = np.argwhere(pred > 0)
    coords2 = np.argwhere(true > 0)
    return max(directed_hausdorff(coords1, coords2)[0],
               directed_hausdorff(coords2, coords1)[0])

def show_3d_slice(pred, true, slice_id=64):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("Ground Truth")
    plt.imshow(true[slice_id], cmap='gray')
    plt.subplot(1, 3, 2)
    plt.title("Prediction")
    plt.imshow(pred[slice_id], cmap='gray')
    plt.subplot(1, 3, 3)
    plt.title("Overlay")
    plt.imshow(true[slice_id], cmap='Reds', alpha=0.5)
    plt.imshow(pred[slice_id], cmap='Blues', alpha=0.5)
    plt.show()

def evaluate_model(model_path, dataset_root):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet3D()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    dataset = HVSMR3DDataset("/content/patients_0_to_5", split="val")
    loader = DataLoader(dataset, batch_size=1)

    dice_total, jacc_total, hd_total = 0, 0, 0
    count = 0

    pred_output_dir = os.path.join(dataset_root, "predictions")
    os.makedirs(pred_output_dir, exist_ok=True)

    for img, msk in loader:
        img, msk = img.to(device), msk.to(device)
        with torch.no_grad():
            output = model(img)
            pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
            true = msk.squeeze().cpu().numpy()

        original_filename = f"pred_{count:03d}.nii"
        out_path = os.path.join(pred_output_dir, original_filename)
        out_nifti = nib.Nifti1Image(pred.astype(np.uint8), affine=np.eye(4))
        nib.save(out_nifti, out_path)
        print(f"Saved prediction to: {out_path}")

        dsc = dice_score(pred, true)
        # jsc = jaccard_index(pred, true) # jaccard_score is not defined in the notebook.
        # hd  = hausdorff(pred, true) # directed_hausdorff is not defined in the notebook.

        dice_total += dsc
        # jacc_total += jsc
        # hd_total += hd
        count += 1

    print("\n=== Overall Evaluation ===")
    print(f"Average Dice: {dice_total/count:.4f}")
    # print(f"Average Jaccard: {jacc_total/count:.4f}")
    # print(f"Average Hausdorff Distance: {hd_total/count:.2f}")


In [14]:
def generate_colored_cardiac_model(nii_path, save_dir, visualize=True):
    # Load NIfTI
    img = nib.load(nii_path)
    volume = img.get_fdata()

    # --- Mapping from labels to chambers and colors ---
    # !! This is an assumed mapping. You may need to verify and adjust it. !!
    label_to_chamber = {
        1: "Left Ventricle",
        2: "Right Ventricle",
        3: "Left Atrium",
        4: "Right Atrium"
    }

    chamber_to_color = {
        "Left Ventricle": "red",
        "Right Ventricle": "blue",
        "Left Atrium": "red",
        "Right Atrium": "blue"
    }

    # Create a plotter for visualization
    plt = Plotter(title="Color-Coded 3D Cardiac Model")
    all_meshes = []

    for label, chamber_name in label_to_chamber.items():
        # Extract binary mask of target label
        binary = (volume == label).astype(np.uint8)

        # Check if the binary volume contains the label before marching cubes
        if not np.any(binary):
            print(f"Warning: No voxels with label {label} ({chamber_name}) found in {nii_path}.")
            continue

        # Run marching cubes
        verts, faces, normals, _ = measure.marching_cubes(binary, level=0)

        # Create a vedo mesh and color it
        mesh = Mesh([verts, faces])
        color = chamber_to_color.get(chamber_name, "gray") # Default to gray if not specified
        mesh.color(color).opacity(0.7)

        all_meshes.append(mesh)

        # Save individual chamber mesh if needed
        filename = os.path.splitext(os.path.basename(nii_path))[0]
        save_path = os.path.join(save_dir, f"{filename}_{chamber_name.replace(' ', '_')}.ply")
        mesh.write(save_path)
        print(f"Saved {chamber_name} mesh to: {save_path}")


    if visualize and all_meshes:
        print("Visualizing combined colored model...")
        plt.show(all_meshes, chamber_name)

    return all_meshes


In [15]:
if __name__ == "__main__":
    # 1. Unzip the dataset
    dataset_root = unzip_local_zip()

    # 2. Convert .nii.gz to .nii
    batch_convert_and_delete(os.path.join(dataset_root, "cropped"))

    # 3. Filter patients by age (0-5 years)
    csv_path = "hvsmr_clinical.csv" # Assuming the clinical data file is in the root directory
    df = pd.read_csv(csv_path)

    source_dir = os.path.join(dataset_root, "cropped")
    age_0to5_dir = "/content/patients_0_to_5"

    os.makedirs(age_0to5_dir, exist_ok=True)

    copied = 0
    for _, row in df.iterrows():
        patient_id = f"pat{row['Pat']}"
        age = row["Age"]

        if 0 <= age <= 5:
            for file in os.listdir(source_dir):
                if file.startswith(patient_id) and file.endswith(".nii"):
                    src = os.path.join(source_dir, file)
                    dst = os.path.join(age_0to5_dir, file)
                    shutil.copy2(src, dst)
                    copied += 1

    print(f"✅ Done! {copied} files successfully copied for patients aged 0-5.")

    # 4. Train the model
    train_model(age_0to5_dir)

    # 5. Evaluate the model
    evaluate_model("hvsmr3d_unet.pth", age_0to5_dir)

    # 6. Generate 3D mesh from a sample prediction
    pred_folder = os.path.join(age_0to5_dir, "predictions")
    nii_files = glob.glob(os.path.join(pred_folder, "*.nii"))

    if nii_files:
      for nii_path in nii_files:
          print(f"Processing: {nii_path}")

          filename = os.path.splitext(os.path.basename(nii_path))[0]
          save_mesh_path = f"/content/{filename}_label1.ply"

          generate_3d_mesh_from_nii(nii_path, label=1, save_path=save_mesh_path)
    else:
      print("No prediction files found to generate meshes.")


Extracted dataset to: /content/cropped_unzipped

Contents of extracted directory:
cropped_unzipped/
    cropped/
        pat45_cropped_seg_endpoints.nii.gz
        pat15_cropped.nii
        pat33_cropped_seg.nii
        pat0_cropped.nii
        pat2_cropped_seg_endpoints.nii.gz
        pat1_cropped_seg_endpoints.nii
        pat39_cropped_seg_endpoints.nii.gz
        pat4_cropped_seg.nii.gz
        pat4_cropped_seg.nii
        pat23_cropped_seg.nii
        pat4_cropped.nii.gz
        pat9_cropped_seg_endpoints.nii
        pat6_cropped_seg.nii
        pat21_cropped.nii.gz
        pat26_cropped_seg_endpoints.nii
        pat48_cropped_seg_endpoints.nii.gz
        pat10_cropped_seg_endpoints.nii
        pat42_cropped_seg_endpoints.nii.gz
        pat1_cropped_seg.nii
        pat15_cropped_seg.nii.gz
        pat8_cropped.nii
        pat58_cropped_seg_endpoints.nii
        pat35_cropped.nii
        pat34_cropped_seg_endpoints.nii.gz
        pat15_cropped.nii.gz
        pat57_cropped_seg_endpoi