**Install numpy==1.23.5 the code works with this version!**

In [None]:
!pip install numpy==1.23.5 --force-reinstall


**Pip Install Libs and check**

In [None]:
!pip install tensorboardX
!pip install SimpleITK
!pip install nibabel

# Reinstall packages (if not persisted)
!pip install pyvista==0.36.1
!pip install trimesh==3.12.6
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install torch-geometric
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git'

import torch
import torch_geometric
import pyvista
import trimesh
import pytorch3d

print(" All modules are working after restart.")


**Mount the Google Drive, My dataset is in drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')


**Clone the Repo; it contains code, Models etc..,**


**Note: you must create the data.json and link it with the dataset.py**


In [None]:
import os
import shutil


%cd /content

#  Remove the repo if it exists
shutil.rmtree("/content/Geo_Seg", ignore_errors=True)

#  Clone repo from scratch
!git clone https://github.com/AliakbarMzadeh/Geo_Seg.git

#  Now CD into it
%cd /content/Geo_Seg


**This model needs a CSV file, you must genarate it with Meshlab or this code, it used for GCN blocks and actuaalt converts the .stl format to triangle mesh based**

In [None]:
import trimesh
import pandas as pd
import os
import numpy as np

# Verified STL directory
stl_dir = "/content/drive/MyDrive/Team_Internship_Dataset/Normal/SurfaceMeshes"

# Output path for converted dataset
out_dir = "/content/converted_dataset/CoronaryArtery"
os.makedirs(out_dir, exist_ok=True)

for i in range(1, 21):
    case_name = f"case{i-1}"
    case_path = os.path.join(out_dir, case_name)
    os.makedirs(case_path, exist_ok=True)

    # Load STL and sample mesh
    mesh_path = os.path.join(stl_dir, f"Normal_{i}.stl")
    mesh = trimesh.load(mesh_path)

    points = mesh.sample(3000)
    pd.DataFrame(points).to_csv(os.path.join(case_path, "mesh.csv"), index=False, header=False, sep=" ")
    print(f"Saved: {case_name}/mesh.csv")


**push the data from drive to colab**

In [None]:
import shutil
import os

image_dir = "/content/drive/MyDrive/Team_Internship_Dataset/Normal/CTCA"
label_dir = "/content/drive/MyDrive/Team_Internship_Dataset/Normal/Annotations"
out_dir = "/content/converted_dataset/CoronaryArtery"

for i in range(1, 21):
    case_dir = os.path.join(out_dir, f"case{i-1}")
    os.makedirs(case_dir, exist_ok=True)

    # Copy as .nrrd (do NOT rename to .nii.gz)
    shutil.copy(os.path.join(image_dir, f"Normal_{i}.nrrd"), os.path.join(case_dir, "image.nrrd"))
    shutil.copy(os.path.join(label_dir, f"Normal_{i}.nrrd"), os.path.join(case_dir, "label.nrrd"))

    print(f"Copied Normal_{i} as .nrrd to {case_dir}")


**NOTE: in this dataset the number of Z depth is not same for samples, so we must crop them because our batch size is more than 1**

In [None]:
import SimpleITK as sitk
import os
import json
import numpy as np
import shutil


original_dir = "/content/converted_dataset/CoronaryArtery"
corrected_dir = "/content/converted_dataset/CoronaryArtery_fixed"
target_depth = 160

os.makedirs(corrected_dir, exist_ok=True)
inconsistent_cases = []

def crop_or_pad(volume, target_depth):
    d, h, w = volume.shape
    if d > target_depth:
        start = (d - target_depth) // 2
        return volume[start:start+target_depth]
    elif d < target_depth:
        pad_before = (target_depth - d) // 2
        pad_after = target_depth - d - pad_before
        return np.pad(volume, ((pad_before, pad_after), (0, 0), (0, 0)), mode='constant')
    else:
        return volume

for i in range(20):
    case = f"case{i}"
    case_path = os.path.join(original_dir, case)
    img_path = os.path.join(case_path, "image.nrrd")
    lbl_path = os.path.join(case_path, "label.nrrd")

    img = sitk.ReadImage(img_path)
    lbl = sitk.ReadImage(lbl_path)

    img_np = sitk.GetArrayFromImage(img)
    lbl_np = sitk.GetArrayFromImage(lbl)

    if img_np.shape[0] != target_depth or lbl_np.shape[0] != target_depth:
        inconsistent_cases.append(case)

        # Crop/pad and save
        img_fixed = crop_or_pad(img_np, target_depth)
        lbl_fixed = crop_or_pad(lbl_np, target_depth)

        # Save
        case_fixed_path = os.path.join(corrected_dir, case)
        os.makedirs(case_fixed_path, exist_ok=True)
        sitk.WriteImage(sitk.GetImageFromArray(img_fixed), os.path.join(case_fixed_path, "image.nrrd"))
        sitk.WriteImage(sitk.GetImageFromArray(lbl_fixed), os.path.join(case_fixed_path, "label.nrrd"))

        # Also copy mesh
        mesh_src = os.path.join(case_path, "mesh.csv")
        mesh_dst = os.path.join(case_fixed_path, "mesh.csv")
        shutil.copy(mesh_src, mesh_dst)

if len(inconsistent_cases) == 0:
    print("All images and labels have consistent Z-depth.")
    final_data_path = original_dir
else:
    print("Found inconsistent Z-depths in cases:", inconsistent_cases)
    print(f" Saved corrected files to: {corrected_dir}")
    final_data_path = corrected_dir

# Rewrite data.json to match the final path
print(" Writing new data.json...")
data_json = {}
base_path = os.path.basename(final_data_path)

for i in range(20):
    case = f"case{i}"
    data_json[case] = [{
        "image": f"{base_path}/{case}/image.nrrd",
        "label": f"{base_path}/{case}/label.nrrd",
        "verts": f"{base_path}/{case}/mesh.csv"
    }]

with open("/content/converted_dataset/data.json", "w") as f:
    json.dump(data_json, f, indent=2)

print("Final data.json updated.")


**Reduce the Batch size, num_worker, total_epoches because the data is 3D medical image and the U-net model is too dense, it needs heavy RAM and GPU**

**The paper mentioned we need: NVIDIA A100 (80GB) GPU**

In [None]:
import yaml

yaml_path = "/content/Geo_Seg/config/config-s1-train.yaml"

with open(yaml_path, 'r') as f:
    config = yaml.safe_load(f)

# Patch dataset loading
config['dataset']['batch_size'] = 1          # smaller batch to reduce memory
config['dataset']['num_worker'] = 0          # prevent worker overload

#  Patch training loop
config['trainer']['total_epoches'] = 2      # for test runs
config['trainer']['current_epoch'] = 1


with open(yaml_path, 'w') as f:
    yaml.dump(config, f)

print(" Patched config-s1-train.yaml: batch_size=1, num_worker=0, epochs=1")

**Run the Stage 1**

In [None]:
%cd /content/Geo_Seg
!python3 train.py -c ./config/config-s1-train.yaml | tee train_stage1_output.log


**Plots the info after Stage 1**

In [None]:
import re
import matplotlib.pyplot as plt

# Read the log file
log_path = "/content/Geo_Seg/train_stage1_output.log"
with open(log_path, "r") as f:
    lines = f.readlines()

# Initialize metric lists
train_epochs, train_dice, train_chamfer, train_loss = [], [], [], []
eval_dict = {}

for line in lines:
    if "[TRAIN][Epoch" in line:
        match = re.search(r"\[TRAIN\]\[Epoch (\d+)\] Dice: ([0-9.]+) \| Chamfer: ([0-9.eE+-]+)", line)
        if match:
            epoch = int(match.group(1))
            train_epochs.append(epoch)
            train_dice.append(float(match.group(2)))
            train_chamfer.append(float(match.group(3)))
    elif "[EVAL][Epoch" in line:
        match = re.search(r"\[EVAL\]\[Epoch (\d+)\] Dice: ([0-9.]+) \| Chamfer: ([0-9.eE+-]+)", line)
        if match:
            epoch = int(match.group(1))
            eval_dict[epoch] = (float(match.group(2)), float(match.group(3)))
    elif "finished ! Loss:" in line:
        match = re.search(r"Epoch(\d+).*Loss: ([0-9.]+)", line)
        if match:
            train_loss.append(float(match.group(2)))

# Align eval metrics with train epochs
eval_epochs, eval_dice, eval_chamfer = [], [], []
for ep in train_epochs:
    if ep in eval_dict:
        eval_epochs.append(ep)
        eval_dice.append(eval_dict[ep][0])
        eval_chamfer.append(eval_dict[ep][1])

# === Plots === #

# Training Loss
plt.figure(figsize=(6, 4))
plt.plot(train_epochs, train_loss, marker='o', color='blue')
plt.title("Training Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# Training Dice
plt.figure(figsize=(6, 4))
plt.plot(train_epochs, train_dice, marker='o', color='blue')
plt.title("Training Dice per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.grid(True)
plt.show()

# Evaluation Dice
plt.figure(figsize=(6, 4))
plt.plot(eval_epochs, eval_dice, marker='o', color='orange')
plt.title("Evaluation Dice per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.grid(True)
plt.show()

# Training Chamfer
plt.figure(figsize=(6, 4))
plt.plot(train_epochs, train_chamfer, marker='o', color='blue')
plt.title("Training Chamfer Distance per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Chamfer Distance")
plt.grid(True)
plt.show()

# Evaluation Chamfer
plt.figure(figsize=(6, 4))
plt.plot(eval_epochs, eval_chamfer, marker='o', color='orange')
plt.title("Evaluation Chamfer Distance per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Chamfer Distance")
plt.grid(True)
plt.show()


In [None]:
!grep "Epoch" train_stage1_output.log
!grep "dice" train_stage1_output.log
!grep "chamfer distance" train_stage1_output.log

**Stage 2**

**Reduce the Batch, num_worker, total_epoches like stage 1**

In [None]:

import yaml

yaml_path = "/content/Geo_Seg/config/config-s2-train.yaml"

with open(yaml_path, "r") as f:
    config = yaml.safe_load(f)

# Patch settings for Colab and single-sample validation
config['dataset']['batch_size'] = 1
config['dataset']['num_worker'] = 0
config['trainer']['total_epoches'] = 2
config['trainer']['current_epoch'] = 1

with open(yaml_path, "w") as f:
    yaml.dump(config, f)

print(" Patched config-s2-train.yaml for minimal test.")


**Run Stage 2**

In [None]:
%cd /content/Geometry_Segmentation_for_Coronary_Artery
!python3 train.py -c ./config/config-s2-train.yaml | tee train_stage2_output.log


**Plot the info**

In [None]:
import re
import matplotlib.pyplot as plt

# Read the log file for stage 2
log_path = "/content/Geometry_Segmentation_for_Coronary_Artery/train_stage2_output.log"
with open(log_path, "r") as f:
    lines = f.readlines()

# Initialize metric lists
train_epochs, train_dice, train_chamfer, train_loss = [], [], [], []
eval_dict = {}

# Parse training and evaluation metrics
for line in lines:
    if "[TRAIN][Epoch" in line:
        match = re.search(r"\[TRAIN\]\[Epoch (\d+)\] Dice: ([0-9.]+) \| Chamfer: ([0-9.eE+-]+)", line)
        if match:
            epoch = int(match.group(1))
            train_epochs.append(epoch)
            train_dice.append(float(match.group(2)))
            train_chamfer.append(float(match.group(3)))
    elif "[EVAL][Epoch" in line:
        match = re.search(r"\[EVAL\]\[Epoch (\d+)\] Dice: ([0-9.]+) \| Chamfer: ([0-9.eE+-]+)", line)
        if match:
            epoch = int(match.group(1))
            eval_dict[epoch] = (float(match.group(2)), float(match.group(3)))
    elif "finished ! Loss:" in line:
        match = re.search(r"Epoch(\d+).*Loss: ([0-9.]+)", line)
        if match:
            train_loss.append(float(match.group(2)))

# Align eval metrics with train epochs
eval_epochs, eval_dice, eval_chamfer = [], [], []
for ep in train_epochs:
    if ep in eval_dict:
        eval_epochs.append(ep)
        eval_dice.append(eval_dict[ep][0])
        eval_chamfer.append(eval_dict[ep][1])

# Define plotting function
def plot_metric(x, y, title, ylabel, color):
    plt.figure(figsize=(6, 4))
    plt.plot(x, y, marker='o', color=color)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.show()

# Plot each metric separately
plot_metric(train_epochs, train_loss, "Stage 2 - Training Loss per Epoch", "Loss", "blue")
plot_metric(train_epochs, train_dice, "Stage 2 - Training Dice per Epoch", "Dice", "blue")
plot_metric(eval_epochs, eval_dice, "Stage 2 - Evaluation Dice per Epoch", "Dice", "orange")
plot_metric(train_epochs, train_chamfer, "Stage 2 - Training Chamfer Distance per Epoch", "Chamfer Distance", "blue")
plot_metric(eval_epochs, eval_chamfer, "Stage 2 - Evaluation Chamfer Distance per Epoch", "Chamfer Distance", "orange")


**Prediction**

**Pass the weight to the prediction part and reduce the parameters**

In [None]:
import yaml

cfg_path = "/content/Geometry_Segmentation_for_Coronary_Artery/config/config-predict.yaml"
ckpt_path = "./checkpoints/Tag-GeometrySegmentation-CoronaryArtery-s2-latest-checkpoint.pth"  # ✅ Safe file

with open(cfg_path, "r") as f:
    cfg = yaml.safe_load(f)

#  Patch checkpoint paths
cfg['network']['modules']['Unet']['cur_params'] = ckpt_path
cfg['network']['modules']['Gseg']['cur_params'] = ckpt_path

#  Reduce batch size and number of workers for CPU
cfg['dataset']['batch_size'] = 2
cfg['dataset']['num_worker'] = 0
cfg['dataset']['is_shuffle'] = False  #  makes prediction reproducible

with open(cfg_path, "w") as f:
    yaml.dump(cfg, f)

print(f" Prediction config updated to use: {ckpt_path} with CPU-safe settings.")



**Run Prediction**

In [None]:
%cd /content/Geo_Seg
!python3 predict.py -c ./config/config-predict.yaml

In [None]:
from google.colab import files
files.download("/content/Geo_Seg/predict/right-predict-0.stl")
files.download("/content/Geo_Seg/CoronaryArtery-pointcloud.xyz")