Automatic estimation of Ejection Fraction from echocardiographic images using the Simpson's biplane method of disks
===

***
# <span style="color:brown"> Preamble

This notebook provides a method to compute the ejection fraction from Simpson's biplane method of disks using the segmentation obtained from 2D echocardiographic images at end diastole and end systole time instances from Apical two and four chambers views. This method was used in the following paper:

Leclerc S, Smistad E, Pedrosa J, Østvik A, Cervenansky F, Espinosa F, Espeland T, Rye Berg EA, Jodoin PM, Grenier T, Lartizien C, D’hooge J, Lovstakken L, Bernard O. "Deep Learning for Segmentation using an Open Large-Scale Dataset in 2D Echocardiography" IEEE Trans Med Imaging, 2019:38:2198-2210, DOI: 10.1109/TMI.2019.2900516
    
# <span style="color:brown"> Objectives

* Provide the code to compute EF for open science purposes 
* This code can be run from the [CAMUS dataset](https://humanheart-project.creatis.insa-lyon.fr/database/#collection/6373703d73e9f0047faa1bc8) to reproduce the EF values provided in this collection
    
***

# <span style="color:brown"> Warnings

* We have observed that the way in which Simpson's biplane method is implemented can have a significant influence on the final values calculated. We do not guarantee that the method implemented in this notebook is optimal. The values produced by this method should be used with caution.
    

## Import the different python librairies

In [None]:
import logging
from pathlib import Path
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import logging
from pathlib import Path
from typing import Any, Dict, Tuple

import numpy as np
import PIL
import SimpleITK as sitk
from PIL.Image import Resampling
from skimage.measure import find_contours
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix, f1_score




logger = logging.getLogger(__name__)

## Let's define a few useful functions to load and manipulate images

In [None]:
def sitk_load(filepath: str | Path) -> Tuple[np.ndarray, Dict[str, Any]]:
    """Loads an image using SimpleITK and returns the image and its metadata.

    Args:
        filepath: Path to the image.

    Returns:
        - ([N], H, W), Image array.
        - Collection of metadata.
    """
    # Load image and save info
    image = sitk.ReadImage(str(filepath))
    info = {"origin": image.GetOrigin(), "spacing": image.GetSpacing(), "direction": image.GetDirection()}

    # Extract numpy array from the SimpleITK image object
    im_array = np.squeeze(sitk.GetArrayFromImage(image))

    return im_array, info

In [None]:
def resize_image(image: np.ndarray, size: Tuple[int, int], resample: Resampling = Resampling.NEAREST) -> np.ndarray:
    """Resizes the image to the specified dimensions.

    Args:
        image: (H, W), Input image to resize. Must be in a format supported by PIL.
        size: Width (W') and height (H') dimensions of the resized image to output.
        resample: Resampling filter to use.

    Returns:
        (H', W'), Input image resized to the specified dimensions.
    """
    resized_image = np.array(PIL.Image.fromarray(image).resize(size, resample=resample))
    return resized_image

In [None]:
def resize_image_to_isotropic(
    image: np.ndarray, spacing: Tuple[float, float], resample: Resampling = Resampling.NEAREST
) -> np.ndarray:
    """Resizes the image to attain isotropic spacing, by resampling the dimension with the biggest voxel size.

    Args:
        image: (H, W), Input image to resize. Must be in a format supported by PIL.
        spacing: Size of the image's pixels along each (height, width) dimension.
        resample: Resampling filter to use.

    Returns:
        (H', W'), Input image resized so that the spacing is isotropic, and the isotropic value of the new spacing.
    """
    scaling = np.array(spacing) / min(spacing)
    new_height, new_width = (np.array(image.shape) * scaling).round().astype(int)
    return resize_image(image, (new_width, new_height), resample=resample), min(spacing)

## Implement Simpson's biplane method of disks

In [None]:
def compute_left_ventricle_volumes(
    a2c_ed: np.ndarray,
    a2c_es: np.ndarray,
    a2c_voxelspacing: Tuple[float, float],
    a4c_ed: np.ndarray,
    a4c_es: np.ndarray,
    a4c_voxelspacing: Tuple[float, float],
) -> Tuple[float, float]:
    """Computes the ED and ES volumes of the left ventricle from 2 orthogonal 2D views (A2C and A4C).

    Args:
        a2c_ed: (H,W), Binary segmentation map of the left ventricle from the end-diastole (ED) instant of the 2-chamber
            apical view (A2C).
        a2c_es: (H,W), Binary segmentation map of the left ventricle from the end-systole (ES) instant of the 2-chamber
            apical view (A2C).
        a2c_voxelspacing: Size (in mm) of the 2-chamber apical view's voxels along each (height, width) dimension.
        a4c_ed: (H,W), Binary segmentation map of the left ventricle from the end-diastole (ED) instant of the 4-chamber
            apical view (A4C).
        a4c_es: (H,W), Binary segmentation map of the left ventricle from the end-systole (ES) instant of the 4-chamber
            apical view (A4C).
        a4c_voxelspacing: Size (in mm) of the 4-chamber apical view's voxels along each (height, width) dimension.

    Returns:
        Left ventricle ED and ES volumes.
    """
    for mask_name, mask in [("a2c_ed", a2c_ed), ("a2c_es", a2c_es), ("a4c_ed", a4c_ed), ("a4c_es", a4c_es)]:
        if mask.max() > 1:
            logger.warning(
                f"`compute_left_ventricle_volumes` expects binary segmentation masks of the left ventricle (LV). "
                f"However, the `{mask_name}` segmentation contains a label greater than '1/True'. If this was done "
                f"voluntarily, you can safely ignore this warning. However, the most likely cause is that you forgot "
                f"to extract the binary LV segmentation from a multi-class segmentation mask."
            )

    a2c_ed_diameters, a2c_ed_step_size = _compute_diameters(a2c_ed, a2c_voxelspacing)
    a2c_es_diameters, a2c_es_step_size = _compute_diameters(a2c_es, a2c_voxelspacing)
    a4c_ed_diameters, a4c_ed_step_size = _compute_diameters(a4c_ed, a4c_voxelspacing)
    a4c_es_diameters, a4c_es_step_size = _compute_diameters(a4c_es, a4c_voxelspacing)
    step_size = max((a2c_ed_step_size, a2c_es_step_size, a4c_ed_step_size, a4c_es_step_size))

    ed_volume = _compute_left_ventricle_volume_by_instant(a2c_ed_diameters, a4c_ed_diameters, step_size)
    es_volume = _compute_left_ventricle_volume_by_instant(a2c_es_diameters, a4c_es_diameters, step_size)
    return ed_volume, es_volume


def _compute_left_ventricle_volume_by_instant(
    a2c_diameters: np.ndarray, a4c_diameters: np.ndarray, step_size: float
) -> float:
    """Compute left ventricle volume using Biplane Simpson's method.

    Args:
        a2c_diameters: Diameters measured at each key instant of the cardiac cycle, from the 2-chamber apical view.
        a4c_diameters: Diameters measured at each key instant of the cardiac cycle, from the 4-chamber apical view.
        step_size:

    Returns:
        Left ventricle volume (in millilitres).
    """
    # All measures are now in millimeters, convert to meters by dividing by 1000
    a2c_diameters /= 1000
    a4c_diameters /= 1000
    step_size /= 1000

    # Estimate left ventricle volume from orthogonal disks
    lv_volume = np.sum(a2c_diameters * a4c_diameters) * step_size * np.pi / 4

    # Volume is now in cubic meters, so convert to milliliters (1 cubic meter = 1_000_000 milliliters)
    return round(lv_volume * 1e6)


def _find_distance_to_edge(
    segmentation: np.ndarray, point_on_mid_line: np.ndarray, normal_direction: np.ndarray
) -> float:
    distance = 8  # start a bit in to avoid line stopping early at base
    while True:
        current_position = point_on_mid_line + distance * normal_direction

        y, x = np.round(current_position).astype(int)
        if segmentation.shape[0] <= y or y < 0 or segmentation.shape[1] <= x or x < 0:
            # out of bounds
            return distance

        elif segmentation[y, x] == 0:
            # Edge found
            return distance

        distance += 0.5


def _distance_line_to_points(line_point_0: np.ndarray, line_point_1: np.ndarray, points: np.ndarray) -> np.ndarray:
    # https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
    return np.absolute(np.cross(line_point_1 - line_point_0, line_point_0 - points)) / np.linalg.norm(
        line_point_1 - line_point_0
    )


def _get_angle_of_lines_to_point(reference_point: np.ndarray, moving_points: np.ndarray) -> np.ndarray:
    diff = moving_points - reference_point
    return abs(np.degrees(np.arctan2(diff[:, 0], diff[:, 1])))


def _compute_diameters(segmentation: np.ndarray, voxelspacing: Tuple[float, float]) -> Tuple[np.ndarray, float]:
    """

    Args:
        segmentation: Binary segmentation of the structure for which to find the diameter.
        voxelspacing: Size of the segmentations' voxels along each (height, width) dimension (in mm).

    Returns:
    """

    # Make image isotropic, have same spacing in both directions.
    # The spacing can be multiplied by the diameter directly.
    segmentation, isotropic_spacing = resize_image_to_isotropic(segmentation, voxelspacing)

    # Go through entire contour to find AV plane
    contour = find_contours(segmentation, 0.5)[0]

    # For each pair of contour points
    # Check if angle is ok
    # If angle is ok, check that almost all other contour points are above the line
    # Or check that all points between are close to the line
    # If so, it is accepted, select the longest stretch
    best_length = 0
    for point_idx in range(2, len(contour)):
        previous_points = contour[:point_idx]
        angles_to_previous_points = _get_angle_of_lines_to_point(contour[point_idx], previous_points)

        for acute_angle_idx in np.nonzero(angles_to_previous_points <= 45)[0]:
            intermediate_points = contour[acute_angle_idx + 1 : point_idx]
            distance_to_intermediate_points = _distance_line_to_points(
                contour[point_idx], contour[acute_angle_idx], intermediate_points
            )
            if np.all(distance_to_intermediate_points <= 8):
                distance = np.linalg.norm(contour[point_idx] - contour[acute_angle_idx])
                if best_length < distance:
                    best_length = distance
                    best_i = point_idx
                    best_j = acute_angle_idx

    mid_point = int(best_j + round((best_i - best_j) / 2))
    # Apex is longest from midpoint
    mid_line_length = 0
    apex = 0
    for i in range(len(contour)):
        length = np.linalg.norm(contour[mid_point] - contour[i])
        if mid_line_length < length:
            mid_line_length = length
            apex = i

    direction = contour[apex] - contour[mid_point]
    normal_direction = np.array([-direction[1], direction[0]])
    normal_direction = normal_direction / np.linalg.norm(normal_direction)  # Normalize
    diameters = []
    for fraction in np.linspace(0, 1, 20, endpoint=False):
        point_on_mid_line = contour[mid_point] + direction * fraction

        distance1 = _find_distance_to_edge(segmentation, point_on_mid_line, normal_direction)
        distance2 = _find_distance_to_edge(segmentation, point_on_mid_line, -normal_direction)
        diameters.append((distance1 + distance2) * isotropic_spacing)

    step_size = (mid_line_length * isotropic_spacing) / 20
    return np.array(diameters), step_size


## Load the 2D segmentation masks required to compute the left ventricular volumes and ejection fraction (EF) for one patient

NOTE: The following cells assume that the `database_nifti` archive was downloaded and extracted in the current directory.

In [None]:
###########################################
# PARAMETERS TO PLAY WITH

database_nifti_root = Path("./data")
lv_label = 1
# Select the patient identification (scalar value between 1 and 500)
patient_id = 237


In [None]:
print(database_nifti_root)

In [None]:
# Specify the ID and path of the patient to be loaded
patient_name = f"patient{patient_id:04d}"
patient_dir = database_nifti_root / patient_name
gt_mask_pattern = "{patient_name}_{view}_{instant}_gt.nii.gz"
print(f"Loading data from patient folder: {patient_dir}")

In [None]:
view = "2CH"
instant = "ED"
a2c_ed, a2c_info = sitk_load(patient_dir / gt_mask_pattern.format(patient_name=patient_name, view=view, instant=instant))
a2c_voxelspacing = a2c_info["spacing"][:2][::-1]    # Extract the (width,height) dimension from the metadata and order them like in the mask

In [None]:
instant = "ES"
a2c_es, _ = sitk_load(patient_dir / gt_mask_pattern.format(patient_name=patient_name, view=view, instant=instant))

In [None]:
view = "4CH"
instant = "ED"
a4c_ed, a4c_info = sitk_load(patient_dir / gt_mask_pattern.format(patient_name=patient_name, view=view, instant=instant))
a4c_voxelspacing = a4c_info["spacing"][:2][::-1]    # Extract the (width,height) dimension from the metadata and order them like in the mask

In [None]:
instant = "ES"
a4c_es, _ = sitk_load(patient_dir / gt_mask_pattern.format(patient_name=patient_name, view=view, instant=instant))

## Run Simpson's biplane method of disks on the data from the selected patient

In [None]:
# Extract binary LV masks from the multi-class segmentation masks
a2c_ed_lv_mask = a2c_ed == lv_label
a2c_es_lv_mask = a2c_es == lv_label
a4c_ed_lv_mask = a4c_ed == lv_label
a4c_es_lv_mask = a4c_es == lv_label

# Use the provided implementation to compute the LV volumes
edv, esv = compute_left_ventricle_volumes(a2c_ed_lv_mask, a2c_es_lv_mask, a2c_voxelspacing, a4c_ed_lv_mask, a4c_es_lv_mask, a4c_voxelspacing)
ef = round(100 * (edv - esv) / edv) # Round the computed value to the nearest integer

print(f"{patient_name=}: {ef=}, {edv=}, {esv=}")

In [None]:
import matplotlib.pyplot as plt

# Data
volumes = [95, 44]  # EDV, ESV
labels = ['End Diastolic', 'End Systolic']
colors = ['#2ca02c', '#d62728']

# Create figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Volume plot
ax1.bar(labels, volumes, color=colors)
ax1.set_title('Left Ventricle Volumes')
ax1.set_ylabel('Volume (mL)')
for i, v in enumerate(volumes):
    ax1.text(i, v+2, str(v), ha='center')

# EF gauge
ax2.set_title('Ejection Fraction')
ax2.set_xlim(0, 100)
ax2.set_ylim(0, 1)
ax2.axis('off')
ax2.text(50, 0.8, f'{ef}%', ha='center', fontsize=24)

# EF reference lines
ax2.plot([0, 20], [0.2, 0.2], color='red', linewidth=10)
ax2.plot([20, 40], [0.2, 0.2], color='orange', linewidth=10)
ax2.plot([40, 55], [0.2, 0.2], color='green', linewidth=10)
ax2.plot([55, 75], [0.2, 0.2], color='orange', linewidth=10)
ax2.plot([75, 100], [0.2, 0.2], color='red', linewidth=10)

ax2.text(10, 0.3, 'Severely Low', ha='center')
ax2.text(30, 0.3, 'Low', ha='center')
ax2.text(47.5, 0.3, 'Normal', ha='center')
ax2.text(65, 0.3, 'High', ha='center')
ax2.text(87.5, 0.3, 'Very High', ha='center')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap

# Create a custom colormap for overlay (red for LV segmentation)
mask_cmap = ListedColormap([[0, 0, 0, 0], [1, 0, 0, 0.5]])  # Transparent, then semi-transparent red

def load_echo_image(patient_dir, patient_name, view, instant):
    """Load original echocardiographic image"""
    img_pattern = f"{patient_name}_{view}_{instant}.nii.gz"
    img, _ = sitk_load(patient_dir / img_pattern)
    return img

# Load original images
a2c_ed_img = load_echo_image(patient_dir, patient_name, "2CH", "ED")
a2c_es_img = load_echo_image(patient_dir, patient_name, "2CH", "ES")
a4c_ed_img = load_echo_image(patient_dir, patient_name, "4CH", "ED")
a4c_es_img = load_echo_image(patient_dir, patient_name, "4CH", "ES")

# Create figure
fig, axes = plt.subplots(4, 3, figsize=(15, 20))
fig.suptitle(f"Patient {patient_id} - LV Segmentation Visualization", fontsize=16)

views = ["2CH", "4CH"]
instants = ["ED", "ES"]
images = {
    "2CH_ED": a2c_ed_img,
    "2CH_ES": a2c_es_img,
    "4CH_ED": a4c_ed_img,
    "4CH_ES": a4c_es_img
}
masks = {
    "2CH_ED": a2c_ed_lv_mask,
    "2CH_ES": a2c_es_lv_mask,
    "4CH_ED": a4c_ed_lv_mask,
    "4CH_ES": a4c_es_lv_mask
}

for i, (view, instant) in enumerate([(v, t) for v in views for t in instants]):
    key = f"{view}_{instant}"
    row = i * 2
    
    # Original image
    axes[i, 0].imshow(images[key], cmap='gray')
    axes[i, 0].set_title(f"{view} {instant} - Original")
    axes[i, 0].axis('off')
    
    # Segmentation mask
    axes[i, 1].imshow(masks[key], cmap='gray')
    axes[i, 1].set_title(f"{view} {instant} - LV Mask")
    axes[i, 1].axis('off')
    
    # Overlay
    axes[i, 2].imshow(images[key], cmap='gray')
    axes[i, 2].imshow(masks[key], cmap=mask_cmap)
    axes[i, 2].set_title(f"{view} {instant} - Overlay")
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Create improved colormaps
lvendo_cmap = ListedColormap([[0, 0, 0, 0], [1, 0, 0, 0.7]])  # Solid red for ENDOcardium
lvepi_cmap = ListedColormap([[0, 0, 0, 0], [0, 1, 0, 0.4]])    # Semi-transparent green for EPIcardium
la_cmap = ListedColormap([[0, 0, 0, 0], [0, 0.5, 1, 0.3]])     # Light blue for LA

def load_all_masks(patient_dir, patient_name, view, instant):
    """Load all masks with proper structure labels"""
    mask, info = sitk_load(patient_dir / f"{patient_name}_{view}_{instant}_gt.nii.gz")
    return {
        'LVendo': mask == 1,  # Inner blood pool boundary
        'LVepi': mask == 2,   # Outer heart wall boundary
        'LA': mask == 3,      # Left atrium
        'spacing': info['spacing'][:2][::-1]
    }

# Create figure with improved layout
fig, axes = plt.subplots(2, 3, figsize=(18, 20))# make 4,3 if include 4CH in views
fig.suptitle(f"Patient {patient_id} - Corrected LV Structure Visualization", fontsize=16, y=0.98)

views =  ["2CH"]#["2CH", "4CH"] change subplots to 4,3 if add in 4CH 
instants = ["ED", "ES"]

for i, (view, instant) in enumerate([(v, t) for v in views for t in instants]):
    # Load data
    img = load_echo_image(patient_dir, patient_name, view, instant)
    masks = load_all_masks(patient_dir, patient_name, view, instant)
    
    # Column 1: Original image
    axes[i, 0].imshow(img, cmap='gray')
    axes[i, 0].set_title(f"{view} {instant}\nOriginal Image", pad=10)
    axes[i, 0].axis('off')
    
    # Column 2: Fill visualization (corrected order)
    axes[i, 1].imshow(img, cmap='gray')
    axes[i, 1].imshow(masks['LVepi'], cmap=lvepi_cmap)  # EPI first (outer)
    axes[i, 1].imshow(masks['LVendo'], cmap=lvendo_cmap)  # ENDO second (inner)
    axes[i, 1].imshow(masks['LA'], cmap=la_cmap)
    axes[i, 1].set_title(f"{view} {instant}\nStructure Fills", pad=10)
    axes[i, 1].axis('off')
    
    # Column 3: Corrected contour visualization
    axes[i, 2].imshow(img, cmap='gray')
    
    # Plot EPI first (outer contour)
    epi_contours = find_contours(masks['LVepi'], 0.5)
    for contour in epi_contours:
        axes[i, 2].plot(contour[:, 1], contour[:, 0], linewidth=2, color='lime', label='LVEpi' if i==0 else "")
    
    # Then plot ENDO (inner contour)
    endo_contours = find_contours(masks['LVendo'], 0.5)
    for contour in endo_contours:
        axes[i, 2].plot(contour[:, 1], contour[:, 0], linewidth=2, color='red', label='LVEndo' if i==0 else "")
    
    # Finally plot LA
    la_contours = find_contours(masks['LA'], 0.5)
    for contour in la_contours:
        axes[i, 2].plot(contour[:, 1], contour[:, 0], linewidth=2, color='cyan', label='LA' if i==0 else "")
    
    axes[i, 2].set_title(f"{view} {instant}\nStructure Contours", pad=10)
    axes[i, 2].axis('off')
    
    # Add legend only once
    if i == 0:
        axes[i, 2].legend(loc='upper right', fontsize=8)

plt.tight_layout(pad=3.0)
plt.show()

## Convolutional Vision Transformer 

In [None]:
class ConvBlock(nn.Module):
    """Convolutional block with batch norm and ReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))

class PatchEmbedding(nn.Module):
    """Convert image to patches and embed with conv layers"""
    def __init__(self, in_channels=1, embed_dim=128, patch_size=16):
        super().__init__()
        self.proj = nn.Sequential(
            ConvBlock(in_channels, embed_dim//4, 7, 3),
            ConvBlock(embed_dim//4, embed_dim//2, 3, 1),
            ConvBlock(embed_dim//2, embed_dim, 3, 1),
            nn.MaxPool2d(kernel_size=patch_size//8, stride=patch_size//8)
        )
        
    def forward(self, x):
        x = self.proj(x)  # [B, C, H, W] -> [B, embed_dim, H', W']
        return rearrange(x, 'b c h w -> b (h w) c')  # Flatten to sequence

class TransformerBlock(nn.Module):
    """Transformer block with multi-head attention"""
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim*4, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        # Self-attention
        res = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = res + x
        
        # MLP
        res = x
        x = self.norm2(x)
        x = res + self.mlp(x)
        return x

class CVT(nn.Module):
    """Convolutional Visual Transformer for cardiac segmentation"""
    def __init__(self, in_channels=1, num_classes=3, embed_dim=128, 
                 num_heads=4, num_layers=4, patch_size=16):
        super().__init__()
        
        # 1. Patch embedding with convs
        self.patch_embed = PatchEmbedding(in_channels, embed_dim, patch_size)
        
        # 2. Transformer encoder
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads) 
            for _ in range(num_layers)
        ])
        
        # 3. Decoder with transposed convs
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim//2, 4, 2, 1),
            ConvBlock(embed_dim//2, embed_dim//4),
            nn.ConvTranspose2d(embed_dim//4, embed_dim//8, 4, 2, 1),
            ConvBlock(embed_dim//8, embed_dim//16),
            nn.Conv2d(embed_dim//16, num_classes, 1)
        )
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # 1. Encode with conv + transformer
        x = self.patch_embed(x)  # [B, N, embed_dim]
        x = self.transformer(x)
        
        # 2. Reshape back to spatial
        h, w = H // self.patch_embed.proj[-1].stride, W // self.patch_embed.proj[-1].stride
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        
        # 3. Decode with transposed convs
        x = self.decoder(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)
        
        return x


# 1. Ultrasound Image Loading Functions (from your earlier code)
def load_echo_image(patient_dir, patient_name, view, instant):
    """Load original echocardiographic image"""
    img_path = patient_dir / f"{patient_name}_{view}_{instant}.nii.gz"
    img = sitk.ReadImage(str(img_path))
    return sitk.GetArrayFromImage(img).squeeze()

# 2. Modified Cardiac Dataset Class
class CardiacDataset(Dataset):
    def __init__(self, root_dir, transform=None, view="2CH", instant="ED"):
        """
        Args:
            root_dir: Directory with patient folders
            transform: Optional transform
            view: View to load ('2CH' or '4CH')
            instant: Cardiac phase ('ED' or 'ES')
        """
        self.root_dir = Path(root_dir)
        self.patient_dirs = sorted([d for d in self.root_dir.iterdir() if d.is_dir()])
        self.transform = transform
        self.view = view
        self.instant = instant
        
    def __len__(self):
        return len(self.patient_dirs)
    
    def __getitem__(self, idx):
        patient_dir = self.patient_dirs[idx]
        patient_name = patient_dir.name
        
        # Load ultrasound image
        img = load_echo_image(patient_dir, patient_name, self.view, self.instant)
        
        # Load corresponding mask
        mask_path = patient_dir / f"{patient_name}_{self.view}_{self.instant}_gt.nii.gz"
        mask = sitk.GetArrayFromImage(sitk.ReadImage(str(mask_path))).squeeze()
        
        # Convert to tensors
        img = torch.FloatTensor(img).unsqueeze(0)  # Add channel dim [1, H, W]
        mask = torch.LongTensor(mask)  # Class indices 0-3
        
        # Normalize image to [0, 1]
        img = (img - img.min()) / (img.max() - img.min())
        
        if self.transform:
            img = self.transform(img)
            
        return img, mask

# 3. Training Setup with Ultrasound-specific Parameters
def train_ultrasound_model(data_root):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Model with 4 output classes (background + 3 structures)
    model = CVT(in_channels=1, num_classes=4).to(device)
    
    # Weighted loss (adjust based on your class distribution)
    class_weights = torch.tensor([0.1, 1.0, 1.5, 1.0]).to(device)  # [bg, LVendo, LVepi, LA]
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    
    # Optimizer with lower learning rate for ultrasound
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Dataset and loader
    train_dataset = CardiacDataset(
        root_dir=data_root,
        view="2CH",
        instant="ED",
        transform=None  # Add transforms if needed (e.g., RandomCrop)
    )
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4
    )
    
    # Training loop
    for epoch in range(50):
        model.train()
        for batch_idx, (images, masks) in enumerate(train_loader):
            images, masks = images.to(device), masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
    
    return model

# 4. Example Usage
if __name__ == "__main__":
    # Path to your CAMUS dataset
    database_nifti_root = "./data"  # Update with your path
    
    # Verify data loading
    dataset = CardiacDataset(database_nifti_root)
    img, mask = dataset[0]
    print(f"Image shape: {img.shape}, Mask shape: {mask.shape}")
    print(f"Unique mask values: {torch.unique(mask)}")  # Should be 0,1,2,3
    
    # Train model
    model = train_ultrasound_model(database_nifti_root)
    
    # Save model
    torch.save(model.state_dict(), "cvt_ultrasound_segmentation.pth")
    # 1. Create model
    model = CVT(in_channels=1, num_classes=3)
    
    # 2. Create dummy input (batch of 1, 1 channel, 256x256)
    x = torch.randn(1, 1, 256, 256)
    
    # 3. Forward pass
    with torch.no_grad():
        output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")  # Should be [1, 3, 256, 256]import nump as np

## Validation and results

In [None]:
def validate_model(model, dataset, device, num_samples=3):
    model.eval()
    fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))
    
    with torch.no_grad():
        for i in range(num_samples):
            # Get random sample
            idx = np.random.randint(len(dataset))
            img, true_mask = dataset[idx]
            img = img.unsqueeze(0).to(device)
            
            # Predict
            pred_logits = model(img)
            pred_mask = torch.argmax(pred_logits, dim=1).squeeze().cpu().numpy()
            
            # Original data
            original_img = img.squeeze().cpu().numpy()
            true_mask = true_mask.numpy()
            
            # Plotting
            axes[i, 0].imshow(original_img, cmap='gray')
            axes[i, 0].set_title("Original Image")
            axes[i, 0].axis('off')
            
            # Ground Truth
            axes[i, 1].imshow(true_mask, cmap='jet', vmin=0, vmax=3)
            axes[i, 1].set_title("Ground Truth")
            axes[i, 1].axis('off')
            
            # Prediction
            axes[i, 2].imshow(pred_mask, cmap='jet', vmin=0, vmax=3)
            axes[i, 2].set_title("Prediction")
            axes[i, 2].axis('off')
            
            # Individual Class Comparisons
            for class_idx, class_name in enumerate(["LVendo", "LVepi", "LA"], start=1):
                axes[i, class_idx+2].imshow(original_img, cmap='gray')
                axes[i, class_idx+2].imshow(
                    np.ma.masked_where(pred_mask != class_idx, pred_mask), 
                    cmap='autumn', alpha=0.5, vmin=0, vmax=3
                )
                axes[i, class_idx+2].imshow(
                    np.ma.masked_where(true_mask != class_idx, true_mask), 
                    cmap='winter', alpha=0.3, vmin=0, vmax=3
                )
                axes[i, class_idx+2].set_title(f"{class_name}\nRed=Pred, Blue=GT")
                axes[i, class_idx+2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Quantitative Metrics
    all_true = []
    all_pred = []
    for img, true_mask in DataLoader(dataset, batch_size=8):
        pred_mask = torch.argmax(model(img.to(device)), dim=1)
        all_true.append(true_mask.flatten())
        all_pred.append(pred_mask.cpu().flatten())
    
    all_true = torch.cat(all_true).numpy()
    all_pred = torch.cat(all_pred).numpy()
    
    print("Confusion Matrix:")
    print(confusion_matrix(all_true, all_pred, labels=[0,1,2,3]))
    
    print("\nF1 Scores:")
    for i, name in enumerate(["Background", "LVendo", "LVepi", "LA"]):
        print(f"{name}: {f1_score(all_true==i, all_pred==i):.3f}")

# Usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CVT().to(device)
model.load_state_dict(torch.load("best_model.pth"))  # Load trained weights
dataset = CardiacDataset(database_nifti_root)

validate_model(model, dataset, device, num_samples=3)