# Tooth Segmentation with SAM 3

This notebook segments individual teeth from dental CT scan images using SAM 3 and numbers them automatically.

## 1. Import Libraries

In [None]:
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pathlib import Path
import cv2
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

%matplotlib inline

## 2. Define Tooth Segmentation Class

In [None]:
class ToothSegmenter:
    def __init__(self, checkpoint_path=r"D:\CDAC\Playground\Sam_3\Models_saved\sam3.pt"):
        """Initialize SAM 3 model for tooth segmentation."""
        print("Initializing SAM3 model for tooth segmentation...")
        self.model = build_sam3_image_model(checkpoint_path=checkpoint_path)
        self.processor = Sam3Processor(self.model)
        print("Model initialized.")
    
    def segment_teeth(self, image_path, prompt="tooth"):
        """Segment all teeth in an image."""
        print(f"Loading image: {image_path}")
        image = Image.open(image_path).convert("RGB")
        
        # Process with SAM 3
        inference_state = self.processor.set_image(image)
        output = self.processor.set_text_prompt(state=inference_state, prompt=prompt)
        
        return {
            "image": image,
            "masks": output["masks"],
            "boxes": output["boxes"],
            "scores": output["scores"]
        }
    
    def visualize_segmentation(self, result, output_path=None, show_numbers=True):
        """Visualize segmented teeth with bounding boxes and numbers."""
        image = result["image"]
        masks = result["masks"]
        boxes = result["boxes"]
        scores = result["scores"]
        
        # Create figure
        fig, ax = plt.subplots(1, figsize=(12, 12))
        ax.imshow(image)
        
        # Draw each tooth
        num_teeth = len(masks)
        colors = plt.cm.rainbow(np.linspace(0, 1, num_teeth))
        
        for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
            # Convert mask to numpy
            mask_np = mask.cpu().numpy()
            
            # Show mask with transparency
            color = colors[i][:3]
            h, w = mask_np.shape[-2:]
            mask_image = mask_np.reshape(h, w, 1) * np.array([*color, 0.4]).reshape(1, 1, -1)
            ax.imshow(mask_image)
            
            # Draw bounding box
            box_np = box.cpu().numpy()
            x0, y0, x1, y1 = box_np
            rect = patches.Rectangle(
                (x0, y0), x1-x0, y1-y0,
                linewidth=2, edgecolor=color, facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add tooth number and score
            if show_numbers:
                ax.text(
                    x0, y0-5,
                    f"Tooth #{i+1}\n{score:.2f}",
                    bbox=dict(facecolor=color, alpha=0.7),
                    fontsize=10,
                    color='white',
                    weight='bold'
                )
        
        ax.set_title(f"Detected {num_teeth} teeth", fontsize=16, weight='bold')
        ax.axis('off')
        plt.tight_layout()
        
        # Save if path provided
        if output_path:
            plt.savefig(output_path, dpi=150, bbox_inches='tight')
            print(f"Saved visualization to: {output_path}")
        
        plt.show()
    
    def save_individual_masks(self, result, output_folder):
        """Save each tooth mask as a separate image."""
        output_path = Path(output_folder)
        output_path.mkdir(parents=True, exist_ok=True)
        
        masks = result["masks"]
        image = result["image"]
        image_np = np.array(image)
        
        for i, mask in enumerate(masks):
            # Convert mask to binary
            mask_np = mask.cpu().numpy().astype(np.uint8) * 255
            
            # Create masked image (tooth only)
            masked_img = cv2.bitwise_and(image_np, image_np, mask=mask_np)
            
            # Save mask and masked image
            mask_path = output_path / f"tooth_{i+1:02d}_mask.png"
            tooth_path = output_path / f"tooth_{i+1:02d}_image.png"
            
            cv2.imwrite(str(mask_path), mask_np)
            cv2.imwrite(str(tooth_path), cv2.cvtColor(masked_img, cv2.COLOR_RGB2BGR))
        
        print(f"Saved {len(masks)} individual tooth masks to: {output_folder}")

## 3. Initialize the Model

In [None]:
segmenter = ToothSegmenter()

## 4. Process a Single Image

In [None]:
# Path to your dental image
image_path = r"C:\Users\cdac\Downloads\results\ToothFairy_teeth_001_0000_slice_130.png"

# Segment teeth
result = segmenter.segment_teeth(image_path, prompt="tooth")

print(f"\nFound {len(result['masks'])} teeth")
print(f"Confidence scores: {result['scores']}")

## 5. Visualize Results

In [None]:
# Display the segmentation with numbered teeth
segmenter.visualize_segmentation(
    result,
    output_path=r"C:\Users\cdac\Downloads\segmentation_results\visualization.png",
    show_numbers=True
)

## 6. Save Individual Tooth Masks

In [None]:
# Save each tooth as a separate image
segmenter.save_individual_masks(
    result,
    output_folder=r"C:\Users\cdac\Downloads\segmentation_results\individual_teeth"
)

## 7. Process Multiple Images (Optional)

Uncomment and run the code below to process an entire folder of images.

In [None]:
# input_folder = r"C:\Users\cdac\Downloads\results"
# output_folder = r"C:\Users\cdac\Downloads\all_segmentation_results"

# # Find all PNG images
# from pathlib import Path
# image_files = list(Path(input_folder).glob("*.png"))

# print(f"Found {len(image_files)} images to process")

# # Process each image
# for img_file in image_files[:5]:  # Process first 5 images as example
#     print(f"\nProcessing: {img_file.name}")
#     result = segmenter.segment_teeth(str(img_file), prompt="tooth")
#     print(f"  Detected {len(result['masks'])} teeth")
    
#     if len(result['masks']) > 0:
#         output_path = Path(output_folder)
#         output_path.mkdir(parents=True, exist_ok=True)
        
#         vis_path = output_path / f"{img_file.stem}_segmented.png"
#         segmenter.visualize_segmentation(result, output_path=vis_path)