# Import the YOLOv8 pretrained model

- The model is pretrained (in another notebook)  using a Roboflow dataset version on OM images. 

In [None]:
from ultralytics import YOLO
from matplotlib import pyplot as plt
from PIL import Image
PYTORCH_NO_CUDA_MEMORY_CACHING=1

from pathlib import Path
import matplotlib.pyplot as plt
import cv2
import torch
from torch import cuda
import os
import numpy as np
import random
from PIL import Image
import matplotlib.colors as mcolors
import numpy.ma as ma
import json
np.set_printoptions(precision=15)

torch.backends.cudnn.deterministic = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from typing import Any, Dict, Generator,List
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
import dataset_utils
reload(dataset_utils)
from dataset_utils import *

import predictor_utils
reload(predictor_utils)
from predictor_utils import *

import loss
reload(loss)
from loss import *

import torch.autograd.profiler as profiler
device_id = 3
torch.cuda.set_device(device_id) # ❗️❗️❗️

## Dataset (YOLOv8 format)

In [None]:
yolo_dataset_path = './xmm_om_images_512_no_stars-4-YOLO/'

In [None]:
import yaml
with open(yolo_dataset_path+"data.yaml", 'r') as stream:
    yam_data = yaml.safe_load(stream) # dict with keys 'names', 'nc', 'roboflow', 'test', 'train', 'val'
yam_data['names']

classes = {i:name for i, name in enumerate(yam_data['names'])}
train_path = yam_data['train']
val_path = yam_data['val']
test_path = yam_data['test']
print(classes)

In [None]:
# get masks from dataset (in YOLOv8 format) given an image file

def get_label_file_path(dataset_path, image_location):
    dataset_path = '/'.join(dataset_path.split('/')[:-2])+'/'+'labels'+'/'
    label_file_path = os.path.join(dataset_path, image_location)
    label_loc = '.'.join(image_location.split('.')[:-1]) + '.txt'
    label_file_path = dataset_path+label_loc
    return label_file_path

def read_annotations(label_file_path):
    annotations = []
    with open(label_file_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            class_id = int(parts[0])
            segmentation_points = [float(p) for p in parts[1:]]
            annotations.append({
                'class_id': class_id,
                'segmentation_points': segmentation_points
            })
    return annotations

def get_masks_from_image(yolo_dataset_path, image_location):
    label_file_path = get_label_file_path(yolo_dataset_path, image_location)
    annotations = read_annotations(label_file_path)
    masks = [create_mask_0_1(annot['segmentation_points'], (512, 512)) for annot in annotations]
    return masks


**hyperparameters docs: https://docs.ultralytics.com/usage/cfg/#train**

In [None]:
def show_masks(masks, ax, random_color=False):
    for mask in masks:
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
                color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)

yolov8_pretrained_model = YOLO('./yolov8-segm-fine-tuning/200_epochs-2/weights/best.pt');
yolov8_pretrained_model.to(f'cuda:{device_id}');
# yolov8_pretrained_model.eval();

# Couple YOLO bboxes with SAM

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    ax = plt.gca()
    ax.set_autoscale_on(False)
    img = np.ones((anns.shape[1], anns.shape[2], 4))
    img[:,:,3] = 0
    for ann in range(anns.shape[0]):
        m = anns[ann].bool()
        m=m.cpu().numpy()
        color_mask = np.concatenate([np.random.random(3), [1]])
        img[m] = color_mask
    ax.imshow(img)

def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
    assert len(args) > 0 and all(
        len(a) == len(args[0]) for a in args
    ), "Batched iteration must have inputs of all the same size."
    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
    for b in range(n_batches):
        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]

**load SAM model**

In [None]:
import sys
import PIL
from PIL import Image

sys.path.append('/workspace/raid/OM_DeepLearning/MobileSAM-fine-tuning/')
from ft_mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/MobileSAM-fine-tuning/weights/mobile_sam.pt"
device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"
print("device:", device)

mobile_sam_model = sam_model_registry["vit_t"](checkpoint=mobile_sam_checkpoint)
mobile_sam_model.to(device)
predictor = SamPredictor(mobile_sam_model)

In [None]:
train_dir = yolo_dataset_path+'train/images/'
valid_dir = yolo_dataset_path+'valid/images/'
train_image_files = os.listdir(train_dir)
valid_image_files = os.listdir(valid_dir)

In [None]:
import glob
from roboflow import Roboflow

def export_image_det_to_Roboflow(input_dir, filename, masks, obj_results):
    class_names = obj_results[0].names
    class_labels = obj_results[0].boxes.data[:, -1].int().tolist()
    
    objects = []
    for i in range(len(masks)):
        # masks[i]: [ 1, H, W]
        mask_np = masks[i].detach().cpu().numpy()
        polygon = binary_image_to_polygon(mask_np[0])
        print(polygon[0].shape)
        bbox = mask_to_bbox(mask_np)
        if class_names[class_labels[i]] != 'star' and class_names[class_labels[i]] != 'other': # ignore stars and 'other' label
            objects.append({
                'name': class_names[class_labels[i]],
                'bbox': bbox,
                'segmentations': polygon[0]
            })
    if len(objects)>0:
        create_annotation_SAM(filename=filename, width=512, height=512, depth=3, objects=objects) # generating xml file for VOC format
        image_path = input_dir+filename
        annotation_filename = filename.replace(".png", ".xml")
        upload_project.upload(image_path, annotation_filename, overwrite=False)
        os.remove(annotation_filename)
    else:
        print("No objects after label filtering.")

In [None]:
for name, param in mobile_sam_model.named_parameters():
    params_to_train = ['mask_tokens', 'output_upscaling', 'output_hypernetworks_mlps', 'iou_prediction_head']
    if 'mask_decoder' in name and any(s in name for s in params_to_train):
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
def check_requires_grad(model, show=True):
    for name, param in model.named_parameters():
        if param.requires_grad and show:
            print("✅ Param", name, " requires grad.")
        elif param.requires_grad == False:
            print("❌ Param", name, " doesn't require grad.")

print(f"🚀 The model has {sum(p.numel() for p in mobile_sam_model.parameters() if p.requires_grad)} trainable parameters.\n")
check_requires_grad(mobile_sam_model)

In [None]:
import time
start_time = time.time()
import torch.nn.functional as F

import loss
reload(loss)
from loss import *
from importlib import reload
import astronomy_utils, predictor_utils, voc_annotate_and_Roboflow_export
reload(astronomy_utils)
reload(predictor_utils)
reload(voc_annotate_and_Roboflow_export)

from predictor_utils import *
from astronomy_utils import *
from voc_annotate_and_Roboflow_export import * 

import tqdm
from tqdm import tqdm

batch_size = 8
train_num_batches = len(train_image_files) // batch_size
valid_num_batches = len(valid_image_files) // batch_size

lr=3e-4
wd=0.0
parameters_to_optimize = [param for param in mobile_sam_model.mask_decoder.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(parameters_to_optimize, lr=lr, weight_decay=wd) #betas=(0.9, 0.999))

In [None]:
from torch.nn.functional import threshold, normalize

def run_epoch(phase, image_files, images_dir, num_batches, model_, optimizer=None):
    assert phase in ['train', 'val'], "Phase must be 'train' or 'val'"
    
    if phase == 'train':
        model_.train()  
    else:
        model_.eval() 

    epoch_sam_loss = []
    epoch_yolo_loss = []

    for batch_idx in tqdm(range(num_batches), desc=f"{phase.capitalize()} Batch"):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        batch_files = image_files[start_idx:end_idx]

        batch_losses_sam = []
        batch_losses_yolo = []

        for image_name in batch_files:
            image_path = images_dir + image_name
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            with torch.no_grad():
                obj_results = yolov8_pretrained_model.predict(image_path, verbose=False, conf=0.2) 
                predictor.set_image(image)
                
            gt_masks = get_masks_from_image(images_dir, image_name)  
            if len(obj_results[0]) == 0 or len(gt_masks) == 0:
                continue
      
            input_boxes1 = obj_results[0].boxes.xyxy
            input_boxes = input_boxes1.cpu().numpy()
            input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size)
            input_boxes = torch.from_numpy(input_boxes).to(device)
            sam_mask, yolo_masks = [], []
            with torch.no_grad():
                image_embedding=predictor.features
                prompt_embedding=model_.prompt_encoder.get_dense_pe()
                
            non_resized_masks = obj_results[0].masks.data.cpu().numpy()
            
            for i in range(len(non_resized_masks)):
                    yolo_masks.append(cv2.resize(non_resized_masks[i], image.shape[:2][::-1], interpolation=cv2.INTER_LINEAR)) 

            for (boxes,) in batch_iterator(320, input_boxes): 
                with torch.no_grad():
                    image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
                    prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
                    sparse_embeddings, dense_embeddings = mobile_sam_model.prompt_encoder(
                        points=None,
                        boxes=boxes,
                        masks=None,)
                    
                if phase == 'val':
                    with torch.no_grad():
                        low_res_masks, _ = model_.mask_decoder(
                            image_embeddings=image_embedding,
                            image_pe=prompt_embedding,
                            sparse_prompt_embeddings=sparse_embeddings,
                            dense_prompt_embeddings=dense_embeddings,
                            multimask_output=False,
                        )
                else:
                    low_res_masks, _ = model_.mask_decoder(
                            image_embeddings=image_embedding,
                            image_pe=prompt_embedding,
                            sparse_prompt_embeddings=sparse_embeddings,
                            dense_prompt_embeddings=dense_embeddings,
                            multimask_output=False,
                        )
                low_res_masks=predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size)
                # threshold_mask = torch.sigmoid(low_res_masks - model_.mask_threshold)
                threshold_masks = normalize(threshold(low_res_masks, 0.0, 0)).to(device)
                # plt.imshow(threshold_masks[0][0].detach().cpu().numpy())
                # plt.show()
                # plt.close()
                sam_mask_pre = (low_res_masks > model_.mask_threshold)*1.0
                sam_mask.append(sam_mask_pre.squeeze(1))

                # reshape gt_masks to same shape as predicted masks
                gt_masks_tensor = torch.stack([torch.from_numpy(mask).unsqueeze(0) for mask in gt_masks], dim=0).to(device)
                yolo_masks_tensor = torch.stack([torch.from_numpy(mask).unsqueeze(0) for mask in yolo_masks], dim=0).to(device)
                segm_loss_sam = segm_loss_match(threshold_masks, gt_masks_tensor)
                segm_loss_yolo = segm_loss_match(yolo_masks_tensor, gt_masks_tensor)
                batch_losses_sam.append(segm_loss_sam)
                batch_losses_yolo.append(segm_loss_yolo)
                del sparse_embeddings, dense_embeddings, low_res_masks, gt_masks, 
                del gt_masks_tensor, yolo_masks_tensor, segm_loss_sam, segm_loss_yolo
                torch.cuda.empty_cache()

                if phase == 'val':
                    fig, axes = plt.subplots(1, 3, figsize=(18, 6)) 
                    
                    # Plot 1: YOLO Masks
                    axes[0].imshow(image)
                    axes[0].set_title('YOLOv8n predicted Masks')
                    show_masks(yolo_masks, axes[0], random_color=True)
                    
                    # Plot 2: Bounding Boxes
                    image1 = cv2.resize(image, (1024, 1024))
                    for bbox in boxes:
                        x1, y1, x2, y2 = bbox.detach().cpu().numpy()
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        cv2.rectangle(image1, (x1, y1), (x2, y2), (0, 255, 0), 2) 
                    image1_rgb = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
                    axes[1].imshow(image1_rgb)
                    axes[1].set_title('YOLOv8n predicted Bboxes')
                    
                    # Plot 3: SAM Masks
                    sam_masks_numpy = sam_mask[0].detach().cpu().numpy()
                    axes[2].imshow(image)
                    show_masks(sam_masks_numpy, axes[2], random_color=True)
                    axes[2].set_title('MobileSAM predicted masks')
                    plt.tight_layout() 
                    plt.savefig(f'./plots/combined_plots{i}.png')
                    plt.show()

        mean_loss_sam = torch.mean(torch.stack(batch_losses_sam))
        mean_loss_yolo = torch.mean(torch.stack(batch_losses_yolo))
        epoch_sam_loss.append(mean_loss_sam.item())
        epoch_yolo_loss.append(mean_loss_yolo.item())

        if phase == 'train':
            optimizer.zero_grad()
            mean_loss_sam.backward()
            optimizer.step()

        # print(f"{phase.capitalize()} Batch {batch_idx + 1}/{num_batches}, Segmentation loss SAM: {mean_loss_sam.item()}, YOLO: {mean_loss_yolo.item()}")

    print(f'Epoch {epoch}, {phase.capitalize()} Segmentation loss SAM: {np.mean(epoch_sam_loss)}. YOLO: {np.mean(epoch_yolo_loss)}')
    return np.mean(epoch_sam_loss), np.mean(epoch_yolo_loss)
    
num_epochs = 1
epoch_sam_loss_train, epoch_sam_loss_val, epoch_yolo_loss_train, epoch_yolo_loss_val = [], [], [], []

for epoch in range(num_epochs):
    epoch_sam_loss_train, epoch_yolo_loss_train = run_epoch('train', train_image_files, train_dir, train_num_batches, mobile_sam_model, optimizer)
    epoch_sam_loss_val, epoch_yolo_loss_val = run_epoch('val', valid_image_files, valid_dir, valid_num_batches, mobile_sam_model)

In [None]:
# torch.save(mobile_sam_model.state_dict(), f'sam_checkpoint_with_yolo.pth')

In [None]:
# Optional Roboflow export in VOC format given filenames
export_to_Roboflow = False

if export_to_Roboflow:
    new_images_dir = '../XMM_OM_dataset/zscaled_512_stretched/'
    new_image_files =  os.listdir(new_images_dir)
    mobile_sam_model.eval()
    
    with torch.no_grad(): 
        for image_name in new_image_files[600:610]:
                print('*****', new_images_dir, image_name)
                image = cv2.imread(new_images_dir + image_name)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                obj_results = yolov8_pretrained_model.predict(new_images_dir + image_name, conf=0.2)  
                predictor.set_image(image)
            
                if len(obj_results[0]) == 0:
                    print(f"No masks for {image_name}.")
                    plt.imshow(image)
                    plt.show()
                    plt.close()
                    continue
        
                input_boxes1 = obj_results[0].boxes.xyxy
                input_boxes = input_boxes1.cpu().numpy()
                input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size)
                input_boxes = torch.from_numpy(input_boxes).to(device)
                sam_mask, yolo_masks = [], []
                image_embedding=predictor.features
                prompt_embedding=mobile_sam_model.prompt_encoder.get_dense_pe()
                non_resized_masks = obj_results[0].masks.data.cpu().numpy()
                for i in range(len(non_resized_masks)):
                        yolo_masks.append(cv2.resize(non_resized_masks[i], image.shape[:2][::-1], interpolation=cv2.INTER_LINEAR)) 
            
                for (boxes,) in batch_iterator(320, input_boxes): 
                    with torch.no_grad():
                        image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
                        prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
                        sparse_embeddings, dense_embeddings = mobile_sam_model.prompt_encoder(
                            points=None,
                            boxes=boxes,
                            masks=None,)
                        low_res_masks, _ = mobile_sam_model.mask_decoder(
                            image_embeddings=image_embedding,
                            image_pe=prompt_embedding,
                            sparse_prompt_embeddings=sparse_embeddings,
                            dense_prompt_embeddings=dense_embeddings,
                            multimask_output=False,
                        )
                        low_res_masks=predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size)
                        sam_mask_pre = (low_res_masks > mobile_sam_model.mask_threshold)*1.0
                        sam_mask.append(sam_mask_pre.squeeze(1))
        
                        yolo_masks_tensor = torch.stack([torch.from_numpy(mask).unsqueeze(0) for mask in yolo_masks], dim=0)
                        export_image_det_to_Roboflow(new_images_dir, image_name, sam_mask_pre, obj_results)

In [None]:
def export_image_det_to_Roboflow(input_dir, filename, masks, obj_results):
    class_names = obj_results[0].names
    class_labels = obj_results[0].boxes.data[:, -1].int().tolist()
    
    objects = []
    print(masks.shape)
    for i in range(len(masks)):
        # masks[i]: [ 1, H, W]
        mask_np = masks[i].detach().cpu().numpy()
        plt.imshow(mask_np[0])
        plt.show()
        plt.close()
        polygon = binary_image_to_polygon(mask_np[0])
        print(len(polygon))
        bbox = mask_to_bbox(mask_np)
        if class_names[class_labels[i]] != 'star' and class_names[class_labels[i]] != 'other': # ignore stars and 'other' label
            objects.append({
                'name': class_names[class_labels[i]],
                'bbox': bbox,
                'segmentations': polygon[0]
            })
    if len(objects)>0:
        create_annotation_SAM(filename=filename, width=512, height=512, depth=3, objects=objects) # generating xml file for VOC format
        image_path = input_dir+filename
        annotation_filename = filename.replace(".png", ".xml")
        upload_project.upload(image_path, annotation_filename, overwrite=False)
        os.remove(annotation_filename)
    else:
        print("No objects after label filtering.")

In [None]:
plt.plot(list(range(len(epoch_sam_loss_train))), epoch_sam_loss_train)
plt.plot(list(range(len(epoch_sam_loss_val))), epoch_sam_loss_val)
plt.title('SAM vs. YOLO mask loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
# plt.savefig('sam_vs_yolo_masks_loss.png')
plt.show()
plt.close()

In [None]:
# import json 
# import os
# import tarfile

# second_directory_path = '../XMM_OM_dataset/zscaled_512_stretched/'
# archive_name = 'imgs_512_50.tar.gz'

# files_to_archive = []
# for file in os.listdir(second_directory_path):
#         files_to_archive.append(os.path.join(second_directory_path, file))

# files_to_archive = files_to_archive[500:550]
# # Create a tar.gz archive of the filtered files
# with tarfile.open(archive_name, "w:gz") as tar:
#     for file_path in files_to_archive:
#         tar.add(file_path, arcname=os.path.basename(file_path))

# print(f"Archive {archive_name} created with {len(files_to_archive)} files.")