# Coronary Stenosis Detection - vast.ai (RTX 5090)
Python port of the MATLAB stenosis detection pipeline from SAM-VMNet.

**Pipeline:** Segmentation mask → Skeletonize → Measure radii → Find bifurcations → BFS paths → Detect V-shape narrowing → Classify severity

## 1. Setup (run once)

In [None]:
# Install deps if not already installed
!pip install -q numpy opencv-python scikit-image scipy matplotlib tqdm

In [None]:
import sys, os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path

# Add project to path
PROJECT_ROOT = Path('.').resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from stenosis_detection.stenosis_detection import run_stenosis_detection, plot_results

%matplotlib inline
plt.rcParams['figure.dpi'] = 120
print(f'Working dir: {PROJECT_ROOT}')
print('Ready!')

## 2. Set your image paths
You need two images:
1. **Original** angiography image
2. **Segmented mask** (binary vessel mask from SAM-VMNet or other segmentation)

In [None]:
# === SET YOUR PATHS HERE ===
ORIGINAL_IMG = 'data/vessel/test/images/sample.png'   # original angiography
MASK_IMG     = 'data/vessel/test/masks/sample.png'     # segmented vessel mask

# Verify files exist
for p in [ORIGINAL_IMG, MASK_IMG]:
    if os.path.exists(p):
        print(f'OK: {p}')
    else:
        print(f'MISSING: {p} -- update the path above!')

## 3. Preview inputs

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

orig = cv2.imread(ORIGINAL_IMG)
if orig is not None:
    ax1.imshow(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB))
    ax1.set_title(f'Original {orig.shape}')
else:
    ax1.text(0.5, 0.5, 'File not found', ha='center')
ax1.axis('off')

mask = cv2.imread(MASK_IMG, cv2.IMREAD_GRAYSCALE)
if mask is not None:
    ax2.imshow(mask, cmap='gray')
    ax2.set_title(f'Mask {mask.shape}')
else:
    ax2.text(0.5, 0.5, 'File not found', ha='center')
ax2.axis('off')
plt.tight_layout()
plt.show()

## 4. Run Stenosis Detection

In [None]:
%%time
print('Running stenosis detection...')
result = run_stenosis_detection(ORIGINAL_IMG, MASK_IMG)
print('Done!')

In [None]:
# Visualize all 3 stages
plot_results(result, save_path='stenosis_result.png')

## 5. Detailed results

In [None]:
pts = result['points']
degs = result['degrees']

print(f"Total stenosis points: {len(degs)}")
print(f"Skeleton points: {len(result['skeleton_coords'][0])}")
print(f"Bifurcation points: {len(result['bifurcations'])}")
print()

if len(degs) > 0:
    print(f"{'#':<4} {'Location':<20} {'Degree':<10} {'Severity'}")
    print('-' * 50)
    for i, d in enumerate(degs):
        sev = 'SEVERE (>75%)' if d > 0.75 else 'MODERATE (50-75%)' if d > 0.5 else 'MILD (25-50%)'
        print(f"{i+1:<4} ({pts[i,0]:.0f}, {pts[i,1]:.0f}){'':<10} {d*100:.1f}%{'':<5} {sev}")
else:
    print('No stenosis detected.')

## 6. Batch processing (optional)
Process multiple image pairs at once.

In [None]:
# Batch mode: set directories
IMAGE_DIR = 'data/vessel/test/images/'
MASK_DIR  = 'data/vessel/test/masks/'
OUTPUT_DIR = 'stenosis_results/'

os.makedirs(OUTPUT_DIR, exist_ok=True)

if os.path.isdir(IMAGE_DIR) and os.path.isdir(MASK_DIR):
    images = sorted([f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png','.jpg','.bmp'))])
    print(f'Found {len(images)} images')
    
    for img_name in images:
        img_path = os.path.join(IMAGE_DIR, img_name)
        mask_path = os.path.join(MASK_DIR, img_name)
        if not os.path.exists(mask_path):
            print(f'  Skip {img_name} (no mask)')
            continue
        print(f'  Processing {img_name}...')
        try:
            r = run_stenosis_detection(img_path, mask_path)
            plot_results(r, save_path=os.path.join(OUTPUT_DIR, f'result_{img_name}'))
            plt.close('all')
        except Exception as e:
            print(f'    ERROR: {e}')
    print('Batch complete!')
else:
    print(f'Directories not found. Update IMAGE_DIR and MASK_DIR above.')