<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/SAM_Point_Prompt_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install git+https://github.com/facebookresearch/segment-anything.git
!pip -q install opencv-python-headless matplotlib
!pip -q install bitsandbytes transformers accelerate peft
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Download Kitti Sementation Dataset

In [None]:
!gdown 1EB9JSbcQIqjwI5wc8idMWEjmL3CRhh2D
!unzip -q kitti_autonomous_driving_seg.zip

Prepare Dataset and Dataloader:

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Class names for visualization
class_names = [
    "unlabeled", "ego vehicle", "rectification border", "out of roi", "static", "dynamic", "ground", "road",
    "sidewalk", "parking", "rail track", "building", "wall", "fence", "guard rail", "bridge", "tunnel", "pole",
    "polegroup", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck",
    "bus", "caravan", "trailer", "train", "motorcycle", "bicycle"
]

class KITTISegmentationDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_files = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])
        self.mask_files = sorted([f for f in os.listdir(self.masks_dir) if f.endswith('.png')])

        # Define resize transformation for both image and mask
        self.target_size=(256, 512)
        self.image_transform = transforms.Compose([
            transforms.Resize(self.target_size),  # Resize image
            transforms.ToTensor()
        ])


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        mask_path = os.path.join(self.masks_dir, self.mask_files[idx])

        # Load image and mask
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Load as grayscale

        image = self.image_transform(image)
        mask = mask.resize((self.target_size[1], self.target_size[0]), resample=Image.NEAREST)

        # Convert mask to tensor without scaling
        mask = torch.from_numpy(np.array(mask)).long()
        return image, mask

# Initialize dataset and dataloader
dataset_train = KITTISegmentationDataset(root_dir='kitti_autonomous_driving_seg/train')
dataset_test = KITTISegmentationDataset(root_dir='kitti_autonomous_driving_seg/test')
dataloader_train = DataLoader(dataset_train, batch_size=12, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=12, shuffle=True)
print('training sample:', len(dataset_train), 'testing sample:', len(dataset_test))
print('1st sample shape:',dataset_train[0][0].shape, 'classes inside masks:', dataset_train[0][1].unique())

#Validating point prompt with random point of the classes: sky or road

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import matplotlib.pyplot as plt

# Load SAM model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam_checkpoint = "sam_vit_h_4b8939.pth"  # Download from SAM GitHub release page
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device)
predictor = SamPredictor(sam)

# Class names for visualization
class_names = [
    "unlabeled", "ego vehicle", "rectification border", "out of roi", "static", "dynamic", "ground", "road",
    "sidewalk", "parking", "rail track", "building", "wall", "fence", "guard rail", "bridge", "tunnel", "pole",
    "polegroup", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck",
    "bus", "caravan", "trailer", "train", "motorcycle", "bicycle"
]

# Function to extract a random point prompt from the target class in the mask
def get_random_point_from_class(mask, class_name):
    class_index = class_names.index(class_name)
    mask_np = mask.cpu().numpy()

    # Find coordinates of the target class
    target_points = np.argwhere(mask_np == class_index)
    if len(target_points) == 0:
        return None, None

    # Select a random point from the target class region
    random_idx = np.random.choice(len(target_points))
    y, x = target_points[random_idx]

    point = np.array([[x, y]])  # Format for SAM: [[x, y]]
    label = np.array([1])  # 1 = Foreground
    return point, label

# Load dataset and dataloader (for validation)
dataset_test = KITTISegmentationDataset(root_dir='kitti_autonomous_driving_seg/test')
dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False)

# Get a batch of data from the validation dataloader
images, masks = next(iter(dataloader_test))

# Extract image and mask
image = images[0].permute(1, 2, 0).cpu().numpy()  # Convert to HWC format for SAM
mask = masks[0]

# 🔥 Specify the target class ('sky' or 'road')
target_class = "road"  # Change to "sky" if needed
point, label = get_random_point_from_class(mask, target_class)

if point is not None:
    # Run SAM prediction
    predictor.set_image(image)
    masks, _, _ = predictor.predict(
        point_coords=point,
        point_labels=label,
        multimask_output=False  # Get a single best mask
    )

    # ---- Visualization ----
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Input Image
    axes[0].imshow(image)
    axes[0].set_title('Input Image')
    axes[0].axis('off')

    # Ground Truth Mask with Random Point Overlay
    axes[1].imshow(mask.cpu().numpy(), cmap='gray')
    axes[1].scatter(
        point[0][0], point[0][1],  # x, y coordinates
        color='red', s=100, marker='o', edgecolors='white'
    )
    axes[1].set_title(f'Ground Truth Mask with "{target_class}" Random Point')
    axes[1].axis('off')

    # Predicted Mask from SAM (WITHOUT Point Overlay)
    axes[2].imshow(image)
    axes[2].imshow(masks[0], alpha=0.5, cmap='viridis')
    axes[2].set_title(f'Predicted Mask for "{target_class}"')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

else:
    print(f"No valid points found for class '{target_class}'.")


#SAM LoRA finetuning for sky and road only with point prompt

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import random

# Class names for visualization
class_names = [
    "unlabeled", "ego vehicle", "rectification border", "out of roi", "static", "dynamic", "ground", "road",
    "sidewalk", "parking", "rail track", "building", "wall", "fence", "guard rail", "bridge", "tunnel", "pole",
    "polegroup", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck",
    "bus", "caravan", "trailer", "train", "motorcycle", "bicycle"
]

# Class IDs for sky and road
sky_id = class_names.index("sky")
road_id = class_names.index("road")

class KITTISegmentationDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_files = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])
        self.mask_files = sorted([f for f in os.listdir(self.masks_dir) if f.endswith('.png')])

        # Define resize transformation for both image and mask
        self.target_size = (256, 512)
        self.image_transform = transforms.Compose([
            transforms.Resize(self.target_size),  # Resize image
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        mask_path = os.path.join(self.masks_dir, self.mask_files[idx])

        # Load image and mask
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Load as grayscale

        image = self.image_transform(image)
        mask = mask.resize((self.target_size[1], self.target_size[0]), resample=Image.NEAREST)

        # Convert mask to numpy array
        mask = np.array(mask)

        # Filter mask to only include sky and road classes
        mask_filtered = np.zeros_like(mask)
        mask_filtered[mask == sky_id] = 1  # Sky class
        mask_filtered[mask == road_id] = 2  # Road class

        # Convert mask to tensor
        mask_filtered = torch.from_numpy(mask_filtered).long()

        # Generate random point prompt for sky or road class
        target_class = random.choice([1, 2])  # 1: Sky, 2: Road
        target_mask = (mask_filtered == target_class).numpy()

        # Get random point within the target class region
        if np.any(target_mask):
            y, x = np.where(target_mask)
            random_idx = np.random.choice(len(y))
            point = np.array([[x[random_idx], y[random_idx]]])  # Point in (x, y) format
            label = np.array([1])  # Label 1 for foreground
        else:
            point = np.array([[-1, -1]])  # Invalid point
            label = np.array([-1])  # Invalid label

        return image, mask_filtered, point, label, target_class

# Initialize dataset and dataloader
dataset_train = KITTISegmentationDataset(root_dir='kitti_autonomous_driving_seg/train')
dataset_test = KITTISegmentationDataset(root_dir='kitti_autonomous_driving_seg/test')
dataloader_train = DataLoader(dataset_train, batch_size=12, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=12, shuffle=True)

# Load SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
predictor = SamPredictor(sam)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam.to(device)

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(sam.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    sam.train()
    for images, masks, points, labels, target_classes in dataloader_train:
        images, masks = images.to(device), masks.to(device)
        points, labels = points.to(device), labels.to(device)

        # Initialize a list to store predicted masks
        masks_pred_list = []

        # Process each image in the batch individually
        for i in range(images.shape[0]):
            # Get the i-th image and corresponding point/label
            image = images[i].cpu().numpy().transpose(1, 2, 0)  # Convert to [H, W, C]
            point = points[i].cpu().numpy()
            label = labels[i].cpu().numpy()

            # Set the image in the predictor
            predictor.set_image(image)

            # Predict the mask using the point prompt
            masks_pred, _, _ = predictor.predict(
                point_coords=point,
                point_labels=label,
                multimask_output=False,
            )

            # Convert the predicted mask to logits
            masks_pred = masks_pred.astype(np.float32)  # Convert boolean to float
            masks_pred = torch.from_numpy(masks_pred).to(device)

            # Append the predicted mask to the list
            masks_pred_list.append(masks_pred)

        # Stack the predicted masks into a batch
        masks_pred = torch.stack(masks_pred_list, dim=0)

        # Ensure the target masks are in the correct format
        masks = masks.float()  # Convert target masks to float

        # Compute loss
        loss = criterion(masks_pred, masks)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")