In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import lru_cache
from pathlib import Path

import cv2
import datasets
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

import nnx.data
from nnx.data.ctspine1k.dataset import CTSpine1K, PromptAdapter
from nnx.data.data_structures import AnnotationSample, ImagePrompt, SAMSample
from nnx.data.plots import visualize_sam_sample


In [None]:
nnx.data.rng = np.random.default_rng(seed=1)

In [None]:
cache_dir = Path("/Users/alexanderdann/Documents/Privat/Code/Data/CTSpine1K")
files = [file.parent for file in cache_dir.rglob("*") if file.suffix == ".gz"]
files


In [None]:
ct_dataset = CTSpine1K(cache_dir=cache_dir)
dataset = PromptAdapter(dataset=ct_dataset)

In [None]:
sample = dataset[2149]
fig = visualize_sam_sample(sample)
fig.show()

In [None]:
import numpy as np
import cv2
from scipy.ndimage import distance_transform_edt


def create_gaussian_ramp(mask, sigma=10.0):
    """Creates a Gaussian ramp from the boundary of a binary mask.

    Args:
        mask: Binary mask (0 and 1)
        sigma: Controls width of the Gaussian transition

    Returns:
        A float image with values transitioning from 1 at the center to 0 at the boundary

    """
    # Distance transform - each pixel gets its distance to nearest boundary
    dist = distance_transform_edt(mask)

    # Create Gaussian decay from the boundary
    # Values will be 0 at boundary, increasing toward center
    ramp = 1 - np.exp(-(dist**2) / (2 * sigma**2))

    return ramp

In [None]:
ramp_mask = create_gaussian_ramp(mask, sigma=0.5)
ramp_mask[ramp_mask == mask] = 0
ramp_mask[ramp_mask <= 0.99] = 0
np.unique(1-ramp_mask)

In [None]:
alpha = 0.2

dist = distance_transform_edt(mask)
prob = dist / np.amax(dist)
prob /= alpha


prob[prob>1] = 1 
prob[prob == mask] = 0


In [None]:
import plotly.graph_objects as go
import numpy as np
from PIL import Image

# Sample points with visibility toggle state
points = [
    {"name": "Point 1", "x": 150, "y": 150, "color": "red", "visible": True},
    {"name": "Point 2", "x": 250, "y": 250, "color": "blue", "visible": True},
    {"name": "Point 3", "x": 350, "y": 150, "color": "green", "visible": False},
]

import plotly.graph_objects as go
import numpy as np
import base64
from io import BytesIO
from PIL import Image


def plot_points_on_image(img_array, points):
    # Get image dimensions
    if len(img_array.shape) == 3:
        img_height, img_width, _ = img_array.shape
    else:
        img_height, img_width = img_array.shape
    
    # Convert numpy array to base64 for Plotly
    img = Image.fromarray(img_array.astype(np.uint8))
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode()}"
    # Create figure
    fig = go.Figure()

    # Add image
    fig.add_layout_image(
        {
            "source": img_base64,
            "x": 0,
            "y": 0,
            "xref": "x",
            "yref": "y",
            "sizex": img_width,
            "sizey": img_height,
            "sizing": "stretch",
            "layer": "below",
        }
    )

    # Filter visible points
    visible_points = [p for p in points if p["visible"]]

    # Add points
    if visible_points:
        fig.add_trace(
            go.Scatter(
                x=[p["x"] for p in visible_points],
                y=[p["y"] for p in visible_points],
                mode="markers",
                marker=dict(color=[p["color"] for p in visible_points], size=10),
                text=[p["name"] for p in visible_points],
                hoverinfo="text",
            )
        )

    # Set layout
    fig.update_layout(
        width=img_width,
        height=img_height,
        xaxis=dict(range=[0, img_width], visible=False),
        yaxis=dict(
            range=[img_height, 0],  # Inverted for image coordinates
            visible=False,
            scaleanchor="x",
            scaleratio=1,
        ),
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return fig

# SAM2

In [None]:
import os

# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import numpy as np
import torch
from PIL import Image

device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
#elif torch.backends.mps.is_available():
#    device = torch.device("mps")

print(f"using device: {device}")


In [None]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [None]:
image = np.concatenate([inputs[:, :, None] for _ in range(3)], axis=-1).astype(np.float32)

In [None]:
predictor.set_image(image)

In [None]:
import matplotlib.pyplot as plt
np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [None]:
input_point = np.array([[240, 125]])
input_label = np.array([1])

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

In [None]:
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]

In [None]:
show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=False)