In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size and convert them to tensors.
    """
    def __init__(self, size=(512, 512)):
        self.size = size  # (height, width)

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array): Corresponding mask as a NumPy array.
        
        Returns:
            image_tensor (torch.Tensor): Resized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor.
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        mask_pil = Image.fromarray(mask)
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        # Debug: Uncomment for detailed logs
        # print(f"[Dataset] __getitem__ called with idx={idx}")
        
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.masks_dir, base_name + ".npy")
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f"[Error] Mask file not found: {mask_path}")
        mask = np.load(mask_path)

        # Remap any label >1 to 1 to ensure binary segmentation
        mask = np.where(mask > 1, 1, mask).astype(np.uint8)

        # Apply transformations if any
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. U-Net Model Definition
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm and ReLU.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """
    U-Net architecture for image segmentation.
    """
    def __init__(self, in_channels=3, out_channels=2):
        """
        Args:
            in_channels (int): Number of input channels (e.g., 3 for RGB).
            out_channels (int): Number of output channels/classes.
        """
        super(UNet, self).__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        """
        Forward pass through the U-Net.
        """
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)
        d3 = self.down3(p2)
        p3 = self.pool3(d3)
        d4 = self.down4(p3)
        p4 = self.pool4(d4)

        bn = self.bottleneck(p4)

        up_4 = self.up4(bn)
        merge4 = torch.cat([d4, up_4], dim=1)
        c4 = self.conv4(merge4)

        up_3 = self.up3(c4)
        merge3 = torch.cat([d3, up_3], dim=1)
        c3 = self.conv3(merge3)

        up_2 = self.up2(c3)
        merge2 = torch.cat([d2, up_2], dim=1)
        c2 = self.conv2(merge2)

        up_1 = self.up1(c2)
        merge1 = torch.cat([d1, up_1], dim=1)
        c1 = self.conv1(merge1)

        out = self.out(c1)
        return out

# ------------------------------
# 3. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx):
    """
    Trains the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
    
    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    return total_loss / len(loader.dataset)

def validate_one_epoch(model, loader, criterion, device, epoch_idx):
    """
    Validates the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
    
    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
        images = images.to(device)
        masks = masks.to(device)
        with torch.no_grad():
            outputs = model(images)
            loss = criterion(outputs, masks)
        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
    return total_loss / len(loader.dataset)

# ------------------------------
# 4. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.
    
    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.
    
    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device):
    """
    Evaluates the model on the test set and computes the average Dice coefficient.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
    
    Returns:
        float: Average Dice coefficient on the test set.
    """
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Testing"):
            images = images.to(device)
            masks  = masks.to(device)  # [B, H, W]

            outputs = model(images)    # [B, 2, H, W]
            preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            for pred, mask in zip(preds, masks):
                dice = dice_coefficient(pred, mask)
                dice_scores.append(dice)
    
    mean_dice = np.mean(dice_scores) if dice_scores else 0
    return mean_dice

def visualize_predictions(model, dataset, device, num_samples=3):
    """
    Visualizes predictions alongside the original images and ground truth masks.
    
    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device)  # [1, C, H, W]
        
        with torch.no_grad():
            output = model(image_batch)  # [1, 2, H, W]
            pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]
        
        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]
        
        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        
        # Handle image range for visualization
        if image_np.max() > 1:
            image_np = image_np.astype(np.uint8)
        else:
            image_np = (image_np * 255).astype(np.uint8)
        
        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")
        
        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")
        
        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")
        
        plt.tight_layout()
        plt.show()

# ------------------------------
# 5. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Paths Configuration
    # ------------------------------
    # Define your directories here
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 4
    num_epochs    = 10
    learning_rate = 1e-4

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[Main] Using device:", device)

    # ------------------------------
    # Collect File Lists
    # ------------------------------
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    all_train_masks = sorted([
        f for f in os.listdir(masks_dir)
        if f.lower().endswith(".npy")
    ])
    print(f"[Main] Found {len(all_train_masks)} training mask files in {masks_dir}")

    if len(all_train_masks) == 0:
        print("[Error] No training mask files found. Check your training mask path!")
        return

    # Ensure that the number of images and masks match
    if len(all_train_images) != len(all_train_masks):
        print("[Error] Number of training images and masks do not match!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images, 
        test_size=0.2,  # 20% for validation
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets
    # ------------------------------
    transform = ResizeAndToTensor(size=(512, 512))

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0
    )

    # ------------------------------
    # Create Test Dataset and DataLoader
    # ------------------------------
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    all_test_masks = sorted([
        f for f in os.listdir(test_masks_dir)
        if f.lower().endswith(".npy")
    ])
    print(f"[Main] Found {len(all_test_masks)} test mask files in {test_masks_dir}")

    if len(all_test_masks) == 0:
        print("[Error] No test mask files found. Check your test mask path!")
        return

    # Ensure that the number of test images and masks match
    if len(all_test_images) != len(all_test_masks):
        print("[Error] Number of test images and masks do not match!")
        return

    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=transform
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    print(f"[Main] Test samples: {len(test_dataset)}")

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = UNet(in_channels=3, out_channels=2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=3, 
        verbose=True
    )

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)
        val_loss   = validate_one_epoch(model, val_loader, criterion, device, epoch)
        
        # Step scheduler with validation loss
        scheduler.step(val_loss)

        # Get current learning rate
        current_lr = optimizer.param_groups[0]["lr"]
        
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {current_lr:.6f} | "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Save the model if validation loss has decreased
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")

    print(f"\nTraining complete. Best Validation Loss: {best_val_loss:.4f}")
    print("Best model saved as 'best_model.pth'.")

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()

    test_dice = test_segmentation(model, test_loader, device)
    print(f"Test Dice Coefficient: {test_dice:.4f}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=3)
    else:
        print("[Warning] No test samples available for visualization.")

if __name__ == "__main__":
    main()


In [2]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
import torchvision.models as models  # For ResNet encoder

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchmetrics
import random



# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size and convert them to tensors.
    """
    def __init__(self, size=(512, 512), augment=False):
        self.size = size  # (height, width)
        self.augment = augment
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array or PIL Image): Corresponding mask as a NumPy array or PIL Image.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        if isinstance(mask, np.ndarray):
            mask_pil = Image.fromarray(mask)
        elif isinstance(mask, Image.Image):
            mask_pil = mask
        else:
            raise TypeError(f"Unsupported mask type: {type(mask)}")
        
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array if it was initially
        if isinstance(mask, np.ndarray):
            mask = np.array(mask_pil, dtype=np.uint8)
        else:
            mask = np.array(mask_pil, dtype=np.uint8)

        # Data Augmentation (if enabled)
        if self.augment:
            # Random horizontal flip
            if random.random() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                mask = Image.fromarray(mask).transpose(Image.FLIP_LEFT_RIGHT)
                mask = np.array(mask, dtype=np.uint8)

            # Random vertical flip
            if random.random() > 0.5:
                image = image.transpose(Image.FLIP_TOP_BOTTOM)
                mask = Image.fromarray(mask).transpose(Image.FLIP_TOP_BOTTOM)
                mask = np.array(mask, dtype=np.uint8)

            # Random rotation by 0°, 90°, 180°, or 270°
            angles = [0, 90, 180, 270]
            angle = random.choice(angles)
            if angle != 0:
                image = image.rotate(angle)
                mask = Image.fromarray(mask).rotate(angle)
                mask = np.array(mask, dtype=np.uint8)

            # Add more augmentations as needed

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        image_tensor = self.normalize(image_tensor)

        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy or image files).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        # Attempt to find mask with the same base name and supported extensions
        mask_extensions = ['.npy', '.png', '.jpg', '.jpeg']
        mask_path = None
        for ext in mask_extensions:
            potential_path = os.path.join(self.masks_dir, base_name + ext)
            if os.path.isfile(potential_path):
                mask_path = potential_path
                break
        if mask_path is None:
            raise FileNotFoundError(f"[Error] Mask file not found for image: {img_filename}")

        # Load mask
        if mask_path.endswith('.npy'):
            mask = np.load(mask_path)
            mask = Image.fromarray(mask)
        else:
            mask = Image.open(mask_path).convert("L")  # Convert to grayscale

        # Remap any label >1 to 1 to ensure binary segmentation
        mask_np = np.array(mask)
        mask_np = np.where(mask_np > 1, 1, mask_np).astype(np.uint8)
        mask = Image.fromarray(mask_np)

        # Apply transformations if any
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. U-Net Model Definition with ResNet Encoder
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm and ReLU.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class ResNetUNet(nn.Module):
    """
    U-Net architecture with ResNet encoder.
    """
    def __init__(self, n_classes=2, encoder_name='resnet34', pretrained=True):
        """
        Args:
            n_classes (int): Number of output classes.
            encoder_name (str): Name of the ResNet encoder to use ('resnet18', 'resnet34', 'resnet50', etc.).
            pretrained (bool): Whether to use pretrained ResNet weights.
        """
        super(ResNetUNet, self).__init__()
        
        # Initialize ResNet encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet101':
            self.encoder = models.resnet101(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        else:
            raise ValueError("Unsupported ResNet variant")

        # Encoder layers
        self.initial = nn.Sequential(
            self.encoder.conv1,  # [B, 64, H/2, W/2]
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool  # [B, 64, H/4, W/4]
        )
        self.encoder_layer1 = self.encoder.layer1  # [B, 64 or 256, H/4, W/4]
        self.encoder_layer2 = self.encoder.layer2  # [B, 128 or 512, H/8, W/8]
        self.encoder_layer3 = self.encoder.layer3  # [B, 256 or 1024, H/16, W/16]
        self.encoder_layer4 = self.encoder.layer4  # [B, 512 or 2048, H/32, W/32]

        # Decoder layers
        self.up4 = nn.ConvTranspose2d(encoder_channels[4], encoder_channels[3], kernel_size=2, stride=2)
        self.conv4 = DoubleConv(encoder_channels[3] + encoder_channels[3], encoder_channels[3])

        self.up3 = nn.ConvTranspose2d(encoder_channels[3], encoder_channels[2], kernel_size=2, stride=2)
        self.conv3 = DoubleConv(encoder_channels[2] + encoder_channels[2], encoder_channels[2])

        self.up2 = nn.ConvTranspose2d(encoder_channels[2], encoder_channels[1], kernel_size=2, stride=2)
        self.conv2 = DoubleConv(encoder_channels[1] + encoder_channels[1], encoder_channels[1])

        # Removed up1 and conv1 to prevent spatial dimension mismatch
        # If you need more upsampling steps, ensure corresponding encoder layers are present

        self.out_conv = nn.Conv2d(encoder_channels[1], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.initial(x)           # [B, 64, H/4, W/4]
        x1 = self.encoder_layer1(x0)   # [B, 64, H/4, W/4] for resnet34
        x2 = self.encoder_layer2(x1)   # [B, 128, H/8, W/8]
        x3 = self.encoder_layer3(x2)   # [B, 256, H/16, W/16]
        x4 = self.encoder_layer4(x3)   # [B, 512, H/32, W/32]

        # Decoder
        up4 = self.up4(x4)             # [B, 256, H/16, W/16]
        merge4 = torch.cat([up4, x3], dim=1)  # [B, 512, H/16, W/16]
        conv4 = self.conv4(merge4)     # [B, 256, H/16, W/16]

        up3 = self.up3(conv4)          # [B, 128, H/8, W/8]
        merge3 = torch.cat([up3, x2], dim=1)  # [B, 256, H/8, W/8]
        conv3 = self.conv3(merge3)     # [B, 128, H/8, W/8]

        up2 = self.up2(conv3)          # [B, 64, H/4, W/4]
        merge2 = torch.cat([up2, x1], dim=1)  # [B, 128, H/4, W/4]
        conv2 = self.conv2(merge2)     # [B, 64, H/4, W/4]

        # No up1 and conv1 to prevent spatial dimension mismatch
        out = self.out_conv(conv2)     # [B, n_classes, H/4, W/4]

        # Upsample to original size
        out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)  # [B, n_classes, H, W]

        return out

# ------------------------------
# 3. Early Stopping Class
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 4. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx):
    """
    Trains the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
    
    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.
    
    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)
    
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 5. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.

    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.

    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    dice_scores = []
    iou_scores = []
    accuracy_scores = []
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]

            outputs = model(images)    # [B, n_classes, H, W]
            preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values



def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.

    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device)  # [1, C, H, W]

        with torch.no_grad():
            output = model(image_batch)  # [1, n_classes, H, W]
            pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]



        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask" + (" (Post-Processed)" if post_process_flag else ""))
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 6. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 8
    num_epochs    = 20
    learning_rate = 1e-3
    val_split     = 0.2
    save_path     = "best_model.pth"
    patience      = 5
    post_process_flag = True  # Set to True to apply post-processing

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if any(
            os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ext))
            for ext in ['.npy', '.png', '.jpg', '.jpeg']
        )
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if any(
            os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ext))
            for ext in ['.npy', '.png', '.jpg', '.jpeg']
        )
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512), augment=True)
    val_transform = ResizeAndToTensor(size=(512, 512), augment=False)
    test_transform = ResizeAndToTensor(size=(512, 512), augment=False)

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,  # Adjust based on your CPU
        pin_memory=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = ResNetUNet(n_classes=2, encoder_name='resnet34', pretrained=True).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)

        # Scheduler step
        scheduler.step()

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=10, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 7. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()


[Main] Using device: cuda
[Main] Found 14737 training image files in /home/yanghehao/tracklearning/segmentation/phantom_train/images/
[Main] 14737 training images have corresponding masks.
[Main] Found 3685 test image files in /home/yanghehao/tracklearning/segmentation/phantom_test/images/
[Main] 3685 test images have corresponding masks.
[Main] Training samples: 11789
[Main] Validation samples: 2948
[Dataset] Initialized with 11789 files.
[Dataset] Initialized with 2948 files.
[Dataset] Initialized with 3685 files.




--- [Train] Starting epoch 1 ---


Training:   0%|          | 3/1474 [00:00<06:43,  3.65it/s]

[Train] Batch 0, Loss: 0.8693


Training:   1%|          | 11/1474 [00:01<03:12,  7.59it/s]

[Train] Batch 10, Loss: 0.3038


Training:   1%|▏         | 21/1474 [00:03<02:36,  9.27it/s]

[Train] Batch 20, Loss: 0.1870


Training:   2%|▏         | 32/1474 [00:04<02:49,  8.53it/s]

[Train] Batch 30, Loss: 0.1338


Training:   3%|▎         | 42/1474 [00:05<02:55,  8.15it/s]

[Train] Batch 40, Loss: 0.1068


Training:   4%|▎         | 52/1474 [00:06<02:30,  9.44it/s]

[Train] Batch 50, Loss: 0.0834


Training:   4%|▍         | 63/1474 [00:07<02:39,  8.87it/s]

[Train] Batch 60, Loss: 0.0698


Training:   5%|▍         | 71/1474 [00:08<02:47,  8.39it/s]

[Train] Batch 70, Loss: 0.0598


Training:   6%|▌         | 83/1474 [00:10<02:33,  9.06it/s]

[Train] Batch 80, Loss: 0.0533


Training:   6%|▌         | 91/1474 [00:11<02:35,  8.89it/s]

[Train] Batch 90, Loss: 0.0411


Training:   7%|▋         | 103/1474 [00:12<02:42,  8.45it/s]

[Train] Batch 100, Loss: 0.0401


Training:   8%|▊         | 111/1474 [00:13<02:46,  8.18it/s]

[Train] Batch 110, Loss: 0.0369


Training:   8%|▊         | 123/1474 [00:15<02:52,  7.85it/s]

[Train] Batch 120, Loss: 0.0318


Training:   9%|▉         | 132/1474 [00:16<02:21,  9.51it/s]

[Train] Batch 130, Loss: 0.0302


Training:  10%|▉         | 142/1474 [00:17<02:08, 10.39it/s]

[Train] Batch 140, Loss: 0.0326


Training:  10%|█         | 152/1474 [00:18<02:33,  8.62it/s]

[Train] Batch 150, Loss: 0.0253


Training:  11%|█         | 163/1474 [00:20<02:40,  8.16it/s]

[Train] Batch 160, Loss: 0.0233


Training:  12%|█▏        | 171/1474 [00:21<02:37,  8.28it/s]

[Train] Batch 170, Loss: 0.0255


Training:  12%|█▏        | 183/1474 [00:22<02:37,  8.20it/s]

[Train] Batch 180, Loss: 0.0226


Training:  13%|█▎        | 191/1474 [00:23<02:28,  8.61it/s]

[Train] Batch 190, Loss: 0.0203


Training:  14%|█▍        | 203/1474 [00:24<02:22,  8.92it/s]

[Train] Batch 200, Loss: 0.0233


Training:  14%|█▍        | 211/1474 [00:26<02:35,  8.12it/s]

[Train] Batch 210, Loss: 0.0225


Training:  15%|█▌        | 223/1474 [00:27<02:27,  8.49it/s]

[Train] Batch 220, Loss: 0.0208


Training:  16%|█▌        | 231/1474 [00:28<02:16,  9.12it/s]

[Train] Batch 230, Loss: 0.0210


Training:  16%|█▋        | 243/1474 [00:29<02:17,  8.92it/s]

[Train] Batch 240, Loss: 0.0168


Training:  17%|█▋        | 251/1474 [00:30<02:24,  8.48it/s]

[Train] Batch 250, Loss: 0.0154


Training:  18%|█▊        | 261/1474 [00:32<02:35,  7.78it/s]

[Train] Batch 260, Loss: 0.0165


Training:  18%|█▊        | 271/1474 [00:33<02:20,  8.57it/s]

[Train] Batch 270, Loss: 0.0160


Training:  19%|█▉        | 283/1474 [00:34<02:23,  8.31it/s]

[Train] Batch 280, Loss: 0.0143


Training:  20%|█▉        | 291/1474 [00:35<02:20,  8.43it/s]

[Train] Batch 290, Loss: 0.0164


Training:  20%|██        | 301/1474 [00:36<02:04,  9.42it/s]

[Train] Batch 300, Loss: 0.0188


Training:  21%|██        | 311/1474 [00:38<02:16,  8.51it/s]

[Train] Batch 310, Loss: 0.0151


Training:  22%|██▏       | 321/1474 [00:39<02:20,  8.19it/s]

[Train] Batch 320, Loss: 0.0114


Training:  23%|██▎       | 333/1474 [00:40<02:03,  9.22it/s]

[Train] Batch 330, Loss: 0.0189


Training:  23%|██▎       | 342/1474 [00:41<02:17,  8.25it/s]

[Train] Batch 340, Loss: 0.0154


Training:  24%|██▍       | 352/1474 [00:43<02:23,  7.81it/s]

[Train] Batch 350, Loss: 0.0123


Training:  25%|██▍       | 362/1474 [00:44<02:12,  8.36it/s]

[Train] Batch 360, Loss: 0.0145


Training:  25%|██▌       | 372/1474 [00:45<02:10,  8.43it/s]

[Train] Batch 370, Loss: 0.0181


Training:  26%|██▌       | 380/1474 [00:46<02:10,  8.40it/s]

[Train] Batch 380, Loss: 0.0152


Training:  27%|██▋       | 392/1474 [00:48<02:15,  7.97it/s]

[Train] Batch 390, Loss: 0.0126


Training:  27%|██▋       | 400/1474 [00:49<02:09,  8.32it/s]

[Train] Batch 400, Loss: 0.0117


Training:  28%|██▊       | 410/1474 [00:50<02:08,  8.25it/s]

[Train] Batch 410, Loss: 0.0107


Training:  29%|██▊       | 422/1474 [00:51<02:03,  8.49it/s]

[Train] Batch 420, Loss: 0.0100


Training:  29%|██▉       | 430/1474 [00:53<02:06,  8.28it/s]

[Train] Batch 430, Loss: 0.0112


Training:  30%|██▉       | 442/1474 [00:54<02:07,  8.09it/s]

[Train] Batch 440, Loss: 0.0142


Training:  31%|███       | 452/1474 [00:55<01:56,  8.74it/s]

[Train] Batch 450, Loss: 0.0121


Training:  31%|███▏      | 462/1474 [00:56<01:59,  8.47it/s]

[Train] Batch 460, Loss: 0.0098


Training:  32%|███▏      | 470/1474 [00:57<01:48,  9.23it/s]

[Train] Batch 470, Loss: 0.0127


Training:  33%|███▎      | 482/1474 [00:59<02:04,  7.97it/s]

[Train] Batch 480, Loss: 0.0135


Training:  33%|███▎      | 492/1474 [01:00<02:04,  7.88it/s]

[Train] Batch 490, Loss: 0.0132


Training:  34%|███▍      | 501/1474 [01:01<01:49,  8.90it/s]

[Train] Batch 500, Loss: 0.0144


Training:  35%|███▍      | 513/1474 [01:03<01:58,  8.13it/s]

[Train] Batch 510, Loss: 0.0092


Training:  35%|███▌      | 521/1474 [01:04<01:52,  8.45it/s]

[Train] Batch 520, Loss: 0.0098


Training:  36%|███▌      | 533/1474 [01:05<01:53,  8.27it/s]

[Train] Batch 530, Loss: 0.0116


Training:  37%|███▋      | 542/1474 [01:06<01:44,  8.90it/s]

[Train] Batch 540, Loss: 0.0114


Training:  38%|███▊      | 553/1474 [01:08<01:56,  7.92it/s]

[Train] Batch 550, Loss: 0.0112


Training:  38%|███▊      | 561/1474 [01:09<01:52,  8.12it/s]

[Train] Batch 560, Loss: 0.0136


Training:  39%|███▉      | 573/1474 [01:10<01:47,  8.40it/s]

[Train] Batch 570, Loss: 0.0115


Training:  39%|███▉      | 581/1474 [01:11<01:37,  9.17it/s]

[Train] Batch 580, Loss: 0.0116


Training:  40%|████      | 593/1474 [01:13<01:44,  8.46it/s]

[Train] Batch 590, Loss: 0.0086


Training:  41%|████      | 601/1474 [01:14<01:42,  8.55it/s]

[Train] Batch 600, Loss: 0.0107


Training:  42%|████▏     | 613/1474 [01:15<01:37,  8.79it/s]

[Train] Batch 610, Loss: 0.0125


Training:  42%|████▏     | 621/1474 [01:16<01:41,  8.41it/s]

[Train] Batch 620, Loss: 0.0132


Training:  43%|████▎     | 633/1474 [01:18<01:40,  8.35it/s]

[Train] Batch 630, Loss: 0.0092


Training:  43%|████▎     | 641/1474 [01:19<01:20, 10.36it/s]

[Train] Batch 640, Loss: 0.0117


Training:  44%|████▍     | 653/1474 [01:20<01:33,  8.79it/s]

[Train] Batch 650, Loss: 0.0119


Training:  45%|████▍     | 661/1474 [01:21<01:43,  7.82it/s]

[Train] Batch 660, Loss: 0.0169


Training:  46%|████▌     | 673/1474 [01:23<01:32,  8.69it/s]

[Train] Batch 670, Loss: 0.0104


Training:  46%|████▌     | 681/1474 [01:24<01:37,  8.17it/s]

[Train] Batch 680, Loss: 0.0079


Training:  47%|████▋     | 693/1474 [01:25<01:30,  8.64it/s]

[Train] Batch 690, Loss: 0.0112


Training:  48%|████▊     | 702/1474 [01:27<01:36,  8.03it/s]

[Train] Batch 700, Loss: 0.0092


Training:  48%|████▊     | 712/1474 [01:28<01:36,  7.88it/s]

[Train] Batch 710, Loss: 0.0093


Training:  49%|████▉     | 722/1474 [01:29<01:33,  8.07it/s]

[Train] Batch 720, Loss: 0.0118


Training:  50%|████▉     | 730/1474 [01:30<01:27,  8.45it/s]

[Train] Batch 730, Loss: 0.0097


Training:  50%|█████     | 742/1474 [01:31<01:15,  9.64it/s]

[Train] Batch 740, Loss: 0.0098


Training:  51%|█████     | 750/1474 [01:32<01:23,  8.66it/s]

[Train] Batch 750, Loss: 0.0147


Training:  52%|█████▏    | 762/1474 [01:34<01:24,  8.43it/s]

[Train] Batch 760, Loss: 0.0105


Training:  52%|█████▏    | 770/1474 [01:35<01:25,  8.21it/s]

[Train] Batch 770, Loss: 0.0091


Training:  53%|█████▎    | 782/1474 [01:36<01:17,  8.96it/s]

[Train] Batch 780, Loss: 0.0115


Training:  54%|█████▎    | 790/1474 [01:37<01:23,  8.14it/s]

[Train] Batch 790, Loss: 0.0124


Training:  54%|█████▍    | 802/1474 [01:39<01:19,  8.50it/s]

[Train] Batch 800, Loss: 0.0097


Training:  55%|█████▌    | 811/1474 [01:40<01:20,  8.28it/s]

[Train] Batch 810, Loss: 0.0124


Training:  56%|█████▌    | 823/1474 [01:41<01:15,  8.62it/s]

[Train] Batch 820, Loss: 0.0124


Training:  56%|█████▋    | 831/1474 [01:43<01:16,  8.37it/s]

[Train] Batch 830, Loss: 0.0099


Training:  57%|█████▋    | 843/1474 [01:44<01:16,  8.29it/s]

[Train] Batch 840, Loss: 0.0096


Training:  58%|█████▊    | 851/1474 [01:45<01:17,  8.01it/s]

[Train] Batch 850, Loss: 0.0100


Training:  58%|█████▊    | 861/1474 [01:46<01:13,  8.37it/s]

[Train] Batch 860, Loss: 0.0119


Training:  59%|█████▉    | 871/1474 [01:48<01:08,  8.78it/s]

[Train] Batch 870, Loss: 0.0120


Training:  60%|█████▉    | 883/1474 [01:49<01:10,  8.38it/s]

[Train] Batch 880, Loss: 0.0106


Training:  60%|██████    | 891/1474 [01:50<01:12,  8.03it/s]

[Train] Batch 890, Loss: 0.0127


Training:  61%|██████▏   | 903/1474 [01:51<01:01,  9.32it/s]

[Train] Batch 900, Loss: 0.0133


Training:  62%|██████▏   | 911/1474 [01:53<01:07,  8.37it/s]

[Train] Batch 910, Loss: 0.0093


Training:  63%|██████▎   | 923/1474 [01:54<01:14,  7.36it/s]

[Train] Batch 920, Loss: 0.0085


                                                            

KeyboardInterrupt: 

In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
import torchvision.models as models  # For ResNet encoder

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchmetrics
import random

# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size and convert them to tensors.
    """
    def __init__(self, size=(512, 512), augment=False):
        self.size = size  # (height, width)
        self.augment = augment
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array or PIL Image): Corresponding mask as a NumPy array or PIL Image.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        if isinstance(mask, np.ndarray):
            mask_pil = Image.fromarray(mask)
        elif isinstance(mask, Image.Image):
            mask_pil = mask
        else:
            raise TypeError(f"Unsupported mask type: {type(mask)}")
        
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)

        # Data Augmentation (if enabled)
        if self.augment:
            # Random horizontal flip
            if random.random() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                mask = Image.fromarray(mask).transpose(Image.FLIP_LEFT_RIGHT)
                mask = np.array(mask, dtype=np.uint8)

            # Random vertical flip
            if random.random() > 0.5:
                image = image.transpose(Image.FLIP_TOP_BOTTOM)
                mask = Image.fromarray(mask).transpose(Image.FLIP_TOP_BOTTOM)
                mask = np.array(mask, dtype=np.uint8)

            # Random rotation by 0°, 90°, 180°, or 270°
            angles = [0, 90, 180, 270]
            angle = random.choice(angles)
            if angle != 0:
                image = image.rotate(angle)
                mask = Image.fromarray(mask).rotate(angle)
                mask = np.array(mask, dtype=np.uint8)

            # Add more augmentations as needed

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        image_tensor = self.normalize(image_tensor)

        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy or image files).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        # Attempt to find mask with the same base name and supported extensions
        mask_extensions = ['.npy', '.png', '.jpg', '.jpeg']
        mask_path = None
        for ext in mask_extensions:
            potential_path = os.path.join(self.masks_dir, base_name + ext)
            if os.path.isfile(potential_path):
                mask_path = potential_path
                break
        if mask_path is None:
            raise FileNotFoundError(f"[Error] Mask file not found for image: {img_filename}")

        # Load mask
        if mask_path.endswith('.npy'):
            mask = np.load(mask_path)
            mask = Image.fromarray(mask)
        else:
            mask = Image.open(mask_path).convert("L")  # Convert to grayscale

        # Remap any label >1 to 1 to ensure binary segmentation
        mask_np = np.array(mask)
        mask_np = np.where(mask_np > 1, 1, mask_np).astype(np.uint8)
        mask = Image.fromarray(mask_np)

        # Apply transformations if any
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. Transformer Module Implementation
# ------------------------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512*512):
        super(PositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        # x shape: [B, N, D]
        return x + self.position_embeddings[:, :x.size(1), :]

class TransformerBottleneck(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_layers=6, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.embed_dim = embed_dim
        self.positional_encoding = PositionalEncoding(d_model=embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            dim_feedforward=embed_dim*4
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.size()
        N = H * W
        # Flatten spatial dimensions
        x = x.view(B, C, N).permute(0, 2, 1)  # [B, N, C]
        x = self.positional_encoding(x)        # [B, N, C]
        x = self.transformer_encoder(x)        # [B, N, C]
        x = self.layer_norm(x)                 # [B, N, C]
        # Reshape back to [B, C, H, W]
        x = x.permute(0, 2, 1).view(B, C, H, W)  # [B, C, H, W]
        return x

# ------------------------------
# 3. U-Net Model Definition with Transformer
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm and ReLU.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class ResNetUNetWithTransformer(nn.Module):
    """
    U-Net architecture with ResNet encoder and Transformer bottleneck.
    """
    def __init__(self, n_classes=2, encoder_name='resnet34', pretrained=True, transformer_layers=6, transformer_heads=8):
        """
        Args:
            n_classes (int): Number of output classes.
            encoder_name (str): Name of the ResNet encoder to use ('resnet18', 'resnet34', 'resnet50', etc.).
            pretrained (bool): Whether to use pretrained ResNet weights.
            transformer_layers (int): Number of Transformer encoder layers.
            transformer_heads (int): Number of attention heads in Transformer.
        """
        super(ResNetUNetWithTransformer, self).__init__()
        
        # Initialize ResNet encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet101':
            self.encoder = models.resnet101(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        else:
            raise ValueError("Unsupported ResNet variant")

        # Encoder layers
        self.initial = nn.Sequential(
            self.encoder.conv1,  # [B, 64, H/2, W/2]
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool  # [B, 64, H/4, W/4]
        )
        self.encoder_layer1 = self.encoder.layer1  # [B, 64 or 256, H/4, W/4]
        self.encoder_layer2 = self.encoder.layer2  # [B, 128 or 512, H/8, W/8]
        self.encoder_layer3 = self.encoder.layer3  # [B, 256 or 1024, H/16, W/16]
        self.encoder_layer4 = self.encoder.layer4  # [B, 512 or 2048, H/32, W/32]

        # Transformer Bottleneck
        self.transformer = TransformerBottleneck(
            embed_dim=encoder_channels[4],
            num_heads=transformer_heads,
            num_layers=transformer_layers,
            dropout=0.1
        )
        
        # Decoder layers
        # Corrected up4 to take encoder_channels[4] as input and output encoder_channels[3]
        self.up4 = nn.ConvTranspose2d(encoder_channels[4], encoder_channels[3], kernel_size=2, stride=2)
        self.conv4 = DoubleConv(encoder_channels[3] + encoder_channels[3], encoder_channels[3])

        self.up3 = nn.ConvTranspose2d(encoder_channels[3], encoder_channels[2], kernel_size=2, stride=2)
        self.conv3 = DoubleConv(encoder_channels[2] + encoder_channels[2], encoder_channels[2])

        self.up2 = nn.ConvTranspose2d(encoder_channels[2], encoder_channels[1], kernel_size=2, stride=2)
        self.conv2 = DoubleConv(encoder_channels[1] + encoder_channels[1], encoder_channels[1])

        

        self.out_conv = nn.Conv2d(encoder_channels[1], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.initial(x)           # [B, 64, H/4, W/4]
        x1 = self.encoder_layer1(x0)   # [B, 64, H/4, W/4] for resnet34
        x2 = self.encoder_layer2(x1)   # [B, 128, H/8, W/8]
        x3 = self.encoder_layer3(x2)   # [B, 256, H/16, W/16]
        x4 = self.encoder_layer4(x3)   # [B, 512, H/32, W/32]

        # Transformer Bottleneck
        x4 = self.transformer(x4)      # [B, 512, H/32, W/32]

        # Decoder
        up4 = self.up4(x4)             # [B, 256, H/16, W/16]
        merge4 = torch.cat([up4, x3], dim=1)  # [B, 512, H/16, W/16]
        conv4 = self.conv4(merge4)     # [B, 256, H/16, W/16]

        up3 = self.up3(conv4)          # [B, 128, H/8, W/8]
        merge3 = torch.cat([up3, x2], dim=1)  # [B, 256, H/8, W/8]
        conv3 = self.conv3(merge3)     # [B, 128, H/8, W/8]

        up2 = self.up2(conv3)          # [B, 64, H/4, W/4]
        merge2 = torch.cat([up2, x1], dim=1)  # [B, 128, H/4, W/4]
        conv2 = self.conv2(merge2)     # [B, 64, H/4, W/4]


        out = self.out_conv(conv2)     # [B, n_classes, H/2, W/2]

        # Upsample to original size
        out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)  # [B, n_classes, H, W]

        return out

# ------------------------------
# 4. Early Stopping Implementation
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 5. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx):
    """
    Trains the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
    
    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.
    
    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)

    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 6. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.

    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.

    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]

            outputs = model(images)    # [B, n_classes, H, W]
            preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values

# ------------------------------
# 7. Visualization of Predictions
# ------------------------------

def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.

    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device)  # [1, C, H, W]

        with torch.no_grad():
            output = model(image_batch)  # [1, n_classes, H, W]
            pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]

        # No post-processing as cv2 is removed
        # If needed, you can implement torch-based post-processing here

        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 8. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 4
    num_epochs    = 20
    learning_rate = 1e-3
    val_split     = 0.2
    save_path     = "best_model.pth"
    patience      = 7
    post_process_flag = False  # Set to True to apply post-processing if implemented

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if any(
            os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ext))
            for ext in ['.npy', '.png', '.jpg', '.jpeg']
        )
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if any(
            os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ext))
            for ext in ['.npy', '.png', '.jpg', '.jpeg']
        )
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512), augment=True)
    val_transform = ResizeAndToTensor(size=(512, 512), augment=False)
    test_transform = ResizeAndToTensor(size=(512, 512), augment=False)

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,  # Adjust based on your CPU cores
        pin_memory=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = ResNetUNetWithTransformer(
        n_classes=2, 
        encoder_name='resnet50', 
        pretrained=True, 
        transformer_layers=8, 
        transformer_heads=16
    ).to(device)
    
    # Cross-Entropy Loss
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)

        # Scheduler step
        scheduler.step()

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=3, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 9. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
import torchvision.models as models  # For ResNet encoder

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import torchmetrics
import random

# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size and convert them to tensors.
    Removes data augmentations like flipping and rotating.
    """
    def __init__(self, size=(512, 512)):
        self.size = size  # (height, width)

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array): Corresponding mask as a NumPy array.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        mask_pil = Image.fromarray(mask)
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        # Normalize the image tensor
        image_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])(image_tensor)

        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    Assumes that mask files are stored as .npy files.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.masks_dir, base_name + ".npy")
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f"[Error] Mask file not found: {mask_path}")
        mask = np.load(mask_path)

        # Remap label >1 to 1 to ensure binary segmentation
        mask = np.where(mask > 1, 1, mask).astype(np.uint8)

        # Apply transform
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. Transformer Module Implementation
# ------------------------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512*512):
        super(PositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        # x shape: [B, N, D]
        return x + self.position_embeddings[:, :x.size(1), :]

class TransformerBottleneck(nn.Module):
    def __init__(self, embed_dim=2048, num_heads=8, num_layers=6, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.embed_dim = embed_dim
        self.positional_encoding = PositionalEncoding(d_model=embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            dim_feedforward=embed_dim*4
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.size()
        N = H * W
        # Flatten spatial dimensions
        x = x.view(B, C, N).permute(0, 2, 1)  # [B, N, C]
        x = self.positional_encoding(x)        # [B, N, C]
        x = self.transformer_encoder(x)        # [B, N, C]
        x = self.layer_norm(x)                 # [B, N, C]
        # Reshape back to [B, C, H, W]
        x = x.permute(0, 2, 1).view(B, C, H, W)  # [B, C, H, W]
        return x

# ------------------------------
# 3. U-Net Model Definition with Transformer
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm and ReLU.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class ResNetUNetWithTransformer(nn.Module):
    """
    U-Net architecture with ResNet encoder and Transformer bottleneck.
    """
    def __init__(self, n_classes=2, encoder_name='resnet101', pretrained=True, transformer_layers=6, transformer_heads=8):
        """
        Args:
            n_classes (int): Number of output classes.
            encoder_name (str): Name of the ResNet encoder to use ('resnet18', 'resnet34', 'resnet50', etc.).
            pretrained (bool): Whether to use pretrained ResNet weights.
            transformer_layers (int): Number of Transformer encoder layers.
            transformer_heads (int): Number of attention heads in Transformer.
        """
        super(ResNetUNetWithTransformer, self).__init__()
        
        # Initialize ResNet encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet101':
            self.encoder = models.resnet101(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet152':
            self.encoder = models.resnet152(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        else:
            raise ValueError("Unsupported ResNet variant")

        # Encoder layers
        self.initial = nn.Sequential(
            self.encoder.conv1,  # [B, 64, H/2, W/2]
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool  # [B, 64, H/4, W/4]
        )
        self.encoder_layer1 = self.encoder.layer1  # [B, 256, H/4, W/4] for ResNet101
        self.encoder_layer2 = self.encoder.layer2  # [B, 512, H/8, W/8]
        self.encoder_layer3 = self.encoder.layer3  # [B, 1024, H/16, W/16]
        self.encoder_layer4 = self.encoder.layer4  # [B, 2048, H/32, W/32]

        # Transformer Bottleneck
        self.transformer = TransformerBottleneck(
            embed_dim=encoder_channels[4],
            num_heads=transformer_heads,
            num_layers=transformer_layers,
            dropout=0.1
        )
        
        # Decoder layers
        # up4: [B, 2048, H/32, W/32] -> [B,1024,H/16,W/16]
        self.up4 = nn.ConvTranspose2d(encoder_channels[4], encoder_channels[3], kernel_size=2, stride=2)
        self.conv4 = DoubleConv(encoder_channels[3] + encoder_channels[3], encoder_channels[3])

        # up3: [B,1024,H/16,W/16] -> [B,512,H/8,W/8]
        self.up3 = nn.ConvTranspose2d(encoder_channels[3], encoder_channels[2], kernel_size=2, stride=2)
        self.conv3 = DoubleConv(encoder_channels[2] + encoder_channels[2], encoder_channels[2])

        # up2: [B,512,H/8,W/8] -> [B,256,H/4,W/4]
        self.up2 = nn.ConvTranspose2d(encoder_channels[2], encoder_channels[1], kernel_size=2, stride=2)
        self.conv2 = DoubleConv(encoder_channels[1] + encoder_channels[1], encoder_channels[1])

        # up1: [B,256,H/4,W/4] -> [B,64,H/4,W/4] (no upsampling)
        self.up1 = nn.ConvTranspose2d(encoder_channels[1], encoder_channels[0], kernel_size=1, stride=1)
        self.conv1 = DoubleConv(encoder_channels[0] + encoder_channels[0], encoder_channels[0])

        # Output convolution
        self.out_conv = nn.Conv2d(encoder_channels[0], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.initial(x)           # [B,64,H/4,W/4]
        x1 = self.encoder_layer1(x0)   # [B,256,H/4,W/4] for ResNet101
        x2 = self.encoder_layer2(x1)   # [B,512,H/8,W/8]
        x3 = self.encoder_layer3(x2)   # [B,1024,H/16,W/16]
        x4 = self.encoder_layer4(x3)   # [B,2048,H/32,W/32]

        # Transformer Bottleneck
        x4 = self.transformer(x4)      # [B,2048,H/32,W/32]

        # Decoder
        up4 = self.up4(x4)             # [B,1024,H/16,W/16]
        merge4 = torch.cat([up4, x3], dim=1)  # [B,2048,H/16,W/16]
        conv4 = self.conv4(merge4)     # [B,1024,H/16,W/16]

        up3 = self.up3(conv4)          # [B,512,H/8,W/8]
        merge3 = torch.cat([up3, x2], dim=1)  # [B,1024,H/8,W/8]
        conv3 = self.conv3(merge3)     # [B,512,H/8,W/8]

        up2 = self.up2(conv3)          # [B,256,H/4,W/4]
        merge2 = torch.cat([up2, x1], dim=1)  # [B,512,H/4,W/4]
        conv2 = self.conv2(merge2)     # [B,256,H/4,W/4]

        up1 = self.up1(conv2)          # [B,64,H/4,W/4]
        merge1 = torch.cat([up1, x0], dim=1)  # [B,128,H/4,W/4]
        conv1 = self.conv1(merge1)     # [B,64,H/4,W/4]

        # Output convolution
        out = self.out_conv(conv1)     # [B,2,H/4,W/4]

        # Upsample to original size
        out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)  # [B,2,H,W]

        return out

# ------------------------------
# 4. Early Stopping Implementation
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 5. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx, scaler):
    """
    Trains the model for one epoch using Mixed Precision.

    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        scaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision.

    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch using Mixed Precision.

    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.

    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)

    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 6. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.

    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.

    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)  # [B, H, W]

            with torch.cuda.amp.autocast():
                outputs = model(images)    # [B, n_classes, H, W]
                preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values

# ------------------------------
# 7. Visualization of Predictions
# ------------------------------

def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.

    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device, non_blocking=True)  # [1, C, H, W]

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                output = model(image_batch)  # [1, n_classes, H, W]
                pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]

        # No post-processing as cv2 is removed
        # If needed, you can implement torch-based post-processing here

        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 8. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 8  # Increased batch size for speed; adjust based on GPU memory
    num_epochs    = 30  # Increased epochs to allow convergence with ResNet101
    learning_rate = 1e-4
    val_split     = 0.2
    save_path     = "best_model_resnet101.pth"
    patience      = 10
    post_process_flag = False  # Set to True to apply post-processing if implemented

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512))
    val_transform = ResizeAndToTensor(size=(512, 512))
    test_transform = ResizeAndToTensor(size=(512, 512))

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    num_workers = 8  # Increased number of workers for faster data loading; adjust based on CPU cores

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = ResNetUNetWithTransformer(
        n_classes=2, 
        encoder_name='resnet101', 
        pretrained=True, 
        transformer_layers=6, 
        transformer_heads=8
    ).to(device)
    
    # Cross-Entropy Loss
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-6)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    # Initialize Gradient Scaler for Mixed Precision
    scaler = torch.cuda.amp.GradScaler()

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, scaler)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)

        # Scheduler step
        scheduler.step(epoch + val_split)  # Using epoch count including validation split

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=3, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 9. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
import torchvision.models as models  # For ResNet encoder

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import torchmetrics
import random

# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size and convert them to tensors.
    Removes data augmentations like flipping and rotating.
    """
    def __init__(self, size=(512, 512)):
        self.size = size  # (height, width)
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array): Corresponding mask as a NumPy array.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        mask_pil = Image.fromarray(mask)
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        image_tensor = self.normalize(image_tensor)
        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    Assumes that mask files are stored as .npy files.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.masks_dir, base_name + ".npy")
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f"[Error] Mask file not found: {mask_path}")
        mask = np.load(mask_path)

        # Remap label >1 to 1 to ensure binary segmentation
        mask = np.where(mask > 1, 1, mask).astype(np.uint8)

        # Apply transform
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. Transformer Module Implementation
# ------------------------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512*512):
        super(PositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        # x shape: [B, N, D]
        return x + self.position_embeddings[:, :x.size(1), :]

class TransformerBottleneck(nn.Module):
    def __init__(self, embed_dim=2048, num_heads=8, num_layers=6, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.embed_dim = embed_dim
        self.positional_encoding = PositionalEncoding(d_model=embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            dim_feedforward=embed_dim*4
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.size()
        N = H * W
        # Flatten spatial dimensions
        x = x.view(B, C, N).permute(0, 2, 1)  # [B, N, C]
        x = self.positional_encoding(x)        # [B, N, C]
        x = self.transformer_encoder(x)        # [B, N, C]
        x = self.layer_norm(x)                 # [B, N, C]
        # Reshape back to [B, C, H, W]
        x = x.permute(0, 2, 1).view(B, C, H, W)  # [B, C, H, W]
        return x

# ------------------------------
# 3. U-Net Model Definition with Transformer
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm, ReLU, and Dropout.
    """
    def __init__(self, in_channels, out_channels, dropout_p=0.5):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),
        )

    def forward(self, x):
        return self.conv(x)

class DiceLoss(nn.Module):
    """
    Dice Loss for binary segmentation.
    """
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, logits, true):
        probs = nn.functional.softmax(logits, dim=1)
        true_one_hot = nn.functional.one_hot(true, num_classes=probs.shape[1]).permute(0, 3, 1, 2).float()
        dims = (0, 2, 3)
        intersection = torch.sum(probs * true_one_hot, dims)
        cardinality = torch.sum(probs + true_one_hot, dims)
        dice = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1 - dice.mean()

class ResNetUNetWithTransformer(nn.Module):
    """
    U-Net architecture with ResNet encoder and Transformer bottleneck.
    """
    def __init__(self, n_classes=2, encoder_name='resnet50', pretrained=True, transformer_layers=6, transformer_heads=8, dropout_p=0.5):
        """
        Args:
            n_classes (int): Number of output classes.
            encoder_name (str): Name of the ResNet encoder to use ('resnet18', 'resnet34', 'resnet50', etc.).
            pretrained (bool): Whether to use pretrained ResNet weights.
            transformer_layers (int): Number of Transformer encoder layers.
            transformer_heads (int): Number of attention heads in Transformer.
            dropout_p (float): Dropout probability.
        """
        super(ResNetUNetWithTransformer, self).__init__()
        
        # Initialize ResNet encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet101':
            self.encoder = models.resnet101(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet152':
            self.encoder = models.resnet152(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        else:
            raise ValueError("Unsupported ResNet variant")

        # Encoder layers
        self.initial = nn.Sequential(
            self.encoder.conv1,  # [B, 64, H/2, W/2]
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool  # [B, 64, H/4, W/4]
        )
        self.encoder_layer1 = self.encoder.layer1  # [B, 256, H/4, W/4] for ResNet101
        self.encoder_layer2 = self.encoder.layer2  # [B, 512, H/8, W/8]
        self.encoder_layer3 = self.encoder.layer3  # [B, 1024, H/16, W/16]
        self.encoder_layer4 = self.encoder.layer4  # [B, 2048, H/32, W/32]

        # Transformer Bottleneck
        self.transformer = TransformerBottleneck(
            embed_dim=encoder_channels[4],
            num_heads=transformer_heads,
            num_layers=transformer_layers,
            dropout=0.1
        )
        
        # Decoder layers
        # up4: [B, 2048, H/32, W/32] -> [B,1024,H/16,W/16]
        self.up4 = nn.ConvTranspose2d(encoder_channels[4], encoder_channels[3], kernel_size=2, stride=2)
        self.conv4 = DoubleConv(encoder_channels[3] + encoder_channels[3], encoder_channels[3], dropout_p=dropout_p)

        # up3: [B,1024,H/16,W/16] -> [B,512,H/8,W/8]
        self.up3 = nn.ConvTranspose2d(encoder_channels[3], encoder_channels[2], kernel_size=2, stride=2)
        self.conv3 = DoubleConv(encoder_channels[2] + encoder_channels[2], encoder_channels[2], dropout_p=dropout_p)

        # up2: [B,512,H/8,W/8] -> [B,256,H/4,W/4]
        self.up2 = nn.ConvTranspose2d(encoder_channels[2], encoder_channels[1], kernel_size=2, stride=2)
        self.conv2 = DoubleConv(encoder_channels[1] + encoder_channels[1], encoder_channels[1], dropout_p=dropout_p)

        # up1: [B,256,H/4,W/4] -> [B,64,H/4,W/4] (no upsampling)
        self.up1 = nn.ConvTranspose2d(encoder_channels[1], encoder_channels[0], kernel_size=1, stride=1)
        self.conv1 = DoubleConv(encoder_channels[0] + encoder_channels[0], encoder_channels[0], dropout_p=dropout_p)

        # Output convolution
        self.out_conv = nn.Conv2d(encoder_channels[0], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.initial(x)           # [B,64,H/4,W/4]
        x1 = self.encoder_layer1(x0)   # [B,256,H/4,W/4] for ResNet101
        x2 = self.encoder_layer2(x1)   # [B,512,H/8,W/8]
        x3 = self.encoder_layer3(x2)   # [B,1024,H/16,W/16]
        x4 = self.encoder_layer4(x3)   # [B,2048,H/32,W/32]

        # Transformer Bottleneck
        x4 = self.transformer(x4)      # [B,2048,H/32,W/32]

        # Decoder
        up4 = self.up4(x4)             # [B,1024,H/16,W/16]
        merge4 = torch.cat([up4, x3], dim=1)  # [B,2048,H/16,W/16]
        conv4 = self.conv4(merge4)     # [B,1024,H/16,W/16]

        up3 = self.up3(conv4)          # [B,512,H/8,W/8]
        merge3 = torch.cat([up3, x2], dim=1)  # [B,1024,H/8,W/8]
        conv3 = self.conv3(merge3)     # [B,512,H/8,W/8]

        up2 = self.up2(conv3)          # [B,256,H/4,W/4]
        merge2 = torch.cat([up2, x1], dim=1)  # [B,512,H/4,W/4]
        conv2 = self.conv2(merge2)     # [B,256,H/4,W/4]

        up1 = self.up1(conv2)          # [B,64,H/4,W/4]
        merge1 = torch.cat([up1, x0], dim=1)  # [B,128,H/4,W/4]
        conv1 = self.conv1(merge1)     # [B,64,H/4,W/4]

        # Output convolution
        out = self.out_conv(conv1)     # [B,2,H/4,W/4]

        # Upsample to original size
        out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)  # [B,2,H,W]

        return out

# ------------------------------
# 4. Early Stopping Implementation
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 5. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx, scaler):
    """
    Trains the model for one epoch using Mixed Precision.

    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        scaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision.

    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch using Mixed Precision.

    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.

    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)

    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 6. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.

    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.

    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)  # [B, H, W]

            with torch.cuda.amp.autocast():
                outputs = model(images)    # [B, n_classes, H, W]
                preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values

# ------------------------------
# 7. Visualization of Predictions
# ------------------------------

def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.

    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device, non_blocking=True)  # [1, C, H, W]

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                output = model(image_batch)  # [1, n_classes, H, W]
                pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]

        # No post-processing as cv2 is removed
        # If needed, you can implement torch-based post-processing here

        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 8. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 8  # Adjust based on GPU memory
    num_epochs    = 30
    learning_rate = 1e-4
    val_split     = 0.2
    save_path     = "best_model_resnet101.pth"
    patience      = 10
    post_process_flag = False  # Set to True to apply post-processing if implemented

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512))
    val_transform = ResizeAndToTensor(size=(512, 512))
    test_transform = ResizeAndToTensor(size=(512, 512))

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    num_workers = 8  # Adjust based on CPU cores

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = ResNetUNetWithTransformer(
        n_classes=2, 
        encoder_name='resnet34', 
        pretrained=True, 
        transformer_layers=16, 
        transformer_heads=16,
        dropout_p=0.5
    ).to(device)
    
    # Initialize Dice Loss
    criterion = DiceLoss()

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }

    # Initialize Gradient Scaler for Mixed Precision
    scaler = torch.cuda.amp.GradScaler()

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, scaler)
        train_losses.append(train_loss)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)
        val_losses.append(val_loss)

        # Scheduler step
        scheduler.step(epoch + val_split)  # Using epoch count including validation split

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # Plot Training and Validation Loss
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs, train_losses, 'b', label='Training loss')
    plt.plot(epochs, val_losses, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=20, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 9. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
import torchvision.models as models  # For ResNet encoder

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchmetrics
import random

# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size, apply controlled augmentations,
    and convert them to tensors.
    """
    def __init__(self, size=(512, 512), augment=False):
        self.size = size  # (height, width)
        self.augment = augment
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array): Corresponding mask as a NumPy array.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        mask_pil = Image.fromarray(mask)
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)
        
        # Controlled Augmentations
        if self.augment:
            # Random brightness adjustment
            enhancer = ImageEnhance.Brightness(image)
            brightness_factor = random.uniform(0.8, 1.2)
            image = enhancer.enhance(brightness_factor)
            
            # Random contrast adjustment
            enhancer = ImageEnhance.Contrast(image)
            contrast_factor = random.uniform(0.8, 1.2)
            image = enhancer.enhance(contrast_factor)
            
            # Additional augmentations can be added here

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        image_tensor = self.normalize(image_tensor)
        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    Assumes that mask files are stored as .npy files.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.masks_dir, base_name + ".npy")
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f"[Error] Mask file not found: {mask_path}")
        mask = np.load(mask_path)

        # Remap label >1 to 1 to ensure binary segmentation
        mask = np.where(mask > 1, 1, mask).astype(np.uint8)

        # Apply transform
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 2. Attention Gates Implementation
# ------------------------------

class AttentionGate(nn.Module):
    """
    Attention Gate as described in "Attention U-Net: Learning Where to Look for the Pancreas"
    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        # g: gating signal (from decoder)
        # x: skip connection (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# ------------------------------
# 3. Transformer Module Implementation
# ------------------------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512*512):
        super(PositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        # x shape: [B, N, D]
        return x + self.position_embeddings[:, :x.size(1), :]

class SwinTransformerBlock(nn.Module):
    """
    Simplified Swin Transformer Block for demonstration purposes.
    For full implementation, consider using a library like timm.
    """
    def __init__(self, embed_dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., dropout=0.):
        super(SwinTransformerBlock, self).__init__()
        # Placeholder for a Swin Transformer block
        # In practice, use a robust implementation from a library
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # x: [B, N, D]
        residual = x
        x = self.layer_norm1(x)
        x, _ = self.attn(x, x, x)
        x = residual + x
        residual = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = residual + x
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, embed_dim=2048, num_heads=8, num_layers=6, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.embed_dim = embed_dim
        self.positional_encoding = PositionalEncoding(d_model=embed_dim)
        self.transformer_blocks = nn.ModuleList([
            SwinTransformerBlock(embed_dim, num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.size()
        N = H * W
        # Flatten spatial dimensions
        x = x.view(B, C, N).permute(0, 2, 1)  # [B, N, C]
        x = self.positional_encoding(x)        # [B, N, C]
        for block in self.transformer_blocks:
            x = block(x)                        # [B, N, C]
        x = self.layer_norm(x)                 # [B, N, C]
        # Reshape back to [B, C, H, W]
        x = x.permute(0, 2, 1).view(B, C, H, W)  # [B, C, H, W]
        return x

# ------------------------------
# 4. U-Net Model Definition with Transformer and Attention Gates
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm, ReLU, and Dropout.
    Includes a residual connection.
    """
    def __init__(self, in_channels, out_channels, dropout_p=0.5):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),
        )
        # Residual connection
        if in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        return self.conv(x) + self.residual(x)

class DiceLoss(nn.Module):
    """
    Dice Loss for binary segmentation.
    """
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, logits, true):
        probs = nn.functional.softmax(logits, dim=1)
        true_one_hot = nn.functional.one_hot(true, num_classes=probs.shape[1]).permute(0, 3, 1, 2).float()
        dims = (0, 2, 3)
        intersection = torch.sum(probs * true_one_hot, dims)
        cardinality = torch.sum(probs + true_one_hot, dims)
        dice = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    """
    Combined Dice Loss and Cross-Entropy Loss.
    """
    def __init__(self, weight=None, smooth=1):
        super(CombinedLoss, self).__init__()
        self.dice = DiceLoss(smooth)
        self.ce = nn.CrossEntropyLoss(weight=weight)
    
    def forward(self, logits, true):
        return self.ce(logits, true) + self.dice(logits, true)

class ResNetUNetWithTransformerAttention(nn.Module):
    """
    U-Net architecture with ResNet encoder, Transformer bottleneck, and Attention Gates.
    Includes residual connections in the decoder's DoubleConv blocks.
    """
    def __init__(self, n_classes=2, encoder_name='resnet101', pretrained=True, transformer_layers=6, transformer_heads=8, dropout_p=0.5):
        """
        Args:
            n_classes (int): Number of output classes.
            encoder_name (str): Name of the ResNet encoder to use ('resnet18', 'resnet34', 'resnet50', etc.).
            pretrained (bool): Whether to use pretrained ResNet weights.
            transformer_layers (int): Number of Transformer encoder layers.
            transformer_heads (int): Number of attention heads in Transformer.
            dropout_p (float): Dropout probability.
        """
        super(ResNetUNetWithTransformerAttention, self).__init__()
        
        # Initialize ResNet encoder
        if encoder_name == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet34':
            self.encoder = models.resnet34(pretrained=pretrained)
            encoder_channels = [64, 64, 128, 256, 512]
        elif encoder_name == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet101':
            self.encoder = models.resnet101(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        elif encoder_name == 'resnet152':
            self.encoder = models.resnet152(pretrained=pretrained)
            encoder_channels = [64, 256, 512, 1024, 2048]
        else:
            raise ValueError("Unsupported ResNet variant")

        # Encoder layers
        self.initial = nn.Sequential(
            self.encoder.conv1,  # [B, 64, H/2, W/2]
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool  # [B, 64, H/4, W/4]
        )
        self.encoder_layer1 = self.encoder.layer1  # [B, 256, H/4, W/4] for ResNet101
        self.encoder_layer2 = self.encoder.layer2  # [B, 512, H/8, W/8]
        self.encoder_layer3 = self.encoder.layer3  # [B, 1024, H/16, W/16]
        self.encoder_layer4 = self.encoder.layer4  # [B, 2048, H/32, W/32]

        # Transformer Bottleneck
        self.transformer = TransformerBottleneck(
            embed_dim=encoder_channels[4],
            num_heads=transformer_heads,
            num_layers=transformer_layers,
            dropout=0.1
        )
        
        # Decoder layers with Attention Gates
        self.up4 = nn.ConvTranspose2d(encoder_channels[4], encoder_channels[3], kernel_size=2, stride=2)
        self.att4 = AttentionGate(F_g=encoder_channels[3], F_l=encoder_channels[3], F_int=encoder_channels[2])
        self.conv4 = DoubleConv(encoder_channels[3] + encoder_channels[3], encoder_channels[3], dropout_p=dropout_p)

        self.up3 = nn.ConvTranspose2d(encoder_channels[3], encoder_channels[2], kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=encoder_channels[2], F_l=encoder_channels[2], F_int=encoder_channels[1])
        self.conv3 = DoubleConv(encoder_channels[2] + encoder_channels[2], encoder_channels[2], dropout_p=dropout_p)

        self.up2 = nn.ConvTranspose2d(encoder_channels[2], encoder_channels[1], kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=encoder_channels[1], F_l=encoder_channels[1], F_int=encoder_channels[0])
        self.conv2 = DoubleConv(encoder_channels[1] + encoder_channels[1], encoder_channels[1], dropout_p=dropout_p)

        self.up1 = nn.ConvTranspose2d(encoder_channels[1], encoder_channels[0], kernel_size=1, stride=1)
        self.att1 = AttentionGate(F_g=encoder_channels[0], F_l=encoder_channels[0], F_int=encoder_channels[0]//2)
        self.conv1 = DoubleConv(encoder_channels[0] + encoder_channels[0], encoder_channels[0], dropout_p=dropout_p)

        # Output convolution
        self.out_conv = nn.Conv2d(encoder_channels[0], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.initial(x)           # [B,64,H/4,W/4]
        x1 = self.encoder_layer1(x0)   # [B,256,H/4,W/4] for ResNet101
        x2 = self.encoder_layer2(x1)   # [B,512,H/8,W/8]
        x3 = self.encoder_layer3(x2)   # [B,1024,H/16,W/16]
        x4 = self.encoder_layer4(x3)   # [B,2048,H/32,W/32]

        # Transformer Bottleneck
        x4 = self.transformer(x4)      # [B,2048,H/32,W/32]

        # Decoder
        up4 = self.up4(x4)             # [B,1024,H/16,W/16]
        att4 = self.att4(g=up4, x=x3)   # [B,1024,H/16,W/16]
        merge4 = torch.cat([up4, att4], dim=1)  # [B,2048,H/16,W/16]
        conv4 = self.conv4(merge4)     # [B,1024,H/16,W/16]

        up3 = self.up3(conv4)          # [B,512,H/8,W/8]
        att3 = self.att3(g=up3, x=x2)   # [B,512,H/8,W/8]
        merge3 = torch.cat([up3, att3], dim=1)  # [B,1024,H/8,W/8]
        conv3 = self.conv3(merge3)     # [B,512,H/8,W/8]

        up2 = self.up2(conv3)          # [B,256,H/4,W/4]
        att2 = self.att2(g=up2, x=x1)   # [B,256,H/4,W/4]
        merge2 = torch.cat([up2, att2], dim=1)  # [B,512,H/4,W/4]
        conv2 = self.conv2(merge2)     # [B,256,H/4,W/4]

        up1 = self.up1(conv2)          # [B,64,H/4,W/4]
        att1 = self.att1(g=up1, x=x0)   # [B,64,H/4,W/4]
        merge1 = torch.cat([up1, att1], dim=1)  # [B,128,H/4,W/4]
        conv1 = self.conv1(merge1)     # [B,64,H/4,W/4]

        # Output convolution
        out = self.out_conv(conv1)     # [B,2,H/4,W/4]

        # Upsample to original size
        out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)  # [B,2,H,W]

        return out

# ------------------------------
# 5. Early Stopping Implementation
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 6. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx, scaler, clip_grad_norm=1.0):
    """
    Trains the model for one epoch using Mixed Precision.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Combined loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        scaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision.
        clip_grad_norm (float): Maximum norm for gradient clipping.
    
    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        
        # Gradient Clipping
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
        
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch using Mixed Precision.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Combined loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.
    
    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)

    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 7. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.
    
    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.
    
    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)  # [B, H, W]

            with torch.cuda.amp.autocast():
                outputs = model(images)    # [B, n_classes, H, W]
                preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values

# ------------------------------
# 8. Visualization of Predictions
# ------------------------------

def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.
    
    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device, non_blocking=True)  # [1, C, H, W]

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                output = model(image_batch)  # [1, n_classes, H, W]
                pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]

        # Optional Post-Processing
        if post_process_flag:
            # Example: Simple morphological operations can be implemented here
            # For demonstration, this is left as a placeholder
            pass

        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 9. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 8  # Adjust based on GPU memory
    num_epochs    = 50
    learning_rate = 1e-4
    val_split     = 0.2
    save_path     = "best_model_resnet101_attention.pth"
    patience      = 15
    post_process_flag = False  # Set to True to apply post-processing if implemented

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512), augment=True)
    val_transform = ResizeAndToTensor(size=(512, 512), augment=False)
    test_transform = ResizeAndToTensor(size=(512, 512), augment=False)

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    num_workers = 8  # Adjust based on CPU cores

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = ResNetUNetWithTransformerAttention(
        n_classes=2, 
        encoder_name='resnet34', 
        pretrained=True, 
        transformer_layers=12, 
        transformer_heads=16,
        dropout_p=0.1
    ).to(device)
    
    # Initialize Combined Loss (Dice + Cross-Entropy)
    # Without class weights
    criterion = CombinedLoss(weight=None)

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    # Scheduler: Cosine Annealing
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2).to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }
    
    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2).to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }
    
    # Initialize Gradient Scaler for Mixed Precision
    scaler = torch.cuda.amp.GradScaler()

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, scaler)
        train_losses.append(train_loss)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)
        val_losses.append(val_loss)

        # Scheduler step based on epoch
        scheduler.step()

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # Plot Training and Validation Loss
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs_range, train_losses, 'b', label='Training loss')
    plt.plot(epochs_range, val_losses, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=20, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 10. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms.functional import resize, to_tensor
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms

from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchmetrics
import random

# ------------------------------
# 0. Reproducibility (Optional)
# ------------------------------

def set_seed(seed=42):
    """
    Sets the seed for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Ensures that CUDA operations are deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ------------------------------
# 1. Combined Loss Function
# ------------------------------

class DiceLoss(nn.Module):
    """
    Dice Loss for binary segmentation.
    """
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, true):
        probs = nn.functional.softmax(logits, dim=1)
        true_one_hot = nn.functional.one_hot(true, num_classes=probs.shape[1]).permute(0, 3, 1, 2).float()
        dims = (0, 2, 3)
        intersection = torch.sum(probs * true_one_hot, dims)
        cardinality = torch.sum(probs + true_one_hot, dims)
        dice = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    """
    Combined Dice Loss and Cross-Entropy Loss.
    """
    def __init__(self, weight=None, smooth=1):
        super(CombinedLoss, self).__init__()
        self.dice = DiceLoss(smooth)
        self.ce = nn.CrossEntropyLoss(weight=weight)
    
    def forward(self, logits, true):
        return self.ce(logits, true) + self.dice(logits, true)

# ------------------------------
# 2. Dataset and Transforms
# ------------------------------

class ResizeAndToTensor:
    """
    Custom transform to resize images and masks to a fixed size, apply controlled augmentations,
    and convert them to tensors.
    """
    def __init__(self, size=(512, 512), augment=False):
        self.size = size  # (height, width)
        self.augment = augment
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

    def __call__(self, image, mask):
        """
        Args:
            image (PIL Image): Input image.
            mask (np.array): Corresponding mask as a NumPy array.
        
        Returns:
            image_tensor (torch.Tensor): Resized and normalized image tensor.
            mask_tensor (torch.Tensor): Resized mask tensor (Long type).
        """
        # Resize image using bilinear interpolation
        image = resize(
            image,
            self.size,
            interpolation=InterpolationMode.BILINEAR
        )
        # Resize mask using nearest-neighbor interpolation to preserve label integrity
        mask_pil = Image.fromarray(mask)
        mask_pil = resize(
            mask_pil,
            self.size,
            interpolation=InterpolationMode.NEAREST
        )
        # Convert back to NumPy array
        mask = np.array(mask_pil, dtype=np.uint8)
        
        # Controlled Augmentations
        if self.augment:
            # Random brightness adjustment
            enhancer = ImageEnhance.Brightness(image)
            brightness_factor = random.uniform(0.8, 1.2)
            image = enhancer.enhance(brightness_factor)
            
            # Random contrast adjustment
            enhancer = ImageEnhance.Contrast(image)
            contrast_factor = random.uniform(0.8, 1.2)
            image = enhancer.enhance(contrast_factor)
            
            # Additional augmentations can be added here

        # Convert image and mask to tensors
        image_tensor = to_tensor(image)             # [C, H, W], float32 in [0,1]
        image_tensor = self.normalize(image_tensor)
        mask_tensor  = torch.from_numpy(mask).long()# [H, W], dtype=torch.int64

        return image_tensor, mask_tensor

class BinarySegDataset(Dataset):
    """
    Custom Dataset for binary image segmentation.
    Assumes that mask files are stored as .npy files.
    """
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        """
        Args:
            images_dir (str): Directory with input images.
            masks_dir (str): Directory with corresponding mask files (.npy).
            file_list (list): List of image filenames.
            transform (callable, optional): Transform to be applied on a sample.
        """
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.file_list  = file_list
        self.transform  = transform

        # Debug: Print how many files in this split
        print(f"[Dataset] Initialized with {len(file_list)} files.")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask pair at the specified index.
        """
        img_filename = self.file_list[idx]
        base_name = os.path.splitext(img_filename)[0]

        # Load image
        img_path = os.path.join(self.images_dir, img_filename)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"[Error] Image file not found: {img_path}")
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.masks_dir, base_name + ".npy")
        if not os.path.isfile(mask_path):
            raise FileNotFoundError(f"[Error] Mask file not found: {mask_path}")
        mask = np.load(mask_path)

        # Remap label >1 to 1 to ensure binary segmentation
        mask = np.where(mask > 1, 1, mask).astype(np.uint8)

        # Apply transform
        if self.transform is not None:
            image, mask = self.transform(image, mask)

        return image, mask

# ------------------------------
# 3. Attention Gates Implementation
# ------------------------------

class AttentionGate(nn.Module):
    """
    Attention Gate as described in "Attention U-Net: Learning Where to Look for the Pancreas"
    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        # g: gating signal (from decoder)
        # x: skip connection (from encoder)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# ------------------------------
# 4. Custom Transformer Bottleneck Implementation
# ------------------------------

class TransformerBottleneck(nn.Module):
    """
    Custom Transformer Bottleneck for U-Net++ architecture.
    Utilizes standard TransformerEncoder layers to capture global dependencies.
    """
    def __init__(self, embed_dim=512, num_heads=8, num_layers=6, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        self.embed_dim = embed_dim
        self.flatten = nn.Flatten(2)  # [B, C, H, W] -> [B, C, H*W]
        self.pos_embedding = nn.Parameter(torch.zeros(1, embed_dim, 32*32))  # [1, 512, 1024]
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dropout=dropout,
            activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        self.unflatten = nn.Unflatten(2, (32, 32))  # [B, 512, 1024] -> [B, 512, 32, 32]

    def forward(self, x):
        # x: [B, C, H, W] where H=32, W=32
        B, C, H, W = x.size()
        x = self.flatten(x)  # [B, C, H*W] = [B, 512, 1024]
        x = x + self.pos_embedding  # [B, 512, 1024]
        x = x.permute(0, 2, 1)  # [B, 1024, 512]
        x = self.transformer_encoder(x)  # [B, 1024, 512]
        x = x.permute(0, 2, 1)  # [B, 512, 1024]
        x = self.unflatten(x)    # [B, 512, 32, 32]
        return x

# ------------------------------
# 5. U-Net++ Model Definition with Transformer and Attention Gates
# ------------------------------

class DoubleConv(nn.Module):
    """
    A block consisting of two convolutional layers each followed by BatchNorm, ReLU, and Dropout.
    Includes a residual connection.
    """
    def __init__(self, in_channels, out_channels, dropout_p=0.5):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),
        )
        # Residual connection
        if in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        return self.conv(x) + self.residual(x)

class UNetPlusPlus(nn.Module):
    """
    U-Net++ architecture with nested skip connections and a Transformer bottleneck.
    Includes Attention Gates and residual connections.
    """
    def __init__(self, in_channels=3, out_channels=2, dropout_p=0.5):
        super(UNetPlusPlus, self).__init__()

        # Encoder
        self.enc1 = DoubleConv(in_channels, 64, dropout_p)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc2 = DoubleConv(64, 128, dropout_p)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc3 = DoubleConv(128, 256, dropout_p)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc4 = DoubleConv(256, 512, dropout_p)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Transformer Bottleneck
        self.transformer = TransformerBottleneck(
            embed_dim=512,
            num_heads=16,
            num_layers=16,
            dropout=0.1
        )

        # Decoder
        # Level 4
        self.up4_0 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att4_0 = AttentionGate(F_g=256, F_l=512, F_int=128)  # F_l corrected to 512
        self.dec4_0 = DoubleConv(256 + 512, 256, dropout_p)      # Merge up4_0 and att4_0

        # Level 3
        self.up3_0 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att3_0 = AttentionGate(F_g=128, F_l=256, F_int=64)   # F_l corrected to 256
        self.dec3_0 = DoubleConv(128 + 256, 128, dropout_p)      # Merge up3_0 and att3_0

        # Level 2
        self.up2_0 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att2_0 = AttentionGate(F_g=64, F_l=128, F_int=32)    # F_l corrected to 128
        self.dec2_0 = DoubleConv(64 + 128, 64, dropout_p)        # Merge up2_0 and att2_0

        # Final Output
        self.final_conv = nn.Conv2d(64 + 64, out_channels, kernel_size=1)  # Merge up1_0 and enc1

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)          # [B,64,H,W]
        pool1 = self.pool1(enc1)     # [B,64,H/2,W/2]

        enc2 = self.enc2(pool1)      # [B,128,H/2,W/2]
        pool2 = self.pool2(enc2)     # [B,128,H/4,W/4]

        enc3 = self.enc3(pool2)      # [B,256,H/4,W/4]
        pool3 = self.pool3(enc3)     # [B,256,H/8,W/8]

        enc4 = self.enc4(pool3)      # [B,512,H/8,W/8]
        pool4 = self.pool4(enc4)     # [B,512,H/16,W/16]

        # Transformer Bottleneck
        bottleneck = self.transformer(pool4)  # [B,512,H/16,W/16]

        # Decoder
        up4_0 = self.up4_0(bottleneck)        # [B,256,H/8,W/8]
        att4_0 = self.att4_0(up4_0, enc4)      # [B,256,H/8,W/8]
        merge4_0 = torch.cat([up4_0, att4_0], dim=1)  # [B,256+256=512,H/8,W/8]
        dec4_0 = self.dec4_0(merge4_0)         # [B,256,H/8,W/8]

        up3_0 = self.up3_0(dec4_0)            # [B,128,H/4,W/4]
        att3_0 = self.att3_0(up3_0, enc3)      # [B,128,H/4,W/4]
        merge3_0 = torch.cat([up3_0, att3_0], dim=1)  # [B,128+128=256,H/4,W/4]
        dec3_0 = self.dec3_0(merge3_0)         # [B,128,H/4,W/4]

        up2_0 = self.up2_0(dec3_0)            # [B,64,H/2,W/2]
        att2_0 = self.att2_0(up2_0, enc2)      # [B,64,H/2,W/2]
        merge2_0 = torch.cat([up2_0, att2_0], dim=1)  # [B,64+64=128,H/2,W/2]
        dec2_0 = self.dec2_0(merge2_0)         # [B,64,H/2,W/2]

        # Final upsampling
        up1_0 = nn.functional.interpolate(dec2_0, scale_factor=2, mode='bilinear', align_corners=True)  # [B,64,H,W]
        merge1_0 = torch.cat([up1_0, enc1], dim=1)  # [B,64+64=128,H,W]
        final = self.final_conv(merge1_0)            # [B,2,H,W]

        return final

# ------------------------------
# 6. Early Stopping Implementation
# ------------------------------

class EarlyStopping:
    """
    Early stops the training if validation loss doesn't improve after a given patience.
    """
    def __init__(self, patience=7, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f"[EarlyStopping] Initial best loss: {self.best_loss:.4f}")
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement. Counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("[EarlyStopping] Early stopping triggered.")
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Validation loss improved to: {self.best_loss:.4f}")

# ------------------------------
# 7. Training and Validation Functions
# ------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device, epoch_idx, scaler, clip_grad_norm=1.0):
    """
    Trains the model for one epoch using Mixed Precision.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the training set.
        optimizer (Optimizer): Optimizer.
        criterion (Loss): Combined loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        scaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision.
        clip_grad_norm (float): Maximum norm for gradient clipping.
    
    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0
    print(f"--- [Train] Starting epoch {epoch_idx+1} ---")
    # Use tqdm to show progress
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Training", leave=False)):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        
        # Gradient Clipping
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
        
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        if batch_idx % 10 == 0:
            print(f"[Train] Batch {batch_idx}, Loss: {loss.item():.4f}")
    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Train] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    return average_loss

def validate_one_epoch(model, loader, criterion, device, epoch_idx, metrics=None):
    """
    Validates the model for one epoch using Mixed Precision.
    
    Args:
        model (nn.Module): The segmentation model.
        loader (DataLoader): DataLoader for the validation set.
        criterion (Loss): Combined loss function.
        device (torch.device): Device to run on.
        epoch_idx (int): Current epoch index.
        metrics (dict, optional): Dictionary of TorchMetrics to compute.
    
    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0
    print(f"--- [Val] Starting epoch {epoch_idx+1} ---")
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Validation", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)
            if batch_idx % 10 == 0:
                print(f"[Val] Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # Compute metrics
            if metrics:
                preds = torch.argmax(outputs, dim=1)
                for metric in metrics.values():
                    metric(preds, masks)

    average_loss = total_loss / len(loader.dataset)
    print(f"--- [Val] Epoch {epoch_idx+1} Average Loss: {average_loss:.4f} ---")
    
    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        metric_values = {}
    
    # Optionally, print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Val] Epoch {epoch_idx+1} Metrics: {metric_str} ---")
    
    return average_loss

# ------------------------------
# 8. Testing and Metrics
# ------------------------------

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Computes the Dice coefficient for binary masks.
    
    Args:
        pred (torch.Tensor): Predicted mask [H, W].
        target (torch.Tensor): Ground truth mask [H, W].
        smooth (float): Smoothing factor to avoid division by zero.
    
    Returns:
        float: Dice coefficient.
    """
    pred_flat = pred.view(-1).float()
    target_flat = target.view(-1).float()
    
    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return dice.item()

def test_segmentation(model, loader, device, metrics=None):
    """
    Evaluates the model on the test set and computes the average Dice coefficient and other metrics.
    
    Args:
        model (nn.Module): The trained segmentation model.
        loader (DataLoader): DataLoader for the test set.
        device (torch.device): Device to run on.
        metrics (dict, optional): Dictionary of TorchMetrics.
    
    Returns:
        dict: Dictionary of average metrics on the test set.
    """
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Testing", leave=False)):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)  # [B, H, W]

            with torch.cuda.amp.autocast():
                outputs = model(images)    # [B, n_classes, H, W]
                preds = torch.argmax(outputs, dim=1)  # [B, H, W]

            # Compute metrics
            if metrics:
                for metric in metrics.values():
                    metric(preds, masks)

    # Compute metric values
    if metrics:
        metric_values = {k: v.compute().item() for k, v in metrics.items()}
        # Reset metrics
        for metric in metrics.values():
            metric.reset()
    else:
        # If metrics are not provided, compute only Dice
        dice_scores = []
        for pred, mask in zip(preds, masks):
            dice = dice_coefficient(pred, mask)
            dice_scores.append(dice)
        metric_values = {
            'dice': np.mean(dice_scores) if dice_scores else 0
        }

    # Print metrics
    if metrics:
        metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metric_values.items()])
        print(f"--- [Test] Metrics: {metric_str} ---")
    else:
        print(f"--- [Test] Dice Coefficient: {metric_values['dice']:.4f} ---")

    return metric_values

# ------------------------------
# 9. Visualization of Predictions
# ------------------------------

def visualize_predictions(model, dataset, device, num_samples=3, post_process_flag=False):
    """
    Visualizes predictions alongside the original images and ground truth masks.
    
    Args:
        model (nn.Module): The trained segmentation model.
        dataset (Dataset): The test dataset.
        device (torch.device): Device to run on.
        num_samples (int): Number of samples to visualize.
        post_process_flag (bool): Whether to apply post-processing to predictions.
    """
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device, non_blocking=True)  # [1, C, H, W]

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                output = model(image_batch)  # [1, n_classes, H, W]
                pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # [H, W]

        # Optional Post-Processing
        if post_process_flag:
            # Example: Simple morphological operations can be implemented here
            # For demonstration, this is left as a placeholder
            pass

        image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
        mask_np = mask.cpu().numpy()  # [H, W]

        # Handle image normalization for visualization
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)

        # Plotting
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis("off")

        axs[1].imshow(mask_np, cmap='gray')
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis("off")

        axs[2].imshow(pred, cmap='gray')
        axs[2].set_title("Predicted Mask")
        axs[2].axis("off")

        plt.tight_layout()
        plt.show()

# ------------------------------
# 10. Main Function
# ------------------------------

def main():
    # ------------------------------
    # Define Your Directories Here
    # ------------------------------
    # Training and Validation directories (split from the same folder)
    images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_train/images/"
    masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_train/masks/"

    # Testing directories (completely separate)
    test_images_dir = "/home/yanghehao/tracklearning/segmentation/phantom_test/images/"
    test_masks_dir  = "/home/yanghehao/tracklearning/segmentation/phantom_test/masks/"

    # Hyperparameters
    batch_size    = 8  # Adjust based on GPU memory
    num_epochs    = 50
    learning_rate = 1e-4
    val_split     = 0.2
    save_path     = "best_model_unetpp_transformer.pth"
    patience      = 2
    post_process_flag = False  # Set to True to apply post-processing if implemented

    # ------------------------------
    # Device Configuration
    # ------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Main] Using device: {device}")

    # ------------------------------
    # Collect Training File Lists
    # ------------------------------
    # List all training image files
    all_train_images = sorted([
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_train_images)} training image files in {images_dir}")

    if len(all_train_images) == 0:
        print("[Error] No training image files found. Check your training path!")
        return

    # Ensure corresponding mask files exist
    all_train_images = sorted([
        f for f in all_train_images
        if os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_train_images)} training images have corresponding masks.")

    if len(all_train_images) == 0:
        print("[Error] No training mask files found or mismatched filenames. Check your training mask path!")
        return

    # List all test image files
    all_test_images = sorted([
        f for f in os.listdir(test_images_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])
    print(f"[Main] Found {len(all_test_images)} test image files in {test_images_dir}")

    if len(all_test_images) == 0:
        print("[Error] No test image files found. Check your test path!")
        return

    # Ensure corresponding test mask files exist
    all_test_images = sorted([
        f for f in all_test_images
        if os.path.isfile(os.path.join(test_masks_dir, os.path.splitext(f)[0] + ".npy"))
    ])
    print(f"[Main] {len(all_test_images)} test images have corresponding masks.")

    if len(all_test_images) == 0:
        print("[Error] No test mask files found or mismatched filenames. Check your test mask path!")
        return

    # ------------------------------
    # Train/Validation Split
    # ------------------------------
    train_files, val_files = train_test_split(
        all_train_images,
        test_size=val_split,
        random_state=42
    )
    print(f"[Main] Training samples: {len(train_files)}")
    print(f"[Main] Validation samples: {len(val_files)}")

    # ------------------------------
    # Create Datasets with Transforms
    # ------------------------------
    train_transform = ResizeAndToTensor(size=(512, 512), augment=True)
    val_transform = ResizeAndToTensor(size=(512, 512), augment=False)
    test_transform = ResizeAndToTensor(size=(512, 512), augment=False)

    train_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=train_files,
        transform=train_transform
    )
    val_dataset = BinarySegDataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        file_list=val_files,
        transform=val_transform
    )
    test_dataset = BinarySegDataset(
        images_dir=test_images_dir,
        masks_dir=test_masks_dir,
        file_list=all_test_images,
        transform=test_transform
    )

    # ------------------------------
    # Create DataLoaders
    # ------------------------------
    num_workers = 8  # Adjust based on CPU cores

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # ------------------------------
    # Initialize Model, Loss, Optimizer, Scheduler
    # ------------------------------
    model = UNetPlusPlus(
        in_channels=3, 
        out_channels=2, 
        dropout_p=0.1
    ).to(device)
    
    # Initialize Combined Loss (Dice + Cross-Entropy)
    # Without class weights
    criterion = CombinedLoss(weight=None)

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    # Scheduler: Cosine Annealing
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    # Initialize Metrics
    metrics_val = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }
    
    metrics_test = {
        'dice': torchmetrics.Dice(num_classes=2, average='macro').to(device),
        'iou': torchmetrics.JaccardIndex(task='binary').to(device),
        'accuracy': torchmetrics.Accuracy(task='binary').to(device)
    }
    
    # Initialize Gradient Scaler for Mixed Precision
    scaler = torch.cuda.amp.GradScaler()

    # ------------------------------
    # Training Loop
    # ------------------------------
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, scaler)
        train_losses.append(train_loss)

        # Validate
        val_loss = validate_one_epoch(model, val_loader, criterion, device, epoch, metrics_val)
        val_losses.append(val_loss)

        # Scheduler step based on epoch
        scheduler.step()

        # Print epoch summary
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR: {optimizer.param_groups[0]['lr']:.6f} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f">> Saved best model with Val Loss: {best_val_loss:.4f}")
            early_stopping.counter = 0  # Reset early stopping counter
        else:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("[Main] Early stopping triggered. Stopping training.")
                break

    print(f"\n[Main] Training complete. Best Validation Loss: {best_val_loss:.4f}")
    print(f"[Main] Best model saved at: {save_path}")

    # Plot Training and Validation Loss
    epochs_range = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs_range, train_losses, 'b', label='Training loss')
    plt.plot(epochs_range, val_losses, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # ------------------------------
    # Testing
    # ------------------------------
    print("\n>>> Loading best model for testing...")
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_metrics = test_segmentation(model, test_loader, device, metrics=metrics_test)
    print(f"Test Metrics: {test_metrics}")

    # ------------------------------
    # Visualization of Test Samples
    # ------------------------------
    if len(test_dataset) > 0:
        print("\n>>> Visualizing predictions on test samples...")
        visualize_predictions(model, test_dataset, device, num_samples=20, post_process_flag=post_process_flag)
    else:
        print("[Warning] No test samples available for visualization.")

# ------------------------------
# 11. Run the Main Function
# ------------------------------

if __name__ == "__main__":
    main()
