In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import json
import os
import pandas as pd

In [None]:
# Load the TissueNet dataset

test_data = np.load('/home/yiming/WorkSpace/datasets/tissuenet_1.0/tissuenet_v1.0_test.npz')
X_test = test_data['X']  # Original images
y_true = test_data['y'][...,1:2]


In [None]:
y_true.shape

In [None]:
# Basic count of number of images and annotations
total_image_count = y_true.shape[0]
ground_truth_masks_count = 0
for i in range(total_image_count):
    ground_truth_masks_count += len(np.unique(y_true[i]))



print('Toal nuclei masks in ground truth:' , ground_truth_masks_count)
# ground_truth_masks_count = sum(len(detection) for detection in y_true)
# print(ground_truth_masks_count)

## Example image

In [None]:
# summarize the detect results in a json variable

with open('/home/yiming/WorkSpace/CenterSAM/demo_results/TissueNet_model_on_TissueNet_results.json', 'r') as f:
    detec_results = json.load(f)

json_image_names = list(detec_results.keys())


In [None]:
json_image_names[0]

In [None]:
threshold = 0.45
filtered_data = {}
masks_counter = 0

# Iterate over each key-value pair in the original dictionary
for k, v in detec_results.items():
    
    file_name = os.path.basename(k)

    # Filter values where the element is less than the threshold value
    filtered_v = [lst[:4] for lst in v if lst[-1] > threshold]


    # Add the result to the new dictionary
    filtered_data[file_name] = filtered_v
    
    masks_counter += np.shape(filtered_v)[0]

print('Toal nucleis masks count after filter', masks_counter)



# Visualazation
# for image in base_image_names:
#     image = cv2.imread(target_image_folder+str(image))
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#     plt.figure(figsize=(10,10))
#     plt.imshow(image)
#     plt.axis('on')
#     plt.show()

## Selecting objects with SAM

First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results.

In [None]:
# For single GPU
import sys
import time

import torch
import torch.nn as nn
from segment_anything import sam_model_registry, SamPredictor


# Load the pretrained VIT_H checkpoint
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

# Create SAM and send to the device
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Instantiating the SAM predictor
predictor = SamPredictor(sam)

In [None]:
# inference the test image according to their file name to make sure the right orider
predict_times = []
combined_masks = []
image_counter = 0
total_annotation_process = 0
start_time = time.time()
for image_path in json_image_names:

    
    image_counter += 1
    print(f"Processing image {image_counter} of {len(json_image_names)}")
    # Extract file name
    image_name = os.path.basename(image_path)
    # Read image and perform color space conversion
#     image = cv2.imread(target_image_folder + image_name)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Setting the input image for the predictor
    predictor.set_image(image)

    # Getting Input Boxes from Filtered Data Dictionary

    print(len(filtered_data[image_name]))
    total_annotation_process += len((filtered_data[image_name]))

    input_boxes = torch.tensor(filtered_data[image_name], device=predictor.device)

    # Send the input box to the default device
    input_boxes = input_boxes.to(device=device)

    # Predict the masks using SAM
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

    start_predict = time.time()  

    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )

    end_predict = time.time()  
    time_taken = end_predict - start_predict  
    print(f"Time taken for predictor.predict_torch: {time_taken} seconds")
    predict_times.append(time_taken) 

    
    
    print(masks.shape)
    # Create a combined masks array
    combined_mask = np.zeros((256, 256), dtype=np.uint16)


    for idx, mask in enumerate(masks):

        mask_bool = mask.squeeze().cpu().numpy().astype(bool)
        combined_mask[mask_bool] = idx + 1  # Assign a different integer label to each mask

    # Add a dimension to convert the list to a four-dimensional numpy array later on
    combined_mask = combined_mask[..., np.newaxis]

    # Add the merged mask to the list
    combined_masks.append(combined_mask)
    current_time = time.time()
    elapsed_time = current_time - start_time
    print('Speed:', total_annotation_process/elapsed_time)

# After all the images have been processed, the time list is converted into a DataFrame
df_predict_times = pd.DataFrame(predict_times, columns=['PredictTime'])

# Save DataFrame as CSV file
csv_file_path = '/home/yiming/WorkSpace/deepcell-tf/inference_time/SAM_stage_time.csv'
df_predict_times.to_csv(csv_file_path, index=False)

print(f"Predict times saved to {csv_file_path}")

combined_masks = np.stack(combined_masks)
print(np.shape(combined_masks))

Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction.

In [None]:
# Save final combined mask locally as npy file

save_path = "/home/yiming/WorkSpace/CenterSAM/demo_results/SAM_on_TissueNet_with_TissueNet_Model_combined_masks.npy"
np.save(save_path, combined_masks)

In [None]:
# Define function for visulization

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


In [None]:
# Improt json of original detection results

import json
with open('/home/yiming/WorkSpace/CenterSAM/demo_results/TissueNet_model_on_TissueNet_results.json', 'r') as f:
    detec_results = json.load(f)

images = list(detec_results.keys())


# target_image_folder = '/home/yiming/WorkSpace/datasets/tissuenet_1.0/rgb_images/test/'
# print(target_image_folder+images)

for image in images:
    print(image)

## Define the evaluation metric

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import find_boundaries

def match_pred_to_true(y_true, y_pred, iou_threshold=0.5):
    """Find for each predicted instance the real instance that best matches it.""""
    true_ids = list(np.unique(y_true))
    pred_ids = list(np.unique(y_pred))
    
    true_ids.remove(0)  # remove background
    pred_ids.remove(0)  # remove background
    
    best_matches = {}
    
    for pred_id in pred_ids:
        best_iou = 0
        best_true_id = None
        for true_id in true_ids:
            intersection = np.sum((y_pred == pred_id) & (y_true == true_id))
            union = np.sum((y_pred == pred_id) | (y_true == true_id))
            iou = intersection / union
            if iou > best_iou:
                best_iou = iou
                best_true_id = true_id
        
        if best_iou > iou_threshold:
            best_matches[pred_id] = best_true_id
            true_ids.remove(best_true_id)  # remove this true_id so it can't be matched again
    
    return best_matches

def compute_instance_metrics(y_true, y_pred):
    matches = match_pred_to_true(y_true, y_pred)
    TP = len(matches)
    FN = len(np.unique(y_true)) - 1 - TP  # subtract 1 for background
    FP = len(np.unique(y_pred)) - 1 - TP  # subtract 1 for background
    TN = 0  # No true negatives in instance segmentation
    
    jaccard = TP / (TP + FP + FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * (precision * recall) / (precision + recall)
    dice = 2 * TP / (2 * TP + FP + FN)
    ious = []
    for pred_id, true_id in matches.items():
        intersection = np.sum((y_pred == pred_id) & (y_true == true_id))
        union = np.sum((y_pred == pred_id) | (y_true == true_id))
        iou = intersection / union
        ious.append(iou)
    
    return recall, precision, jaccard, f1, dice, matches, ious

    
def visualize_comparison(X_orig, y_true, y_pred, index):
    fig, ax = plt.subplots(1, 2, figsize=(20, 6))
    
    ax[0].imshow(y_true, cmap='tab20b')
    ax[0].set_title('Ground Truth Segmentation')
    
    ax[1].imshow(y_pred, cmap='tab20b')
    ax[1].set_title('Prediction Segmentation')
    
    # Compute metrics
    recall, precision, jaccard, f1, dice, matches, _ = compute_instance_metrics(y_true, y_pred)  # 添加 _ 来接收多余的返回值
    metrics_text = f"recall: {recall:.5f}\nPrecision: {precision:.5f}\nJaccard: {jaccard:.5f}\nF1 Score: {f1:.5f}\nDice Score: {dice:.5f}"
    fig.text(0.5, 0.04, metrics_text, ha='center')
    
    plt.savefig(f"comparison_{index}.png")
    # plt.show()

def compute_seg(y_true, y_pred):
    """Compute the SEG score"""
    matches = match_pred_to_true(y_true, y_pred)
    total_iou = 0
    for pred_id, true_id in matches.items():
        intersection = np.sum((y_pred == pred_id) & (y_true == true_id))
        union = np.sum((y_pred == pred_id) | (y_true == true_id))
        iou = intersection / union
        total_iou += iou
        
    seg_score = total_iou / len(matches) if matches else 0
    return seg_score

def compute_aji(y_true, y_pred):
    """Compute the Aggregated Jaccard Index (AJI)"""
    true_ids = list(np.unique(y_true))
    pred_ids = list(np.unique(y_pred))
    
    true_ids.remove(0)  # remove background
    pred_ids.remove(0)  # remove background
    
    union_total = 0
    intersection_total = 0
    already_matched = []
    
    for pred_id in pred_ids:
        best_iou = 0
        best_true_id = None
        for true_id in true_ids:
            if true_id in already_matched:
                continue
            
            intersection = np.sum((y_pred == pred_id) & (y_true == true_id))
            union = np.sum((y_pred == pred_id) | (y_true == true_id))
            iou = intersection / union
            
            if iou > best_iou:
                best_iou = iou
                best_true_id = true_id

        if best_true_id is not None:
            intersection_total += np.sum((y_pred == pred_id) & (y_true == best_true_id))
            union_total += np.sum((y_pred == pred_id) | (y_true == best_true_id))
            already_matched.append(best_true_id)
    
    # Consider the unmatched true instances
    for true_id in true_ids:
        if true_id not in already_matched:
            union_total += np.sum(y_true == true_id)

    aji_score = intersection_total / union_total
    return aji_score



## Start Evaluating of CenterSAM on TissueNet Dataset

In [None]:
# Load the predict masks
CenterSAM_predict_result = np.load('/home/yiming/WorkSpace/CenterSAM/demo_results/SAM_on_TissueNet_with_TissueNet_Model_combined_masks.npy')

In [None]:
CenterSAM_predict_result.shape

In [None]:
# This section is for limited test with N images

X = 10

for i, (orig_img, true_img, pred_img) in enumerate(zip(X_test, y_true, CenterSAM_predict_result)):
    if i >= X:
        break
    
    print(f"Processing image {i + 1}/{X}...")
    visualize_comparison(orig_img, true_img.squeeze(), pred_img.squeeze(), i)


In [None]:
import pandas as pd
import numpy as np

# Calculate the Jaccard, precision, recall, F1 score, and dice score

# Initialize lists to store metrics for each image
jaccards = []
precisions = []
recalls = []
f1_scores = []
dices = []
ious = []
ajis = []  
segs = [] 

# Initialize a list to store results for each image
results = []

# Initialize a dataframe to store results
df = pd.DataFrame(columns=['Image Index', 'Recall', 'Precision', 'Jaccard', 'F1 Score', 'Dice Score', 'AJI', 'SEG'])

for i, (orig_img, true_img, pred_img) in enumerate(zip(X_test, y_true, CenterSAM_predict_result)):
    print(f"Processing image {i + 1}/{len(y_true)}...")
    visualize_comparison(orig_img, true_img.squeeze(), pred_img.squeeze(), i)
    
    # Compute metrics for the current image
    jaccard, precision, recall, f1, dice, _, image_ious = compute_instance_metrics(true_img.squeeze(), pred_img.squeeze())
    aji = compute_aji(true_img.squeeze(), pred_img.squeeze())
    seg = compute_seg(true_img.squeeze(), pred_img.squeeze())

    # Store the results in the list
    results.append({'Image Index': i+1, 'Recall': recall, 'Precision': precision, 'Jaccard': jaccard, 'F1 Score': f1, 'Dice Score': dice, 'AJI': aji, 'SEG': seg})

    # Store metrics in the corresponding lists
    jaccards.append(jaccard)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)
    dices.append(dice)
    ious.append(image_ious)
    ajis.append(aji)
    segs.append(seg)

    print('AJI: ' + str(aji))
    print('SEG: ' + str(seg))
    print('recall: ' + str(recall))
    print('precision: ' + str(precision))
    print('jaccard: ' + str(jaccard))
    print('f1: ' + str(f1))
    print('dice: ' + str(dice))

# Convert the results list to a dataframe
df = pd.DataFrame(results)

# Save the dataframe to CSV
output_path = "/home/yiming/WorkSpace/CenterSAM/demo_results/CenterSAM_on_TissueNet.csv"
df.to_csv(output_path, index=False)
print(f"Saved results to {output_path}")

# Compute the average of the metrics over all images
avg_recall = np.mean(recalls)
avg_precision = np.mean(precisions)
avg_jaccard = np.mean(jaccards)
avg_f1 = np.mean(f1_scores)
avg_dice = np.mean(dices)
avg_iou = np.mean(ious)
avg_aji = np.mean(ajis)
avg_seg = np.mean(segs)

print("\nOverall Evaluation Metrics:")
print(f"Average Recall: {avg_recall:.5f}")
print(f"Average Precision: {avg_precision:.5f}")
print(f"Average Jaccard: {avg_jaccard:.5f}")
print(f"Average F1 Score: {avg_f1:.5f}")
print(f"Average Dice Score: {avg_dice:.5f}")
print(f"Average IoU: {avg_iou:.5f}")
print(f"Average AJI: {avg_aji:.5f}")
print(f"Average SEG: {avg_seg:.5f}")


## Below is the evaluation of Mesmer predict result

In [None]:
# Load the mesmer predict result
Mesmer_predict_result = np.load('/home/yiming/WorkSpace/CenterSAM/demo_results/mesmer_prediction_results.npy')

In [None]:
Mesmer_predict_result.shape

In [None]:
Mesmer_predict_result[0]

In [None]:
import pandas as pd

# Calculate the Jaccard, precision, recall, F1 score, and dice score

# Initialize lists to store metrics for each image
jaccards = []
precisions = []
recalls = []
f1_scores = []
dices = []
ious = []
ajis = []  
segs = [] 

# Initialize a list to store results for each image
results = []

# Initialize a dataframe to store results
df = pd.DataFrame(columns=['Image Index', 'Recall', 'Precision', 'Jaccard', 'F1 Score', 'Dice Score', 'AJI', 'SEG'])

for i, (orig_img, true_img, pred_img) in enumerate(zip(X_test, y_true, Mesmer_predict_result)):
    print(f"Processing image {i + 1}/{len(y_true)}...")
    visualize_comparison(orig_img, true_img.squeeze(), pred_img.squeeze(), i)
    
    # Compute metrics for the current image
    jaccard, precision, recall, f1, dice, _, image_ious = compute_instance_metrics(true_img.squeeze(), pred_img.squeeze())
    aji = compute_aji(true_img.squeeze(), pred_img.squeeze())
    seg = compute_seg(true_img.squeeze(), pred_img.squeeze())

    # Store the results in the list
    results.append({'Image Index': i+1, 'Recall': recall, 'Precision': precision, 'Jaccard': jaccard, 'F1 Score': f1, 'Dice Score': dice, 'AJI': aji, 'SEG': seg})

    print('AJI: ' + str(aji))
    print('SEG: ' + str(seg))
    print('recall: ' + str(recall))
    print('precision: ' + str(precision))
    print('jaccard: ' + str(jaccard))
    print('f1: ' + str(f1))
    print('dice: ' + str(dice))

# Convert the results list to a dataframe
df = pd.DataFrame(results)

# Save the dataframe to CSV
output_path = "/home/yiming/WorkSpace/CenterSAM/demo_results/Mesmer_on_TissueNet.csv"
df.to_csv(output_path, index=False)
print(f"Saved results to {output_path}")

# Compute the average of the metrics over all images
avg_recall = np.mean(recalls)
avg_precision = np.mean(precisions)
avg_jaccard = np.mean(jaccards)
avg_f1 = np.mean(f1_scores)
avg_dice = np.mean(dices)
avg_iou = np.mean(ious)
avg_aji = np.mean(ajis)
avg_seg = np.mean(segs)

print("\nOverall Evaluation Metrics:")
print(f"Average Recall: {avg_recall:.5f}")
print(f"Average Precision: {avg_precision:.5f}")
print(f"Average Jaccard: {avg_jaccard:.5f}")
print(f"Average F1 Score: {avg_f1:.5f}")
print(f"Average Dice Score: {avg_dice:.5f}")
print(f"Average IoU: {avg_iou:.5f}")
print(f"Average AJI: {avg_aji:.5f}")
print(f"Average SEG: {avg_seg:.5f}")


## Detectron2 on TissueNet Evaluation

In [None]:
# Load the predict masks
Detectron2_predict_result = np.load('/home/yiming/WorkSpace/detectron2/Detectron2_predict_results_on_TissueNet.npy')

In [None]:
# This section is for limited test with N images

X = 10

for i, (orig_img, true_img, pred_img) in enumerate(zip(X_test, y_true, CenterSAM_predict_result)):
    if i >= X:
        break
    
    print(f"Processing image {i + 1}/{X}...")
    visualize_comparison(orig_img, true_img.squeeze(), pred_img.squeeze(), i)


In [None]:
import pandas as pd
import numpy as np

# Calculate the Jaccard, precision, recall, F1 score, and dice score

# Initialize lists to store metrics for each image
jaccards = []
precisions = []
recalls = []
f1_scores = []
dices = []
ious = []
ajis = []  
segs = [] 

# Initialize a list to store results for each image
results = []

# Initialize a dataframe to store results
df = pd.DataFrame(columns=['Image Index', 'Recall', 'Precision', 'Jaccard', 'F1 Score', 'Dice Score', 'AJI', 'SEG'])

for i, (orig_img, true_img, pred_img) in enumerate(zip(X_test, y_true, Detectron2_predict_result)):
    print(f"Processing image {i + 1}/{len(y_true)}...")
    visualize_comparison(orig_img, true_img.squeeze(), pred_img.squeeze(), i)
    
    # Compute metrics for the current image
    jaccard, precision, recall, f1, dice, _, image_ious = compute_instance_metrics(true_img.squeeze(), pred_img.squeeze())
    aji = compute_aji(true_img.squeeze(), pred_img.squeeze())
    seg = compute_seg(true_img.squeeze(), pred_img.squeeze())

    # Store the results in the list
    results.append({'Image Index': i+1, 'Recall': recall, 'Precision': precision, 'Jaccard': jaccard, 'F1 Score': f1, 'Dice Score': dice, 'AJI': aji, 'SEG': seg})

    # Store metrics in the corresponding lists
    jaccards.append(jaccard)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)
    dices.append(dice)
    ious.append(image_ious)
    ajis.append(aji)
    segs.append(seg)

    print('AJI: ' + str(aji))
    print('SEG: ' + str(seg))
    print('recall: ' + str(recall))
    print('precision: ' + str(precision))
    print('jaccard: ' + str(jaccard))
    print('f1: ' + str(f1))
    print('dice: ' + str(dice))

# Convert the results list to a dataframe
df = pd.DataFrame(results)

# Save the dataframe to CSV
output_path = "/home/yiming/WorkSpace/CenterSAM/demo_results/Detectron2_on_TissueNet.csv"
df.to_csv(output_path, index=False)
print(f"Saved results to {output_path}")

# Compute the average of the metrics over all images
avg_recall = np.mean(recalls)
avg_precision = np.mean(precisions)
avg_jaccard = np.mean(jaccards)
avg_f1 = np.mean(f1_scores)
avg_dice = np.mean(dices)
avg_iou = np.mean(ious)
avg_aji = np.mean(ajis)
avg_seg = np.mean(segs)

print("\nOverall Evaluation Metrics:")
print(f"Average Recall: {avg_recall:.5f}")
print(f"Average Precision: {avg_precision:.5f}")
print(f"Average Jaccard: {avg_jaccard:.5f}")
print(f"Average F1 Score: {avg_f1:.5f}")
print(f"Average Dice Score: {avg_dice:.5f}")
print(f"Average IoU: {avg_iou:.5f}")
print(f"Average AJI: {avg_aji:.5f}")
print(f"Average SEG: {avg_seg:.5f}")
