# maskgen Demo

This notebook demonstrates inference with `MaskGenerator` using all three strategies:
**whole**, **tile**, and **downsample**.

A synthetic random RGB image is used so the notebook runs without real data.
Replace it with a real microscopy image to see meaningful masks.

In [None]:
# Uncomment to install in Colab:
# !pip install "maskgen @ git+https://github.com/Katsurado/bernlab_maskgen.git"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from maskgen import MaskGenerator
from maskgen.model import Net
import torch

# --- Create a mock generator (no real weights needed for demo) ---
# To use real weights, replace this block with:
#   from maskgen import download_weights
#   gen = MaskGenerator(download_weights())
# or point to a local checkpoint:
#   gen = MaskGenerator("path/to/best.pth")

class DemoGenerator(MaskGenerator):
    """MaskGenerator with random weights for demo purposes."""
    def __init__(self):
        self.config = self.DEFAULT_CONFIG
        self.device = torch.device("cpu")
        self.model = Net(self.config["channels"], self.config)
        self.model.eval()

gen = DemoGenerator()

# Create a synthetic test image (random RGB)
h, w = 256, 384
test_image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
print(f"Test image shape: {test_image.shape}")

In [None]:
# Generate masks with all 3 strategies
mask_whole = gen.generate(test_image, strategy={"name": "whole"})
mask_tile = gen.generate(test_image, strategy={"name": "tile", "tile_size": 128, "overlap": 16})
mask_down = gen.generate(test_image, strategy={"name": "downsample", "max_dim": 128})

# Display results
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(test_image)
axes[0].set_title("Input")

axes[1].imshow(mask_whole, cmap="gray")
axes[1].set_title("Whole")

axes[2].imshow(mask_tile, cmap="gray")
axes[2].set_title("Tile")

axes[3].imshow(mask_down, cmap="gray")
axes[3].set_title("Downsample")

for ax in axes:
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import os

# generate_and_save demo
output_path = "demo_output/mask.png"
mask = gen.generate_and_save(
    test_image,
    output_path,
    strategy={"name": "tile", "tile_size": 128, "overlap": 16},
)

print(f"Saved mask to: {output_path}")
print(f"File exists: {os.path.exists(output_path)}")
print(f"Mask size: {mask.size}")