# maskgen Demo

This notebook demonstrates inference with `MaskGenerator` using all three strategies:
**whole**, **tile**, and **downsample**.

Uses a real microscopy image from `data/generator_test/` and real model weights
downloaded from GitHub Releases. Falls back to synthetic data if the image isn't available.

In [None]:
# Uncomment to install in Colab:
# !pip install "maskgen @ git+https://github.com/Katsurado/bernlab_maskgen.git"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from maskgen import MaskGenerator, download_weights

# Load real model weights from GitHub Releases
gen = MaskGenerator(download_weights())

# Load real microscopy image, or fall back to synthetic
image_path = Path("../data/generator_test/Picture1.png")
if image_path.exists():
    test_image = np.array(Image.open(image_path).convert("RGB"))
    print(f"Loaded real image: {image_path} ({test_image.shape})")
else:
    test_image = np.random.randint(0, 255, (256, 384, 3), dtype=np.uint8)
    print(f"Image not found at {image_path}, using synthetic ({test_image.shape})")

In [None]:
# Generate masks with all 3 strategies
mask_whole = gen.generate(test_image, strategy={"name": "whole"})
mask_tile = gen.generate(test_image, strategy={"name": "tile", "tile_size": 512, "overlap": 64})
mask_down = gen.generate(test_image, strategy={"name": "downsample", "max_dim": 256})

# Display results
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(test_image)
axes[0].set_title("Input")

axes[1].imshow(mask_whole, cmap="gray")
axes[1].set_title("Whole")

axes[2].imshow(mask_tile, cmap="gray")
axes[2].set_title("Tile")

axes[3].imshow(mask_down, cmap="gray")
axes[3].set_title("Downsample")

for ax in axes:
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import os

# generate_and_save demo
output_path = "demo_output/mask.png"
mask = gen.generate_and_save(
    test_image,
    output_path,
    strategy={"name": "tile", "tile_size": 512, "overlap": 64},
)

print(f"Saved mask to: {output_path}")
print(f"File exists: {os.path.exists(output_path)}")
print(f"Mask size: {mask.size}")