# maskgen — Colab Training Notebook

Train the maskgen segmentation model on Google Colab with Google Drive storage.

In [None]:
# Cell 1: Install maskgen with training extras
!pip install "maskgen[train] @ git+https://github.com/<user>/maskgen.git"

In [None]:
# Cell 2: Mount Google Drive and set paths
from google.colab import drive
drive.mount("/content/drive")

DATA_ROOT = "/content/drive/MyDrive/maskgen/data"
CHECKPOINT_DIR = "/content/drive/MyDrive/maskgen/checkpoints"

In [None]:
# Cell 3: Define training config
config = {
    # Model architecture
    "channels": 128,
    "layers": [1, 2, 4, 2],
    "stochastic_depth": 0.1,
    "ema": 0.9999,

    # Data
    "img_size": 512,
    "crop_per_img": 4,

    # Training
    "batch_size": 4,
    "lr": 1e-3,
    "weight_decay": 1e-2,
    "warmup_epoch": 5,
    "scheduler_params": {"T-max": 100},
    "epochs": 100,
    "gradient_clip": 1.0,

    # W&B (optional — set use_wandb=False in train() to skip)
    # "wandb_key": "your-api-key",
    # "wandb_params": {
    #     "entity": "your-entity",
    #     "project": "maskgen",
    #     "name": "colab-run-1",
    #     "reinit": True,
    # },
}

In [None]:
# Cell 4: Train
from maskgen.train import train

train(
    config,
    checkpoint_dir=CHECKPOINT_DIR,
    data_root=DATA_ROOT,
    use_wandb=False,  # set to True and fill in wandb_params above to enable
)

In [None]:
# Cell 5: Test inference with trained model
from maskgen import MaskGenerator
import os

weights_path = os.path.join(CHECKPOINT_DIR, "best.pth")
gen = MaskGenerator(weights_path)

# Generate a mask from a test image
test_image = os.path.join(DATA_ROOT, "test/images")  # adjust to actual image path
mask = gen.generate(test_image, strategy={"name": "tile", "tile_size": 512, "overlap": 64})
mask

In [None]:
# Cell 6: (Alternative) Download pre-trained weights from GitHub Releases
from maskgen import MaskGenerator, download_weights

# Update the URL after uploading weights:
#   gh release create v0.1.0 model/best.pth --title "v0.1.0"
# path = download_weights(url="https://github.com/<user>/maskgen/releases/latest/download/best.pth")
# gen = MaskGenerator(path)