Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: CI

on:
workflow_dispatch:
push:
branches: [main, master]
pull_request:
branches: [main, master]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -e .
pip install -r requirements.txt

- name: Detect GPU availability
id: detect-gpu
shell: bash
run: |
HAS_GPU=$(python -c "import torch; print('true' if torch.cuda.is_available() else 'false')")
echo "has_gpu=$HAS_GPU" >> "$GITHUB_OUTPUT"
echo "GPU available: $HAS_GPU"

- name: Pre-cache model checkpoint
if: steps.detect-gpu.outputs.has_gpu == 'true'
run: |
mkdir -p ~/.cache/torch/hub/checkpoints
if [ ! -f ~/.cache/torch/hub/checkpoints/sharp_2572gikvuh.pt ]; then
curl -L -o ~/.cache/torch/hub/checkpoints/sharp_2572gikvuh.pt https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt
fi

- name: Run tests
run: |
pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@ cython_debug/
.DS_STORE
*.pt
.aider*
.vscode/
Empty file added tests/__init__.py
Empty file.
Binary file added tests/data/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def test_import():
"""Test that the package can be imported."""
import sharp
assert sharp is not None


def test_cli_import():
"""Test that CLI modules can be imported."""
from sharp import cli
assert cli is not None

73 changes: 73 additions & 0 deletions tests/test_gaussians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import tempfile
from pathlib import Path

import numpy as np
import torch

from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply


def test_gaussians3d_creation():
"""Test creating a Gaussians3D object."""
num_gaussians = 100
gaussians = Gaussians3D(
mean_vectors=torch.randn(1, num_gaussians, 3),
singular_values=torch.rand(1, num_gaussians, 3),
quaternions=torch.randn(1, num_gaussians, 4),
colors=torch.rand(1, num_gaussians, 3),
opacities=torch.rand(1, num_gaussians),
)

assert gaussians.mean_vectors.shape == (1, num_gaussians, 3)
assert gaussians.singular_values.shape == (1, num_gaussians, 3)
assert gaussians.quaternions.shape == (1, num_gaussians, 4)
assert gaussians.colors.shape == (1, num_gaussians, 3)
assert gaussians.opacities.shape == (1, num_gaussians)


def test_gaussians3d_to_device():
"""Test moving Gaussians3D to a device."""
num_gaussians = 50
gaussians = Gaussians3D(
mean_vectors=torch.randn(1, num_gaussians, 3),
singular_values=torch.rand(1, num_gaussians, 3),
quaternions=torch.randn(1, num_gaussians, 4),
colors=torch.rand(1, num_gaussians, 3),
opacities=torch.rand(1, num_gaussians),
)

gaussians_cpu = gaussians.to(torch.device("cpu"))
assert gaussians_cpu.mean_vectors.device.type == "cpu"
assert gaussians_cpu.singular_values.device.type == "cpu"


def test_scene_metadata():
"""Test creating SceneMetaData."""
metadata = SceneMetaData(
focal_length_px=1000.0,
resolution_px=(1920, 1080),
color_space="linearRGB",
)

assert metadata.focal_length_px == 1000.0
assert metadata.resolution_px == (1920, 1080)
assert metadata.color_space == "linearRGB"


def test_save_ply(tmp_path):
"""Test saving Gaussians to PLY file."""
num_gaussians = 10
gaussians = Gaussians3D(
mean_vectors=torch.randn(1, num_gaussians, 3),
singular_values=torch.rand(1, num_gaussians, 3),
quaternions=torch.randn(1, num_gaussians, 4),
colors=torch.rand(1, num_gaussians, 3),
opacities=torch.rand(1, num_gaussians),
)

output_path = tmp_path / "test_gaussians.ply"
save_ply(gaussians, f_px=1000.0, image_shape=(1920, 1080), path=output_path)

assert output_path.exists()
assert output_path.stat().st_size > 0

69 changes: 69 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import tempfile
from pathlib import Path

import numpy as np
import pytest
from PIL import Image

from sharp.utils import io


def test_load_rgb(tmp_path):
"""Test loading an RGB image."""
test_image_path = tmp_path / "test.jpg"
img_array = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
img = Image.fromarray(img_array)
img.save(test_image_path)

image, icc_profile, f_px = io.load_rgb(test_image_path)

assert image.shape == (100, 100, 3)
assert image.dtype == np.uint8
assert f_px > 0
assert isinstance(icc_profile, (list, type(None)))


def test_get_supported_image_extensions():
"""Test getting supported image extensions."""
extensions = io.get_supported_image_extensions()
assert isinstance(extensions, list)
assert len(extensions) > 0
assert ".jpg" in extensions or ".JPG" in extensions
assert ".png" in extensions or ".PNG" in extensions


def test_get_supported_video_extensions():
"""Test getting supported video extensions."""
extensions = io.get_supported_video_extensions()
assert isinstance(extensions, list)
assert ".mp4" in extensions or ".MP4" in extensions


def test_save_image(tmp_path):
"""Test saving an image."""
test_image_path = tmp_path / "test_output.jpg"
img_array = np.random.randint(0, 255, (50, 50, 3), dtype=np.uint8)

io.save_image(img_array, test_image_path)

assert test_image_path.exists()
loaded_image, _, _ = io.load_rgb(test_image_path)
assert loaded_image.shape == (50, 50, 3)


def test_convert_focallength():
"""Test focal length conversion."""
f_px = io.convert_focallength(1920, 1080, 30.0)
assert f_px > 0
assert isinstance(f_px, float)


def test_load_example_image():
"""Test loading the example image from tests/data directory."""
example_path = Path(__file__).parent / "data" / "example.jpg"
if example_path.exists():
image, _, f_px = io.load_rgb(example_path)
assert image.ndim == 3
assert image.shape[2] == 3
assert f_px > 0

75 changes: 75 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
import pytest
import torch
from pathlib import Path

from sharp.cli.predict import predict_image
from sharp.models import PredictorParams, create_predictor
from sharp.utils.gaussians import Gaussians3D


@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.mps.is_available(),
reason="Requires CUDA or MPS for model inference",
)
def test_predict_image_with_model(tmp_path):
"""Test predict_image function with a real model checkpoint."""
example_path = Path(__file__).parent / "data" / "example.jpg"
if not example_path.exists():
pytest.skip("Test image not found")

from sharp.utils import io

image, _, f_px = io.load_rgb(example_path)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

# Use the pre-cached model checkpoint
checkpoint_path = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" / "sharp_2572gikvuh.pt"
if not checkpoint_path.exists():
pytest.skip("Model checkpoint not found in cache")

try:
predictor = create_predictor(PredictorParams(checkpoint_path=str(checkpoint_path)))
predictor.eval()
predictor.to(device)

gaussians = predict_image(predictor, image, f_px, device)

assert isinstance(gaussians, Gaussians3D)
assert gaussians.mean_vectors.shape[0] == 1
assert gaussians.mean_vectors.shape[2] == 3
assert gaussians.colors.shape[2] == 3
assert gaussians.opacities.shape[1] > 0
except Exception as e:
pytest.skip(f"Model inference failed (likely missing checkpoint): {e}")


def test_predict_image_signature():
"""Test that predict_image function has correct signature."""
import inspect

sig = inspect.signature(predict_image)
params = sig.parameters

assert "predictor" in params
assert "image" in params
assert "f_px" in params
assert "device" in params


def test_create_predictor():
"""Test creating a predictor model."""
params = PredictorParams()
predictor = create_predictor(params)

assert predictor is not None
assert hasattr(predictor, "eval")
assert hasattr(predictor, "to")


def test_predictor_params():
"""Test PredictorParams creation."""
params = PredictorParams()
assert params is not None