# Pupil & Iris Segmentation Model Training

Trains a lightweight U-Net segmentation model for pupil and iris detection
using freely available open-source datasets.

**Datasets used:**
- OpenEDS (Facebook Research) - synthetic eye segmentation
- LPW (Labelled Pupils in the Wild) - pupil center annotations
- CASIA-Iris-Thousand - iris segmentation masks
- Synthetic augmentation for anisocoria simulation

**Output:** TFLite model + SavedModel for Cloud Run deployment

In [None]:
# Install dependencies
!pip install -q tensorflow opencv-python-headless albumentations gdown kaggle pillow scikit-learn

In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
import albumentations as A
from pathlib import Path
import json
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## 1. Dataset Download & Preparation

We use multiple open-source datasets and unify them into a common format:
- Image: 256x256 RGB
- Mask: 256x256 with classes {0: background, 1: iris, 2: pupil}

In [None]:
DATA_DIR = Path('/content/eye_data')
IMAGES_DIR = DATA_DIR / 'images'
MASKS_DIR = DATA_DIR / 'masks'
IMG_SIZE = 256
NUM_CLASSES = 3  # background, iris, pupil

os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(MASKS_DIR, exist_ok=True)

print(f"Data directory: {DATA_DIR}")

In [None]:
# Download OpenEDS dataset (Facebook Research - open for research use)
# OpenEDS provides eye images with semantic segmentation labels
# Classes: background, sclera, iris, pupil
# If direct download is unavailable, we generate synthetic training data

def download_openeds():
    """Attempt to download OpenEDS or fall back to info message."""
    openeds_dir = DATA_DIR / 'openeds'
    os.makedirs(openeds_dir, exist_ok=True)
    
    # OpenEDS requires sign-up at https://research.facebook.com/openeds-challenge
    # If you have access, place the data in /content/eye_data/openeds/
    print("OpenEDS dataset: Place downloaded files in", openeds_dir)
    print("Download from: https://research.facebook.com/openeds-challenge")
    print("If unavailable, synthetic data will be generated below.")
    return openeds_dir

openeds_dir = download_openeds()

In [None]:
# Download LPW (Labelled Pupils in the Wild) dataset
# Provides pupil center coordinates which we convert to circular masks

def download_lpw():
    """Download LPW dataset."""
    lpw_dir = DATA_DIR / 'lpw'
    os.makedirs(lpw_dir, exist_ok=True)
    
    # LPW is available at: https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/labelled-pupils-in-the-wild-lpw
    print("LPW dataset: Place downloaded files in", lpw_dir)
    print("Download from: https://perceptualui.org/research/datasets/LPW/")
    return lpw_dir

lpw_dir = download_lpw()

In [None]:
# Try to download from Roboflow Universe - community eye segmentation datasets
# These are freely available with API key

def download_roboflow_datasets():
    """Download available eye segmentation datasets from Roboflow."""
    try:
        from roboflow import Roboflow
        # Pupil detection datasets on Roboflow Universe
        # Users can get a free API key at https://roboflow.com
        print("To use Roboflow datasets:")
        print("1. Sign up at https://roboflow.com (free tier)")
        print("2. Search 'pupil segmentation' or 'iris segmentation'")
        print("3. Export in 'Semantic Segmentation' format")
        print("4. Place in /content/eye_data/roboflow/")
    except ImportError:
        print("Roboflow not installed. Using synthetic data generation.")

download_roboflow_datasets()

## 2. Synthetic Data Generation

To ensure we have sufficient training data regardless of dataset availability,
we generate realistic synthetic eye images with precise ground-truth masks.

This approach:
- Creates eyes with varying pupil/iris ratios (simulating anisocoria)
- Adds realistic lighting, reflections, and noise
- Generates perfect pixel-level segmentation masks
- Covers edge cases (dilated, constricted, off-center pupils)

In [None]:
def generate_synthetic_eye(img_size=256, pupil_ratio=None, anisocoria=False):
    """
    Generate a synthetic eye image with precise segmentation mask.
    
    Returns:
        image: RGB uint8 array (img_size, img_size, 3)
        mask: uint8 array (img_size, img_size) with values {0: bg, 1: iris, 2: pupil}
    """
    img = np.zeros((img_size, img_size, 3), dtype=np.uint8)
    mask = np.zeros((img_size, img_size), dtype=np.uint8)
    
    center_x = img_size // 2 + np.random.randint(-15, 16)
    center_y = img_size // 2 + np.random.randint(-15, 16)
    center = (center_x, center_y)
    
    # Iris parameters
    iris_radius = np.random.randint(img_size // 4, img_size // 3)
    
    # Iris color variations (brown, blue, green, hazel, gray)
    iris_colors = [
        (np.random.randint(40, 90), np.random.randint(60, 120), np.random.randint(100, 180)),   # brown
        (np.random.randint(140, 200), np.random.randint(100, 160), np.random.randint(40, 80)),   # blue
        (np.random.randint(60, 120), np.random.randint(120, 180), np.random.randint(60, 100)),   # green
        (np.random.randint(80, 130), np.random.randint(100, 150), np.random.randint(120, 180)),  # hazel
        (np.random.randint(120, 160), np.random.randint(120, 160), np.random.randint(120, 160)), # gray
    ]
    iris_color = iris_colors[np.random.randint(0, len(iris_colors))]
    
    # Pupil parameters
    if pupil_ratio is None:
        if anisocoria:
            # Simulate pathological pupils
            pupil_ratio = np.random.choice([
                np.random.uniform(0.15, 0.25),  # very constricted
                np.random.uniform(0.60, 0.80),  # very dilated
            ])
        else:
            pupil_ratio = np.random.uniform(0.25, 0.65)  # normal range
    
    pupil_radius = int(iris_radius * pupil_ratio)
    
    # Slight pupil offset from iris center (realistic)
    pupil_offset_x = np.random.randint(-3, 4)
    pupil_offset_y = np.random.randint(-3, 4)
    pupil_center = (center_x + pupil_offset_x, center_y + pupil_offset_y)
    
    # Draw sclera (white background around iris)
    sclera_color = (
        np.random.randint(220, 250),
        np.random.randint(220, 250),
        np.random.randint(225, 255)
    )
    # Elliptical sclera
    sclera_w = int(iris_radius * np.random.uniform(1.6, 2.2))
    sclera_h = int(iris_radius * np.random.uniform(1.1, 1.4))
    cv2.ellipse(img, center, (sclera_w, sclera_h), 0, 0, 360, sclera_color, -1)
    
    # Draw iris with radial gradient
    for r in range(iris_radius, 0, -1):
        t = r / iris_radius  # 1.0 at edge, 0.0 at center
        # Darker at edges (limbal ring)
        darken = 0.5 + 0.5 * t  # outer is darker
        c = tuple(int(v * darken) for v in iris_color)
        cv2.circle(img, center, r, c, -1)
    cv2.circle(mask, center, iris_radius, 1, -1)  # iris mask
    
    # Add iris texture (radial patterns)
    num_fibers = np.random.randint(20, 50)
    for _ in range(num_fibers):
        angle = np.random.uniform(0, 2 * np.pi)
        length = np.random.uniform(0.3, 0.95) * iris_radius
        x1 = int(center_x + np.cos(angle) * pupil_radius * 1.1)
        y1 = int(center_y + np.sin(angle) * pupil_radius * 1.1)
        x2 = int(center_x + np.cos(angle) * length)
        y2 = int(center_y + np.sin(angle) * length)
        fiber_color = tuple(int(v * np.random.uniform(0.7, 1.3)) for v in iris_color)
        fiber_color = tuple(max(0, min(255, v)) for v in fiber_color)
        cv2.line(img, (x1, y1), (x2, y2), fiber_color, 1)
    
    # Draw pupil (black)
    pupil_darkness = np.random.randint(5, 30)
    cv2.circle(img, pupil_center, pupil_radius, (pupil_darkness, pupil_darkness, pupil_darkness), -1)
    cv2.circle(mask, pupil_center, pupil_radius, 2, -1)  # pupil mask (overwrites iris)
    
    # Add specular reflection (corneal light reflex)
    num_reflections = np.random.randint(1, 4)
    for _ in range(num_reflections):
        ref_x = center_x + np.random.randint(-iris_radius//2, iris_radius//2)
        ref_y = center_y + np.random.randint(-iris_radius//2, iris_radius//2)
        ref_r = np.random.randint(2, 8)
        cv2.circle(img, (ref_x, ref_y), ref_r, (240, 240, 255), -1)
    
    # Add eyelids (partial occlusion)
    if np.random.random() > 0.3:
        lid_y_top = center_y - int(iris_radius * np.random.uniform(0.7, 1.2))
        lid_y_bot = center_y + int(iris_radius * np.random.uniform(0.7, 1.2))
        # Skin color
        skin_color = (
            np.random.randint(140, 220),
            np.random.randint(120, 200),
            np.random.randint(100, 180)
        )
        # Top eyelid
        pts_top = np.array([
            [0, 0], [img_size, 0], [img_size, lid_y_top],
            [center_x, lid_y_top - np.random.randint(5, 25)],
            [0, lid_y_top]
        ])
        cv2.fillPoly(img, [pts_top], skin_color)
        cv2.fillPoly(mask, [pts_top], 0)  # background where lid is
        
        # Bottom eyelid
        pts_bot = np.array([
            [0, lid_y_bot], [center_x, lid_y_bot + np.random.randint(5, 25)],
            [img_size, lid_y_bot], [img_size, img_size], [0, img_size]
        ])
        cv2.fillPoly(img, [pts_bot], skin_color)
        cv2.fillPoly(mask, [pts_bot], 0)
    
    # Add noise
    noise = np.random.normal(0, np.random.uniform(3, 12), img.shape).astype(np.int16)
    img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
    
    # Random brightness/contrast
    alpha = np.random.uniform(0.7, 1.3)
    beta = np.random.randint(-30, 31)
    img = np.clip(alpha * img.astype(np.float32) + beta, 0, 255).astype(np.uint8)
    
    return img, mask


# Test the generator
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(4):
    aniso = i >= 2
    img, msk = generate_synthetic_eye(IMG_SIZE, anisocoria=aniso)
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"{'Anisocoria' if aniso else 'Normal'} eye")
    axes[0, i].axis('off')
    axes[1, i].imshow(msk, cmap='viridis', vmin=0, vmax=2)
    axes[1, i].set_title('Mask (0=bg, 1=iris, 2=pupil)')
    axes[1, i].axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Generate training dataset
NUM_SYNTHETIC = 5000  # Generate 5000 synthetic samples
NUM_ANISOCORIA = 2000  # Extra anisocoria cases for the clinical use case

print(f"Generating {NUM_SYNTHETIC + NUM_ANISOCORIA} synthetic eye images...")

images = []
masks = []

# Normal pupils
for i in range(NUM_SYNTHETIC):
    img, msk = generate_synthetic_eye(IMG_SIZE)
    images.append(img)
    masks.append(msk)
    if (i + 1) % 1000 == 0:
        print(f"  Normal: {i + 1}/{NUM_SYNTHETIC}")

# Anisocoria cases (extreme pupil sizes)
for i in range(NUM_ANISOCORIA):
    img, msk = generate_synthetic_eye(IMG_SIZE, anisocoria=True)
    images.append(img)
    masks.append(msk)
    if (i + 1) % 500 == 0:
        print(f"  Anisocoria: {i + 1}/{NUM_ANISOCORIA}")

images = np.array(images, dtype=np.uint8)
masks = np.array(masks, dtype=np.uint8)

print(f"\nDataset shape: images={images.shape}, masks={masks.shape}")
print(f"Mask classes: {np.unique(masks)}")
print(f"Class distribution: bg={np.mean(masks==0):.1%}, iris={np.mean(masks==1):.1%}, pupil={np.mean(masks==2):.1%}")

In [None]:
# Load any real datasets that are available and merge with synthetic

def load_real_datasets():
    """Load any downloaded real eye segmentation datasets."""
    real_images = []
    real_masks = []
    
    # Check for OpenEDS data
    openeds_imgs = sorted(glob.glob(str(DATA_DIR / 'openeds' / '**' / '*.png'), recursive=True))
    if openeds_imgs:
        print(f"Found {len(openeds_imgs)} OpenEDS images")
        for img_path in openeds_imgs[:2000]:  # Limit to 2000
            # OpenEDS has paired image/label files
            mask_path = img_path.replace('/images/', '/labels/').replace('/image/', '/label/')
            if os.path.exists(mask_path):
                img = cv2.imread(img_path)
                if img is not None:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                    msk = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                    msk = cv2.resize(msk, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
                    # Remap OpenEDS classes to our format:
                    # OpenEDS: 0=bg, 1=sclera, 2=iris, 3=pupil
                    # Ours: 0=bg, 1=iris, 2=pupil
                    new_msk = np.zeros_like(msk)
                    new_msk[msk == 2] = 1  # iris
                    new_msk[msk == 3] = 2  # pupil
                    real_images.append(img)
                    real_masks.append(new_msk)
    
    # Check for Roboflow exports
    roboflow_imgs = sorted(glob.glob(str(DATA_DIR / 'roboflow' / '**' / '*.jpg'), recursive=True))
    if roboflow_imgs:
        print(f"Found {len(roboflow_imgs)} Roboflow images")
        for img_path in roboflow_imgs[:2000]:
            mask_path = img_path.replace('/images/', '/masks/').replace('.jpg', '.png')
            if os.path.exists(mask_path):
                img = cv2.imread(img_path)
                if img is not None:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                    msk = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                    msk = cv2.resize(msk, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
                    real_images.append(img)
                    real_masks.append(msk)
    
    if real_images:
        print(f"Loaded {len(real_images)} real images total")
        return np.array(real_images), np.array(real_masks)
    else:
        print("No real datasets found. Using synthetic data only.")
        return None, None

real_imgs, real_msks = load_real_datasets()

if real_imgs is not None:
    images = np.concatenate([images, real_imgs], axis=0)
    masks = np.concatenate([masks, real_msks], axis=0)
    print(f"Combined dataset: {images.shape[0]} samples")

In [None]:
# Train/validation split
X_train, X_val, y_train, y_val = train_test_split(
    images, masks, test_size=0.15, random_state=42
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")

# Normalize images to [0, 1]
X_train = X_train.astype(np.float32) / 255.0
X_val = X_val.astype(np.float32) / 255.0

# One-hot encode masks
y_train_oh = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_val_oh = tf.keras.utils.to_categorical(y_val, NUM_CLASSES)

print(f"X_train: {X_train.shape}, y_train: {y_train_oh.shape}")

## 3. Data Augmentation Pipeline

In [None]:
# Augmentation pipeline using albumentations
augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=15, p=0.7,
                       border_mode=cv2.BORDER_CONSTANT),
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 7), p=1),
        A.MotionBlur(blur_limit=(3, 7), p=1),
    ], p=0.3),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
        A.CLAHE(clip_limit=4.0, p=1),
    ], p=0.5),
    A.GaussNoise(var_limit=(5, 30), p=0.3),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.4),
])


def create_tf_dataset(images, masks, batch_size=16, augment_fn=None, shuffle=True):
    """Create a tf.data.Dataset with optional augmentation."""
    def augment_sample(image, mask):
        """Apply augmentation to a single sample."""
        def _augment(img, msk):
            img_np = (img.numpy() * 255).astype(np.uint8)
            msk_np = np.argmax(msk.numpy(), axis=-1).astype(np.uint8)
            result = augment_fn(image=img_np, mask=msk_np)
            aug_img = result['image'].astype(np.float32) / 255.0
            aug_msk = tf.keras.utils.to_categorical(result['mask'], NUM_CLASSES)
            return aug_img.astype(np.float32), aug_msk.astype(np.float32)
        
        img_aug, msk_aug = tf.py_function(
            _augment, [image, mask], [tf.float32, tf.float32]
        )
        img_aug.set_shape(image.shape)
        msk_aug.set_shape(mask.shape)
        return img_aug, msk_aug
    
    dataset = tf.data.Dataset.from_tensor_slices((images, masks))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=min(len(images), 2000))
    if augment_fn is not None:
        dataset = dataset.map(augment_sample, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

BATCH_SIZE = 16
train_ds = create_tf_dataset(X_train, y_train_oh, BATCH_SIZE, augment_fn=augment)
val_ds = create_tf_dataset(X_val, y_val_oh, BATCH_SIZE, augment_fn=None, shuffle=False)

print(f"Training batches: {tf.data.experimental.cardinality(train_ds).numpy()}")
print(f"Validation batches: {tf.data.experimental.cardinality(val_ds).numpy()}")

## 4. Model Architecture

Lightweight U-Net with MobileNetV3-Small encoder.
- Optimized for mobile inference (~2MB model)
- Fast enough for Cloud Run (<100ms per image)
- 3-class output: background, iris, pupil

In [None]:
def conv_block(x, filters, kernel_size=3):
    """Convolution + BatchNorm + ReLU block."""
    x = layers.Conv2D(filters, kernel_size, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


def build_mobilenet_unet(input_shape=(256, 256, 3), num_classes=3):
    """
    Build a U-Net with MobileNetV3-Small encoder.
    
    Architecture:
    - Encoder: MobileNetV3-Small (pretrained on ImageNet)
    - Decoder: Lightweight upsampling with skip connections
    - Output: num_classes channel softmax
    """
    # Encoder (MobileNetV3-Small)
    base_model = tf.keras.applications.MobileNetV3Small(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet',
        minimalistic=True,
    )
    
    # Extract skip connection layers at different resolutions
    # MobileNetV3-Small feature maps at various scales
    layer_names = [
        'multiply',          # 128x128, early features
        'multiply_1',        # 64x64
        'multiply_3',        # 32x32
        'multiply_7',        # 16x16
        'multiply_11',       # 8x8, deep features
    ]
    
    # Try to find actual layer names (they may vary by TF version)
    available_layers = [l.name for l in base_model.layers]
    skip_layers = []
    for name in layer_names:
        if name in available_layers:
            skip_layers.append(name)
    
    if len(skip_layers) < 3:
        # Fallback: pick layers by output shape
        print("Using fallback layer selection...")
        target_sizes = [128, 64, 32, 16, 8]
        skip_layers = []
        for target in target_sizes:
            for layer in reversed(base_model.layers):
                if hasattr(layer, 'output_shape'):
                    shape = layer.output_shape
                    if isinstance(shape, list):
                        shape = shape[0]
                    if len(shape) == 4 and shape[1] == target:
                        if layer.name not in skip_layers:
                            skip_layers.append(layer.name)
                            break
    
    print(f"Skip connection layers: {skip_layers}")
    
    # Get encoder outputs at each scale
    encoder_outputs = [base_model.get_layer(name).output for name in skip_layers]
    
    # Create encoder model
    encoder = Model(inputs=base_model.input, outputs=encoder_outputs)
    
    # Freeze encoder initially (transfer learning)
    encoder.trainable = False
    
    # Build decoder
    inputs = layers.Input(shape=input_shape)
    skips = encoder(inputs)
    
    # Start from deepest features
    x = skips[-1]
    
    # Decoder path with skip connections
    decoder_filters = [128, 64, 48, 32, 24]
    
    for i in range(len(skips) - 2, -1, -1):
        # Upsample
        x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
        
        # Resize if needed to match skip connection
        skip = skips[i]
        if x.shape[1] != skip.shape[1] or x.shape[2] != skip.shape[2]:
            x = layers.Resizing(skip.shape[1], skip.shape[2])(x)
        
        # Concatenate skip connection
        x = layers.Concatenate()([x, skip])
        
        # Conv blocks
        f = decoder_filters[min(i, len(decoder_filters) - 1)]
        x = conv_block(x, f)
        x = conv_block(x, f)
    
    # Final upsampling to input resolution
    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    if x.shape[1] != input_shape[0]:
        x = layers.Resizing(input_shape[0], input_shape[1])(x)
    x = conv_block(x, 16)
    
    # Output layer
    outputs = layers.Conv2D(num_classes, 1, activation='softmax', name='segmentation')(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='pupil_iris_segnet')
    return model, encoder


model, encoder = build_mobilenet_unet()
model.summary()

## 5. Loss Function & Metrics

In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Dice coefficient for multi-class segmentation."""
    y_true_f = tf.cast(tf.reshape(y_true, [-1, NUM_CLASSES]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred, [-1, NUM_CLASSES]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0)
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice[1:])  # Exclude background


def dice_loss(y_true, y_pred):
    """Dice loss = 1 - dice coefficient."""
    return 1.0 - dice_coefficient(y_true, y_pred)


def combined_loss(y_true, y_pred):
    """Combined cross-entropy + dice loss."""
    ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    ce = tf.reduce_mean(ce)
    dl = dice_loss(y_true, y_pred)
    return ce + dl


def pupil_iou(y_true, y_pred):
    """IoU specifically for the pupil class."""
    y_true_pupil = y_true[..., 2]
    y_pred_pupil = tf.cast(tf.argmax(y_pred, axis=-1) == 2, tf.float32)
    intersection = tf.reduce_sum(y_true_pupil * y_pred_pupil)
    union = tf.reduce_sum(y_true_pupil) + tf.reduce_sum(y_pred_pupil) - intersection
    return (intersection + 1e-6) / (union + 1e-6)


def iris_iou(y_true, y_pred):
    """IoU specifically for the iris class."""
    y_true_iris = y_true[..., 1]
    y_pred_iris = tf.cast(tf.argmax(y_pred, axis=-1) == 1, tf.float32)
    intersection = tf.reduce_sum(y_true_iris * y_pred_iris)
    union = tf.reduce_sum(y_true_iris) + tf.reduce_sum(y_pred_iris) - intersection
    return (intersection + 1e-6) / (union + 1e-6)


print("Loss and metrics defined.")

## 6. Training

Two-phase training:
1. **Phase 1**: Frozen encoder, train decoder only (10 epochs)
2. **Phase 2**: Unfreeze encoder, fine-tune entire model with lower LR (20 epochs)

In [None]:
# Phase 1: Train decoder only (encoder frozen)
print("=" * 60)
print("Phase 1: Training decoder with frozen encoder")
print("=" * 60)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=combined_loss,
    metrics=[dice_coefficient, pupil_iou, iris_iou]
)

callbacks_p1 = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_coefficient', patience=5, mode='max', restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_dice_coefficient', factor=0.5, patience=3, mode='max', min_lr=1e-5
    ),
]

history_p1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks_p1,
)

In [None]:
# Phase 2: Fine-tune entire model
print("\n" + "=" * 60)
print("Phase 2: Fine-tuning entire model")
print("=" * 60)

# Unfreeze encoder
encoder.trainable = True

# Recompile with lower learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=[dice_coefficient, pupil_iou, iris_iou]
)

callbacks_p2 = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_coefficient', patience=7, mode='max', restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_dice_coefficient', factor=0.5, patience=3, mode='max', min_lr=1e-6
    ),
    tf.keras.callbacks.ModelCheckpoint(
        '/content/best_model.keras', monitor='val_dice_coefficient',
        mode='max', save_best_only=True
    ),
]

history_p2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=callbacks_p2,
)

In [None]:
# Plot training history
def plot_history(h1, h2):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Combine histories
    metrics = ['loss', 'dice_coefficient', 'pupil_iou']
    titles = ['Loss', 'Dice Coefficient', 'Pupil IoU']
    
    for ax, metric, title in zip(axes, metrics, titles):
        train_vals = h1.history.get(metric, []) + h2.history.get(metric, [])
        val_vals = h1.history.get(f'val_{metric}', []) + h2.history.get(f'val_{metric}', [])
        epochs = range(1, len(train_vals) + 1)
        
        ax.plot(epochs, train_vals, 'b-', label='Train')
        ax.plot(epochs, val_vals, 'r-', label='Validation')
        ax.axvline(x=len(h1.history.get(metric, [])), color='g', linestyle='--', label='Unfreeze')
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_history(history_p1, history_p2)

## 7. Evaluation & Visualization

In [None]:
# Evaluate on validation set
results = model.evaluate(val_ds)
print(f"\nValidation Results:")
print(f"  Loss: {results[0]:.4f}")
print(f"  Dice: {results[1]:.4f}")
print(f"  Pupil IoU: {results[2]:.4f}")
print(f"  Iris IoU: {results[3]:.4f}")

In [None]:
# Visualize predictions
def visualize_predictions(model, images, masks, n=6):
    """Show input images, ground truth masks, and predictions side by side."""
    indices = np.random.choice(len(images), n, replace=False)
    fig, axes = plt.subplots(n, 4, figsize=(16, 4 * n))
    
    for i, idx in enumerate(indices):
        img = images[idx]
        gt = np.argmax(masks[idx], axis=-1)
        pred = model.predict(img[np.newaxis, ...], verbose=0)[0]
        pred_class = np.argmax(pred, axis=-1)
        
        # Input image
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Input')
        axes[i, 0].axis('off')
        
        # Ground truth
        axes[i, 1].imshow(gt, cmap='viridis', vmin=0, vmax=2)
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Prediction
        axes[i, 2].imshow(pred_class, cmap='viridis', vmin=0, vmax=2)
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        # Overlay
        overlay = img.copy()
        overlay[pred_class == 1] = overlay[pred_class == 1] * 0.5 + np.array([0, 0.5, 0]) * 0.5
        overlay[pred_class == 2] = overlay[pred_class == 2] * 0.5 + np.array([0.5, 0, 0]) * 0.5
        axes[i, 3].imshow(np.clip(overlay, 0, 1))
        axes[i, 3].set_title('Overlay (green=iris, red=pupil)')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, X_val, y_val_oh)

In [None]:
# Circle fitting from segmentation mask - this is what the cloud service will use

def fit_circle_from_mask(mask, class_id):
    """
    Fit a circle to a segmentation mask region.
    
    Args:
        mask: 2D array of predicted class IDs
        class_id: which class to fit (1=iris, 2=pupil)
    
    Returns:
        dict with {x, y, radius} in pixel coordinates, or None if not found
    """
    binary = (mask == class_id).astype(np.uint8)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return None
    
    # Get largest contour
    largest = max(contours, key=cv2.contourArea)
    
    if len(largest) < 5:
        return None
    
    # Fit minimum enclosing circle
    (cx, cy), radius = cv2.minEnclosingCircle(largest)
    
    # Also compute centroid for better center estimate
    M = cv2.moments(largest)
    if M['m00'] > 0:
        cx = M['m10'] / M['m00']
        cy = M['m01'] / M['m00']
    
    # Compute effective radius from area
    area = cv2.contourArea(largest)
    eff_radius = np.sqrt(area / np.pi)
    
    return {
        'x': float(cx),
        'y': float(cy),
        'radius': float(eff_radius),
    }


def predict_circles(model, image):
    """
    Full inference pipeline: image -> segmentation -> circle fitting.
    
    Args:
        model: trained segmentation model
        image: RGB image (any size)
    
    Returns:
        dict with pupil and iris circle parameters, scaled to original image
    """
    h, w = image.shape[:2]
    
    # Preprocess
    resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    normalized = resized.astype(np.float32) / 255.0
    
    # Predict
    pred = model.predict(normalized[np.newaxis, ...], verbose=0)[0]
    pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
    
    # Fit circles
    pupil = fit_circle_from_mask(pred_mask, 2)
    iris = fit_circle_from_mask(pred_mask, 1)
    
    # Scale back to original image coordinates
    scale_x = w / IMG_SIZE
    scale_y = h / IMG_SIZE
    
    if pupil:
        pupil['x'] *= scale_x
        pupil['y'] *= scale_y
        pupil['radius'] *= (scale_x + scale_y) / 2
    
    if iris:
        iris['x'] *= scale_x
        iris['y'] *= scale_y
        iris['radius'] *= (scale_x + scale_y) / 2
    
    # Compute confidence from mask probabilities
    pupil_conf = float(np.mean(pred[..., 2][pred_mask == 2])) if np.any(pred_mask == 2) else 0.0
    iris_conf = float(np.mean(pred[..., 1][pred_mask == 1])) if np.any(pred_mask == 1) else 0.0
    
    return {
        'pupil': pupil,
        'iris': iris,
        'confidence': {
            'pupil': round(pupil_conf, 3),
            'iris': round(iris_conf, 3),
        },
        'ratio': round(pupil['radius'] / iris['radius'], 4) if pupil and iris else None,
    }


# Test the full pipeline on validation images
for i in range(3):
    idx = np.random.randint(len(X_val))
    result = predict_circles(model, (X_val[idx] * 255).astype(np.uint8))
    print(f"\nSample {i+1}: {json.dumps(result, indent=2)}")

## 8. Export Models

Export in two formats:
1. **SavedModel** - for TensorFlow Serving on Cloud Run
2. **TFLite** - optional client-side fallback

In [None]:
EXPORT_DIR = Path('/content/exported_models')
os.makedirs(EXPORT_DIR, exist_ok=True)

# 1. SavedModel for Cloud Run
saved_model_path = EXPORT_DIR / 'pupil_segnet'
model.save(str(saved_model_path))
print(f"SavedModel exported to: {saved_model_path}")

# Check size
import subprocess
size = subprocess.check_output(['du', '-sh', str(saved_model_path)]).decode().split()[0]
print(f"SavedModel size: {size}")

# 2. TFLite for potential client-side use
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()

tflite_path = EXPORT_DIR / 'pupil_segnet.tflite'
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"TFLite model size: {len(tflite_model) / 1024 / 1024:.1f} MB")
print(f"TFLite exported to: {tflite_path}")

In [None]:
# Verify exported SavedModel works
loaded_model = tf.saved_model.load(str(saved_model_path))
infer = loaded_model.signatures['serving_default']

# Test inference
test_img = X_val[0:1].astype(np.float32)
test_input = tf.constant(test_img)
output = infer(test_input)

# Get output tensor name
output_key = list(output.keys())[0]
pred = output[output_key].numpy()
print(f"SavedModel input shape: {test_input.shape}")
print(f"SavedModel output shape: {pred.shape}")
print(f"SavedModel output key: {output_key}")
print("SavedModel verification: OK")

In [None]:
# Verify TFLite model works
interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(f"TFLite input: {input_details[0]['shape']}, dtype={input_details[0]['dtype']}")
print(f"TFLite output: {output_details[0]['shape']}, dtype={output_details[0]['dtype']}")

# Run inference
interpreter.set_tensor(input_details[0]['index'], test_img.astype(np.float32))
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details[0]['index'])
print(f"TFLite prediction shape: {tflite_pred.shape}")
print("TFLite verification: OK")

In [None]:
# Measure inference speed
import time

# SavedModel speed
times = []
for _ in range(50):
    start = time.time()
    _ = model.predict(test_img, verbose=0)
    times.append(time.time() - start)
print(f"SavedModel inference: {np.mean(times)*1000:.1f}ms (avg of 50 runs)")

# TFLite speed
times = []
for _ in range(50):
    start = time.time()
    interpreter.set_tensor(input_details[0]['index'], test_img.astype(np.float32))
    interpreter.invoke()
    _ = interpreter.get_tensor(output_details[0]['index'])
    times.append(time.time() - start)
print(f"TFLite inference: {np.mean(times)*1000:.1f}ms (avg of 50 runs)")

## 9. Download Models

Download the exported models for Cloud Run deployment.

In [None]:
# Package SavedModel for download
!cd /content/exported_models && tar -czf /content/pupil_segnet_savedmodel.tar.gz pupil_segnet/
print("Packaged SavedModel for download.")

# In Colab, use files.download
try:
    from google.colab import files
    files.download('/content/pupil_segnet_savedmodel.tar.gz')
    files.download(str(tflite_path))
    print("Downloads initiated.")
except ImportError:
    print("Not in Colab. Models saved at:")
    print(f"  SavedModel: /content/pupil_segnet_savedmodel.tar.gz")
    print(f"  TFLite: {tflite_path}")

## Summary

**Model: `pupil_segnet`**
- Architecture: U-Net with MobileNetV3-Small encoder
- Input: 256x256 RGB image
- Output: 256x256x3 softmax (background, iris, pupil)
- Training: 7,000 synthetic + optional real data images
- Two-phase training with transfer learning

**Outputs:**
1. `pupil_segnet/` - TF SavedModel for Cloud Run
2. `pupil_segnet.tflite` - Quantized TFLite for optional client-side

**Next steps:**
1. Deploy SavedModel to Cloud Run (see `cloud/` directory)
2. Add real datasets for fine-tuning when available
3. Integrate cloud detection endpoint into PWA