In [None]:
import os
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
import math
import gzip
import shutil
import tempfile
from skimage.filters import sobel
from skimage.feature import canny

# Helper functions
def handle_compressed_file(filepath):
    """Handle both .nii.gz and .niigz files"""
    if filepath.endswith('.niigz'):
        with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file:
            with gzip.open(filepath, 'rb') as f_in:
                shutil.copyfileobj(f_in, tmp_file)
            return tmp_file.name
    return filepath

def rotate_image_clockwise(image, angle=90):
    """Rotate image clockwise with better boundary handling"""
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, -angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE)
    return rotated

def enhanced_edge_detection(seg_binary):
    """Improved edge detection using multiple methods"""
    seg_binary = (seg_binary > 0).astype(np.uint8) * 255
    edges_canny = canny(seg_binary/255.0, sigma=1.5)
    edges_sobel = sobel(seg_binary/255.0) > 0.1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    gradient = cv2.morphologyEx(seg_binary, cv2.MORPH_GRADIENT, kernel)
    combined_edges = (edges_canny | edges_sobel) | (gradient > 0)
    combined_edges = combined_edges.astype(np.uint8) * 255
    kernel = np.ones((3,3), np.uint8)
    combined_edges = cv2.morphologyEx(combined_edges, cv2.MORPH_CLOSE, kernel)
    return combined_edges

def detect_landmarks(seg_binary, delta_y_fraction=0.15, y_tolerance=20, curvature_threshold=0.2, min_x_separation=0.2):
    """Enhanced landmark detection combining view-specific and flexible logic."""
    seg_binary = (seg_binary > 0).astype(np.uint8) * 255
    contours, _ = cv2.findContours(seg_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not contours:
        raise ValueError("No LV contour found")
    contour = max(contours, key=cv2.contourArea).squeeze()
    x = contour[:, 0].astype(float)
    y = contour[:, 1].astype(float)
    img_width = seg_binary.shape[1]
    x_smooth = gaussian_filter1d(x, sigma=5, mode='wrap')
    y_smooth = gaussian_filter1d(y, sigma=5, mode='wrap')
    dx = np.gradient(x_smooth)
    dy = np.gradient(y_smooth)
    ddx = np.gradient(dx)
    ddy = np.gradient(dy)
    curvature = np.abs(dx * ddy - dy * ddx) / (dx**2 + dy**2 + 1e-10)**1.5
    curvature = (curvature - np.min(curvature)) / (np.max(curvature) - np.min(curvature) + 1e-10)
    apex_idx = np.argmin(y)
    apex_point = (int(x[apex_idx]), int(y[apex_idx]))
    mean_x = np.mean(x)
    view_type = '2CH' if abs(np.median(x) - mean_x) < 0.3*img_width else '4CH'
    max_y = max(y)
    min_y = min(y)
    basal_threshold = max_y - delta_y_fraction * (max_y - min_y)
    peaks, _ = find_peaks(curvature, prominence=0.0005)
    basal_peaks = [i for i in peaks if y[i] >= basal_threshold]
    if len(basal_peaks) < 2:
        basal_threshold = max_y - 2 * delta_y_fraction * (max_y - min_y)
        basal_peaks = [i for i in peaks if y[i] >= basal_threshold]
    basal_points = [i for i in range(len(y)) if y[i] >= basal_threshold]
    if view_type == '4CH':
        if basal_points:
            sorted_by_y = sorted(basal_points, key=lambda i: -y[i])
            lateral_idx = sorted_by_y[0]
            mean_x = np.mean(x[basal_points])
            septal_candidates = [i for i in basal_points 
                                 if (x[i] - mean_x) * (x[lateral_idx] - mean_x) < 0
                                 and curvature[i] >= curvature_threshold]
            if septal_candidates:
                septal_idx = max(septal_candidates, key=lambda i: curvature[i])
            else:
                other_points = [i for i in basal_points if i != lateral_idx]
                if other_points:
                    septal_idx = max(other_points, key=lambda i: curvature[i])
                else:
                    septal_idx = lateral_idx
        else:
            raise ValueError("No basal points found for 4CH")
    else:
        if len(basal_peaks) >= 2:
            basal_peaks_sorted = sorted(basal_peaks, key=lambda i: curvature[i], reverse=True)
            septal_idx = basal_peaks_sorted[0]
            septal_side = "right" if x[septal_idx] > mean_x else "left"
            lateral_candidates = [i for i in basal_peaks_sorted[1:] 
                                  if ((x[i] > mean_x and septal_side == "left") or 
                                      (x[i] < mean_x and septal_side == "right")) 
                                  and curvature[i] >= curvature_threshold / 2]
            if lateral_candidates:
                scores = []
                for i in lateral_candidates:
                    curvature_score = curvature[i] * 0.6
                    y_alignment_score = (1 - min(1, abs(y[i] - y[septal_idx]) / y_tolerance)) * 0.4
                    scores.append(curvature_score + y_alignment_score)
                lateral_idx = lateral_candidates[np.argmax(scores)]
            else:
                lateral_idx = basal_peaks_sorted[1]
            min_x_dist = min_x_separation * img_width
            if abs(x[septal_idx] - x[lateral_idx]) < min_x_dist:
                outward_dir = 1 if x[lateral_idx] > mean_x else -1
                for i in basal_peaks_sorted[2:]:
                    if abs(x[i] - x[septal_idx]) > min_x_dist and \
                       ((x[i] > mean_x and outward_dir > 0) or (x[i] < mean_x and outward_dir < 0)):
                        lateral_idx = i
                        break
            if abs(x[septal_idx] - x[lateral_idx]) < min_x_dist:
                print(f"Warning: Points still close after fallback. Distance: {abs(x[septal_idx] - x[lateral_idx]):.1f} px")
        else:
            if basal_points:
                left_idx = basal_points[np.argmin(x[basal_points])]
                right_idx = basal_points[np.argmax(x[basal_points])]
                if curvature[left_idx] > curvature[right_idx]:
                    septal_idx = left_idx
                    lateral_idx = right_idx
                else:
                    septal_idx = right_idx
                    lateral_idx = left_idx
            else:
                raise ValueError("No basal points found for 2CH")
    septal_point = (int(x[septal_idx]), int(y[septal_idx]))
    lateral_point = (int(x[lateral_idx]), int(y[lateral_idx]))
    mitral_mid = ((septal_point[0] + lateral_point[0]) / 2, (septal_point[1] + lateral_point[1]) / 2)
    print(f"View type: {view_type}")
    print(f"Apex: {apex_point}")
    print(f"Septal: {septal_point}")
    print(f"Lateral: {lateral_point}")
    print(f"Mitral Mid: {mitral_mid}")
    return apex_point, septal_point, lateral_point, mitral_mid, curvature, contour

def compute_triangular_volume(diameters_2ch, diameters_3ch, diameters_4ch, L_cm, nr_disks=20, theta_deg=60):
    """Calculate LV volume using triangular decomposition method"""
    if nr_disks <= 0:
        return 0.0, [], 0.0
    
    # Use the largest L to mitigate foreshortening
    h_cm = L_cm / nr_disks
    sin_theta = np.sin(np.deg2rad(theta_deg))
    
    total_volume = 0.0
    disk_volumes = []
    
    # Ensure all diameters lists have same length
    min_len = min(len(diameters_2ch), len(diameters_3ch), len(diameters_4ch))
    
    # Calculate volume for each disk using triangular decomposition
    for i in range(min_len):
        a_i = diameters_2ch[i] / 2  # Half-diameter for 2CH
        b_i = diameters_3ch[i] / 2  # Half-diameter for 3CH
        c_i = diameters_4ch[i] / 2  # Half-diameter for 4CH
        
        # Skip invalid diameters
        if a_i < 0.05 or b_i < 0.05 or c_i < 0.05:
            disk_volumes.append(0.0)
            continue
            
        # Calculate area of 6 triangles using formula: 0.5 * a * b * sin(60°)
        # Each triangle represents a segment between two adjacent views
        triangle_areas = [
            0.5 * a_i * b_i * sin_theta,  # Triangle between 2CH and 3CH
            0.5 * b_i * c_i * sin_theta,  # Triangle between 3CH and 4CH
            0.5 * c_i * a_i * sin_theta,  # Triangle between 4CH and 2CH
            0.5 * a_i * b_i * sin_theta,  # Second triangle between 2CH and 3CH (other side)
            0.5 * b_i * c_i * sin_theta,  # Second triangle between 3CH and 4CH (other side)
            0.5 * c_i * a_i * sin_theta   # Second triangle between 4CH and 2CH (other side)
        ]
        
        # Total area for this disk
        disk_area = sum(triangle_areas)
        
        # Disk volume (area × height)
        disk_volume = disk_area * h_cm
        total_volume += disk_volume
        disk_volumes.append(disk_volume)
    
    return total_volume, disk_volumes, 0.0  # No basal volume in this method

def compute_tbc_diameters(seg_binary, apex, septal, lateral, mitral_mid, pixel_spacing=(0.1, 0.1), nr_disks=20):
    """TBC method: Cylindrical disks with truncated basal cylinder"""
    # Create edge detection and filled mask
    edges = enhanced_edge_detection(seg_binary)
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not contours:
        return 0.0, [], [], 0.0, None, None
    
    max_contour = max(contours, key=cv2.contourArea)
    filled_mask = np.zeros_like(seg_binary, dtype=np.uint8)
    cv2.drawContours(filled_mask, [max_contour], -1, 1, thickness=cv2.FILLED)
    
    # Create boundary from filled mask
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    boundary = cv2.morphologyEx(filled_mask, cv2.MORPH_GRADIENT, kernel)
    
    # Long axis vector (apex to mitral_mid)
    dx_long = mitral_mid[0] - apex[0]
    dy_long = mitral_mid[1] - apex[1]
    L_cm = np.sqrt((dx_long * pixel_spacing[0])**2 + (dy_long * pixel_spacing[1])**2)
    if L_cm == 0:
        return L_cm, [], [], 0.0, None, None
    
    # Perpendicular vector (rotated 90°)
    long_norm = np.sqrt(dx_long**2 + dy_long**2) + 1e-10
    perp_vector = np.array([-dy_long/long_norm, dx_long/long_norm])
    
    max_dist = max(seg_binary.shape) // 2
    diameters_cm = []
    disk_info = []
    
    for i in range(nr_disks):
        # Disk center along long axis
        t = (i + 0.5) / nr_disks
        center_x = apex[0] + t * dx_long
        center_y = apex[1] + t * dy_long
        
        # Find boundary intersections
        point1 = point2 = None
        for direction in [-1, 1]:
            for dist in range(1, max_dist):
                px = int(round(center_x + direction * dist * perp_vector[0]))
                py = int(round(center_y + direction * dist * perp_vector[1]))
                
                if not (0 <= px < seg_binary.shape[1] and 0 <= py < seg_binary.shape[0]):
                    break
                
                if boundary[py, px]:
                    if direction == -1:
                        point1 = (px, py)
                    else:
                        point2 = (px, py)
                    break
        
        if point1 and point2:
            dx_px = point2[0] - point1[0]
            dy_px = point2[1] - point1[1]
            diameter_cm = np.sqrt((dx_px * pixel_spacing[0])**2 + (dy_px * pixel_spacing[1])**2)
            diameters_cm.append(diameter_cm)
        else:
            diameters_cm.append(0.0)
        
        disk_info.append({
            'center': (center_x, center_y),
            'point1': point1,
            'point2': point2,
            'diameter_cm': diameters_cm[-1] if diameters_cm else 0.0
        })
    
    # TBC method adds a half-cut cylinder at the basal region
    basal_diameter = diameters_cm[-1] if diameters_cm else 0.0
    basal_area_cm2 = (np.pi * basal_diameter * basal_diameter) / 8  # Half ellipse area
    
    return L_cm, diameters_cm, disk_info, basal_area_cm2, None, filled_mask

def visualize_processing(seg_data, landmarks, disk_info, view, phase, pixel_spacing):
    apex, septal, lateral, mitral_mid = landmarks
    plt.figure(figsize=(18, 6))
    edges = enhanced_edge_detection(seg_data)
    plt.subplot(131)
    plt.imshow(edges, cmap='gray')
    plt.title(f'{view} {phase} - Edge Detection')
    plt.subplot(132)
    plt.imshow(seg_data, cmap='gray')
    plt.scatter(apex[0], apex[1], c='red', s=100, marker='x', label='Apex')
    plt.scatter(septal[0], septal[1], c='blue', s=100, marker='x', label='Septal')
    plt.scatter(lateral[0], lateral[1], c='green', s=100, marker='x', label='Lateral')
    plt.scatter(mitral_mid[0], mitral_mid[1], c='magenta', s=150, marker='+', label='Mitral Mid')
    plt.plot([apex[0], mitral_mid[0]], [apex[1], mitral_mid[1]], 'y--', linewidth=2, label='Long Axis')
    plt.plot([septal[0], lateral[0]], [septal[1], lateral[1]], 'c--', linewidth=2, label='Mitral Plane')
    plt.legend(loc='upper right')
    plt.title(f'{view} {phase} - Landmarks')
    plt.subplot(133)
    plt.imshow(seg_data, cmap='gray')
    for disk in disk_info:
        if disk['point1'] and disk['point2']:
            plt.plot([disk['point1'][0], disk['point2'][0]], 
                     [disk['point1'][1], disk['point2'][1]], 
                     'r-', linewidth=1.5)
            plt.scatter(disk['center'][0], disk['center'][1], c='yellow', s=30)
    plt.scatter(apex[0], apex[1], c='red', s=100, marker='x')
    plt.scatter(septal[0], septal[1], c='blue', s=100, marker='x')
    plt.scatter(lateral[0], lateral[1], c='green', s=100, marker='x')
    plt.scatter(mitral_mid[0], mitral_mid[1], c='magenta', s=150, marker='+')
    plt.plot([apex[0], mitral_mid[0]], [apex[1], mitral_mid[1]], 'y--', linewidth=1)
    plt.plot([septal[0], lateral[0]], [septal[1], lateral[1]], 'c--', linewidth=1)
    plt.title(f'{view} {phase} - TBC Method Disks')
    plt.tight_layout()
    plt.show()

def process_patient(patient_dir, visualize=False):
    patient_id = os.path.basename(patient_dir)
    print(f"\nProcessing Patient: {patient_id}")
    volumes = {'2CH': {'ED': None, 'ES': None}, '3CH': {'ED': None, 'ES': None}, '4CH': {'ED': None, 'ES': None}}
    all_landmarks = {}
    all_disk_info = {}
    try:
        for file in os.listdir(patient_dir):
            if not (file.endswith('_gt.niigz') or file.endswith('_gt.nii.gz')):
                continue
            view = None
            phase = None
            if '2ch' in file.lower():
                view = '2CH'
            elif '3ch' in file.lower():
                view = '3CH'
            elif '4ch' in file.lower():
                view = '4CH'
            if 'ed' in file.lower():
                phase = 'ED'
            elif 'es' in file.lower():
                phase = 'ES'
            if not view or not phase:
                continue
            print(f"\nProcessing {view} {phase}...")
            filepath = os.path.join(patient_dir, file)
            temp_file = None
            try:
                if file.endswith('.niigz'):
                    temp_file = handle_compressed_file(filepath)
                    gt_img = nib.load(temp_file)
                else:
                    gt_img = nib.load(filepath)
                gt_data = gt_img.get_fdata()
                pixel_spacing = gt_img.header.get_zooms()[:2]
                mask = np.zeros_like(gt_data, dtype=np.uint8)
                mask[gt_data == 1] = 255
                rotated_mask = rotate_image_clockwise(mask)
                try:
                    apex, septal, lateral, mitral_mid, _, _ = detect_landmarks(rotated_mask)
                    L_cm, diameters_cm, disk_info, basal_area_cm2, _, _ = compute_tbc_diameters(
                        rotated_mask, apex, septal, lateral, mitral_mid, pixel_spacing
                    )
                    volumes[view][phase] = {
                        'L_cm': L_cm,
                        'diameters_cm': diameters_cm,
                        'basal_area_cm2': basal_area_cm2,
                        'status': 'success'
                    }
                    all_landmarks[(view, phase)] = (apex, septal, lateral, mitral_mid)
                    all_disk_info[(view, phase)] = disk_info
                    if visualize:
                        visualize_processing(rotated_mask, (apex, septal, lateral, mitral_mid), 
                                             disk_info, view, phase, pixel_spacing)
                except Exception as landmark_error:
                    print(f"Landmark detection failed for {view} {phase}: {str(landmark_error)}")
                    volumes[view][phase] = {
                        'status': f'landmark_error: {str(landmark_error)}'
                    }
            except Exception as file_error:
                print(f"Error processing file {file}: {str(file_error)}")
                if view and phase:
                    volumes[view][phase] = {
                        'status': f'file_error: {str(file_error)}'
                    }
            finally:
                if temp_file and os.path.exists(temp_file):
                    os.unlink(temp_file)
    except Exception as patient_error:
        print(f"Error processing patient {patient_id}: {str(patient_error)}")
        return None, None, None
    
    # Check for complete data
    missing = []
    for view in ['2CH', '3CH', '4CH']:
        for phase in ['ED', 'ES']:
            if not volumes[view][phase] or volumes[view][phase].get('status') != 'success':
                missing.append(f"{view} {phase}")
    
    if missing:
        print(f"\nIncomplete data for EF calculation. Missing: {', '.join(missing)}")
    else:
        try:
            # Calculate volumes using triangular decomposition method
            # Use the maximum long-axis length
            L_cm = max(
                volumes['2CH']['ED']['L_cm'],
                volumes['3CH']['ED']['L_cm'],
                volumes['4CH']['ED']['L_cm']
            )
            
            ed_vol, ed_disk_volumes, _ = compute_triangular_volume(
                volumes['2CH']['ED']['diameters_cm'],
                volumes['3CH']['ED']['diameters_cm'],
                volumes['4CH']['ED']['diameters_cm'],
                L_cm
            )
            
            es_vol, es_disk_volumes, _ = compute_triangular_volume(
                volumes['2CH']['ES']['diameters_cm'],
                volumes['3CH']['ES']['diameters_cm'],
                volumes['4CH']['ES']['diameters_cm'],
                L_cm
            )
            
            ef = ((ed_vol - es_vol) / ed_vol) * 100 if ed_vol > 0 else 0
            
            # Store results
            volumes['EDV'] = ed_vol
            volumes['ESV'] = es_vol
            volumes['EF'] = ef
            volumes['ED_disk_volumes'] = ed_disk_volumes
            volumes['ES_disk_volumes'] = es_disk_volumes
            
            print(f"\nTriangular Decomposition Method Results:")
            print(f"EDV: {ed_vol:.2f} ml")
            print(f"ESV: {es_vol:.2f} ml")
            print(f"EF: {ef:.2f}%")
            
            if visualize:
                plt.figure(figsize=(12, 6))
                plt.subplot(121)
                plt.plot(ed_disk_volumes, 'b-', label='ED Disk Volume')
                plt.plot(es_disk_volumes, 'r-', label='ES Disk Volume')
                plt.xlabel('Disk Number')
                plt.ylabel('Volume Contribution (ml)')
                plt.title('Volume Distribution Along Long Axis')
                plt.legend()
                
                plt.subplot(122)
                plt.bar(['EDV', 'ESV', 'EF'], [ed_vol, es_vol, ef], 
                       color=['blue', 'red', 'green'])
                plt.ylabel('Volume (ml) / Percentage (%)')
                plt.title(f'Triangular Method: EF = {ef:.1f}%')
                plt.tight_layout()
                plt.show()
                
        except Exception as vol_error:
            print(f"Volume calculation error: {str(vol_error)}")
    
    return volumes, all_landmarks, all_disk_info

if __name__ == "__main__":
    base_dir = r"E:\echo heart\Resources"
    patient_dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
    if not patient_dirs:
        print("No patient directories found")
    else:
        print("Available patients:")
        for i, patient in enumerate(patient_dirs):
            print(f"{i+1}. {patient}")
        choice = input("Enter patient number to process (or 'all' for all patients): ")
        if choice.lower() == 'all':
            for patient_dir in patient_dirs:
                full_path = os.path.join(base_dir, patient_dir)
                process_patient(full_path, visualize=True)
        elif choice.isdigit():
            idx = int(choice) - 1
            if 0 <= idx < len(patient_dirs):
                full_path = os.path.join(base_dir, patient_dirs[idx])
                process_patient(full_path, visualize=True)
            else:
                print("Invalid patient number")
        else:
            print("Invalid input")