# ContrAI — Contrail Detection & Visualization Demo

This notebook shows how to:
1. **Run contrail detection** with ContrAI
2. **Visualize** the results (original image, mask, and overlay)

You can either:
- Use local file paths (recommended for repo examples), **or**
- **Upload any image** right inside the notebook and run detection on it.

---
**Requirements**
- `contrai` (installed from PyPI)
- Python ≥ 3.11

If you're viewing this in GitHub, open it in a Jupyter environment to run the cells.

In [None]:
# %% [markdown]
# ## 0) Install & Imports
# If ContrAI isn't installed in your environment, uncomment the next line.
# !pip install -U contrai

from pathlib import Path
from typing import Optional, Tuple
import os
import numpy as np
import matplotlib.pyplot as plt

from contrai.inference import predict  # core API used below
try:
    # Optional GOES-16 helper for generating Ash RGBs
    from contrai.data.goes16 import generate_ash_rgb_for_datetime
    HAS_GOES = True
except Exception:
    HAS_GOES = False

print('ContrAI available. GOES helpers:', HAS_GOES)

## 1) Configure paths (local example)
Update the paths below to point to your **model weights** and an **input image**. The output overlay will be saved if `output_path` is not `None`.

In [None]:
# %%
MODEL_PATH = Path("/path/to/your/weights/model.pth")  # <- change me
IMAGE_PATH = Path("/path/to/an/image.png")            # <- change me
OUTPUT_PATH = Path("./outputs/contrail_overlay.png")  # set to None to skip saving

# Basic safety checks (comment out if you want to set them later)
if MODEL_PATH and MODEL_PATH != Path("/path/to/your/weights/model.pth"):
    assert MODEL_PATH.exists(), f"Model weights not found: {MODEL_PATH}"
if IMAGE_PATH and IMAGE_PATH != Path("/path/to/an/image.png"):
    assert IMAGE_PATH.exists(), f"Image not found: {IMAGE_PATH}"

OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) if OUTPUT_PATH else None
print('Paths OK')

## 2) Run detection
This calls `contrai.inference.predict(...)` and returns the **overlay**, **mask**, and **original image** arrays.

In [None]:
# %%
def run_detection(model_path: Path, image_path: Path, output: Optional[Path] = None,
                  tile_h: int = 256, tile_w: int = 256, stride: int = 128,
                  threshold: float = 0.10, log_level: str = "INFO"):
    overlay, mask, image = predict(
        model_path=model_path,
        image_path=image_path,
        tile_h=tile_h,
        tile_w=tile_w,
        stride=stride,
        threshold=threshold,
        output=output,
        show=False,
        log_level=log_level,
    )
    return overlay, mask, image

if MODEL_PATH.exists() and IMAGE_PATH.exists():
    overlay, mask, image = run_detection(MODEL_PATH, IMAGE_PATH, OUTPUT_PATH)
    print('Detection complete.')
else:
    overlay = mask = image = None
    print('Set MODEL_PATH and IMAGE_PATH to valid files to run this cell.')

## 3) Visualize results
We show (a) the original image, (b) the probability/binary mask, and (c) the overlay.

In [None]:
# %%
def show_results(image: np.ndarray, mask: np.ndarray, overlay: np.ndarray):
    assert image is not None and mask is not None and overlay is not None, 'Run detection first.'
    
    # Show original
    plt.figure(figsize=(8, 6))
    plt.title('Original Image')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    # Show mask (as-is; you can threshold elsewhere if needed)
    plt.figure(figsize=(8, 6))
    plt.title('Contrail Mask')
    plt.imshow(mask)
    plt.axis('off')
    plt.show()

    # Show overlay (already blended)
    plt.figure(figsize=(8, 6))
    plt.title('Overlay')
    plt.imshow(overlay)
    plt.axis('off')
    plt.show()

if image is not None:
    show_results(image, mask, overlay)

## 4) (Option A) Generate GOES-16 Ash RGB (optional)
If you have GOES data configured, you can create an Ash RGB frame and then run detection on it.

In [None]:
# %%
if HAS_GOES:
    # Example: adjust date/time to your needs
    png_path, rgb, dt = generate_ash_rgb_for_datetime(2023, 1, 30, "1600")
    print('Ash RGB saved to:', png_path)
    print('Using scan time:', dt)
    
    overlay2, mask2, image2 = run_detection(MODEL_PATH, Path(png_path), OUTPUT_PATH)
    show_results(image2, mask2, overlay2)
else:
    print('GOES-16 helpers are unavailable in this environment.')

## 5) (Option B) Upload any image and detect
Use the widget below to upload a local image, then run detection on it.

In [None]:
# %%
import io
from PIL import Image
import ipywidgets as widgets
from IPython.display import display

uploader = widgets.FileUpload(accept='image/*', multiple=False)
display(widgets.HTML('<b>Upload an image (PNG/JPG):</b>'))
display(uploader)

def on_click(_):
    if not uploader.value:
        print('No file uploaded yet.')
        return
    # Get uploaded file
    item = list(uploader.value.values())[0]
    data = item['content']
    name = item['metadata']['name']

    # Save to a temp path inside ./uploads
    up_dir = Path('./uploads'); up_dir.mkdir(parents=True, exist_ok=True)
    tmp_path = up_dir / name
    with open(tmp_path, 'wb') as f:
        f.write(data)
    print('Saved upload to:', tmp_path)

    # Run detection using the already-set MODEL_PATH
    if not MODEL_PATH.exists():
        print('MODEL_PATH is invalid. Set it in the paths cell.')
        return
    ov, mk, im = run_detection(MODEL_PATH, tmp_path, OUTPUT_PATH)
    show_results(im, mk, ov)

button = widgets.Button(description='Run detection on upload')
button.on_click(on_click)
display(button)

---
### Notes
- To ignore heavyweight artifacts (weights, builds) in Git, add to `.gitignore`:
  ```
  **/weights/
  *.egg-info/
  build/
  dist/
  __pycache__/
  ```
- `predict(...)` parameters you may want to tune: `tile_h`, `tile_w`, `stride`, `threshold`.