# Libraries and Dependancies

In [None]:
!git clone https://github.com/facebookresearch/sam2.git
%cd sam2
!pip install -e .

In [None]:
!pip install ultralytics

In [5]:
import numpy as np
import torch
import cv2
import os
import json
import pickle
from pycocotools import mask as coco_mask

from torch.onnx.symbolic_opset11 import hstack

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

RuntimeError: You're likely running Python from the parent directory of the sam2 repository (i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). This is not supported since the `sam2` Python package could be shadowed by the repository name (the repository is also named `sam2` and contains the Python package in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir rather than its parent dir, or from your home directory) after installing SAM 2.

In [6]:
from ultralytics import YOLO

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = YOLO("C:/Users/richa_0/Documents/Coral Research/Machine Learning Models/V2 Developmental Models/2612-augmented.yolov11/results/50_epochs-/weights/best.pt")

# Dataset Manipulation & Extraction Methods

In [9]:
data_dir = "C:/Users/richa_0/Documents/Coral Research/Machine Learning Models/V2 Developmental Models/2526-augmented.sam2/"
train_dir = os.path.join(data_dir, "train/")
test_dir = os.path.join(data_dir, "test/")
valid_dir = os.path.join(data_dir, "valid/")

def extract_dataset(directory):
    image_arrays = {}
    binary_masks = {}
    bbox_coords = {}

    for file_name in os.listdir(directory):
        if file_name.endswith('.json'):
            json_file_path = os.path.join(directory, file_name)
            with open(json_file_path, 'r') as f:
                coco_data = json.load(f)
            annotations = coco_data['annotations']
            image = coco_data['image']

            image_id = image['image_id']
            file_name = image['file_name']
            image_path = os.path.join(directory, file_name)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_arrays[image_id] = image

            for annotation in annotations:
                image_id = annotation['id']
                segmentation = annotation['segmentation']

                results = model.predict(image, verbose=False)
                for result in results:
                    boxes = result.boxes
                if boxes.xyxy.tolist():
                    bbox = boxes.xyxy.tolist()[0]
                else:
                    bbox = [0, 1024, 0, 1024]
                bbox_array = np.array(bbox)

                bbox_coords[image_id] = (bbox_array)

                binary_mask = coco_mask.decode(segmentation)
                binary_masks[image_id] = binary_mask

    return image_arrays, binary_masks, bbox_coords

dataset_indices = {}
dataset_index_ptr = {}

def read_single(dataset_type, reset_epoch=False):
    global dataset_indices, dataset_index_ptr

    if (dataset_type == "train"):
        image_arrays = train_image_arrays
        binary_masks = train_binary_masks
        bbox_coords = train_bbox_coords
    elif (dataset_type == "test"):
        image_arrays = test_image_arrays
        binary_masks = test_binary_masks
        bbox_coords = test_bbox_coords
    elif (dataset_type == "val"):
        image_arrays = valid_image_arrays
        binary_masks = valid_binary_masks
        bbox_coords = valid_bbox_coords

    # Initialize indices if not already done
    if dataset_type not in dataset_indices:
        dataset_indices[dataset_type] = np.arange(len(image_arrays))
        np.random.shuffle(dataset_indices[dataset_type])  # Shuffle indices
        dataset_index_ptr[dataset_type] = 0  # Start at the beginning

    # Reset and reshuffle if epoch is flagged to reset
    if reset_epoch or dataset_index_ptr[dataset_type] >= len(dataset_indices[dataset_type]):
        dataset_indices[dataset_type] = np.arange(len(image_arrays))
        np.random.shuffle(dataset_indices[dataset_type])
        dataset_index_ptr[dataset_type] = 0

    # Fetch the next index and increment pointer
    entry = dataset_indices[dataset_type][dataset_index_ptr[dataset_type]]
    dataset_index_ptr[dataset_type] += 1  # Increment pointer

    Img = image_arrays[entry]
    mask = binary_masks[entry]
    bbox = bbox_coords[entry]

    return Img, mask, bbox

def read_batch(dataset_type, current_iteration, interval, batch_size=4):
    limage = []
    lmask = []
    lbbox = []
    for i in range(batch_size):
            image,mask,bbox = read_single(dataset_type, reset_epoch=(current_iteration % interval == 0))
            limage.append(image)
            lmask.append(mask)
            lbbox.append(bbox)

    return limage, np.array(lmask), np.array(lbbox)

def return_dataset_size(dataset_type):
    if (dataset_type == "train"):
        return len(train_image_arrays)
    elif (dataset_type == "test"):
        return len(test_image_arrays)
    elif (dataset_type == "val"):
        return len(valid_image_arrays)


def get_itr_interval(dataset_type, epochs):
  if (dataset_type == "train"):
      return (len(train_image_arrays) // 4) * epochs, (len(train_image_arrays) // 4)
  elif (dataset_type == "test"):
      return (len(test_image_arrays) // 4) * epochs, (len(test_image_arrays) // 4)
  elif (dataset_type == "val"):
      return (len(valid_image_arrays) // 4) * epochs, (len(valid_image_arrays) // 4)

train_image_arrays, train_binary_masks, train_bbox_coords = extract_dataset(train_dir)
test_image_arrays, test_binary_masks, test_bbox_coords = extract_dataset(test_dir)
valid_image_arrays, valid_binary_masks, valid_bbox_coords = extract_dataset(valid_dir)

# SAM 2 Training

In [None]:
# Load model

sam2_checkpoint = "/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/sam2.1_hiera_base_plus.pt" # path to model weight
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" #  model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model)

# Set training parameters

predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # mixed precision

iterations, VALIDATION_INTERVAL = get_itr_interval("train", 30)

def compute_metrics(predictor, dataset_type, batch_size):
    """Computes validation/test metrics for the given dataset type."""
    predictor.model.eval()
    total_iou, total_loss = 0, 0
    num_samples = 0
    itr, end = get_itr_interval(dataset_type, 1)

    with torch.no_grad():
        for _ in range(return_dataset_size(dataset_type) // batch_size):
            image, mask, input_bbox = read_batch(dataset_type, current_iteration=itr + 1, interval=end, batch_size = 4)
            if mask.shape[0] == 0:
                continue

            predictor.set_image_batch(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                point_coords=None,
                point_labels=None,
                box=input_bbox,
                mask_logits=None,
                normalize_coords=True
            )
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=None, boxes=unnorm_box, masks=None
            )
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"],
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=False,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (
                -gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)
            ).mean()
            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

            total_loss += seg_loss.item()
            total_iou += iou.mean().item()
            num_samples += 1

    return total_loss / num_samples, total_iou / num_samples

# Training loop

best_val_iou = -float('inf')  # Initialize to negative infinity to always save the best model
best_model_path = "/content/gdrive/MyDrive/Coral SAM 2 Tuner Folder/best_SAM2.pth"

for itr in range(iterations):
    with torch.cuda.amp.autocast():  # cast to mixed precision
        # Load data batch
        image, mask, input_bbox = read_batch("train", current_iteration=itr + 1, interval=VALIDATION_INTERVAL, batch_size = 4)  # Update the function to provide bounding boxes
        if mask.shape[0] == 0:
            continue  # Ignore empty batches

        predictor.set_image_batch(image)  # Apply SAM image encoder to the image

        # Prompt encoding using bounding boxes
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            point_coords=None,
            point_labels=None,
            box=input_bbox,  # Use bounding boxes here
            mask_logits=None,
            normalize_coords=True
        )
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=None,
            boxes=unnorm_box,  # Pass the bounding boxes to the encoder
            masks=None
        )

        # Mask decoder
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"],
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=False,
            high_res_features=high_res_features,
        )
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])  # Upscale the masks to the original image resolution

        # Segmentation loss calculation
        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])  # Turn logit map to probability map
        seg_loss = (
            -gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)
        ).mean()  # Cross entropy loss

        # Score loss calculation (Intersection over Union - IOU)
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05  # Mix losses

        # Apply backpropagation
        predictor.model.zero_grad()  # Empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mixed precision

    # Validation step
    if itr % VALIDATION_INTERVAL == 0:
        val_loss, val_iou = compute_metrics(predictor, dataset_type="val", batch_size=4)
        print(f"Validation - Step: {itr}, Loss: {val_loss:.4f}, IOU: {val_iou:.4f}")

        # Save the best model
        if val_iou > best_val_iou:  # Check if this is the best IoU so far
            best_val_iou = val_iou
            torch.save(predictor.model.state_dict(), best_model_path)
            print(f"New best model saved at iteration {itr} with IoU: {val_iou:.4f}")

# Final evaluation on test set
test_loss, test_iou = compute_metrics(predictor, dataset_type="test", batch_size=4)
print(f"Test Results - Loss: {test_loss:.4f}, IOU: {test_iou:.4f}")

  scaler = torch.cuda.amp.GradScaler() # mixed precision
  with torch.cuda.amp.autocast():  # cast to mixed precision


Validation - Step: 0, Loss: 0.0196, IOU: 0.9235
New best model saved at iteration 0 with IoU: 0.9235
Validation - Step: 1326, Loss: 0.0068, IOU: 0.9674
New best model saved at iteration 1326 with IoU: 0.9674
Validation - Step: 2652, Loss: 0.0065, IOU: 0.9697
New best model saved at iteration 2652 with IoU: 0.9697
Validation - Step: 3978, Loss: 0.0064, IOU: 0.9698
New best model saved at iteration 3978 with IoU: 0.9698
Validation - Step: 5304, Loss: 0.0064, IOU: 0.9705
New best model saved at iteration 5304 with IoU: 0.9705
Validation - Step: 6630, Loss: 0.0064, IOU: 0.9710
New best model saved at iteration 6630 with IoU: 0.9710
Validation - Step: 7956, Loss: 0.0062, IOU: 0.9712
New best model saved at iteration 7956 with IoU: 0.9712
Validation - Step: 9282, Loss: 0.0062, IOU: 0.9711
Validation - Step: 10608, Loss: 0.0063, IOU: 0.9712
Validation - Step: 11934, Loss: 0.0063, IOU: 0.9718
New best model saved at iteration 11934 with IoU: 0.9718
Validation - Step: 13260, Loss: 0.0061, IOU: 

In [None]:
#add training loss and validation loss