### Begin
Select kernel `segment-anything`

Run this cell to import necessary packages and initialise SAM model and mask generator

In [90]:
# Import the necessary libraries
import numpy as np 
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
from PIL import Image
import os
import shutil
from sklearn.model_selection import train_test_split
from IPython.display import display
from IPython.display import clear_output
import yaml
import glob
import pickle
from scipy import ndimage

from sam2yolo_functions import *

from jupyter_bbox_widget import BBoxWidget
import ipywidgets as widgets
import os
import json
import base64

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # since mps does not support all the operations, we need to enable fallback to cpu for some operations

sam_checkpoint = "../models/sam_vit_b_01ec64.pth" # Path to the checkpoint file
model_type = 'vit_b' # Model type

device = "cuda" if torch.cuda.is_available() else "mps" # Use GPU if available, otherwise use CPU

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) # Load the model
sam.to(device=device) # Move the model to the device

mask_generator = SamAutomaticMaskGenerator(sam) # Create a mask generator
mask_predictor = SamPredictor(sam) # Create a mask predictor

### Functions

In [97]:
# this function returns the area of a mask (number of pixels)
def get_area(mask):
    area = 0
    for row in mask:
        for col in row:
            if col:
                area += 1
    return area


# this function returns the index of the mask with the largest area
def get_max_area(masks):
    max_area = 0
    idx = 0
    for i in range(len(masks)):
        if(get_area(masks[i]) > max_area):
            max_area = get_area(masks[i])
            idx = i
    return idx


# this function draws the mask on the image
def overlay_mask_on_image(image, coord):
    # Ensure the mask is in 8-bit format
    image = cv2.drawContours(image, coord, -1, (0, 255, 0), 2)
    return image


# process the mask to remove the holes in the mask and return the largest region
def process_mask(mask):
    # Identify each separate region in the mask.
    labeled_mask, num_labels = ndimage.label(mask)
    
    # Count the size of each region.
    region_sizes = np.bincount(labeled_mask.flatten())
    
    # The first region (index 0) is the background, which we don't want to consider.
    region_sizes[0] = 0
    
    # Find the largest region.
    largest_region = np.argmax(region_sizes)
    
    # Create a mask that only includes the largest region.
    largest_mask = (labeled_mask == largest_region)
    
    # Fill in the holes in this region.
    filled_mask = ndimage.morphology.binary_fill_holes(largest_mask)
    
    return filled_mask



# extract the coordinates of the segment from SAM and store them in a list
def extract_segment(mask):
    binary_mask = np.array(mask) # get the segmentation of the mask and convert it to a numpy array
    binary_mask = (binary_mask * 255).astype(np.uint8) # convert the mask to a binary mask

    binary_mask = (process_mask(binary_mask) * 255).astype(np.uint8)


    contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # find contours from the binary image
    polygon_coords = [] # stores the coordinates of the vertices of the polygon
    
    for contour in contours:
        epsilon = 0.01 * cv2.arcLength(contour, True) # approximate contour with accuracy proportional to the contour perimeter
        approx = cv2.approxPolyDP(contour, epsilon, True) # approximate contour with the Douglas-Peucker algorithm

        polygon_coords.append(approx) # add the coordinates of the vertices of the polygon to the list
    return polygon_coords


# this function is to encode the image for rendering by the widget
def encode_image(filepath):
    with open(filepath, 'rb') as f:
        image_bytes = f.read()
    encoded = str(base64.b64encode(image_bytes), 'utf-8')
    return "data:image/jpg;base64,"+encoded


# this function is to encode the image with mask for rendering by the widget, whilst also returning the polygon coordinates of the mask and the original height and width of the image
def encode_image_mask(filepath, boxes):
    # read in the image file
    image = cv2.imread(filepath)
    h, w = image.shape[:2]
    poly_coords_list = []
    for box in boxes:
        # convert the bbox to format expected by mask_predictor
        box = np.array([
            box['x'],
            box['y'],
            box['x'] + box['width'],
            box['y'] + box['height']
        ])

        mask_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        # predict the masks
        masks, scores, logits = mask_predictor.predict(
            box = box,
            multimask_output = True
        )

        # get the index of the mask with the largest area
        idx = get_max_area(masks)
        mask = masks[idx]

        # convert the pixel array format of the mask to a polygon coordinates
        polygon_coords = extract_segment(mask)
        poly_coords_list.append(polygon_coords)
        # overlay the mask on the image
        image = overlay_mask_on_image(image, polygon_coords)

    # convert the image with mask back to bytes
    is_success, im_buf_arr = cv2.imencode(".jpg", image)
    byte_im = im_buf_arr.tobytes()

    # encode to Base64 for rendering on the web
    encoded = base64.b64encode(byte_im).decode('utf-8')
    
    return "data:image/jpg;base64,"+encoded, poly_coords_list, h, w



# this function initialises the widget
def init_widget(path, images, annotations, data, classes, cur_img_idx):
    # initialise the bbox widget
    w_bbox = BBoxWidget(
        image = encode_image(os.path.join(path, images[cur_img_idx])),
        classes=classes
    )

    # initialise the buttons
    button_next = widgets.Button(description="Next")
    button_prev = widgets.Button(description="Previous")
    # combine the buttons and the bbox widget into a container
    w_container = widgets.VBox([
        w_bbox,
        button_prev,
        button_next
    ])

    # function that updates the image when the buttons are clicked so that the next or previous image is shown
    def update_image(change):
        annotations[images[cur_img_idx]] = w_bbox.bboxes
        if change.description == "Next":
            cur_img_idx = (cur_img_idx + 1) % len(images)
        elif change.description == "Previous":
            cur_img_idx = (cur_img_idx - 1) % len(images)
            
        # check if annotations[cur_img_idx] exists
        if images[cur_img_idx] in annotations:
            w_bbox.bboxes = annotations[images[cur_img_idx]]
        else:
            w_bbox.bboxes = []
        w_bbox.image = encode_image(os.path.join(path, images[cur_img_idx]))
    # add the update_image function to the buttons
    button_next.on_click(update_image)
    button_prev.on_click(update_image)

    # defines what happens when the submit button is clicked, which is to run SAM with the bounding boxes specified by the user and to display the result and save the result to the data dictionary for conversion to YOLO format later
    @w_bbox.on_submit
    def submit():
        if len(w_bbox.bboxes) > 0:
            w_bbox.image, poly_coords_list, h, w = encode_image_mask(os.path.join(path, images[cur_img_idx]), w_bbox.bboxes)
            
            data[images[cur_img_idx]] = []
            
            i = 0
            for polygon_coords in poly_coords_list:
                label_id = [classes.index(w_bbox.bboxes[i]['label'])]
                flat_segment_coords = numpy_to_list(polygon_coords)

                for j in range(len(flat_segment_coords)): # normalise the coordinates of the segment
                    if j%2 == 0:
                        flat_segment_coords[j] = flat_segment_coords[j]/w
                    else:
                        flat_segment_coords[j] = flat_segment_coords[j]/h
                
                data[images[cur_img_idx]].append(label_id + flat_segment_coords)
                i += 1
# this is used to extract the frames from a video file and output into specified directory as jpg images
def extract_frames(video_path, output_dir, frame_interval=300):
    filename = os.path.splitext(os.path.basename(video_path))[0]
    os.makedirs(output_dir, exist_ok=True)

    video = cv2.VideoCapture(video_path)

    if not video.isOpened():
        print(f"Could not open video file: {video_path}")
        return

    fps = video.get(cv2.CAP_PROP_FPS)

    if fps >= 50:
        frame_interval *= 2
    
    frame_index = 0

    while True:
        success, frame = video.read()
        if not success: 
            break

        if frame_index % frame_interval == 0:
            output_path = os.path.join(output_dir, f"{filename}_frame_{frame_index}.png")
            cv2.imwrite(output_path, frame)

        frame_index += 1

    video.release()


# this is used to split the images into training and validation sets into the dataset folder
def split_data(src_directory, out_directory, test_size=0.2):
    os.makedirs(out_directory, exist_ok=True) # create the dataset directory

    os.makedirs(os.path.join(out_directory, 'train', 'images'), exist_ok=False) # create the train images directory
    os.makedirs(os.path.join(out_directory, 'valid', 'images'), exist_ok=False) # create the valid images directory
    os.makedirs(os.path.join(out_directory, 'train', 'labels'), exist_ok=False) # create the train labels directory
    os.makedirs(os.path.join(out_directory, 'valid', 'labels'), exist_ok=False) # create the valid labels directory

    all_files = os.listdir(src_directory) # get all the files in the source directory
    train_files, valid_files = train_test_split(all_files, test_size=test_size, random_state=42) # split the files into training and validation sets

    # Move files into the train and valid directories
    for file_name in train_files:
        shutil.copy(os.path.join(src_directory, file_name), os.path.join(out_directory, 'train', 'images', file_name))
    for file_name in valid_files:
        shutil.copy(os.path.join(src_directory, file_name), os.path.join(out_directory, 'valid', 'images', file_name))


# this is used to load the images from the specified directory and output the data in the format required for YOLO training
def load_images_from_video(img_path, vid_path, ds_path, frame_interval):
    for filename in os.listdir(vid_path):
        extract_frames(os.path.join(vid_path, filename), img_path, frame_interval)
    
    split_data(img_path, ds_path) # splits into train, valid sets and moves into ds_path folder


# this is used to format the YOLO data into appropriate txt files for use in YOLO training
def output_to_txt(data_dict, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for image_name, content in data_dict.items():
        # remove the file extension from the image file name
        base_name = os.path.splitext(image_name)[0]
        # create the output file name by adding .txt extension
        output_file_name = base_name + '.txt'
        output_file_path = os.path.join(output_dir, output_file_name)
        
        with open(output_file_path, 'w') as f:
            for line in content:
                line_str = [str(item) for item in line]  # Convert all items to strings
                f.write(','.join(line_str))  # Join all items in the line with ',' as separator
                f.write('\n')  # Write a new line after each line
    


# this is used to move the data from 3-dataset to dataset for use in YOLO training
def move_files(src_img, src_label, dest_img, dest_label):
    # Check if destination directories exist, if not, create them
    os.makedirs(dest_img, exist_ok=True)
    os.makedirs(dest_label, exist_ok=True)

    all_images = os.listdir(src_img)
    all_labels = os.listdir(src_label)
    
    for image in all_images:
        shutil.copy(os.path.join(src_img, image), os.path.join(dest_img, image))

    
    for label in all_labels:
        shutil.copy(os.path.join(src_label, label), os.path.join(dest_label, label))


# this is used to move the videos from 1-source to dataset/sources
def move_source_vid(src_vid, dest_vid):
    # Check if destination directories exist, if not, create them
    os.makedirs(dest_vid, exist_ok=True)

    all_videos = os.listdir(src_vid)
    
    for video in all_videos:
        shutil.copy(os.path.join(src_vid, video), os.path.join(dest_vid, video))


# clears directory
def clear_directory(directory):
    # Be careful with this function! It deletes all files and subdirectories in the specified directory
    shutil.rmtree(directory)
    os.mkdir(directory)


# this is used to create the data.yaml (necessary for YOLO training) file in the dataset folder 
def create_yaml(labels, path, output_path, train_path="train/images", val_path="valid/images", ):
    # copies labels with the last two elemens removed
    my_dict = {i: labels[i] for i in range(len(labels))}

    data = {
        'names': my_dict,
        'path': path,
        'train': train_path,
        'val': val_path
    }

    with open(output_path, 'w') as outfile: # write the data to the yaml file
        yaml.dump(data, outfile, default_flow_style=False)

### Generate dataset

In [None]:
vid_path = '../1-source/'
img_path = '../2-source-extracted/'
ds_path = '../3-dataset/'
frame_interval = 300 # specify the frame interval to extract from the video

# extract frames from videos in 1-source folder, extract them to 2-source-extracted folder, and split them into train and valid sets in 3-dataset folder
load_images_from_video(img_path, vid_path, ds_path, frame_interval)

### Annotate training set

In [106]:
path = '../3-dataset/train/images/' # path to training images
images = sorted(os.listdir(path))

annotations = {} # dictionary with key = image name, value = corresponding bbox
data = {} # dictionary with key = image name, value = list of list containing: [label_id, x1, y1, x2, y2, ...]
classes = ['plane', 'airtug'] # list of classes

init_widget(path, images, annotations, data, classes, cur_img_idx=0)

In [107]:
w_container # display the widget

VBox(children=(BBoxWidget(bboxes=[{'x': 312, 'y': 84, 'width': 89, 'height': 68, 'label': 'plane'}], classes=[…

In [None]:
output_to_txt(data, '../3-dataset/train/labels/') # output the data to txt files for YOLO training

### Annotate validation set

In [None]:
path = '../3-dataset/valid/images/' # path to validation images
images = sorted(os.listdir(path))

annotations = {} # dictionary with key = image name, value = corresponding bbox
data = {} # dictionary with key = image name, value = list of list containing: [label_id, x1, y1, x2, y2, ...]
classes = ['plane', 'airtug'] # list of classes

init_widget(path, images, annotations, data, classes, cur_img_idx=0)

In [None]:
w_container # display the widget

In [None]:
output_to_txt(data, '../3-dataset/valid/labels/') # output the data to txt files for YOLO training

### Clean up

In [None]:
move_files('../3-dataset/train/images/', '../3-dataset/train/labels/', '../dataset/train/images/', '../dataset/train/labels/')
move_files('../3-dataset/valid/images/', '../3-dataset/valid/labels/', '../dataset/valid/images/', '../dataset/valid/labels/') # moves the data collected in 3-dataset to the actual dataset folder for use in YOLO training
move_source_vid('../1-source/', '../dataset/source/') # moves the videos from 1-source to dataset/sources for documentation/backup

In [None]:
if(input("Are you sure you want to clear the source videos, extracted images, and dataset? (y/n): ")) == 'y':
    clear_directory('../1-source/')
    clear_directory('../2-source-extracted/')
    clear_directory('../3-dataset/')

### YOLO Training

In [None]:
abspath_ds = os.path.abspath('../dataset/') # get the absolute path of the dataset folder
output_path = '../dataset/data.yaml' # path to the data.yaml file

create_yaml(classes, abspath_ds, output_path) # create the data.yaml file in the dataset folder

In [None]:
# YOLO Training