# Fine-tuning SAM 2 using Sentinel 2 data











## What is Segment Anything Model
[Segment Anything Model](https://segment-anything.com/) is a Ai model developed by [Meta Ai](https://ai.meta.com/) that performs promptable image segmentation. SAM was trained general notion of object definiction, this enables zero-shot generaliztion of new, previously unknown images and objects without additional training. This allows SAM to work on almost every segmentation case. First version of SAM was realased on 5 April 2023 it was trained on  [SA-1B](https://ai.meta.com/datasets/segment-anything/) dataset that was designed for training general-purpose object segmentation models. On July 29th 2024 Meta released [SAM2](https://ai.meta.com/sam2/) second version of their segmentation model builded on the foundations of SAM. SAM2 was designed to solve promptable image and videos segmentation. SAM 2 segmentation is also faster and more accurate then SAM. SAM architecture was extended with streaming memory and model-in-the-loop data engine for real-time video processing. SAM2 was trained additionally on [SA-V](https://ai.meta.com/datasets/segment-anything-video/) dataset that consists of 51K diverse videos.

### Purpose of Fine-Tuning SAM 2
While SAM 2 model was designed to handle segmentation of even previously unknown objects it may not be enught accurate in not-mainstrem data. In this notebook SAM 2 models will be fine-tuned to improve segmentaion of agricultural fields.

Main benefits of Fine-Tuning SAM2 are:
*   Imporoved accuraccy
*   More ralistic objects maks
*   Faster than training from scratch

To run this notebook make sure that you're using GPU with egought memory. On Google colab L4 GPU should be enought but it's recomended to use A100 GPU.

Install required libraries:
*   awscli: Amazon Web Services comand line interface, provides interface to interact with AWS.
*   sam2: Segmanet Anything Model 2, Meta's model for image segmentation puproses.
*   transformers: Interface for working with machine learning models.
*   datasets: Simplifies machine lerning dataset managment.



In [None]:
%pip install awscli &> /dev/null
%pip install sam2 &> /dev/null
%pip install transformers &> /dev/null
%pip install datasets &> /dev/null

In [None]:
import os
import glob
import random
import subprocess
import gc
import requests
import awscli
import numpy as np
import matplotlib.pyplot as plt
import tifffile  # Handling TIFF images
from osgeo import gdal
import scipy.ndimage as ndimage
import torch
from google.colab import drive
from PIL import Image
import imghdr
from datasets import Dataset
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


Connect to drive and create directory, final fine-tuned models will be saved there for further use.

In [None]:
drive.mount('/content/gdrive', force_remount=True)

!mkdir -p /content/gdrive/MyDrive/SAM_trained_models

## Download training data
[Fields of The World (FTW)](https://fieldsofthe.world/) is a dataset designed to enchant machine learning models used in field of remote sensing and GIS. FTW agregates and harmonizes many smaller open datasets into one large dataset with data from 24 countres from across the world. FTW contains approximately 1.6 million parcel boundaries and over 70,000 samples of agricultural fields. Each instance contains [Sentinel 2](https://www.esa.int/Applications/Observing_the_Earth/Copernicus/Sentinel-2) image including RGB and NIR bands, image is parired with semantic segmentation masks representig agricultrural fields boundaries.

In [None]:
def download_data(country):
  """
  Download data from Fields of The World for a specific country.

  Parameters:
  - country: The name of the country.

  Returns:
  - None
  """
  command = f"aws s3 sync s3://kerner-lab/fields-of-the-world/{country} ./{country} --endpoint-url=https://data.source.coop --no-sign-request"
  subprocess.run(command, shell=True)


In [None]:
# Country from whom Fields of The World data will be downloaded
country= "sweden"

In [None]:
# Create directories to store data
os.makedirs(country, exist_ok=True)

# Download data to created directories
download_data(country)


Plot mask and RGB bands of Sentinel 2 image seperately. Each image have 128x128 pixels size, it coresponds to 1536x1536 meters.

In [None]:
# Plot mask and Sentinel2 RGB bands
img = tifffile.imread("sweden/s2_images/window_a/g13-5_00003_17.tif")
mask = tifffile.imread("sweden/label_masks/semantic_2class/g13-5_00003_17.tif")

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes[0][0].imshow(np.array(img[:,:,0]), cmap='gray')
axes[0][0].set_title("Band B02")

axes[0][1].imshow(np.array(img[:,:,1]), cmap='gray')
axes[0][1].set_title("Band B03")

axes[1][0].imshow(np.array(img[:,:,2]), cmap='gray')
axes[1][0].set_title("Band B04")

axes[1][1].imshow(mask, cmap='gray')
axes[1][1].set_title("Mask")

plt.show()


## Data preprocessing
FTW data is but it does reqiure some processing. Firstly Sanetinel 2 image RGB compositon is created. Then image is devied by Sentintinel 2 scale factor (10000), normalized and converted to u-int8 format. On masks morfologic open operation is aplied to get rid of small regions on masks (regions on mask are somtimes build of one pixel). For each region on mask random point will be selected, this points will works as a prompt to finetune SAM 2.

In [None]:
def get_points(binary_mask, n):
  """
  Get n random points from a binary mask.

  Parameters:
  - binary_mask: The binary mask from which points are to be extracted.
  - n: The number of points to be extracted.

  Returns:
  - random_points: A list of n random points.
  """

  # Morphology erosion to make sure that points won't be near region's border
  msk = ndimage.binary_erosion(binary_mask, structure=np.ones((3,3)))

  # Create labels for each unique region on mask
  labeled_mask, num_features = ndimage.label(msk)

  random_points = []

  for label_id in range(1, num_features + 1):
    # Indices of current region
    region_indices = np.argwhere(labeled_mask == label_id)

    for i in range(n):
      if len(region_indices) > 0:
        # Random points for current region
        yx = np.array(region_indices[np.random.randint(len(region_indices))])

        random_points.append([yx[1], yx[0]])

  return random_points

In [None]:
def get_num_masks(mask):
  """
  Get the number of unique regions in a binary mask.

  Parameters:
  - mask: The binary mask.

  Returns:
  - num_features: The number of unique regions in the mask.
  """

  labeled_mask, num_features = ndimage.label(mask)

  return num_features

In [None]:
def load_data(country, dataset):
  """
  Load and preprocess data from a specific country into a dataset.

  Parameters:
  - country: The name of the country.
  - dataset: The dataset to which the data will be added.

  Returns:
  - None
  """

  # Save mask paths to lists
  masks_names = glob.glob(f"{country}/label_masks/semantic_2class/*.tif")


  # Empty list to store readed data
  images = []
  masks = []
  points = []
  num_masks = []

  print(f"Loading data from: {country}, number of images: {len(masks_names)}")

  # for pth in zip(masks_names, images_names):
  for mask_path in masks_names:
    mask_name = os.path.splitext(os.path.basename(mask_path))[0]

    image_path = f"{country}/s2_images/window_a/" + mask_name + ".tif"

    if imghdr.what(mask_path) != 'tiff' and imghdr.what(image_path) != 'tiff':
      print(f"Warning: Skipping file - not a TIFF file.")
      continue


    # Read image and mask
    try:
      mask = tifffile.imread(mask_path)
      image = tifffile.imread(image_path)
    except Exception as e:
      print(e)
      continue

    # Morfologic open operation to get rid of small masks
    mask = ndimage.binary_opening(mask, structure=np.ones((3,3)))

    # Rejecting mask without objects
    if mask.max() != 0:
      # Create RGB image
      image = image[:, :, [2, 1, 0]]

      # Scale Senetinel2 pixels values
      min_reflectance = 0
      max_reflectance = np.percentile(image / 10000, 98)  # Clip at the 98th percentile
      image_clip = np.clip(image / 10000, min_reflectance, max_reflectance)
      image_norm = (image_clip - min_reflectance) / (max_reflectance - min_reflectance)
      image_uint8 = (image_norm * 255).astype(np.uint8)


      # Add image to list
      images.append(Image.fromarray(image_uint8).convert('RGB'))

      # Convert mask to image format
      mask = Image.fromarray(mask)

      # Add mask to list
      masks.append(mask)

      # Get points for every region on mask
      p = get_points(mask, 2)

      points.append(p)

      # Add number of regions to list
      num_masks.append(len(p))

  # Add data from current country to dataset
  dataset["image"].extend([i for i in images])
  dataset["label"].extend([m for m in masks])
  dataset["points"].extend([p for p in points])
  dataset["num_masks"].extend([n for n in num_masks])


Loading downloaded data, preprocessing it and converting to dataset for better data managment. Loaded dataset contains 4180 images from Sweden.

In [None]:
# Create empty dict to store FTW data
dataset = {
    "image": [],
    "label": [],
    "points": [],
    "num_masks": []
}

# Load FTW data and store it in dict
load_data(country, dataset)

# Convert dictionary to dataset
dataset = Dataset.from_dict(dataset)


View dataset to make sure that dataset is correctly created.

In [None]:
print(dataset)


Data is splited to train (80%) and validate (20%) sub-datasets.

In [None]:
# Split the dataset into training (80%) and validation (20%)
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)

# Access the training and validation datasets
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]


View datasets after split. Training dataset contains 3344 images and test dataset 837.

In [None]:
print(train_dataset)
print(val_dataset)


Plot image and mask both with random selected points for each region. Each region have 2 random points. Points are distanted from regions border by using erosion, this will model to properly segment agricultural fields.

In [None]:
# Plot image and mask with random points
data = train_dataset[np.random.randint(len(train_dataset))]

fig, axs = plt.subplots(1, 2, figsize=(10, 10))

# Add image with points
axs[0].imshow(data["image"])
axs[0].set_title("Sentinel image")
for point in data["points"]:
    axs[0].plot(point[0], point[1], 'ro')

# Add mask with points
axs[1].imshow(data['label'], cmap="gray")
axs[1].set_title("Mask")
for point in data["points"]:
    axs[1].plot(point[0], point[1], 'ro')

plt.savefig("plot_random_points_image.png", dpi=300)
plt.show()


Define function to read random batch from dataset. Function adjusts mask and points shapes to meet training process requirements format.

In [None]:
def read_random_batch(dataset):
  """
  Read a random batch from a dataset.

  Parameters:
  - dataset: The dataset from which the batch is to be read.

  Returns:
  - batch: A random batch from the dataset.
  """

  # Read random batch from dataset
  rnd = np.random.randint(len(dataset))
  batch = dataset[rnd]

  # Convert to numpy array
  img =  np.array(batch["image"])
  binary_mask = np.array(batch["label"])
  points = np.array(batch["points"])

  # Expand mask dimentions
  binary_mask = np.expand_dims(binary_mask, axis=-1)
  binary_mask = binary_mask.transpose((2, 0, 1))

  # Expand points dimentions
  points = np.expand_dims(points, axis=1)

  return img, binary_mask, points, batch["num_masks"], rnd


Test function by reading random batch and ploting it.

In [None]:
# Plot random image and mask from random batch
img, mask, points, num_masks, _= read_random_batch(dataset)

# Reduce mask dimentions for plot
mask = np.squeeze(mask, axis=0)

fig, axs = plt.subplots(1, 2, figsize=(10, 10))

# Add image
axs[0].imshow(img)
axs[0].set_title("Sentinel image")

# Add mask
axs[1].imshow(mask, cmap="gray")
axs[1].set_title("Mask")

plt.savefig("image_and_mask.png", dpi=300)
fig.show()

## Download SAM2 models checkpoints
To fine-tune SAM2 pre-trained model weights are needed. This weights will work as a starting point for futher improving models accuracy in segmentation of agriculture fields. SAM 2 comes with 4 models checkpoints with different in sizes and speed:

| Model | Size (M) | Speed (FPS) |
|----------|----------|----------|
| sam2_hiera_tiny   | 38.9 | 47.2  |
| sam2_hiera_small   | 46  | 43.3 (53.0 compiled)   |
| sam2_hiera_base_plus    | 80.8  | 34.8 (43.8 compiled)   |
| sam2_hiera_large    | 224.4  | 24.2 (30.2 compiled)  |

In [None]:
!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
!wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
!wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
!wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"


## Fine-Tune model
Clear cuda memory every time before changing SAM2 model, this prevents GPU memory over folow.

In [None]:
gc.collect()
torch.cuda.empty_cache()


In [None]:
sam2_checkpoint = "sam2_hiera_small.pt"  # @param ["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"]
model_cfg = "sam2_hiera_s.yaml" # @param ["sam2_hiera_t.yaml", "sam2_hiera_s.yaml", "sam2_hiera_b+.yaml", "sam2_hiera_l.yaml"]

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model) # Create net


In [None]:
# Enable train of model parts
predictor.model.sam_mask_decoder.train(True)
predictor.model.image_encoder.train(True)
predictor.model.sam_prompt_encoder.train(True)

# Configure AdamW optimizer
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=0.0001, weight_decay=1e-4)

# Mix precision
scaler = torch.amp.GradScaler()

# Number of steps per epoch to train the model
STEPS_PER_EPOCH = 500 # @param

# Fine-tuned model name to be saved
FINE_TUNED_MODEL_NAME = "fine_tuned_sam2_small" # @param

# Number of epochs
EPOCHS = 15 # @param


In [None]:
# Initialize scheduler
accumulation_steps = 2
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEPS_PER_EPOCH, gamma=0.95)
mean_iou = 0
drive_model_path = ""
accuraces = []

for epoch in range(EPOCHS):
  print(f"Epoch {epoch}")
  for step in range(STEPS_PER_EPOCH):
    with torch.amp.autocast('cuda'):
        image, mask, input_point, num_masks, rnd = read_random_batch(train_dataset)
        input_label = np.ones((num_masks, 1))

        predictor.set_image(image)

        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            input_point, input_label, box=None, mask_logits=None, normalize_coords=True)

        if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
           continue

        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None,)

        # Mask decoder
        batched_mode = unnorm_coords.shape[0] > 1

        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]

        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features,
        )
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        # Loss
        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.000001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05

        # Apply gradient accumulation
        loss = loss / accumulation_steps
        scaler.scale(loss).backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)

        if step % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            predictor.model.zero_grad()

        # Update scheduler
        scheduler.step()

        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

        if step % 100 == 0:
            accuraces.append(mean_iou)
            print(f"Step {step}:\t Accuracy (IoU) = {mean_iou:.3f}, Loss = {loss.item():.3f}, LR = {scheduler.get_last_lr()[0]:.7f}")

        if step % 100 == 0:
            pass

  gc.collect()
  torch.cuda.empty_cache()

  torch.save(predictor.model.state_dict(), f"{FINE_TUNED_MODEL_NAME}_epoch_{epoch}" + ".torch")

  drive_model_path = f"/content/gdrive/MyDrive/SAM_trained_models/{FINE_TUNED_MODEL_NAME}_epoch_{epoch}.torch"

torch.save(predictor.model.state_dict(), drive_model_path)

torch.cuda.empty_cache()

In [None]:
# Plot training accuracy
x = range(len(accuraces))
x = np.array(x) / 5

plt.plot(x, accuraces, color="green", linestyle="-")

plt.title("Fine-tuning accuracy of small model")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.show()

## Verify model accuracy
To verify models accuracy the intersection over union (IoU) metric will be used. IoU also known as Jaccard's index is a crucial metric in image segmentation, it measures how well model distinguishes objects from background. IoU is the ratio of intersection of ground truth mask with segmented mask and their union area.

Formula for Jaccard's Index:

$J[A, B] = \frac{|A \cap B|}  {|A \cup B|}$.





In [None]:
def calculate_iou(image1, image2):
  """
  Calculate intersection over union (IoU) between two binary images.

  Parameters:
  - image1: The first binary image.
  - image2: The second binary image.

  Returns:
  - iou: The IoU value between the two images.
  """

  intersection = np.sum(np.logical_and(image1, image2))  # Pixels where both are 1
  union = np.sum(np.logical_or(image1, image2))  # Pixels where either is 1

  # Calculate IoU
  iou = intersection / union

  return(iou)

In [None]:
def segment_image(predictor, image, input_points, origin_mask, fine_tune):
    """
    Segment an image using a fine-tuned model and visualize the segmentation.

    Parameters:
    - predictor: The trained model used for segmentation.
    - image: The input image (e.g., a Sentinel image).
    - input_points: The coordinates of the input points used for the model.
    - origin_mask: The original mask to compare the segmented result.

    Returns:
    - seg_map: The final segmented map.
    - fig: The matplotlib figure containing the visualization.
    """

    # Create segmentation masks
    with torch.no_grad():
        predictor.set_image(image)

        masks, scores, logits = predictor.predict(
            point_coords=input_points,
            point_labels=np.ones([input_points.shape[0], 1])
        )

    # Process masks
    np_masks = np.array(masks[:, 0])

    # If dimensions number is 1, expand to 2 dims
    if scores.ndim == 1:
      scores = np.expand_dims(scores, axis=1)  # Adds a new axis

    np_scores = scores[:, 0]

    # Sort masks by scores
    sorted_masks = np_masks[np.argsort(np_scores)][::-1]

    # Initialization of segmentation mask and occupancy mask
    seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
    occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

    # Merge all mask into one segemntation map
    for i in range(sorted_masks.shape[0]):
        mask = sorted_masks[i]
        if (mask * occupancy_mask).sum() / mask.sum() > 0.2:
          continue

        mask_bool = mask.astype(bool)
        mask_bool[occupancy_mask] = False
        seg_map[mask_bool] = i + 1
        occupancy_mask[mask_bool] = True  # Update occupancy_mask

    seg_map = seg_map.astype(bool)

    # Calculate iou
    iou = calculate_iou(seg_map, origin_mask)

    # Visualization
    if fine_tune:
      segmentstion_title = "Fine-Tuned Model Segmentation"
    else:
      segmentstion_title = "Pre-Trained Model Segmentation"

    fig, axs = plt.subplots(1, 3, figsize=(10, 10))

    # Add Sentinel2 image
    axs[0].imshow(image)
    axs[0].set_title("Sentinel Image")

    # Add grount truth mask
    axs[1].imshow(origin_mask, cmap="gray")
    axs[1].set_title("Original Mask")

    # Add final segmentation map
    axs[2].imshow(seg_map, cmap="gray")
    axs[2].set_title(segmentstion_title)

    # Stops image showing
    plt.close(fig)

    return seg_map, iou, fig

## Pre-trained model

In [None]:
sam2_checkpoint = "sam2_hiera_large.pt"  # @param ["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"]
model_cfg = "sam2_hiera_l.yaml" # @param ["sam2_hiera_t.yaml", "sam2_hiera_s.yaml", "sam2_hiera_b+.yaml", "sam2_hiera_l.yaml"]

# Build SAM2 net
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)

First pre-trained model will be tested. Results of segmentation are shown on plot where sentinel ground truth mask and final segmentation mask are compare. Usuall pre-trained models

In [None]:
iou_list = []

for i in range(50):
  image, mask, input_points, region_num, id = read_random_batch(val_dataset)
  mask = np.squeeze(mask, axis=0)

  origin_mask = mask

  seg_map, iou, fig = segment_image(predictor, image, input_points, origin_mask, fine_tune=False)

  iou_list.append(iou)

mean_iou = np.mean(iou_list)

print(f"Mean IoU: {mean_iou}")

# Save results to file
file_name="/content/gdrive/MyDrive/SAM_trained_models/iou_results.txt"
with open(file_name, 'a') as f:
  f.write(f"Pretrained {sam2_checkpoint} Accuracy: {mean_iou}\n")


## Fine-Tuned model

In [None]:
sam2_checkpoint = "sam2_hiera_large.pt"  # @param ["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"]
model_cfg = "sam2_hiera_l.yaml" # @param ["sam2_hiera_t.yaml", "sam2_hiera_s.yaml", "sam2_hiera_b+.yaml", "sam2_hiera_l.yaml"]

tiny = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_tiny_epoch_14.torch"
small = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_small_epoch_14.torch"
base_plus = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_base_epoch_14.torch"
large = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_large_epoch_14.torch"

# Choose fine-tuned model weights
FINE_TUNED_MODEL_WEIGHTS = "large" # @param ["tiny", "small", "base_plus", "large"]
FINE_TUNED_MODEL_WEIGHTS = eval(FINE_TUNED_MODEL_WEIGHTS)

# Build SAM2 net and load fine-tuned weights
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS, weights_only=True))


In [None]:
iou_list = []

for i in range(50):
  image, mask, input_points, region_num, _ = read_random_batch(val_dataset)
  mask = np.squeeze(mask, axis=0)

  origin_mask = mask

  seg_map, iou, fig = segment_image(predictor, image, input_points, origin_mask, fine_tune=True)

  iou_list.append(iou)

mean_iou = np.mean(iou_list)

print(f"Mean IoU: {mean_iou}")

# # Save results to file
file_name="/content/gdrive/MyDrive/SAM_trained_models/iou_results.txt"
with open(file_name, 'a') as f:
  f.write(f"Finetune {sam2_checkpoint} Accuracy: {mean_iou}\n")

## Compare results

Fine-tuned model weights are

In [None]:
# Read random batch
image, mask, input_points, region_num, _ = read_random_batch(val_dataset)
mask = np.squeeze(mask, axis=0)

sam2_checkpoint = "sam2_hiera_tiny.pt"  # @param ["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"]
model_cfg = "sam2_hiera_t.yaml" # @param ["sam2_hiera_t.yaml", "sam2_hiera_s.yaml", "sam2_hiera_b+.yaml", "sam2_hiera_l.yaml"]

# Build SAM2 net and load fine-tuned weights
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)

tiny = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_tiny_epoch_14.torch"
small = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_small_epoch_14.torch"
base_plus = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_base_epoch_14.torch"
large = "/content/gdrive/MyDrive/SAM_trained_models/fine_tuned_sam2_large_epoch_14.torch"

# Choose fine-tuned model weights
FINE_TUNED_MODEL_WEIGHTS = "tiny" # @param ["tiny", "small", "base_plus", "large"]
FINE_TUNED_MODEL_WEIGHTS = eval(FINE_TUNED_MODEL_WEIGHTS)


In [None]:
seg_map, iou, fig = segment_image(predictor, image, input_points, mask, fine_tune=False)

fig


In [None]:
# Load weights
predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS))

seg_map2, iou2, fig2 = segment_image(predictor, image, input_points, mask, fine_tune=True)

fig2
