In [96]:
# Start with basic imports
import os
import numpy as np
from PIL import Image

# Set environment variable to avoid OpenMP error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
print(torch.__version__)

# Import torchvision components separately
import torchvision
print(torchvision.__version__)
import torchvision.transforms as transforms

# SAHI imports for version 0.11.20
from sahi.slicing import slice_image
from sahi.prediction import PredictionResult

2.7.0.dev20250306
0.22.0.dev20250306


## Define segmentation dataclass

In [97]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if not f.startswith('.')])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if not f.startswith('.')])

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform(image)
        mask = self.transform(mask)

        return image, mask

## Define the UNet model

In [98]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def double_conv(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.dconv_down1 = double_conv(in_channels, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        # Modified maxpool with ceil_mode
        self.maxpool = nn.MaxPool2d(2, ceil_mode=True)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(64 + 128, 64)

        self.conv_last = nn.Conv2d(64, out_channels, 1)

    def crop_tensor(self, target_tensor, tensor_to_crop):
        _, _, H, W = tensor_to_crop.size()
        return target_tensor[:, :, :H, :W]

    def forward(self, x):
        # Encoder path
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        # Bottleneck
        x = self.dconv_down4(x)

        # Decoder path with cropping
        x = self.upsample(x)
        conv3 = self.crop_tensor(conv3, x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)

        x = self.upsample(x)
        conv2 = self.crop_tensor(conv2, x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)

        x = self.upsample(x)
        conv1 = self.crop_tensor(conv1, x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)

        out = self.conv_last(x)
        return out

## Define class for sahi segmentation

In [125]:
class SAHISegmentationAdapter(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # Get all image files
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png', '.tif', '.tiff'))])

        # Get corresponding mask files
        self.mask_files = []
        for img_file in self.image_files:
            # Adjust this logic based on how your mask filenames correspond to image filenames
            mask_base = os.path.splitext(img_file)[0]
            mask_candidates = [
                f"{mask_base}.png",
                f"{mask_base}.jpg",
                f"{mask_base}.tif",
                f"{mask_base}_mask.png",
                f"{mask_base}_mask.jpg",
                # Add other potential mask filename patterns
            ]

            found = False
            for mask_file in mask_candidates:
                if os.path.exists(os.path.join(mask_dir, mask_file)):
                    self.mask_files.append(mask_file)
                    found = True
                    break

            if not found:
                print(f"Warning: No mask found for image {img_file}")
                # You can either skip this image or use a blank mask
                # For now, we'll keep the image and create a warning

        # Verify we have the same number of images and masks
        assert len(self.image_files) == len(self.mask_files), "Number of images and masks don't match"

        # Print dataset info
        print(f"Found {len(self.image_files)} image-mask pairs")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        # Load mask
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        mask = Image.open(mask_path).convert("L")  # Convert to grayscale

        # Convert to tensors
        image_tensor = transforms.ToTensor()(image)
        mask_tensor = transforms.ToTensor()(mask)

        # Check if dimensions match
        if image_tensor.shape[1:] != mask_tensor.shape[1:]:
            print(f"Fixing dimension mismatch in item {idx}: Image {image_tensor.shape} vs Mask {mask_tensor.shape}")

            # Resize mask to match image dimensions
            mask_tensor = F.interpolate(
                mask_tensor.unsqueeze(0),  # Add batch dimension
                size=(image_tensor.shape[1], image_tensor.shape[2]),
                mode='nearest'
            ).squeeze(0)  # Remove batch dimension

        # Apply additional transformations if provided
        if self.transform:
            image_tensor, mask_tensor = self.transform(image_tensor, mask_tensor)

        return image_tensor, mask_tensor


In [100]:
# Define a function to properly convert tensors to numpy arrays for SAHI
def prepare_for_sahi(tensor):
    """
    Convert a PyTorch tensor to a numpy array format that SAHI can process

    Args:
        tensor: PyTorch tensor of shape [C, H, W]

    Returns:
        numpy array of shape [H, W, C] with values in range 0-255 (uint8)
    """
    # Move tensor to CPU if it's on another device
    tensor = tensor.cpu()

    # Convert from [C, H, W] to [H, W, C]
    np_array = tensor.permute(1, 2, 0).numpy()

    # If it's a single channel image, remove the channel dimension or repeat to 3 channels
    if np_array.shape[2] == 1:
        np_array = np.repeat(np_array, 3, axis=2)  # Repeat to 3 channels

    # Ensure values are in range 0-255 for uint8
    if np_array.dtype == np.float32 or np_array.dtype == np.float64:
        if np_array.max() <= 1.0:
            np_array = (np_array * 255).astype(np.uint8)
        else:
            np_array = np_array.astype(np.uint8)

    return np_array


In [101]:
def process_image_with_sahi(model, image, slice_height=512, slice_width=512, overlap_ratio=0.2):
    """
    Process a large image using SAHI slicing technique

    Args:
        model: Your UNET model wrapped in SAHISegmentationAdapter
        image: Input image as numpy array (H, W, C)
        slice_height, slice_width: Size of slices
        overlap_ratio: Overlap between adjacent slices

    Returns:
        Full-sized segmentation mask
    """
    # Get image dimensions
    height, width = image.shape[:2]

    # Create empty mask for the full image
    full_mask = np.zeros((height, width), dtype=np.uint8)

    # Create weight map for blending overlapping regions
    weight_map = np.zeros((height, width), dtype=np.float32)

    # Slice the image
    slices = slice_image(
        image=image,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_ratio,
        overlap_width_ratio=overlap_ratio
    )

    # Process each slice
    for slice_data in slices:
        # Get the slice image
        slice_image_data = slice_data["image"]

        # Get coordinates
        x_min, y_min, x_max, y_max = slice_data["coordinates"]

        # Predict on this slice
        slice_mask = model.predict_single_image(slice_image_data)

        # Create weight matrix for this slice (higher in the center, lower at the edges)
        h, w = y_max - y_min, x_max - x_min
        y, x = np.ogrid[:h, :w]
        y_center, x_center = h / 2, w / 2

        # Weight falls off with distance from center (gaussian-like)
        weight = np.exp(-((x - x_center)**2 / (w/2)**2 + (y - y_center)**2 / (h/2)**2))

        # Add weighted mask to full mask
        full_mask[y_min:y_max, x_min:x_max] += (slice_mask * weight).astype(np.uint8)
        weight_map[y_min:y_max, x_min:x_max] += weight

    # Normalize by weight map to blend overlapping regions
    # Avoid division by zero
    weight_map = np.maximum(weight_map, 1e-10)
    full_mask = (full_mask / weight_map).astype(np.uint8)

    # Threshold to get binary mask if needed
    full_mask = (full_mask > 0.5).astype(np.uint8)

    return full_mask

## Define a collate function with padding

In [143]:
def collate_fn(batch):
    images = [item[0] for item in batch]
    masks = [item[1] for item in batch]

    # First, ensure masks match their corresponding images in dimensions
    aligned_masks = []
    for i, (img, msk) in enumerate(zip(images, masks)):
        # Check if dimensions don't match
        if img.shape[1:] != msk.shape[1:]:
            print(f"Item {i}: Fixing dimension mismatch: Image {img.shape} vs Mask {msk.shape}")
            # Resize mask to match image dimensions
            c, h, w = msk.shape
            img_h, img_w = img.shape[1:]

            # Use interpolate to resize
            resized_mask = F.interpolate(
                msk.unsqueeze(0),  # Add batch dimension
                size=(img_h, img_w),
                mode='nearest'
            ).squeeze(0)  # Remove batch dimension

            aligned_masks.append(resized_mask)
        else:
            aligned_masks.append(msk)

    # Now use the aligned masks
    masks = aligned_masks

    # Print all dimensions for debugging
    for i, (img, msk) in enumerate(zip(images, masks)):
        print(f"Item {i} after alignment: Image {img.shape}, Mask {msk.shape}")

    # Find max dimensions in the batch
    # Get the current batch's dimensions
    batch_heights = [img.shape[1] for img in images]
    batch_widths = [img.shape[2] for img in images]

    max_h = max(batch_heights)
    max_w = max(batch_widths)

    print(f"Max dimensions in batch: Height={max_h}, Width={max_w}")

    # Set target dimensions to be divisible by 32 (common for U-Net architectures)
    target_height = ((max_h + 31) // 32) * 32
    target_width = ((max_w + 31) // 32) * 32

    print(f"Target dimensions: Height={target_height}, Width={target_width}")

    # Check if any image is larger than the target
    for i, img in enumerate(images):
        if img.shape[1] > target_height or img.shape[2] > target_width:
            print(f"Warning: Image {i} with shape {img.shape} is larger than target ({target_height}, {target_width})")

    def process_to_target_size(tensor, target_height, target_width, mode='constant'):
        """Resize or pad tensor to target size"""
        # Handle both 2D and 3D tensors
        if tensor.dim() == 2:
            # For 2D tensor (H, W), add channel dimension
            tensor = tensor.unsqueeze(0)  # Convert to (1, H, W)

        # Now tensor should be 3D (C, H, W)
        c, h, w = tensor.shape

        # If image is larger than target in any dimension, resize it down
        if h > target_height or w > target_width:
            print(f"Resizing tensor from {tensor.shape} to fit within ({target_height}, {target_width})")
            # Resize down maintaining aspect ratio
            scale = min(target_height / h, target_width / w)
            new_h, new_w = int(h * scale), int(w * scale)
            resized = F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), 
                                   mode='bilinear' if c == 3 else 'nearest')
            tensor = resized.squeeze(0)
            c, h, w = tensor.shape

        # Now pad to exact target size
        pad_h = target_height - h
        pad_w = target_width - w

        if pad_h < 0:
            print(f"Error: Negative height padding {pad_h} for tensor of shape {tensor.shape}")
            # Force resize instead of padding
            tensor = F.interpolate(tensor.unsqueeze(0), size=(target_height, w), 
                                  mode='bilinear' if c == 3 else 'nearest').squeeze(0)
            pad_h = 0

        if pad_w < 0:
            print(f"Error: Negative width padding {pad_w} for tensor of shape {tensor.shape}")
            # Force resize instead of padding
            tensor = F.interpolate(tensor.unsqueeze(0), size=(h, target_width), 
                                  mode='bilinear' if c == 3 else 'nearest').squeeze(0)
            pad_w = 0

        print(f"Padding tensor from {tensor.shape} with padding (0, {pad_w}, 0, {pad_h})")
        padded_tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode=mode)
        return padded_tensor

    # Process images and masks
    processed_images = []
    processed_masks = []

    for i, (img, msk) in enumerate(zip(images, masks)):
        print(f"Processing item {i} - Image: {img.shape}, Mask: {msk.shape}")

        try:
            processed_img = process_to_target_size(img, target_height, target_width, mode='reflect')
            processed_images.append(processed_img)

            # Handle mask based on its dimensionality
            processed_mask = process_to_target_size(msk, target_height, target_width, mode='constant')
            processed_masks.append(processed_mask)

            print(f"Successfully processed item {i} - Image: {processed_img.shape}, Mask: {processed_mask.shape}")
        except Exception as e:
            print(f"Error processing item {i}: {e}")
            # Skip this item if there's an error
            continue

    # If we had to skip items, make sure we still have something to return
    if len(processed_images) == 0:
        raise RuntimeError("All items in batch were skipped due to processing errors")

    # Stack into batches
    batched_images = torch.stack(processed_images)
    batched_masks = torch.stack(processed_masks)

    print(f"Final batched shapes - Images: {batched_images.shape}, Masks: {batched_masks.shape}")
    return batched_images, batched_masks


### 0 reduction update

## upload train/val sets

In [145]:
# Initialize datasets
train_dataset = SegmentationDataset(
    image_dir="/Users/aja294/Documents/Hemp_local/projects/pytorch_train_leaf_morpho/data/train/images",
    mask_dir="/Users/aja294/Documents/Hemp_local/projects/pytorch_train_leaf_morpho/data/train/masks"
)

In [146]:
val_dataset = SegmentationDataset(
    image_dir="/Users/aja294/Documents/Hemp_local/projects/pytorch_train_leaf_morpho/data/val/images",
    mask_dir="/Users/aja294/Documents/Hemp_local/projects/pytorch_train_leaf_morpho/data/val/masks"
)

## Data loader setup


In [147]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

### With CPU

In [106]:
device = torch.device('cpu')
print(f"Using device: {device}")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
device

Using device: cpu


device(type='cpu')

### With Metal/Cuda

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

device

Using device: mps


device(type='mps')

### Run Training

###  Updated model train

In [108]:
num_epochs = 10
slice_height = 512
slice_width = 512
overlap_ratio = 0.2

for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")
    model.train()
    train_loss = 0

    for batch_idx, (images, masks) in enumerate(train_loader):
        print(f"Processing training batch {batch_idx+1}/{len(train_loader)}")
        batch_loss = 0

        # Process each image in the batch
        for i in range(images.shape[0]):
            # Get single image and mask
            image = images[i]
            mask = masks[i]

            # Convert to SAHI-compatible format
            image_np = prepare_for_sahi(image)
            mask_np = prepare_for_sahi(mask.unsqueeze(0) if mask.dim() == 2 else mask)

            print(f"Converted image shape: {image_np.shape}, mask shape: {mask_np.shape}")

            # Slice the image
            image_slices = slice_image(
                image=image_np,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_height_ratio=overlap_ratio,
                overlap_width_ratio=overlap_ratio
            )

            # Process each slice
            slice_losses = []
            optimizer.zero_grad()

            for slice_data in image_slices:
                # Get the slice image and coordinates
                slice_image_data = slice_data["image"]
                x_min, y_min, x_max, y_max = slice_data["coordinates"]

                # Extract corresponding mask slice
                slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                # Convert back to tensors
                slice_image_tensor = torch.from_numpy(
                    slice_image_data.transpose(2, 0, 1)
                ).float().div(255.0).unsqueeze(0).to(device)

                slice_mask_tensor = torch.from_numpy(
                    slice_mask_data[:, :, 0:1].transpose(2, 0, 1)
                ).float().div(255.0).unsqueeze(0).to(device)

                # Forward pass
                slice_output = model(slice_image_tensor)

                # Calculate loss
                loss = F.binary_cross_entropy_with_logits(
                    slice_output, 
                    slice_mask_tensor,
                    reduction='mean'
                )

                # Backward pass
                loss.backward()
                slice_losses.append(loss.item())

            # Update weights after processing all slices
            optimizer.step()

            # Calculate average loss for this image
            if slice_losses:
                image_loss = sum(slice_losses) / len(slice_losses)
                batch_loss += image_loss
                print(f"Image {i+1} loss: {image_loss:.4f}")

        # Average loss for the batch
        batch_loss /= images.shape[0]
        train_loss += batch_loss
        print(f"Batch {batch_idx+1} average loss: {batch_loss:.4f}")

    # Calculate epoch training loss
    train_loss /= len(train_loader)

    # Validation
    print(f"Starting validation epoch {epoch+1}")
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):
            print(f"Processing validation batch {batch_idx+1}/{len(val_loader)}")
            batch_loss = 0

            # Process each image in the batch
            for i in range(images.shape[0]):
                # Get single image and mask
                image = images[i]
                mask = masks[i]

                # Convert to SAHI-compatible format
                image_np = prepare_for_sahi(image)
                mask_np = prepare_for_sahi(mask.unsqueeze(0) if mask.dim() == 2 else mask)

                # Slice the image
                image_slices = slice_image(
                    image=image_np,
                    slice_height=slice_height,
                    slice_width=slice_width,
                    overlap_height_ratio=overlap_ratio,
                    overlap_width_ratio=overlap_ratio
                )

                # Process each slice
                slice_losses = []

                for slice_data in image_slices:
                    # Get the slice image and coordinates
                    slice_image_data = slice_data["image"]
                    x_min, y_min, x_max, y_max = slice_data["coordinates"]

                    # Extract corresponding mask slice
                    slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                    # Convert back to tensors
                    slice_image_tensor = torch.from_numpy(
                        slice_image_data.transpose(2, 0, 1)
                    ).float().div(255.0).unsqueeze(0).to(device)

                    slice_mask_tensor = torch.from_numpy(
                        slice_mask_data[:, :, 0:1].transpose(2, 0, 1)
                    ).float().div(255.0).unsqueeze(0).to(device)

                    # Forward pass
                    slice_output = model(slice_image_tensor)

                    # Calculate loss
                    loss = F.binary_cross_entropy_with_logits(
                        slice_output, 
                        slice_mask_tensor,
                        reduction='mean'
                    )

                    slice_losses.append(loss.item())

                # Calculate average loss for this image
                if slice_losses:
                    image_loss = sum(slice_losses) / len(slice_losses)
                    batch_loss += image_loss

            # Average loss for the batch
            batch_loss /= images.shape[0]
            val_loss += batch_loss
            print(f"Validation batch {batch_idx+1} average loss: {batch_loss:.4f}")

        # Calculate epoch validation loss
        val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, f'model_checkpoint_epoch_{epoch+1}.pth')

Starting training epoch 1
Processing - Image shape: torch.Size([3, 2197, 2573]), Mask shape: torch.Size([1, 2197, 2573])
Processing - Image shape: torch.Size([3, 2443, 2581]), Mask shape: torch.Size([1, 2443, 2581])
Processing - Image shape: torch.Size([3, 1963, 2016]), Mask shape: torch.Size([1, 1963, 2016])
Processing - Image shape: torch.Size([3, 2165, 2458]), Mask shape: torch.Size([1, 2165, 2458])
Batched shapes - Images: torch.Size([4, 3, 2464, 2592]), Masks: torch.Size([4, 1, 2464, 2592])
Processing training batch 1/5
Converted image shape: (2464, 2592, 3), mask shape: (2464, 2592, 3)


KeyError: 'coordinates'

In [None]:
def custom_slice_image(image_np, slice_height, slice_width, overlap_ratio):
    """
    Custom function to slice a numpy image array into smaller patches with overlap

    Args:
        image_np: Numpy array of shape [H, W, C]
        slice_height, slice_width: Size of slices
        overlap_ratio: Overlap between slices (0-1)

    Process:
        Groups images into batches of 4
        Aligns dimensions within each batch by padding to a common size
        Prepares each batch for SAHI slicing
        Processes 5 batches in total per epoch (based on "Processing training batch 3/5")

    Returns:
        List of dictionaries with 'image' and 'coordinates' keys
    """
    height, width = image_np.shape[:2]
    stride_h = int(slice_height * (1 - overlap_ratio))
    stride_w = int(slice_width * (1 - overlap_ratio))

    slices = []
    for y in range(0, height, stride_h):
        for x in range(0, width, stride_w):
            # Calculate slice coordinates
            x_min = x
            y_min = y
            x_max = min(x + slice_width, width)
            y_max = min(y + slice_height, height)

            # Handle edge cases - ensure slices are of size (slice_height, slice_width) when possible
            if x_max - x_min < slice_width and x_min > 0:
                x_min = max(0, x_max - slice_width)
            if y_max - y_min < slice_height and y_min > 0:
                y_min = max(0, y_max - slice_height)

            # Extract slice
            slice_image = image_np[y_min:y_max, x_min:x_max, :]

            # Create slice data in the expected format
            slice_data = {
                "image": slice_image,
                "coordinates": (x_min, y_min, x_max, y_max)
            }

            slices.append(slice_data)

    return slices


## Updated training loop THIS ONE WORKS 

In [150]:
num_epochs = 10
slice_height = 512
slice_width = 512
overlap_ratio = 0.2
print(device)

for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")
    model.train()
    train_loss = 0

    for batch_idx, (images, masks) in enumerate(train_loader):
        print(f"Processing training batch {batch_idx+1}/{len(train_loader)}")
        batch_loss = 0

        # Process each image in the batch
        for i in range(images.shape[0]):
            # Get single image and mask
            image = images[i]
            mask = masks[i]

            # Convert to numpy arrays
            image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
            mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

            print(f"Converted image shape: {image_np.shape}, mask shape: {mask_np.shape}")

            # Slice the image using our custom function
            image_slices = custom_slice_image(
                image_np=image_np,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_ratio=overlap_ratio
            )

            # Process each slice
            slice_losses = []
            optimizer.zero_grad()

            for slice_data in image_slices:
                # Get the slice image and coordinates
                slice_image_data = slice_data["image"]
                x_min, y_min, x_max, y_max = slice_data["coordinates"]

                # Extract corresponding mask slice
                slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                # Convert back to tensors
                slice_image_tensor = torch.from_numpy(
                    slice_image_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                slice_mask_tensor = torch.from_numpy(
                    slice_mask_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                # Forward pass
                slice_output = model(slice_image_tensor)

                # Calculate loss
                loss = F.binary_cross_entropy_with_logits(
                    slice_output, 
                    slice_mask_tensor,
                    reduction='mean'
                )

                # Backward pass
                loss.backward()
                slice_losses.append(loss.item())

            # Update weights after processing all slices
            optimizer.step()

            # Calculate average loss for this image
            if slice_losses:
                image_loss = sum(slice_losses) / len(slice_losses)
                batch_loss += image_loss
                print(f"Image {i+1} loss: {image_loss:.4f}")

        # Average loss for the batch
        batch_loss /= images.shape[0]
        train_loss += batch_loss
        print(f"Batch {batch_idx+1} average loss: {batch_loss:.4f}")

    # Calculate epoch training loss
    train_loss /= len(train_loader)

    # Validation
    print(f"Starting validation epoch {epoch+1}")
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):
            print(f"Processing validation batch {batch_idx+1}/{len(val_loader)}")
            batch_loss = 0

            # Process each image in the batch
            for i in range(images.shape[0]):
                # Get single image and mask
                image = images[i]
                mask = masks[i]

                # Convert to numpy arrays
                image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
                mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

                # Slice the image using our custom function
                image_slices = custom_slice_image(
                    image_np=image_np,
                    slice_height=slice_height,
                    slice_width=slice_width,
                    overlap_ratio=overlap_ratio
                )

                # Process each slice
                slice_losses = []

                for slice_data in image_slices:
                    # Get the slice image and coordinates
                    slice_image_data = slice_data["image"]
                    x_min, y_min, x_max, y_max = slice_data["coordinates"]

                    # Extract corresponding mask slice
                    slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                    # Convert back to tensors
                    slice_image_tensor = torch.from_numpy(
                        slice_image_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    slice_mask_tensor = torch.from_numpy(
                        slice_mask_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    # Forward pass
                    slice_output = model(slice_image_tensor)

                    # Calculate loss
                    loss = F.binary_cross_entropy_with_logits(
                        slice_output, 
                        slice_mask_tensor,
                        reduction='mean'
                    )

                    slice_losses.append(loss.item())

                # Calculate average loss for this image
                if slice_losses:
                    image_loss = sum(slice_losses) / len(slice_losses)
                    batch_loss += image_loss
                    print(f"Validation image {i+1} loss: {image_loss:.4f}")

            # Average loss for the batch
            batch_loss /= images.shape[0]
            val_loss += batch_loss
            print(f"Validation batch {batch_idx+1} average loss: {batch_loss:.4f}")

        # Calculate epoch validation loss
        val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, f'model_checkpoint_epoch_{epoch+1}.pth')


mps
Starting training epoch 1
Item 0 after alignment: Image torch.Size([3, 2199, 2639]), Mask torch.Size([1, 2199, 2639])
Item 1 after alignment: Image torch.Size([3, 2197, 2573]), Mask torch.Size([1, 2197, 2573])
Item 2 after alignment: Image torch.Size([3, 2525, 2908]), Mask torch.Size([1, 2525, 2908])
Item 3 after alignment: Image torch.Size([3, 2443, 2581]), Mask torch.Size([1, 2443, 2581])
Max dimensions in batch: Height=2525, Width=2908
Target dimensions: Height=2528, Width=2912
Processing item 0 - Image: torch.Size([3, 2199, 2639]), Mask: torch.Size([1, 2199, 2639])
Padding tensor from torch.Size([3, 2199, 2639]) with padding (0, 273, 0, 329)
Padding tensor from torch.Size([1, 2199, 2639]) with padding (0, 273, 0, 329)
Successfully processed item 0 - Image: torch.Size([3, 2528, 2912]), Mask: torch.Size([1, 2528, 2912])
Processing item 1 - Image: torch.Size([3, 2197, 2573]), Mask: torch.Size([1, 2197, 2573])
Padding tensor from torch.Size([3, 2197, 2573]) with padding (0, 339, 0,

In [152]:
num_epochs = 10
slice_height = 512
slice_width = 512
overlap_ratio = 0.2

for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")
    model.train()
    train_loss = 0

    for batch_idx, (images, masks) in enumerate(train_loader):
        print(f"Processing training batch {batch_idx+1}/{len(train_loader)}")
        batch_loss = 0

        # Process each image in the batch
        for i in range(images.shape[0]):
            # Get single image and mask
            image = images[i]
            mask = masks[i]

            # Convert to numpy arrays
            image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
            mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

            # Slice the image using our custom function
            image_slices = custom_slice_image(
                image_np=image_np,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_ratio=overlap_ratio
            )

            # Process each slice
            slice_losses = []
            optimizer.zero_grad()

            for slice_data in image_slices:
                # Get the slice image and coordinates
                slice_image_data = slice_data["image"]
                x_min, y_min, x_max, y_max = slice_data["coordinates"]

                # Extract corresponding mask slice
                slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                # Convert back to tensors
                slice_image_tensor = torch.from_numpy(
                    slice_image_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                slice_mask_tensor = torch.from_numpy(
                    slice_mask_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                # Create a mask for valid (non-padded) regions
                # For slices, all pixels are valid since we're taking exact slices
                valid_mask = torch.ones_like(slice_mask_tensor)

                # Forward pass
                slice_output = model(slice_image_tensor)

                # Calculate loss with masking
                # Option 1: Using reduction='none' and manual masking
                loss = F.binary_cross_entropy_with_logits(
                    slice_output, 
                    slice_mask_tensor,
                    reduction='none'
                )
                # Apply mask and calculate mean over valid pixels
                masked_loss = (loss * valid_mask).sum() / valid_mask.sum()

                # Option 2: Using weight parameter (alternative)
                # loss = F.binary_cross_entropy_with_logits(
                #     slice_output, 
                #     slice_mask_tensor,
                #     weight=valid_mask,  # Weight parameter acts as our mask
                #     reduction='mean'
                # )

                # Backward pass
                masked_loss.backward()
                slice_losses.append(masked_loss.item())

            # Update weights after processing all slices
            optimizer.step()

            # Calculate average loss for this image
            if slice_losses:
                image_loss = sum(slice_losses) / len(slice_losses)
                batch_loss += image_loss
                print(f"Image {i+1} loss: {image_loss:.4f}")

        # Average loss for the batch
        batch_loss /= images.shape[0]
        train_loss += batch_loss
        print(f"Batch {batch_idx+1} average loss: {batch_loss:.4f}")

    # Calculate epoch training loss
    train_loss /= len(train_loader)

    # Validation
    print(f"Starting validation epoch {epoch+1}")
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):
            print(f"Processing validation batch {batch_idx+1}/{len(val_loader)}")
            batch_loss = 0

            # Process each image in the batch
            for i in range(images.shape[0]):
                # Get single image and mask
                image = images[i]
                mask = masks[i]

                # Convert to numpy arrays
                image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
                mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

                # Slice the image using our custom function
                image_slices = custom_slice_image(
                    image_np=image_np,
                    slice_height=slice_height,
                    slice_width=slice_width,
                    overlap_ratio=overlap_ratio
                )

                # Process each slice
                slice_losses = []

                for slice_data in image_slices:
                    # Get the slice image and coordinates
                    slice_image_data = slice_data["image"]
                    x_min, y_min, x_max, y_max = slice_data["coordinates"]

                    # Extract corresponding mask slice
                    slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                    # Convert back to tensors
                    slice_image_tensor = torch.from_numpy(
                        slice_image_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    slice_mask_tensor = torch.from_numpy(
                        slice_mask_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    # Create a mask for valid (non-padded) regions
                    valid_mask = torch.ones_like(slice_mask_tensor)

                    # Forward pass
                    slice_output = model(slice_image_tensor)

                    # Calculate loss with masking
                    loss = F.binary_cross_entropy_with_logits(
                        slice_output, 
                        slice_mask_tensor,
                        reduction='none'
                    )
                    # Apply mask and calculate mean over valid pixels
                    masked_loss = (loss * valid_mask).sum() / valid_mask.sum()

                    slice_losses.append(masked_loss.item())

                # Calculate average loss for this image
                if slice_losses:
                    image_loss = sum(slice_losses) / len(slice_losses)
                    batch_loss += image_loss
                    print(f"Validation image {i+1} loss: {image_loss:.4f}")

            # Average loss for the batch
            batch_loss /= images.shape[0]
            val_loss += batch_loss
            print(f"Validation batch {batch_idx+1} average loss: {batch_loss:.4f}")

        # Calculate epoch validation loss
        val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, f'model_checkpoint_epoch_{epoch+1}.pth')


Starting training epoch 1
Item 0 after alignment: Image torch.Size([3, 1870, 2232]), Mask torch.Size([1, 1870, 2232])
Item 1 after alignment: Image torch.Size([3, 2197, 2573]), Mask torch.Size([1, 2197, 2573])
Item 2 after alignment: Image torch.Size([3, 2042, 2337]), Mask torch.Size([1, 2042, 2337])
Item 3 after alignment: Image torch.Size([3, 1743, 2050]), Mask torch.Size([1, 1743, 2050])
Max dimensions in batch: Height=2197, Width=2573
Target dimensions: Height=2208, Width=2592
Processing item 0 - Image: torch.Size([3, 1870, 2232]), Mask: torch.Size([1, 1870, 2232])
Padding tensor from torch.Size([3, 1870, 2232]) with padding (0, 360, 0, 338)
Padding tensor from torch.Size([1, 1870, 2232]) with padding (0, 360, 0, 338)
Successfully processed item 0 - Image: torch.Size([3, 2208, 2592]), Mask: torch.Size([1, 2208, 2592])
Processing item 1 - Image: torch.Size([3, 2197, 2573]), Mask: torch.Size([1, 2197, 2573])
Padding tensor from torch.Size([3, 2197, 2573]) with padding (0, 19, 0, 11)


KeyboardInterrupt: 

## Attempt to mask the loss reduction and only focus on unpadded, original images

In [151]:
num_epochs = 10
slice_height = 512
slice_width = 512
overlap_ratio = 0.2

for epoch in range(num_epochs):
    print(f"Starting training epoch {epoch+1}")
    model.train()
    train_loss = 0
    total_pixels = 0
    valid_pixels = 0

    for batch_idx, (images, masks) in enumerate(train_loader):
        print(f"Processing training batch {batch_idx+1}/{len(train_loader)}")
        batch_loss = 0
        batch_valid_pixels = 0

        # Process each image in the batch
        for i in range(images.shape[0]):
            # Get single image and mask
            image = images[i]
            mask = masks[i]

            # Convert to numpy arrays
            image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
            mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

            # Slice the image using our custom function
            image_slices = custom_slice_image(
                image_np=image_np,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_ratio=overlap_ratio
            )

            # Process each slice
            slice_losses = []
            slice_valid_pixels = 0
            optimizer.zero_grad()

            for slice_data in image_slices:
                # Get the slice image and coordinates
                slice_image_data = slice_data["image"]
                x_min, y_min, x_max, y_max = slice_data["coordinates"]

                # Extract corresponding mask slice
                slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                # Convert back to tensors
                slice_image_tensor = torch.from_numpy(
                    slice_image_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                slice_mask_tensor = torch.from_numpy(
                    slice_mask_data.transpose(2, 0, 1)
                ).float().unsqueeze(0).to(device)  # Add batch dimension

                # Create a validity mask for the original content (non-padded areas)
                # For slices that come from the interior of the image, all pixels are valid
                # For edge slices, we need to check if they were padded

                # By default, all pixels in the slice are valid
                validity_mask = torch.ones_like(slice_mask_tensor)

                # If this is an edge slice (right or bottom edge of the original image),
                # mark padded areas as invalid (0)
                if x_max == image_np.shape[1] and x_max - x_min < slice_width:
                    # This slice includes padding on the right
                    valid_width = image_np.shape[1] - x_min
                    validity_mask[:, :, :, valid_width:] = 0

                if y_max == image_np.shape[0] and y_max - y_min < slice_height:
                    # This slice includes padding on the bottom
                    valid_height = image_np.shape[0] - y_min
                    validity_mask[:, :, valid_height:, :] = 0

                # Forward pass
                slice_output = model(slice_image_tensor)

                # Calculate masked loss
                # Use reduction='none' to get per-pixel losses
                per_pixel_loss = F.binary_cross_entropy_with_logits(
                    slice_output, 
                    slice_mask_tensor,
                    reduction='none'
                )

                # Apply validity mask to the loss
                masked_loss = per_pixel_loss * validity_mask

                # Count valid pixels for normalization
                num_valid_pixels = validity_mask.sum().item()
                slice_valid_pixels += num_valid_pixels

                # Normalize by the number of valid pixels
                if num_valid_pixels > 0:
                    loss = masked_loss.sum() / num_valid_pixels
                    loss.backward()
                    slice_losses.append(loss.item())

            # Update weights after processing all slices
            optimizer.step()

            # Calculate average loss for this image
            if slice_losses and slice_valid_pixels > 0:
                image_loss = sum(slice_losses) / len(slice_losses)
                batch_loss += image_loss * slice_valid_pixels
                batch_valid_pixels += slice_valid_pixels
                print(f"Image {i+1} loss: {image_loss:.4f} (valid pixels: {slice_valid_pixels})")

        # Average loss for the batch, weighted by valid pixels
        if batch_valid_pixels > 0:
            batch_loss /= batch_valid_pixels
            train_loss += batch_loss * batch_valid_pixels
            total_pixels += batch_valid_pixels
            valid_pixels += batch_valid_pixels
            print(f"Batch {batch_idx+1} average loss: {batch_loss:.4f} (valid pixels: {batch_valid_pixels})")

    # Calculate epoch training loss, weighted by valid pixels
    if total_pixels > 0:
        train_loss /= total_pixels

    # Validation
    print(f"Starting validation epoch {epoch+1}")
    model.eval()
    val_loss = 0
    val_total_pixels = 0
    val_valid_pixels = 0

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):
            print(f"Processing validation batch {batch_idx+1}/{len(val_loader)}")
            batch_loss = 0
            batch_valid_pixels = 0

            # Process each image in the batch
            for i in range(images.shape[0]):
                # Get single image and mask
                image = images[i]
                mask = masks[i]

                # Convert to numpy arrays
                image_np = image.permute(1, 2, 0).cpu().numpy()  # [H, W, C]
                mask_np = mask.permute(1, 2, 0).cpu().numpy()    # [H, W, C]

                # Slice the image using our custom function
                image_slices = custom_slice_image(
                    image_np=image_np,
                    slice_height=slice_height,
                    slice_width=slice_width,
                    overlap_ratio=overlap_ratio
                )

                # Process each slice
                slice_losses = []
                slice_valid_pixels = 0

                for slice_data in image_slices:
                    # Get the slice image and coordinates
                    slice_image_data = slice_data["image"]
                    x_min, y_min, x_max, y_max = slice_data["coordinates"]

                    # Extract corresponding mask slice
                    slice_mask_data = mask_np[y_min:y_max, x_min:x_max, :]

                    # Convert back to tensors
                    slice_image_tensor = torch.from_numpy(
                        slice_image_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    slice_mask_tensor = torch.from_numpy(
                        slice_mask_data.transpose(2, 0, 1)
                    ).float().unsqueeze(0).to(device)  # Add batch dimension

                    # Create validity mask for the original content
                    validity_mask = torch.ones_like(slice_mask_tensor)

                    # Mark padded areas as invalid
                    if x_max == image_np.shape[1] and x_max - x_min < slice_width:
                        valid_width = image_np.shape[1] - x_min
                        validity_mask[:, :, :, valid_width:] = 0

                    if y_max == image_np.shape[0] and y_max - y_min < slice_height:
                        valid_height = image_np.shape[0] - y_min
                        validity_mask[:, :, valid_height:, :] = 0

                    # Forward pass
                    slice_output = model(slice_image_tensor)

                    # Calculate masked loss
                    per_pixel_loss = F.binary_cross_entropy_with_logits(
                        slice_output, 
                        slice_mask_tensor,
                        reduction='none'
                    )

                    # Apply validity mask
                    masked_loss = per_pixel_loss * validity_mask

                    # Count valid pixels
                    num_valid_pixels = validity_mask.sum().item()
                    slice_valid_pixels += num_valid_pixels

                    # Normalize by valid pixels
                    if num_valid_pixels > 0:
                        loss = masked_loss.sum() / num_valid_pixels
                        slice_losses.append(loss.item())

                # Calculate average loss for this image
                if slice_losses and slice_valid_pixels > 0:
                    image_loss = sum(slice_losses) / len(slice_losses)
                    batch_loss += image_loss * slice_valid_pixels
                    batch_valid_pixels += slice_valid_pixels

            # Average loss for the batch, weighted by valid pixels
            if batch_valid_pixels > 0:
                batch_loss /= batch_valid_pixels
                val_loss += batch_loss * batch_valid_pixels
                val_total_pixels += batch_valid_pixels
                val_valid_pixels += batch_valid_pixels
                print(f"Validation batch {batch_idx+1} average loss: {batch_loss:.4f} (valid pixels: {batch_valid_pixels})")

        # Calculate epoch validation loss, weighted by valid pixels
        if val_total_pixels > 0:
            val_loss /= val_total_pixels

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f} (valid: {valid_pixels}/{total_pixels}), Val Loss: {val_loss:.4f} (valid: {val_valid_pixels}/{val_total_pixels})")

    # Save model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, f'model_checkpoint_epoch_{epoch+1}.pth')


Starting training epoch 1
Item 0 after alignment: Image torch.Size([3, 2525, 2908]), Mask torch.Size([1, 2525, 2908])
Item 1 after alignment: Image torch.Size([3, 2197, 2302]), Mask torch.Size([1, 2197, 2302])
Item 2 after alignment: Image torch.Size([3, 2350, 2472]), Mask torch.Size([1, 2350, 2472])
Item 3 after alignment: Image torch.Size([3, 1679, 1972]), Mask torch.Size([1, 1679, 1972])
Max dimensions in batch: Height=2525, Width=2908
Target dimensions: Height=2528, Width=2912
Processing item 0 - Image: torch.Size([3, 2525, 2908]), Mask: torch.Size([1, 2525, 2908])
Padding tensor from torch.Size([3, 2525, 2908]) with padding (0, 4, 0, 3)
Padding tensor from torch.Size([1, 2525, 2908]) with padding (0, 4, 0, 3)
Successfully processed item 0 - Image: torch.Size([3, 2528, 2912]), Mask: torch.Size([1, 2528, 2912])
Processing item 1 - Image: torch.Size([3, 2197, 2302]), Mask: torch.Size([1, 2197, 2302])
Padding tensor from torch.Size([3, 2197, 2302]) with padding (0, 610, 0, 331)
Paddin

KeyboardInterrupt: 

### Built to only train on originals while padding degrades

In [19]:
# Save model
torch.save(model.state_dict(), "unet_model.pth")