In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
# from PIL import Image
import matplotlib.pyplot as plt
import skimage as ski
from skimage import measure
from skimage import util as sk_util
from skimage.morphology import remove_small_objects
import os


In [2]:
"""
Skin lesion segmentation - great for beginners
- URL: https://challenge.isic-archive.com/
- Images: ~2000+ dermoscopy images
- Task: Melanoma/skin lesion segmentation
- Format: RGB images with binary masks
"""

# Example loading for ISIC
class ISICDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        ISIC skin lesion dataset loader.

        Expected structure:
        root_dir/
            ISIC_0000000.jpg
            ISIC_0000000_segmentation.png
            ...
        """
        self.root_dir = root_dir
        self.transform = transform
        self.images = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.root_dir, img_name)
        mask_path = os.path.join(self.root_dir,
                                img_name.replace('.jpg', '_segmentation.png'))

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            augmented = self.transform(image=np.array(image),
                                     mask=np.array(mask))
            image = augmented['image']
            mask = augmented['mask']

        return transforms.ToTensor()(image), torch.tensor(np.array(mask) > 0).long()

In [3]:
class UNet(nn.Module):
    """
    U-Net architecture for semantic segmentation.

    U-Net is a convolutional neural network that was developed for biomedical image segmentation.
    The architecture consists of:
    1. A contracting path (encoder/downsampling) that captures context
    2. An expansive path (decoder/upsampling) that enables precise localization
    3. Skip connections between encoder and decoder that preserve spatial information

    The network has a distinctive U-shaped architecture when visualized, hence the name "U-Net".

    Key features:
    - Symmetric architecture with skip connections
    - Combines low-level feature maps with high-level ones
    - Works well with limited training data (common in medical imaging)
    - Produces dense predictions (output same size as input)

    Architecture diagram:

    Input (572x572x3) ─┐
                      ↓
    ┌─────────────────┴─────────────────┐
    │          ENCODER PATH             │
    │                                   │
    │  Conv Block 1 (64 features)       │──── Skip Connection 1 ────┐
    │       ↓ MaxPool                   │                           │
    │  Conv Block 2 (128 features)      │──── Skip Connection 2 ───┐│
    │       ↓ MaxPool                   │                          ││
    │  Conv Block 3 (256 features)      │──── Skip Connection 3 ──┐││
    │       ↓ MaxPool                   │                         │││
    │  Conv Block 4 (512 features)      │──── Skip Connection 4 ─┐│││
    │       ↓ MaxPool                   │                        ││││
    └───────────────────────────────────┘                        ││││
                      ↓                                          ││││
    ┌─────────────────┴─────────────────┐                        ││││
    │         BOTTLENECK                │                        ││││
    │     (1024 features)               │                        ││││
    └───────────────────────────────────┘                        ││││
                      ↓                                          ││││
    ┌─────────────────┴─────────────────┐                        ││││
    │          DECODER PATH             │                        ││││
    │                                   │                        ││││
    │  UpConv + Concat ←───────────────────────────────────────┘│││
    │  Conv Block 5 (512 features)      │                         │││
    │       ↓                           │                         │││
    │  UpConv + Concat ←───────────────────────────────────────┘││
    │  Conv Block 6 (256 features)      │                          ││
    │       ↓                           │                          ││
    │  UpConv + Concat ←───────────────────────────────────────┘│
    │  Conv Block 7 (128 features)      │                           │
    │       ↓                           │                           │
    │  UpConv + Concat ←───────────────────────────────────────┘
    │  Conv Block 8 (64 features)       │
    │       ↓                           │
    │  Final Conv (output_channels)     │
    └───────────────────────────────────┘
                      ↓
               Output (572x572xC)
    """

    def __init__(self, in_channels=3, out_channels=2, features=[64, 128, 256, 512]):
        """
        Initialize U-Net model.

        Args:
            in_channels (int): Number of input channels
                              (3 for RGB images, 1 for grayscale)
            out_channels (int): Number of output channels
                               (number of classes for segmentation)
                               - 2 for binary segmentation (background + 1 object class)
                               - N for multi-class segmentation (background + N-1 object classes)
            features (list): Number of features/filters in each encoder layer
                            Default: [64, 128, 256, 512]
                            Each decoder layer has the same features in reverse order

        The network depth is determined by len(features).
        Deeper networks can capture more complex patterns but require more memory.
        """
        super(UNet, self).__init__()

        # ModuleList to store encoder and decoder blocks
        # Using ModuleList ensures all layers are properly registered with PyTorch
        self.ups = nn.ModuleList()    # Decoder/upsampling blocks
        self.downs = nn.ModuleList()  # Encoder/downsampling blocks

        # MaxPool2d for downsampling between encoder blocks
        # kernel_size=2, stride=2 reduces spatial dimensions by half
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        #############################################
        # ENCODER PATH (LEFT SIDE OF U)
        #############################################
        # Build encoder blocks progressively
        # Each block doubles the number of features while halving spatial dimensions
        for feature in features:
            # Create double convolution block
            # Input: in_channels (3 for first block, then previous feature size)
            # Output: feature (64, 128, 256, 512)
            self.downs.append(self._block(in_channels, feature))
            in_channels = feature  # Update for next iteration

        #############################################
        # BOTTLENECK (BOTTOM OF U)
        #############################################
        # Deepest part of the network
        # Has the most features but smallest spatial dimensions
        # Input: features[-1] (512)
        # Output: features[-1]*2 (1024)
        # This layer captures the most abstract/high-level features
        self.bottleneck = self._block(features[-1], features[-1]*2)

        #############################################
        # DECODER PATH (RIGHT SIDE OF U)
        #############################################
        # Build decoder blocks in reverse order
        # Each block halves the number of features while doubling spatial dimensions
        for feature in reversed(features):
            # Transposed convolution for upsampling
            # Input: feature*2 (e.g., 1024 for first decoder block)
            # Output: feature (e.g., 512)
            # kernel_size=2, stride=2 doubles spatial dimensions
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )

            # Double convolution block after concatenation with skip connection
            # Input: feature*2 (after concatenation with encoder features)
            # Output: feature
            # The *2 accounts for concatenation with skip connection
            self.ups.append(self._block(feature*2, feature))

        # Final 1x1 convolution to map to desired number of output channels
        # This is essentially a pixel-wise classification layer
        # Input: features[0] (64)
        # Output: out_channels (number of classes)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels):
        """
        Create a basic convolutional block used throughout U-Net.

        This is the fundamental building block of U-Net, consisting of:
        1. 3x3 Convolution (preserves spatial dimensions with padding=1)
        2. Batch Normalization (normalizes features, helps training stability)
        3. ReLU activation (introduces non-linearity)
        4. Another 3x3 Convolution
        5. Batch Normalization
        6. ReLU activation

        This double convolution pattern is used in the original U-Net paper.

        Args:
            in_channels (int): Number of input channels
            out_channels (int): Number of output channels

        Returns:
            nn.Sequential: Sequential container of layers

        Why this design:
        - Two 3x3 convolutions can capture the same receptive field as one 5x5
          but with fewer parameters and more non-linearity
        - Batch normalization helps with training stability and speed
        - ReLU is computationally efficient and helps with gradient flow
        """
        return nn.Sequential(
            # First convolution
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),  # inplace=True saves memory

            # Second convolution
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        Forward pass through the U-Net.

        The forward pass follows the U-shaped architecture:
        1. Encoder: Progressively downsample while increasing features
        2. Bottleneck: Process at the deepest level
        3. Decoder: Progressively upsample while decreasing features
        4. Skip connections: Concatenate encoder features to decoder

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W)
                             B = batch size
                             C = channels (e.g., 3 for RGB)
                             H = height
                             W = width

        Returns:
            torch.Tensor: Output segmentation map of shape (B, num_classes, H, W)

        Example dimensions for a 256x256 RGB image:
        Input: (B, 3, 256, 256)
        After block 1: (B, 64, 256, 256) → After pool: (B, 64, 128, 128)
        After block 2: (B, 128, 128, 128) → After pool: (B, 128, 64, 64)
        After block 3: (B, 256, 64, 64) → After pool: (B, 256, 32, 32)
        After block 4: (B, 512, 32, 32) → After pool: (B, 512, 16, 16)
        After bottleneck: (B, 1024, 16, 16)
        ... (decoder reverses this process)
        Output: (B, num_classes, 256, 256)
        """
        # List to store feature maps from encoder for skip connections
        skip_connections = []

        #############################################
        # ENCODER PATH
        #############################################
        # Process through each encoder block
        for down in self.downs:
            # Apply double convolution block
            x = down(x)

            # Store output for skip connection before pooling
            # These features contain important spatial information
            # that would be lost after pooling
            skip_connections.append(x)

            # Downsample using max pooling
            # Reduces height and width by half
            x = self.pool(x)

        #############################################
        # BOTTLENECK
        #############################################
        # Process through the bottleneck (deepest layer)
        # At this point, x has smallest spatial dimensions
        # but highest number of feature channels
        x = self.bottleneck(x)

        #############################################
        # DECODER PATH
        #############################################
        # Reverse the skip connections list to match decoder order
        # We concatenate from deepest to shallowest
        skip_connections = skip_connections[::-1]

        # Process through decoder blocks
        # Note: self.ups has alternating upconv and conv blocks
        # Even indices (0,2,4,...): transposed convolutions
        # Odd indices (1,3,5,...): double convolution blocks
        for idx in range(0, len(self.ups), 2):
            # 1. Upsample using transposed convolution
            x = self.ups[idx](x)

            # 2. Get corresponding skip connection from encoder
            skip_connection = skip_connections[idx//2]

            # 3. Handle potential size mismatches
            # Due to pooling/upsampling, sizes might not match exactly
            # This can happen with odd-sized inputs
            if x.shape != skip_connection.shape:
                x = F.resize(x, size=skip_connection.shape[2:])

            # 4. Concatenate along channel dimension
            # This combines high-resolution features from encoder
            # with upsampled features from decoder
            # Concatenation happens along channel dim (dim=1)
            concat_skip = torch.cat((skip_connection, x), dim=1)

            # 5. Process concatenated features through double conv block
            x = self.ups[idx+1](concat_skip)

        #############################################
        # FINAL OUTPUT
        #############################################
        # Apply final 1x1 convolution to get desired number of output channels
        # This essentially performs pixel-wise classification
        # Output shape: (B, num_classes, H, W)
        return self.final_conv(x)


In [4]:
from skimage import util as sk_util

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, debug=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))
        self.debug = debug

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        # Load images
        image = ski.io.imread(img_path)
        mask = ski.io.imread(mask_path)

        # Convert to numpy
        image = np.array(image)
        mask = np.array(mask)

        # Debug: print original mask values
        if self.debug and idx == 0:
            print(f"\nDebug Sample {idx}:")
            print(f"Original mask shape: {mask.shape}")
            print(f"Original mask dtype: {mask.dtype}")
            print(f"Original mask unique values: {np.unique(mask)}")
            print(f"Original mask range: [{mask.min()}, {mask.max()}]")

        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Convert image to tensor
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0

        # Method 1: Using skimage.util.img_as_float
        mask_float = sk_util.img_as_float(mask)  # Converts to [0, 1] float
        mask_binary = (mask_float > 0.5).astype(np.int64)  # Threshold at 0.5

        # Method 2: Direct conversion
        # mask_binary = (mask > 127).astype(np.int64)

        # Method 3: Divide and round
        # mask_binary = np.round(mask / 255.0).astype(np.int64)

        # Method 4: Any non-zero becomes 1
        # mask_binary = (mask != 0).astype(np.int64)

        if self.debug and idx == 0:
            print(f"After conversion:")
            print(f"mask_float range: [{mask_float.min()}, {mask_float.max()}]")
            print(f"mask_binary unique values: {np.unique(mask_binary)}")

        mask_tensor = torch.from_numpy(mask_binary).long()

        # Final verification
        assert mask_tensor.max() <= 1, f"Mask has values > 1: {mask_tensor.max()}"
        assert mask_tensor.min() >= 0, f"Mask has negative values: {mask_tensor.min()}"

        return image, mask_tensor

In [5]:

def train_unet(model, train_loader, val_loader, epochs=50, lr=0.001):
    """
    Train the U-Net model.

    Args:
        model: U-Net model
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs (int): Number of training epochs
        lr (float): Learning rate

    Returns:
        dict: Training history
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

        # Calculate average losses
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)

        print(f'Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    return history


In [6]:
# Example usage
def setup_and_train():
    """
    Set up and train the U-Net model for object detection and classification.
    """
    # Initialize model
    model = UNet(in_channels=3, out_channels=10)  # Adjust out_channels for your number of classes

    # Create data loaders
    train_dataset = SegmentationDataset('path/to/train/images', 'path/to/train/masks')
    val_dataset = SegmentationDataset('path/to/val/images', 'path/to/val/masks')

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

    # Train model
    history = train_unet(model, train_loader, val_loader, epochs=50)

    return model, history

def predict_and_detect_objects(model, image_path):
    """
    Predict segmentation mask and detect individual objects.

    Args:
        model: Trained U-Net model
        image_path (str): Path to input image

    Returns:
        list: Detected objects with masks and bounding boxes
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0
    image_tensor = image_tensor.unsqueeze(0).to(device)

    # Predict segmentation mask
    with torch.no_grad():
        output = model(image_tensor)
        pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    # Post-process to find individual objects
    objects = post_process_segmentation(pred_mask)

    # Visualize results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(image_np)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(pred_mask, cmap='tab20')
    axes[1].set_title('Segmentation Mask')
    axes[1].axis('off')

    # Draw bounding boxes
    axes[2].imshow(image_np)
    for obj in objects:
        minr, minc, maxr, maxc = obj['bbox']
        rect = plt.Rectangle((minc, minr), maxc-minc, maxr-minr,
                           fill=False, edgecolor='red', linewidth=2)
        axes[2].add_patch(rect)
    axes[2].set_title(f'Detected Objects ({len(objects)})')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    return objects

# Classification component (can be added after U-Net)
class ObjectClassifier(nn.Module):
    """
    Simple CNN classifier for classifying detected objects.
    """

    def __init__(self, num_classes):
        """
        Initialize the classifier.

        Args:
            num_classes (int): Number of object classes
        """
        super(ObjectClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [7]:
dataset = SegmentationDataset(image_dir=r'C:\Users\btb51\Documents\GitHub\unet_tests\data\images',
                              mask_dir=r'C:\Users\btb51\Documents\GitHub\unet_tests\data\masks',
                              debug=True)


In [8]:
val_data = SegmentationDataset(image_dir=r'C:\Users\btb51\Documents\GitHub\unet_tests\data\val\images',
                               mask_dir=r'C:\Users\btb51\Documents\GitHub\unet_tests\data\val\masks', debug=True)

In [None]:
# Test different conversion methods
# import matplotlib.pyplot as plt

# Load a sample mask
mask_path = "data\masks\mask_0000.png"  # Replace with your mask path
mask = ski.io.imread(mask_path)
mask_array = np.array(mask)

print(f"Original mask:")
print(f"  Shape: {mask_array.shape}")
print(f"  Dtype: {mask_array.dtype}")
print(f"  Unique values: {np.unique(mask_array)}")
print(f"  Range: [{mask_array.min()}, {mask_array.max()}]")

# Test different methods
methods = {
    "Method 1: skimage.img_as_float": lambda m: (sk_util.img_as_float(m) > 0.5).astype(np.int64),
    "Method 2: Threshold at 127": lambda m: (m > 127).astype(np.int64),
    "Method 3: Divide by 255": lambda m: np.round(m / 255.0).astype(np.int64),
    "Method 4: Non-zero to 1": lambda m: (m != 0).astype(np.int64),
    # "Method 5: Direct division": lambda m: (m // 255).astype(np.int64)
}

# fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# axes = axes.flatten()

# # Plot original
# axes[0].imshow(mask_array, cmap='gray')
# axes[0].set_title(f"Original (values: {np.unique(mask_array)})")

# Test each method
for i, (name, method) in enumerate(methods.items(), 1):
    converted = method(mask_array)
    # axes[i].imshow(converted, cmap='gray')
    # axes[i].set_title(f"{name}\n(values: {np.unique(converted)})")
    print(f"\n{name}:")
    print(f"  Unique values: {np.unique(converted)}")
    print(f"  Value counts: {dict(zip(*np.unique(converted, return_counts=True)))}")

plt.tight_layout()
plt.show()

In [None]:
post_segs = post_process_segmentation(dataset[0][1].numpy(), min_size=50)
post_segs[3]

In [None]:
len(post_segs)

In [9]:
train_loader = DataLoader(dataset, batch_size=5, shuffle=True)
images, masks = next(iter(train_loader))
print(f"Image batch shape: {images.shape}")  # [16, 3, 256, 256]
print(f"Mask batch shape: {masks.shape}")    # [16, 256, 256]
print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Unique mask values: {torch.unique(masks)}")  # Should be [0, 1]

: 

In [None]:
val_loader = DataLoader(val_data, batch_size=5, shuffle=True)
images, masks = next(iter(val_loader))
print(f"Image batch shape: {images.shape}")  # [16, 3, 256, 256]
print(f"Mask batch shape: {masks.shape}")    # [16, 256, 256]
print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Unique mask values: {torch.unique(masks)}")  # Should be [0, 1]

In [None]:
# Force CPU to get better error messages
device = torch.device('cpu')
model = UNet(in_channels=3, out_channels=2)
model = model.to(device)

# Test with a single batch
images, masks = next(iter(train_loader))
images = images.to(device)
masks = masks.to(device)

print(f"Images shape: {images.shape}")
print(f"Masks shape: {masks.shape}")
print(f"Unique mask values: {torch.unique(masks)}")
print(f"Mask dtype: {masks.dtype}")

# Try forward pass
try:
    outputs = model(images)
    print(f"Outputs shape: {outputs.shape}")
    
    # Try loss calculation
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs, masks)
    print(f"Loss: {loss.item()}")
except Exception as e:
    print(f"Error: {e}")

In [None]:
model = UNet(in_channels=3, out_channels=2)  # 2 classes: background, circle
trained_model = train_unet(model, train_loader, val_loader, epochs=20)
type(model)

In [None]:
import numpy as np
from skimage import draw, io
import matplotlib.pyplot as plt
import os
from typing import Tuple, List
import random

def generate_circle_image_and_mask(
    image_size: Tuple[int, int] = (256, 256),
    num_circles: int = 5,
    radius_range: Tuple[int, int] = (10, 40),
    intensity_range: Tuple[float, float] = (0.3, 0.9),
    background_intensity: float = 0.1,
    noise_level: float = 0.05,
    overlap_allowed: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic image with circles of varying intensity and corresponding mask.

    Args:
        image_size: Size of the image (height, width)
        num_circles: Number of circles to generate
        radius_range: Range of circle radii (min, max)
        intensity_range: Range of circle intensities (min, max)
        background_intensity: Background pixel intensity
        noise_level: Amount of noise to add (0-1)
        overlap_allowed: Whether circles can overlap

    Returns:
        image: Generated image with circles
        mask: Binary mask (0=background, 1=circle)
    """
    height, width = image_size

    # Initialize image with background
    image = np.full((height, width), background_intensity, dtype=np.float32)
    mask = np.zeros((height, width), dtype=np.uint8)

    # Keep track of existing circles to avoid overlap if needed
    existing_circles = []

    circles_placed = 0
    attempts = 0
    max_attempts = num_circles * 100  # Prevent infinite loop

    while circles_placed < num_circles and attempts < max_attempts:
        attempts += 1

        # Random circle parameters
        radius = random.randint(*radius_range)
        center_y = random.randint(radius, height - radius)
        center_x = random.randint(radius, width - radius)
        intensity = random.uniform(*intensity_range)

        # Check for overlap if not allowed
        if not overlap_allowed:
            overlap = False
            for existing_center, existing_radius in existing_circles:
                distance = np.sqrt((center_x - existing_center[0])**2 +
                                 (center_y - existing_center[1])**2)
                if distance < (radius + existing_radius + 5):  # 5 pixel buffer
                    overlap = True
                    break

            if overlap:
                continue

        # Draw circle
        rr, cc = draw.disk((center_y, center_x), radius, shape=image.shape)

        # Apply intensity to image
        image[rr, cc] = intensity

        # Update mask
        mask[rr, cc] = 1

        # Record circle position
        existing_circles.append(((center_x, center_y), radius))
        circles_placed += 1

    # Add noise to image
    if noise_level > 0:
        noise = np.random.normal(0, noise_level, image.shape)
        image = np.clip(image + noise, 0, 1)

    return image, mask

def generate_dataset(
    num_images: int = 100,
    output_dir: str = "synthetic_circles",
    image_size: Tuple[int, int] = (256, 256),
    num_circles_range: Tuple[int, int] = (3, 8),
    **kwargs
) -> None:
    """
    Generate a dataset of synthetic circle images and masks.

    Args:
        num_images: Number of images to generate
        output_dir: Directory to save images and masks
        image_size: Size of each image
        num_circles_range: Range for number of circles per image
        **kwargs: Additional arguments for generate_circle_image_and_mask
    """
    # Create output directories
    image_dir = os.path.join(output_dir, "images")
    mask_dir = os.path.join(output_dir, "masks")
    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(mask_dir, exist_ok=True)

    for i in range(num_images):
        # Random number of circles for this image
        num_circles = random.randint(*num_circles_range)

        # Generate image and mask
        image, mask = generate_circle_image_and_mask(
            image_size=image_size,
            num_circles=num_circles,
            **kwargs
        )

        # Save image and mask
        image_path = os.path.join(image_dir, f"image_{i:04d}.png")
        mask_path = os.path.join(mask_dir, f"mask_{i:04d}.png")

        # Convert to uint8 for saving
        image_uint8 = (image * 255).astype(np.uint8)
        mask_uint8 = (mask * 255).astype(np.uint8)

        io.imsave(image_path, image_uint8, check_contrast=False)
        io.imsave(mask_path, mask_uint8, check_contrast=False)

        print(f"Generated image {i+1}/{num_images}")


In [None]:

def create_multi_class_circles(
    image_size: Tuple[int, int] = (256, 256),
    circle_types: List[dict] = None,
    num_circles_per_type: int = 3,
    background_intensity: float = 0.1,
    noise_level: float = 0.05
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic image with different types of circles (multi-class segmentation).

    Args:
        image_size: Size of the image
        circle_types: List of dicts with circle properties (intensity, radius_range, class_id)
        num_circles_per_type: Number of circles to generate per type
        background_intensity: Background pixel intensity
        noise_level: Amount of noise to add

    Returns:
        image: Generated image
        mask: Multi-class mask (0=background, 1,2,3...=different circle types)
    """
    if circle_types is None:
        circle_types = [
            {"intensity": 0.3, "radius_range": (5, 15), "class_id": 1},   # Small dark circles
            {"intensity": 0.6, "radius_range": (15, 25), "class_id": 2},  # Medium gray circles
            {"intensity": 0.9, "radius_range": (25, 40), "class_id": 3},  # Large bright circles
        ]

    height, width = image_size
    image = np.full((height, width), background_intensity, dtype=np.float32)
    mask = np.zeros((height, width), dtype=np.uint8)

    for circle_type in circle_types:
        for _ in range(num_circles_per_type):
            radius = random.randint(*circle_type["radius_range"])
            center_y = random.randint(radius, height - radius)
            center_x = random.randint(radius, width - radius)

            rr, cc = draw.disk((center_y, center_x), radius, shape=image.shape)

            # Only update if not already occupied (to avoid overlap)
            empty_pixels = mask[rr, cc] == 0

            image[rr[empty_pixels], cc[empty_pixels]] = circle_type["intensity"]
            mask[rr[empty_pixels], cc[empty_pixels]] = circle_type["class_id"]

    # Add noise
    if noise_level > 0:
        noise = np.random.normal(0, noise_level, image.shape)
        image = np.clip(image + noise, 0, 1)

    return image, mask

def visualize_samples(num_samples: int = 4):
    """
    Generate and visualize sample images with their masks.
    """
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))

    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_samples):
        # Generate image and mask
        if i < num_samples // 2:
            # Binary segmentation examples
            image, mask = generate_circle_image_and_mask(
                num_circles=random.randint(3, 7),
                overlap_allowed=True if i % 2 == 0 else False
            )
            title_suffix = "(Binary - Overlap)" if i % 2 == 0 else "(Binary - No Overlap)"
        else:
            # Multi-class segmentation examples
            image, mask = create_multi_class_circles()
            title_suffix = "(Multi-class)"

        # Original image
        axes[i, 0].imshow(image, cmap='gray')
        axes[i, 0].set_title(f'Original Image {title_suffix}')
        axes[i, 0].axis('off')

        # Mask
        axes[i, 1].imshow(mask, cmap='tab20' if i >= num_samples // 2 else 'gray')
        axes[i, 1].set_title('Segmentation Mask')
        axes[i, 1].axis('off')

        # Overlay
        overlay = np.stack([image, image, image], axis=-1)
        mask_colored = plt.cm.tab20(mask / mask.max())[:, :, :3] if mask.max() > 0 else np.zeros_like(overlay)
        overlay = 0.7 * overlay + 0.3 * mask_colored
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title('Overlay')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage with PyTorch Dataset
class SyntheticCircleDataset(Dataset):
    """
    PyTorch Dataset for synthetic circle images.
    """
    def __init__(self, num_samples: int = 1000, transform=None, multi_class: bool = False):
        """
        Initialize synthetic dataset.

        Args:
            num_samples: Number of samples to generate
            transform: Optional transforms
            multi_class: If True, generate multi-class segmentation
        """
        self.num_samples = num_samples
        self.transform = transform
        self.multi_class = multi_class

        # Pre-generate all samples for consistency
        self.images = []
        self.masks = []

        print(f"Generating {num_samples} synthetic samples...")
        for i in range(num_samples):
            if multi_class:
                image, mask = create_multi_class_circles()
            else:
                image, mask = generate_circle_image_and_mask(
                    num_circles=random.randint(3, 8)
                )

            self.images.append(image)
            self.masks.append(mask)

            if (i + 1) % 100 == 0:
                print(f"Generated {i + 1}/{num_samples} samples")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        if self.transform:
            # Apply augmentations
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Convert to tensors
        image = torch.from_numpy(image).unsqueeze(0).float()  # Add channel dimension
        mask = torch.from_numpy(mask).long()

        return image, mask

# Complete example with training
def train_on_synthetic_data():
    """
    Example of training U-Net on synthetic circle data.
    """
    # Create synthetic dataset
    train_dataset = SyntheticCircleDataset(num_samples=1000, multi_class=False)
    val_dataset = SyntheticCircleDataset(num_samples=200, multi_class=False)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet(in_channels=1, out_channels=2).to(device)  # 1 input channel (grayscale)