# Figure 4 - Spectral Unmixing + SAM Segmentation

In [68]:
# import modules
import sys
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
import os
import tifffile as tf
from skimage.morphology import disk, binary_dilation, binary_erosion
import pandas as pd

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)
from functions_EDX import *

import torch
print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())
from skimage.filters import gaussian 
#from segment_anything_hq import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
from skimage.feature import peak_local_max



PyTorch version: 2.3.0
CUDA is available: False


### Set variables

In [69]:
HomePath = '/Users/AJ/Desktop/CellFigures/raw_material/Figure 4/PeterMasks' #r'D:\Projects\IDENTIFY\Data\Figure 3'
# HomePath Structure:
# /path/to/directory
#   |-- PaCMAP_instance
#   |-- HAADFS (Exported using preprocessing/ExtractAndCorrectHAADFS)

supp_file_path = os.path.join(HomePath, "PaCMAP_instance", "pacmap_panc_euc_20percent_SavedTree20240209-105636.npz")

haadf_folder = os.path.join(HomePath, 'HAADFs')


sam_checkpoint = '/Users/aj/Desktop/work/PostDoc_UMCG/work/analysis/EDX_Project/primary_data/sam_vit_h_4b8939.pth'   #r"D:\Projects\IDENTIFY\SAM\sam_vit_h_4b8939.pth"
model_type = "vit_h"

### functions

In [70]:
# sub-routine to get file names
def get_file_paths(spectrum_folder):
    file_names = []
    tmp = os.listdir(spectrum_folder)
    for filename in tmp:
        if filename.endswith('.npz'):
            file_names.append(filename)
    
    file_names = sorted(file_names)
    files = [os.path.join(spectrum_folder,file_name) for file_name in file_names]
    return files

def show_mask(mask, ax, random_color=False,alpha=0.35):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='o', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='o', s=marker_size, edgecolor='white', linewidth=1.25)   

### structure indices

In [71]:
structure_names = ['Unclassified','Insulin','Nucleic acids','Exocrine granules','PP or Ghrelin','Glucagon',
              'Membranes','Lysosomes A','Lysosomes B','Lysosomes C','Nucleolus','Lysosomes D']

# print structure indices
for idx, structure in enumerate(structure_names):
    print("%02d - %s" % (idx,structure))

00 - Unclassified
01 - Insulin
02 - Nucleic acids
03 - Exocrine granules
04 - PP or Ghrelin
05 - Glucagon
06 - Membranes
07 - Lysosomes A
08 - Lysosomes B
09 - Lysosomes C
10 - Nucleolus
11 - Lysosomes D


### Load ColorEM data

In [5]:
# choose a tile and a structure to segment
tile_indices = [i for i in range(30)]
#tile_idx = 4
structure_indices = [1, 2, 3, 4, 5]
tresholds = [0.4, 0.5, 0.4, 0.4, 0.4]


abundance_maps = np.load(supp_file_path)['abundance_maps']
abundance_maps = abundance_maps / np.max(abundance_maps) # Scale to have max coefficient = 1
colors = np.load(supp_file_path)['colors']
colors[0] = [0,0,0]

# get HAADF img
haadf_stack = []
for i in tile_indices:
    haadf_stack.append(tf.imread(os.path.join(haadf_folder, "Tile_%02d.tiff" % i)))
haadf_stack = np.asarray(haadf_stack).transpose((1,2,0))
    
#spectrum_folder = '/Volumes/Microscopy3/EDX_data/Identify/main_mosaic_6by5/NPZ/'
#files = get_file_paths(spectrum_folder)
#haadf = rebin_XY(np.load(files[tile_idx])['haadf'],1024)

### Pre-process abundance maps

In [6]:
for c, structure_idx in enumerate(structure_indices):
    for tile_idx in tile_indices:
        img = gaussian(abundance_maps[structure_idx,:,:,tile_idx], 4)
        img = img* (img>tresholds[c])
        abundance_maps[structure_idx,:,:,tile_idx] = img     

### Apply SAM to all tiles for each structure and save as arrays

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

try:
    os.mkdir(os.path.join(HomePath, 'SAM Masks'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Starting {structure_names[structure_idx]}', end = '\n') 
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Masks', structure_names[structure_idx]))
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = normalize8(haadf_stack[:,:,tile_idx])
        coordinates = peak_local_max(abundance_maps[structure_idx,:,:,tile_idx], min_distance=10)
        all_masks = []
        all_scores = []
        predictor.set_image(np.dstack((img,img,img)))
        for i in range(coordinates.shape[0]):
            input_point = np.array([(coordinates[i][1],coordinates[i][0])])
            input_label = np.array([1])
        
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )
            all_masks.append(masks)
            all_scores.append(scores)
        masks_1, masks_2, masks_3 = [], [], []
        for i in range(len(all_masks)):
            masks_1.append(all_masks[i][0])
            masks_2.append(all_masks[i][1])
            masks_3.append(all_masks[i][2])
        masks_1, masks_2, masks_3, scores = np.asarray(masks_1), np.asarray(masks_2), np.asarray(masks_3), np.asarray(all_scores)
        np.savez_compressed(os.path.join(HomePath, 'SAM Masks', structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)),
                            masks_1 = masks_1,
                            masks_2 = masks_2,
                            masks_3 = masks_3,
                            scores = scores)
    print('Done', end = '\n')
                

### Repeat the above cell but with random purtubations to the prompt locations (as a benchmark)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
perturb_limits = [40,60,80,100] #20 #pixels

try:
    os.mkdir(os.path.join(HomePath, 'SAM Masks_purturbed'))
except:
    pass


for perturb_lim in perturb_limits:
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Masks_purturbed','%02d' % perturb_lim))
    except:
        pass
    
    for structure_idx in structure_indices:
        print(f'Starting {structure_names[structure_idx]}', end = '\n') 
        try:
            os.mkdir(os.path.join(HomePath, 'SAM Masks_purturbed', '%02d' % perturb_lim ,structure_names[structure_idx]))
        except:
            pass
        print('Tile: ', end = '\t')
        for tile_idx in tile_indices:
            print(f'{tile_idx} ', end = '\t')
            img = normalize8(haadf_stack[:,:,tile_idx])
            coordinates = peak_local_max(abundance_maps[structure_idx,:,:,tile_idx], min_distance=10)
            all_masks = []
            all_scores = []
            predictor.set_image(np.dstack((img,img,img)))
            for i in range(coordinates.shape[0]):
                input_point = np.array([(coordinates[i][1], coordinates[i][0])]) + np.random.randint(-perturb_lim, perturb_lim + 1, size=2)
                input_point = np.clip(input_point, 0, 1023)
                            
                input_label = np.array([1])
            
                masks, scores, logits = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True,
                )
                all_masks.append(masks)
                all_scores.append(scores)
            masks_1, masks_2, masks_3 = [], [], []
            for i in range(len(all_masks)):
                masks_1.append(all_masks[i][0])
                masks_2.append(all_masks[i][1])
                masks_3.append(all_masks[i][2])
            masks_1, masks_2, masks_3, scores = np.asarray(masks_1), np.asarray(masks_2), np.asarray(masks_3), np.asarray(all_scores)
            np.savez_compressed(os.path.join(HomePath, 'SAM Masks_purturbed','%02d' % perturb_lim, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)),
                                #masks_1 = masks_1,
                                #masks_2 = masks_2,
                                #masks_3 = masks_3,
                                scores = scores)
        print('Done', end = '\n')
                

### Load arrays and convert to individual tiffs with black background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Converting {structure_names[structure_idx]}', end = '\n')
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual', structure_names[structure_idx] + f' {mask_name}')) # With overlay!
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
        for m in mask:
            img[m,0], img[m,1], img[m,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
        tf.imwrite(os.path.join(HomePath,'SAM Tiffs', 'Individual', structure_names[structure_idx] + f' {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
    print('Done', end = '\n')
    


### Load arrays and convert to individual tiffs with HAADF background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'
alpha = 0.35

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Converting {structure_names[structure_idx]}', end = '\n')
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)', structure_names[structure_idx] + f' {mask_name}')) 
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        haadf_img = np.dstack((haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx]))/255
        mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
        for m in mask:
            img[m,0], img[m,1], img[m,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]  
        img = (alpha * img) + ((1-alpha)*haadf_img)
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)', structure_names[structure_idx] + f' {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
    print('Done', end = '\n')

### Load arrays and convert to combined tiffs with black background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name}'))
except:
    pass

for tile_idx in tile_indices:
        print(f'Converting tile {tile_idx}', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        for structure_idx in structure_indices:
            mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
            bin_mask = np.zeros((1024,1024),dtype='bool')
            for m in mask:  
                bin_mask[m] = 1
                
            # dilate for the nucleic acid masks (optional)
            if structure_idx == 2:
                radius = 70 if tile_idx == 16 else 20
                bin_mask = binary_dilation(bin_mask, disk(radius, dtype=bool))
                bin_mask = binary_erosion(bin_mask, disk(radius, dtype=bool))

            # color
            img[bin_mask,0], img[bin_mask,1], img[bin_mask,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
        print('Done', end = '\n')

### Load arrays and convert to combined tiffs with HAADF background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'
alpha = 0.5

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name} (HAADF Overlay)'))
except:
    pass

for tile_idx in tile_indices:
        print(f'Converting tile {tile_idx}', end = '\t')
        #img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        haadf_img = np.dstack((haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx]))/255
        img = haadf_img
        for structure_idx in structure_indices:
            mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
            bin_mask = np.zeros((1024,1024),dtype='bool')
            for m in mask:  
                bin_mask[m] = 1
                
            # dilate for the nucleic acid masks (optional)
            if structure_idx == 2:
                radius = 70 if tile_idx == 16 else 20
                bin_mask = binary_dilation(bin_mask, disk(radius, dtype=bool))
                bin_mask = binary_erosion(bin_mask, disk(radius, dtype=bool))

            # color
            #img[bin_mask,0], img[bin_mask,1], img[bin_mask,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
            img[bin_mask,0] = alpha*colors[structure_idx][0] + (1-alpha)*haadf_img[bin_mask,0]
            img[bin_mask,1] = alpha*colors[structure_idx][1] + (1-alpha)*haadf_img[bin_mask,1]
            img[bin_mask,2] = alpha*colors[structure_idx][2] + (1-alpha)*haadf_img[bin_mask,2]
        #img = (alpha * img) + ((1-alpha)*haadf_img)
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name} (HAADF Overlay)','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
        print('Done', end = '\n')

### Import and display scores (for revision)

In [None]:
score_array = np.zeros((len(tile_indices),len(structure_indices)))
mask_folder = os.path.join(HomePath, 'SAM Masks_purturbed')
mask_name = 'masks_1'
                       
for structure_idx in structure_indices:
    for tile_idx in tile_indices:
        scores = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))['scores']
        print(scores)
        try:
            score_array[tile_idx,structure_idx-1] = np.mean(scores,axis=0)[0]
        except:
            score_array[tile_idx,structure_idx-1] = np.nan
        

In [58]:
scores_df = pd.DataFrame(data=score_array, index=['%02d' % i for i in tile_indices], columns=[structure_names[i] for i in structure_indices])
display(scores_df)

df2 = scores_df.describe()
display(df2)

Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
0,0.896835,0.91176,,,0.909657
1,0.940427,0.762272,0.998435,,0.869396
2,0.880091,0.934199,0.799018,,0.826607
3,0.871137,0.822926,,,0.902646
4,0.865876,0.906072,0.626717,,0.859756
5,0.880612,0.845197,,,0.600254
6,0.861151,0.737689,0.916362,,0.824971
7,0.907513,0.838343,0.881573,0.864934,0.843662
8,0.857615,0.795383,0.826038,0.873079,0.905944
9,0.833215,0.88435,0.902132,,0.894988


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.852826,0.820816,0.898892,0.899196,0.851019
std,0.079623,0.082102,0.118881,0.052447,0.076743
min,0.610817,0.586674,0.592664,0.864934,0.600254
25%,0.838535,0.791344,0.886712,0.869007,0.839895
50%,0.866112,0.822926,0.943208,0.873079,0.859756
75%,0.888075,0.877138,0.973607,0.916326,0.898817
max,0.977129,0.946953,1.000808,0.959574,0.931055


### row averages

In [59]:
print(df2.mean(axis=1))

count    16.200000
mean      0.864550
std       0.081959
min       0.651069
25%       0.845098
50%       0.873016
75%       0.910793
max       0.963104
dtype: float64


#### Export the results

In [None]:
with pd.ExcelWriter('/Users/AJ/Desktop/CellFigures/raw_material/Tables/IoU.xlsx') as writer:  
    scores_df.to_excel(writer, sheet_name='Per tile segmentation IoU')
    df2.to_excel(writer, sheet_name='Average over tiles segmentation IoU')

### Import scores for perturbed masks

In [75]:


for perturb_lim in perturb_limits:
    score_array = np.zeros((len(tile_indices),len(structure_indices)))
    mask_folder = os.path.join(HomePath, 'SAM Masks_purturbed','%02d' % perturb_lim)
    mask_name = 'masks_1'
    for structure_idx in structure_indices:
        for tile_idx in tile_indices:
            scores = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))['scores']
            #print(scores)
            try:
                score_array[tile_idx,structure_idx-1] = np.mean(scores,axis=0)[0]
            except:
                score_array[tile_idx,structure_idx-1] = np.nan
            scores_df = pd.DataFrame(data=score_array, index=['%02d' % i for i in tile_indices], columns=[structure_names[i] for i in structure_indices])
            #display(scores_df)
            
    df2 = scores_df.describe()
    print('Perturbation = %02d pixels' % perturb_lim)
    display(df2)

Perturbation = 40 pixels


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.751406,0.8277,0.854948,0.831419,0.797666
std,0.172957,0.058713,0.157286,0.140849,0.076408
min,0.267805,0.715135,0.427671,0.680619,0.61182
25%,0.769535,0.784633,0.867239,0.767343,0.770716
50%,0.817301,0.826653,0.915658,0.854068,0.803701
75%,0.84113,0.866339,0.933383,0.906819,0.824406
max,0.885932,0.948348,0.997157,0.959571,0.934178


Perturbation = 60 pixels


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.797109,0.824496,0.804783,0.87551,0.788563
std,0.082996,0.060389,0.180361,0.075682,0.073477
min,0.534447,0.69052,0.314189,0.808143,0.637139
25%,0.753648,0.803264,0.767988,0.834564,0.751574
50%,0.812591,0.824749,0.872854,0.860985,0.794546
75%,0.841308,0.856818,0.906815,0.909193,0.835011
max,0.957634,0.93051,0.972034,0.957402,0.885367


Perturbation = 80 pixels


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.800723,0.830769,0.802731,0.740716,0.793285
std,0.116883,0.059976,0.18284,0.052089,0.056478
min,0.391806,0.703406,0.318738,0.709067,0.704047
25%,0.786654,0.797589,0.813213,0.710656,0.746996
50%,0.82299,0.832821,0.858343,0.712245,0.791523
75%,0.859866,0.864986,0.90483,0.75654,0.814344
max,0.961589,0.936747,0.966424,0.800835,0.908612


Perturbation = 100 pixels


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.82387,0.826406,0.832892,0.697042,0.773371
std,0.087569,0.054152,0.073636,0.071543,0.054088
min,0.502384,0.69734,0.681747,0.616135,0.689475
25%,0.805799,0.799004,0.793143,0.669588,0.734843
50%,0.835826,0.817542,0.837028,0.723041,0.773275
75%,0.85574,0.858127,0.878028,0.737496,0.811963
max,0.984867,0.924597,0.953822,0.751951,0.864744
