#Libraries and Dependancies

In [None]:
!git clone https://github.com/facebookresearch/sam2.git
%cd sam2
!pip install -e .

Cloning into 'sam2'...
remote: Enumerating objects: 1070, done.[K
remote: Total 1070 (delta 0), reused 0 (delta 0), pack-reused 1070 (from 1)[K
Receiving objects: 100% (1070/1070), 134.70 MiB | 14.36 MiB/s, done.
Resolving deltas: 100% (375/375), done.
/content/sam2
Obtaining file:///content/sam2
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.3.2 (from SAM-2==1.0)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting iopath>=0.1.10 (from SAM-2==1.0)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting omegaconf<2.4,>=2.2 (from hydra-core>=1.3.2->SAM-2==1.

In [None]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.3.73-py3-none-any.whl.metadata (35 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Downloading ultralytics-8.3.73-py3-none-any.whl (914 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m914.6/914.6 kB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.14-py3-none-any.whl (26 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.73 ultralytics-thop-2.0.14


In [None]:
import numpy as np
import torch
import cv2
import os
import json
import pickle
from google.colab import drive, files
from pycocotools import mask as coco_mask

from torch.onnx.symbolic_opset11 import hstack

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [None]:
from ultralytics import YOLO

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = YOLO("/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/2733_augmented_adjusted_YOLOV11.pt")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


#Dataset Manipulation & Extraction Methods

In [None]:
data_dir = "/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/coral_masks.v17-v2-mature.sam2/"
train_dir = os.path.join(data_dir, "train/")
test_dir = os.path.join(data_dir, "test/")
valid_dir = os.path.join(data_dir, "valid/")

def extract_dataset(directory):
    image_arrays = {}
    binary_masks = {}
    bbox_coords = {}
    skipped_count = 0
    image_id = 0

    for file_name in os.listdir(directory):
        if file_name.endswith('.json'):
            json_file_path = os.path.join(directory, file_name)
            with open(json_file_path, 'r') as f:
                coco_data = json.load(f)

            annotations = coco_data['annotations']
            if not annotations:  # Fix: Skip if annotations are empty
                skipped_count += 1
                continue

            image_info = coco_data['image']
            if not image_info:  # Fix: Skip if image metadata is missing
                skipped_count += 1
                continue

            file_name = image_info['file_name']
            if not file_name:
                skipped_count += 1
                continue

            image_path = os.path.join(directory, file_name)
            if not os.path.exists(image_path):  # Check if image file exists
                skipped_count += 1
                continue

            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            annotation = annotations[0]
            segmentation = annotation['segmentation']
            if not segmentation:  # Skip if segmentation is missing
                skipped_count += 1
                continue

            """
            results = model.predict(image, verbose=False)
            for result in results:
                boxes = result.boxes
            if boxes.xyxy.tolist():
                bbox = boxes.xyxy.tolist()[0]
            else:
                bbox = [0, 1024, 0, 1024]

            if bbox == [0, 1024, 0, 1024]:
                skipped_count += 1
                continue

            bbox_array = np.array(bbox)
            """

            bbox = annotation['bbox']
            if not bbox:  # Skip if bbox is missing
                skipped_count += 1
                continue
            x_min, y_min, width, height = bbox
            x_max = x_min + width
            y_max = y_min + height
            bbox_array = np.array([x_min, y_min, x_max, y_max])

            binary_mask = coco_mask.decode(segmentation)
            if binary_mask is None:  # Ensure mask is valid
                skipped_count += 1
                continue

            bbox_coords[image_id] = (bbox_array)
            binary_masks[image_id] = binary_mask
            image_arrays[image_id] = image

            image_id += 1

    print(f"Skipped {skipped_count} images in the {directory} dataset.")
    return image_arrays, binary_masks, bbox_coords

dataset_indices = {}
dataset_index_ptr = {}

def read_single(dataset_type, reset_epoch=False):
    global dataset_indices, dataset_index_ptr

    if (dataset_type == "train"):
        image_arrays = train_image_arrays
        binary_masks = train_binary_masks
        bbox_coords = train_bbox_coords
    elif (dataset_type == "test"):
        image_arrays = test_image_arrays
        binary_masks = test_binary_masks
        bbox_coords = test_bbox_coords
    elif (dataset_type == "val"):
        image_arrays = valid_image_arrays
        binary_masks = valid_binary_masks
        bbox_coords = valid_bbox_coords

    # Initialize indices if not already done
    if dataset_type not in dataset_indices:
        dataset_indices[dataset_type] = np.arange(len(image_arrays))
        np.random.shuffle(dataset_indices[dataset_type])  # Shuffle indices
        dataset_index_ptr[dataset_type] = 0  # Start at the beginning

    # Reset and reshuffle if epoch is flagged to reset
    if reset_epoch or dataset_index_ptr[dataset_type] >= len(dataset_indices[dataset_type]):
        dataset_indices[dataset_type] = np.arange(len(image_arrays))
        np.random.shuffle(dataset_indices[dataset_type])
        dataset_index_ptr[dataset_type] = 0

    # Fetch the next index and increment pointer
    entry = dataset_indices[dataset_type][dataset_index_ptr[dataset_type]]
    dataset_index_ptr[dataset_type] += 1  # Increment pointer

    Img = image_arrays[entry]
    mask = binary_masks[entry]
    bbox = bbox_coords[entry]

    return Img, mask, bbox

def read_batch(dataset_type, current_iteration, interval, batch_size=4):
    limage = []
    lmask = []
    lbbox = []
    for i in range(batch_size):
            image,mask,bbox = read_single(dataset_type, reset_epoch=(current_iteration % interval == 0))
            limage.append(image)
            lmask.append(mask)
            lbbox.append(bbox)

    return limage, np.array(lmask), np.array(lbbox)

def return_dataset_size(dataset_type):
    if (dataset_type == "train"):
        return len(train_image_arrays)
    elif (dataset_type == "test"):
        return len(test_image_arrays)
    elif (dataset_type == "val"):
        return len(valid_image_arrays)


def get_itr_interval(dataset_type, epochs):
  if (dataset_type == "train"):
      return (len(train_image_arrays) // 4) * epochs, (len(train_image_arrays) // 4)
  elif (dataset_type == "test"):
      return (len(test_image_arrays) // 4) * epochs, (len(test_image_arrays) // 4)
  elif (dataset_type == "val"):
      return (len(valid_image_arrays) // 4) * epochs, (len(valid_image_arrays) // 4)

train_image_arrays, train_binary_masks, train_bbox_coords = extract_dataset(train_dir)
test_image_arrays, test_binary_masks, test_bbox_coords = extract_dataset(test_dir)
valid_image_arrays, valid_binary_masks, valid_bbox_coords = extract_dataset(valid_dir)

Skipped 0 images in the /content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/coral_masks.v17-v2-mature.sam2/train/ dataset.
Skipped 0 images in the /content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/coral_masks.v17-v2-mature.sam2/test/ dataset.
Skipped 0 images in the /content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/coral_masks.v17-v2-mature.sam2/valid/ dataset.


#SAM 2 Training

In [None]:
# Load model

sam2_checkpoint = "/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/sam2.1_hiera_large.pt" # path to model weight
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" #  model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model)

# Set training parameters

predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # mixed precision

iterations, VALIDATION_INTERVAL = get_itr_interval("train", 40)

def compute_metrics(predictor, dataset_type, batch_size):
    """Computes validation/test metrics for the given dataset type."""
    predictor.model.eval()
    total_iou, total_loss = 0, 0
    num_samples = 0
    itr, end = get_itr_interval(dataset_type, 1)

    with torch.no_grad():
        for _ in range(return_dataset_size(dataset_type) // batch_size):
            image, mask, input_bbox = read_batch(dataset_type, current_iteration=itr + 1, interval=end, batch_size = 4)
            if mask.shape[0] == 0:
                continue

            predictor.set_image_batch(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                point_coords=None,
                point_labels=None,
                box=input_bbox,
                mask_logits=None,
                normalize_coords=True
            )
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=None, boxes=unnorm_box, masks=None
            )
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"],
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=False,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (
                -gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)
            ).mean()
            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

            total_loss += seg_loss.item()
            total_iou += iou.mean().item()
            num_samples += 1

    return total_loss / num_samples, total_iou / num_samples

# Training loop

best_val_iou = -float('inf')  # Initialize to negative infinity to always save the best model
best_model_path = "/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/best_SAM2.pth"

patience = 5  # Number of validation intervals to wait before stopping
no_improve_counter = 0  # Counter to track the number of validation intervals without improvement
best_val_loss = float('inf')  # Initialize best validation loss
min_lr = 1e-7  # Minimum learning rate for stopping

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=min_lr
)

for itr in range(iterations):
    with torch.cuda.amp.autocast():  # cast to mixed precision
        # Load data batch
        image, mask, input_bbox = read_batch("train", current_iteration=itr + 1, interval=VALIDATION_INTERVAL, batch_size = 4)  # Update the function to provide bounding boxes
        if mask.shape[0] == 0:
            continue  # Ignore empty batches

        predictor.set_image_batch(image)  # Apply SAM image encoder to the image

        # Prompt encoding using bounding boxes
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            point_coords=None,
            point_labels=None,
            box=input_bbox,  # Use bounding boxes here
            mask_logits=None,
            normalize_coords=True
        )
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=None,
            boxes=unnorm_box,  # Pass the bounding boxes to the encoder
            masks=None
        )

        # Mask decoder
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"],
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=False,
            high_res_features=high_res_features,
        )
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])  # Upscale the masks to the original image resolution

        # Segmentation loss calculation
        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])  # Turn logit map to probability map
        seg_loss = (
            -gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)
        ).mean()  # Cross entropy loss

        # Score loss calculation (Intersection over Union - IOU)
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05  # Mix losses

        # Apply backpropagation
        predictor.model.zero_grad()  # Empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mixed precision

    # Validation step with regularization
    if itr % VALIDATION_INTERVAL == 0:
        val_loss, val_iou = compute_metrics(predictor, dataset_type="val", batch_size=4)
        print(f"Validation - Step: {itr}, Loss: {val_loss:.4f}, IOU: {val_iou:.4f}")

        # Learning rate adjustment
        scheduler.step(val_loss)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve_counter = 0
            # Save best model
            torch.save(predictor.model.state_dict(), best_model_path)
            print(f"New best model saved at iteration {itr} with validation loss: {val_loss:.4f}")
        else:
            no_improve_counter += 1
            print(f"No improvement in validation loss for {no_improve_counter} intervals.")

        if no_improve_counter >= patience:
            print(f"Early stopping triggered at iteration {itr}. Best validation loss: {best_val_loss:.4f}")
            break

# Final evaluation on test set
test_loss, test_iou = compute_metrics(predictor, dataset_type="test", batch_size=4)
print(f"Test Results - Loss: {test_loss:.4f}, IOU: {test_iou:.4f}")

  scaler = torch.cuda.amp.GradScaler() # mixed precision
  with torch.cuda.amp.autocast():  # cast to mixed precision


Validation - Step: 0, Loss: 0.1812, IOU: 0.5707
New best model saved at iteration 0 with validation loss: 0.1812
Validation - Step: 790, Loss: 0.0131, IOU: 0.9710
New best model saved at iteration 790 with validation loss: 0.0131
Validation - Step: 1580, Loss: 0.0123, IOU: 0.9726
New best model saved at iteration 1580 with validation loss: 0.0123
Validation - Step: 2370, Loss: 0.0116, IOU: 0.9739
New best model saved at iteration 2370 with validation loss: 0.0116
Validation - Step: 3160, Loss: 0.0114, IOU: 0.9741
New best model saved at iteration 3160 with validation loss: 0.0114
Validation - Step: 3950, Loss: 0.0114, IOU: 0.9743
No improvement in validation loss for 1 intervals.
Validation - Step: 4740, Loss: 0.0113, IOU: 0.9744
New best model saved at iteration 4740 with validation loss: 0.0113
Validation - Step: 5530, Loss: 0.0111, IOU: 0.9751
New best model saved at iteration 5530 with validation loss: 0.0111
Validation - Step: 6320, Loss: 0.0110, IOU: 0.9753
New best model saved a

In [None]:
#add training loss and validation loss