In [None]:
# common imports
import os
import numpy as np
from tqdm import tqdm
from glob import glob
from numpy import zeros
from numpy.random import randint
import torch
import cv2
from statistics import mean
from torch.nn.functional import threshold, normalize
# Data Viz
import matplotlib.pyplot as plt
from pathlib import Path
import cv2

# Install dependencies
! pip install opencv-python pycocotools matplotlib onnxruntime onnx
! pip install git+https://github.com/facebookresearch/segment-anything.git

"""Import Training data"""
# Adjust the data paths
# Training data paths
image_path = "/kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Train/Images"
label_path = "/kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Train/Labels"

total_images = len(os.listdir(image_path))
all_image_paths = sorted(glob(image_path + "/*.jpg"))
print(f"Total Number of Training Images: {total_images}")

total_labels = len(os.listdir(label_path))
all_label_paths = sorted(glob(label_path + "/*.png"))
print(f"Total Number of Training Labels: {total_labels}")

train_image_paths = all_image_paths[0:total_images]
train_label_paths = all_label_paths[0:total_labels]

"""Import Validation data"""
val_image_path = "/kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Val/Images"
val_label_path = "//kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Val/labels"

val_total_images = len(os.listdir(val_image_path))
val_image_paths = sorted(glob(val_image_path + "/*.jpg"))
print(f"Total Number of Validation Images: {val_total_images}")

val_total_labels = len(os.listdir(val_label_path))
val_label_paths = sorted(glob(val_label_path + "/*.png"))
print(f"Total Number of Validation Labels: {val_total_labels}")

val_image_paths = val_image_paths[0:val_total_images]
val_label_paths = val_label_paths[0:val_total_labels]

In [2]:
"""Reading ground_truth_masks for training and validation"""
desired_size = (512, 512)

# Read training masks
ground_truth_masks = {}
for k in range(len(train_label_paths)):
    gt_grayscale = cv2.imread(train_label_paths[k], cv2.IMREAD_GRAYSCALE)
    if desired_size is not None:
        gt_grayscale = cv2.resize(gt_grayscale, desired_size, interpolation=cv2.INTER_LINEAR)
    ground_truth_masks[k] = (gt_grayscale > 0)

# Read validation masks
ground_truth_masksv = {}
for s in range(len(val_label_paths)):
    gt_grayscale = cv2.imread(val_label_paths[s], cv2.IMREAD_GRAYSCALE)
    if desired_size is not None:
        gt_grayscale = cv2.resize(gt_grayscale, desired_size, interpolation=cv2.INTER_LINEAR)
    ground_truth_masksv[s] = (gt_grayscale > 0)

In [None]:
"""Import SAM model"""
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
from segment_anything import SamPredictor, sam_model_registry

# Specify model type and checkpoint path
model_type = 'vit_h'
checkpoint = 'sam_vit_h_4b8939.pth'  # Adjusted path
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load the SAM model
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.train()

In [4]:
"""Preprocess the images for training"""
from collections import defaultdict
from segment_anything.utils.transforms import ResizeLongestSide

transformed_data = defaultdict(dict)
for k in range(len(train_image_paths)):
    image = cv2.imread(train_image_paths[k])
    if desired_size is not None:
        image = cv2.resize(image, desired_size, interpolation=cv2.INTER_LINEAR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device=device)
    transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

    input_image = sam_model.preprocess(transformed_image)
    original_image_size = image.shape[:2]
    input_size = tuple(transformed_image.shape[-2:])

    transformed_data[k]['image'] = input_image
    transformed_data[k]['input_size'] = input_size
    transformed_data[k]['original_image_size'] = original_image_size

In [5]:
"""Set up the optimizer and Loss"""
lr = 1e-5
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)
loss_fn = torch.nn.BCEWithLogitsLoss()
keys = list(ground_truth_masks.keys())
keys1 = list(ground_truth_masksv.keys())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
num_epochs = 2

In [6]:
"""Fine-tuning SAM using Training data"""
def calculate_accuracy(predictions, targets):
    binary_predictions = (predictions > 0.5).float()
    accuracy = (binary_predictions == targets).float().mean()
    return accuracy.item()

def train_on_batch(keys, batch_start, batch_end):
    batch_losses = []
    batch_accuracies = []

    for k in keys[batch_start:batch_end]:
        input_image = transformed_data[k]['image'].to(device)
        input_size = transformed_data[k]['input_size']
        original_image_size = transformed_data[k]['original_image_size']

        with torch.no_grad():
            image_embedding = sam_model.image_encoder(input_image)

            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None,
                boxes=None,
                masks=None,
            )

        low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
        binary_mask = (threshold(torch.sigmoid(upscaled_masks), 0.5, 0))
        gt_mask_resized = torch.from_numpy(np.resize(ground_truth_masks[k], (1, 1, ground_truth_masks[k].shape[0], ground_truth_masks[k].shape[1]))).to(device)
        gt_mask_resized = gt_mask_resized > 0.5
        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)

        loss = loss_fn(binary_mask, gt_binary_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_losses.append(loss.item())

        # Calculate accuracy for training data
        train_accuracy = calculate_accuracy(binary_mask, gt_binary_mask)
        batch_accuracies.append(train_accuracy)

    return batch_losses, batch_accuracies

In [None]:
# Training loop
losses = []
val_losses = []
accuracies = []
best_val_loss = float('inf')  # Initialize best validation loss to positive infinity
val_acc = []

for epoch in range(num_epochs):
    epoch_losses = []
    epoch_accuracies = []

    # Training loop with batch processing
    for batch_start in range(0, len(keys), batch_size):
        batch_end = min(batch_start + batch_size, len(keys))

        batch_losses, batch_accuracies = train_on_batch(keys, batch_start, batch_end)

        # Calculate accuracy for the current batch
        batch_accuracy = mean(batch_accuracies)
        epoch_accuracies.extend(batch_accuracies)

        # Calculate mean training loss for the current batch
        batch_loss = mean(batch_losses)
        epoch_losses.append(batch_loss)

        print(f'Batch: [{batch_start+1}-{batch_end}]')
        print(f'Batch Loss: {batch_loss}')
        print(f'Batch Accuracy: {batch_accuracy}')

    # Calculate mean training loss for the current epoch
    mean_train_loss = mean(epoch_losses)
    mean_train_accuracy = mean(epoch_accuracies)
    losses.append(mean_train_loss)
    accuracies.append(mean_train_accuracy)

    print(f'EPOCH: {epoch}')
    print(f'Mean training loss: {mean_train_loss}')
    print(f'Mean training accuracy: {mean_train_accuracy}')

    predictor_tuned = SamPredictor(sam_model)

    # Validation loop
    val_loss = 0.0
    val_accuracy = 0.0
    num_val_examples = 0
    with torch.no_grad():
        for s in keys1[:len(val_image_paths)]:
            image = cv2.imread(val_image_paths[s])
            if desired_size is not None:
                image = cv2.resize(image, desired_size, interpolation=cv2.INTER_LINEAR)

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Forward pass on validation data
            predictor_tuned.set_image(image)

            masks_tuned, _, _ = predictor_tuned.predict(
                point_coords=None,
                box=None,
                multimask_output=False,
            )

            gt_mask_resized = torch.from_numpy(np.resize(ground_truth_masksv[s], (1, 1, ground_truth_masksv[s].shape[0], ground_truth_masksv[s].shape[1]))).to(device)
            gt_mask_resized = gt_mask_resized > 0.5
            gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
            masks_tuned1 = torch.as_tensor(masks_tuned > 0, dtype=torch.float32)
            new_tensor = masks_tuned1.unsqueeze(0).to(device)

            # Calculate validation loss
            val_loss += loss_fn(new_tensor, gt_binary_mask).item()

            # Calculate accuracy for validation data
            val_accuracy += calculate_accuracy(new_tensor, gt_binary_mask)
            num_val_examples += 1

    # Calculate mean validation loss for the current epoch
    val_loss /= num_val_examples
    val_losses.append(val_loss)
    print(f'Mean validation loss: {val_loss}')

    # Calculate mean validation accuracy for the current epoch
    mean_val_accuracy = val_accuracy / num_val_examples
    val_acc.append(mean_val_accuracy)
    print(f'Mean validation accuracy: {mean_val_accuracy}')

    # Save the model checkpoint if the validation accuracy improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        models_path = '/kaggle/working'  # Save in the working directory
        torch.save(sam_model.state_dict(), os.path.join(models_path, 'SAM5122weights_ViTB.pth'))

    # Clear GPU cache after each epoch
    torch.cuda.empty_cache()

In [None]:
"""Testing fine-tuned SAM model"""
# Test data paths
test_image_path = "/kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Test/Images"
test_label_path = "/kaggle/input/lufi-riversnap/LuFI-RiverSnap.v1/Test/Labels"

test_total_images = len(os.listdir(test_image_path))
test_image_paths = sorted(glob(test_image_path + "/*.jpg"))
print(f"Total Number of Test Images: {test_total_images}")

test_total_labels = len(os.listdir(test_label_path))
test_label_paths = sorted(glob(test_label_path + "/*.png"))
print(f"Total Number of Test Labels: {test_total_labels}")

test_image_paths = test_image_paths[0:test_total_images]
test_label_paths = test_label_paths[0:test_total_labels]

In [11]:
"""Ground truth masks for testing"""
ground_truth_test_masks = {}
for k in range(len(test_image_paths)):
    gt_grayscale = cv2.imread(test_label_paths[k], cv2.IMREAD_GRAYSCALE)
    ground_truth_test = (gt_grayscale > 0).astype(np.float32)
    if desired_size is not None:
        ground_truth_test = cv2.resize(ground_truth_test, desired_size, interpolation=cv2.INTER_NEAREST)
    ground_truth_test_masks[k] = ground_truth_test

In [12]:
"""Prediction using Fine-tuned model"""
masks_tuned_list = {}
images_tuned_list = {}
for k in range(len(test_image_paths)):
    # Load the image and convert color space
    image = cv2.cvtColor(cv2.imread(test_image_paths[k]), cv2.COLOR_BGR2RGB)
    if desired_size is not None:
        image = cv2.resize(image, desired_size, interpolation=cv2.INTER_LINEAR)

    predictor_tuned.set_image(image)

    # Perform prediction using predictor_tuned object
    masks_tuned, _, _ = predictor_tuned.predict(
        point_coords=None,
        box=None,
        multimask_output=False,
    )

    # Get the first mask from the predictions
    kk = masks_tuned[0, :, :]
    binary_mask = (kk > 0).astype(np.float32)
    images_tuned_list[k] = image
    masks_tuned_list[k] = binary_mask

In [None]:
"""Plot results on all of the Test data"""
import numpy as np
import matplotlib.pyplot as plt

# Example plotting code (adjust as needed)
for idx in range(5):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(images_tuned_list[idx])
    plt.axis('off')
    plt.title('Original Image')

    plt.subplot(1, 2, 2)
    plt.imshow(images_tuned_list[idx])
    plt.imshow(masks_tuned_list[idx], alpha=0.5)
    plt.axis('off')
    plt.title('Predicted Mask Overlay')

    plt.show()

In [None]:
# Assuming the model is already trained and the predictor_tuned is available
# Add this code after your existing code

import os
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

# Provide the path to your new images
new_images_path = "/kaggle/input/flow-img-sample/images"  # Replace 'your-new-images' with your dataset name

# Get all image paths
new_image_paths = sorted(glob(os.path.join(new_images_path, "*.*")))  # Adjust the pattern if needed

# Create directories to save the predicted masks if they don't exist
output_masks_path = '/kaggle/working/flow-img/predicted_masks'
if not os.path.exists(output_masks_path):
    os.makedirs(output_masks_path)

# Set desired_size if needed
# If you trained with desired_size = None, set this to None
desired_size = (512, 512)  # Or None if you used original image sizes during training

# Loop over the new images and generate predictions
for idx, image_path in enumerate(new_image_paths):
    # Load and preprocess the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Warning: Unable to read image at {image_path}")
        continue
    if desired_size is not None:
        image = cv2.resize(image, desired_size, interpolation=cv2.INTER_LINEAR)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Set the image for the predictor
    predictor_tuned.set_image(image_rgb)

    # Perform prediction
    masks, scores, logits = predictor_tuned.predict(
        point_coords=None,
        box=None,
        multimask_output=False,
    )

    # Get the first mask
    mask = masks[0]
    binary_mask = (mask > 0).astype(np.uint8)

    # Save the predicted mask to disk
    # We'll save the mask as a PNG image where the mask is white (255) and background is black (0)
    mask_filename = os.path.basename(image_path)
    mask_filename = os.path.splitext(mask_filename)[0] + '_mask.png'
    mask_save_path = os.path.join(output_masks_path, mask_filename)
    cv2.imwrite(mask_save_path, binary_mask * 255)  # Multiply by 255 to convert binary mask to 0-255 range

    # Optionally display the image and mask
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image_rgb)
    plt.axis('off')
    plt.title('Original Image')

    plt.subplot(1, 2, 2)
    plt.imshow(image_rgb)
    plt.imshow(binary_mask, alpha=0.5, cmap='jet')  # You can change the colormap if you like
    plt.axis('off')
    plt.title('Predicted Mask Overlay')

    plt.show()

    print(f"Processed image {idx + 1}/{len(new_image_paths)}: {image_path}")
    print(f"Predicted mask saved to: {mask_save_path}")
