In [None]:
# Prior ro running, install SAM: https://github.com/facebookresearch/segment-anything
# Download the SAM model from: https://github.com/facebookresearch/segment-anything#:~:text=First%20download%20a-,model%20checkpoint,-.%20Then%20the%20model

# Set up 
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import os
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [None]:
# Seeds
import torch
torch.manual_seed(0)


import numpy as np
np.random.seed(0)

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

In [None]:
# Generate the SAM model to be used for segmentation

# # path to model file 
sam_checkpoint = "/Users/innaa/Documents/seg anything models/sam_vit_h_4b8939.pth"
# model type
model_type     = "vit_h"

# path to model file 
# sam_checkpoint = "/Users/innaa/Documents/seg anything models/sam_vit_b_01ec64.pth"
# # model type
# model_type     = "vit_b"

# # path to model file 
# sam_checkpoint = "/Users/innaa/Documents/seg anything models/sam_vit_l_0b3195.pth"
# # model type
# model_type     = "vit_l"

#device = "cuda"
device = torch.device("cpu")

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# SAM model custom params 
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=20,
    pred_iou_thresh=0.88,
    stability_score_thresh=0.95,
    crop_n_layers=0,
    crop_n_points_downscale_factor=1,
)

In [None]:
# Make garyscale copposite from myoep+basal+liminal+tumor markers
def create_composite_image(folder_path, fov,name_list):
    images = []
    
    # Load channels and add them to composite
    for name in name_list:
        image_path = os.path.join(folder_path, name + ".tiff")
        image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        # plt.imshow(image, cmap="gray")
        # plt.axis("off")
        # plt.show()
        images.append(image)
    if len(images) == 0:
        return None
    composite        = np.sum(images, axis=0)
    composite        = (composite * 255).round().astype(np.uint8)
    # Return composite
    return composite

In [None]:
def calculate_overlap(mask1, mask2):
    # Make sure both masks have the same dimensions
    assert mask1.shape == mask2.shape, "Masks must have the same shape"

    # Calculate the number of pixels that are non-zero in both masks
    overlap = np.logical_and(mask1, mask2).sum()

    # Calculate the percentage of non-zero pixels in the first mask that have non-zero values in the second mask
    percentage = overlap / np.count_nonzero(mask1) * 100

    return percentage

In [None]:
# Define channels to be included in composite
myoep         = ['ANXA1', 'Calponin1', 'SMA']
tumor         = ['Ecadherin','EpCAM']
luminal       = ['KRT15','KRT81','KRT18']
basal         = [ 'KRT5', 'KRT14', 'KRT17']
luminal_tumor = ['KRT7']
name_list     = myoep + tumor + luminal + basal + luminal_tumor

In [None]:
# Define fov list to loop over: DCIS and Normal breaast
fov_list = ['TA535_R12C6','TA536_R1C1'];

In [None]:
for fov in fov_list:
    # load and composite image
    print(f'Masking {fov}')
    folder_path = f'/Users/innaa/Documents/DCIS 2.0/20230417_image_test_set/{fov}/TIFs'  # Change based on folder structure
    composite     = create_composite_image(folder_path, fov,name_list)

    # Segment image 
    image_all = cv2.cvtColor(composite, cv2.COLOR_BGR2RGB)
    masks2    = mask_generator.generate(image_all)

    # Load the collagen image of the fov and binarize it 
    image_path       = os.path.join(folder_path, "COL1A1.tiff")
    col_image        = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    col_image        = (col_image * 255).round().astype(np.uint8)
    thresh_value     = 2
    ret, bin_col     = cv2.threshold(col_image, thresh_value, 255, cv2.THRESH_BINARY)

    # Binarize signal image 
    thresh_value     = 3
    ret, bin_sig     = cv2.threshold(composite, thresh_value, 255, cv2.THRESH_BINARY)

    # Define thresholds for mask exclusion
    height, width  = bin_col.shape
    size_max_tr    = 70*height*width
    size_min_tr    = 1000
    overlap_col_tr = 10
    overlap_sig_tr = 2

    # Mask post processing
    new_masks2 = []  # create a new list to store the selected masks
    for i, mask_dict in enumerate(masks2):
        mask = mask_dict['segmentation'] 
        per_overlap_col = calculate_overlap(mask, bin_col)
        per_overlap_sig = calculate_overlap(mask, bin_sig)
        non_zero_pixels = mask_dict['area']
        # check if the mask meets the size and overlap criteria
        if (size_min_tr <= non_zero_pixels <= size_max_tr) and  (overlap_sig_tr <= per_overlap_sig) and  ((per_overlap_col <= overlap_col_tr and overlap_sig_tr <= per_overlap_sig) or (per_overlap_col >= overlap_col_tr and per_overlap_col <= per_overlap_sig)):
            new_masks2.append(mask_dict) 

    # Combine the masks into a single binary mask
    sum_mask = np.zeros_like(composite, dtype=np.uint8)
    # Iterate over each mask and add it to the sum of the masks
    for mask_dict in new_masks2:
        mask = mask_dict['segmentation']
        mask = mask.astype(np.uint8)
        sum_mask = cv2.add(sum_mask, mask)
    # Convert the sum of the masks to a binary mask
    sum_mask = cv2.threshold(sum_mask, 0, 255, cv2.THRESH_BINARY)[1]

    # Save segmentation mask
    epi_mask_filename = f'{folder_path}/epi_mask.tiff'
    cv2.imwrite(epi_mask_filename, sum_mask)

    # Save the composite image
    composite_filename = f'{folder_path}/composite_image.tiff'  
    cv2.imwrite(composite_filename, composite)