# SAM Inference Server (FastAPI + ngrok)

This notebook mounts Google Drive, loads a SAM checkpoint, verifies CUDA, and starts a FastAPI inference server exposed via ngrok.


In [1]:
# 1) Mount Google Drive
from google.colab import drive

drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
# 2) Configure checkpoint path (set DRIVE_SUBPATH env var if needed)
import os

DRIVE_SUBPATH = 'soil_microCT_images/drive_scripts/napari_loader/checkpoints'
checkpoint_path = f"/content/drive/MyDrive/{DRIVE_SUBPATH}/sam_vit_h_4b8939.pth"
print("Checkpoint path:", checkpoint_path)


Checkpoint path: /content/drive/MyDrive/soil_microCT_images/drive_scripts/napari_loader/checkpoints/sam_vit_h_4b8939.pth


In [3]:
# 3) Install dependencies
!pip -q install fastapi uvicorn pyngrok git+https://github.com/facebookresearch/segment-anything.git


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone


In [4]:
# 4) Verify CUDA availability
import torch

assert torch.cuda.is_available() is True, "CUDA is not available. Please switch to a GPU runtime."
print("CUDA is available.")


CUDA is available.


In [5]:
# 5) Load SAM model
import os
import numpy as np
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor

if not os.path.isfile(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
sam.to(device="cuda")
predictor = SamPredictor(sam)
print("SAM model loaded.")


SAM model loaded.


In [6]:
# 6) Define FastAPI app
import base64
import io
import threading
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()
_predictor_lock = threading.Lock()

class PredictRequest(BaseModel):
    image: str  # base64-encoded PNG
    box: list   # [x0, y0, x1, y1]

@app.get("/health")
def health():
    import torch
    return {
        "torch_available": True,
        "torch_version": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device_count": torch.cuda.device_count(),
    }

@app.post("/predict")
def predict(req: PredictRequest):
    image_bytes = base64.b64decode(req.image)
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    image_np = np.array(image)

    box = np.array(req.box, dtype=np.float32)
    with _predictor_lock:
        predictor.set_image(image_np)
        masks, _, _ = predictor.predict(box=box, multimask_output=False)

    mask = (masks[0] * 255).astype(np.uint8)
    mask_img = Image.fromarray(mask)
    buffer = io.BytesIO()
    mask_img.save(buffer, format="PNG")
    mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return {
        "mask": mask_b64,
        "mask_metadata": {
            "shape": list(mask.shape),
            "dtype": str(mask.dtype),
        },
    }


In [7]:
# 7) Start FastAPI server in the background
import threading
import uvicorn

def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

thread = threading.Thread(target=run_server, daemon=True)
thread.start()
print("Server started on port 8000.")


Server started on port 8000.


In [8]:
# 8) Expose server via ngrok
import os
from pyngrok import ngrok
from google.colab import userdata

ngrok_token = userdata.get('NGROK')
if not ngrok_token:
    raise ValueError('NGROK_AUTHTOKEN environment variable is required.')
ngrok.set_auth_token(ngrok_token)

public_url = ngrok.connect(8000, "http")
print("Public URL:", public_url)


Public URL: NgrokTunnel: "https://monogenetic-nonmalarial-nia.ngrok-free.dev" -> "http://localhost:8000"
