# Figure 4 - Spectral Unmixing + SAM Segmentation

In [8]:
# import modules
import sys
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
import os
import tifffile as tf
from skimage.morphology import disk, binary_dilation, binary_erosion
import pandas as pd

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)
from functions_EDX import *

import torch
print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())
from skimage.filters import gaussian 
#from segment_anything_hq import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
from skimage.feature import peak_local_max



PyTorch version: 2.3.0
CUDA is available: False


### Set variables

In [2]:
HomePath = '/Users/AJ/Desktop/CellFigures/raw_material/Figure 4/PeterMasks' #r'D:\Projects\IDENTIFY\Data\Figure 3'
# HomePath Structure:
# /path/to/directory
#   |-- PaCMAP_instance
#   |-- HAADFS (Exported using preprocessing/ExtractAndCorrectHAADFS)

supp_file_path = os.path.join(HomePath, "PaCMAP_instance", "pacmap_panc_euc_20percent_SavedTree20240209-105636.npz")

haadf_folder = os.path.join(HomePath, 'HAADFs')


sam_checkpoint = r"D:\Projects\IDENTIFY\SAM\sam_vit_h_4b8939.pth"
model_type = "vit_h"

### functions

In [3]:
# sub-routine to get file names
def get_file_paths(spectrum_folder):
    file_names = []
    tmp = os.listdir(spectrum_folder)
    for filename in tmp:
        if filename.endswith('.npz'):
            file_names.append(filename)
    
    file_names = sorted(file_names)
    files = [os.path.join(spectrum_folder,file_name) for file_name in file_names]
    return files

def show_mask(mask, ax, random_color=False,alpha=0.35):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='o', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='o', s=marker_size, edgecolor='white', linewidth=1.25)   

### structure indices

In [4]:
structure_names = ['Unclassified','Insulin','Nucleic acids','Exocrine granules','PP or Ghrelin','Glucagon',
              'Membranes','Lysosomes A','Lysosomes B','Lysosomes C','Nucleolus','Lysosomes D']

# print structure indices
for idx, structure in enumerate(structure_names):
    print("%02d - %s" % (idx,structure))

00 - Unclassified
01 - Insulin
02 - Nucleic acids
03 - Exocrine granules
04 - PP or Ghrelin
05 - Glucagon
06 - Membranes
07 - Lysosomes A
08 - Lysosomes B
09 - Lysosomes C
10 - Nucleolus
11 - Lysosomes D


### Load ColorEM data

In [5]:
# choose a tile and a structure to segment
tile_indices = [i for i in range(30)]
#tile_idx = 4
structure_indices = [1, 2, 3, 4, 5]
tresholds = [0.4, 0.5, 0.4, 0.4, 0.4]


abundance_maps = np.load(supp_file_path)['abundance_maps']
abundance_maps = abundance_maps / np.max(abundance_maps) # Scale to have max coefficient = 1
colors = np.load(supp_file_path)['colors']
colors[0] = [0,0,0]

# get HAADF img
haadf_stack = []
for i in tile_indices:
    haadf_stack.append(tf.imread(os.path.join(haadf_folder, "Tile_%02d.tiff" % i)))
haadf_stack = np.asarray(haadf_stack).transpose((1,2,0))
    
#spectrum_folder = '/Volumes/Microscopy3/EDX_data/Identify/main_mosaic_6by5/NPZ/'
#files = get_file_paths(spectrum_folder)
#haadf = rebin_XY(np.load(files[tile_idx])['haadf'],1024)

### Pre-process abundance maps

In [6]:
for c, structure_idx in enumerate(structure_indices):
    for tile_idx in tile_indices:
        img = gaussian(abundance_maps[structure_idx,:,:,tile_idx], 4)
        img = img* (img>tresholds[c])
        abundance_maps[structure_idx,:,:,tile_idx] = img     

### Apply SAM to all tiles for each structure and save as arrays

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

try:
    os.mkdir(os.path.join(HomePath, 'SAM Masks'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Starting {structure_names[structure_idx]}', end = '\n') 
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Masks', structure_names[structure_idx]))
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = normalize8(haadf_stack[:,:,tile_idx])
        coordinates = peak_local_max(abundance_maps[structure_idx,:,:,tile_idx], min_distance=10)
        all_masks = []
        all_scores = []
        predictor.set_image(np.dstack((img,img,img)))
        for i in range(coordinates.shape[0]):
            input_point = np.array([(coordinates[i][1],coordinates[i][0])])
            input_label = np.array([1])
        
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )
            all_masks.append(masks)
            all_scores.append(scores)
        masks_1, masks_2, masks_3 = [], [], []
        for i in range(len(all_masks)):
            masks_1.append(all_masks[i][0])
            masks_2.append(all_masks[i][1])
            masks_3.append(all_masks[i][2])
        masks_1, masks_2, masks_3, scores = np.asarray(masks_1), np.asarray(masks_2), np.asarray(masks_3), np.asarray(all_scores)
        np.savez_compressed(os.path.join(HomePath, 'SAM Masks', structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)),
                            masks_1 = masks_1,
                            masks_2 = masks_2,
                            masks_3 = masks_3,
                            scores = scores)
    print('Done', end = '\n')
                

### Load arrays and convert to individual tiffs with black background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Converting {structure_names[structure_idx]}', end = '\n')
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual', structure_names[structure_idx] + f' {mask_name}')) # With overlay!
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
        for m in mask:
            img[m,0], img[m,1], img[m,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
        tf.imwrite(os.path.join(HomePath,'SAM Tiffs', 'Individual', structure_names[structure_idx] + f' {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
    print('Done', end = '\n')
    


### Load arrays and convert to individual tiffs with HAADF background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'
alpha = 0.35

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)'))
except:
    pass

for structure_idx in structure_indices:
    print(f'Converting {structure_names[structure_idx]}', end = '\n')
    try:
        os.mkdir(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)', structure_names[structure_idx] + f' {mask_name}')) 
    except:
        pass
    print('Tile: ', end = '\t')
    for tile_idx in tile_indices:
        print(f'{tile_idx} ', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        haadf_img = np.dstack((haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx]))/255
        mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
        for m in mask:
            img[m,0], img[m,1], img[m,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]  
        img = (alpha * img) + ((1-alpha)*haadf_img)
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', 'Individual (HAADF Overlay)', structure_names[structure_idx] + f' {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
    print('Done', end = '\n')

### Load arrays and convert to combined tiffs with black background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name}'))
except:
    pass

for tile_idx in tile_indices:
        print(f'Converting tile {tile_idx}', end = '\t')
        img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        for structure_idx in structure_indices:
            mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
            bin_mask = np.zeros((1024,1024),dtype='bool')
            for m in mask:  
                bin_mask[m] = 1
                
            # dilate for the nucleic acid masks (optional)
            if structure_idx == 2:
                radius = 70 if tile_idx == 16 else 20
                bin_mask = binary_dilation(bin_mask, disk(radius, dtype=bool))
                bin_mask = binary_erosion(bin_mask, disk(radius, dtype=bool))

            # color
            img[bin_mask,0], img[bin_mask,1], img[bin_mask,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name}','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
        print('Done', end = '\n')

### Load arrays and convert to combined tiffs with HAADF background

In [None]:
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'
alpha = 0.5

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs'))
except:
    pass

try:
    os.mkdir(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name} (HAADF Overlay)'))
except:
    pass

for tile_idx in tile_indices:
        print(f'Converting tile {tile_idx}', end = '\t')
        #img = np.zeros((haadf_stack.shape[0], haadf_stack.shape[1], 3))
        haadf_img = np.dstack((haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx], haadf_stack[:,:,tile_idx]))/255
        img = haadf_img
        for structure_idx in structure_indices:
            mask = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))[mask_name]
            bin_mask = np.zeros((1024,1024),dtype='bool')
            for m in mask:  
                bin_mask[m] = 1
                
            # dilate for the nucleic acid masks (optional)
            if structure_idx == 2:
                radius = 70 if tile_idx == 16 else 20
                bin_mask = binary_dilation(bin_mask, disk(radius, dtype=bool))
                bin_mask = binary_erosion(bin_mask, disk(radius, dtype=bool))

            # color
            #img[bin_mask,0], img[bin_mask,1], img[bin_mask,2] = colors[structure_idx][0], colors[structure_idx][1], colors[structure_idx][2]
            img[bin_mask,0] = alpha*colors[structure_idx][0] + (1-alpha)*haadf_img[bin_mask,0]
            img[bin_mask,1] = alpha*colors[structure_idx][1] + (1-alpha)*haadf_img[bin_mask,1]
            img[bin_mask,2] = alpha*colors[structure_idx][2] + (1-alpha)*haadf_img[bin_mask,2]
        #img = (alpha * img) + ((1-alpha)*haadf_img)
        tf.imwrite(os.path.join(HomePath, 'SAM Tiffs', f'Combined + {mask_name} (HAADF Overlay)','Tile_%02d.tiff' % (tile_idx)), (img*255).astype('uint8'))
        print('Done', end = '\n')

### Import and display scores (for revision)

In [28]:
score_array = np.zeros((len(tile_indices),len(structure_indices)))
mask_folder = os.path.join(HomePath, 'SAM Masks')
mask_name = 'masks_1'
                       
for structure_idx in structure_indices:
    for tile_idx in tile_indices:
        scores = np.load(os.path.join(mask_folder, structure_names[structure_idx],'Tile_%02d.npz' % (tile_idx)))['scores']
        try:
            score_array[tile_idx,structure_idx-1] = np.mean(scores,axis=0)[0]
        except:
            score_array[tile_idx,structure_idx-1] = np.nan
        

In [43]:
scores_df = pd.DataFrame(data=score_array, index=['%02d' % i for i in tile_indices], columns=[structure_names[i] for i in structure_indices])
display(scores_df)

df2 = scores_df.describe()
display(df2)

Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
0,0.938628,0.930942,,,0.945925
1,0.918185,0.77121,0.997868,,0.950979
2,0.928562,0.93373,0.916907,,0.94973
3,0.940646,0.845485,,,0.947697
4,0.93172,0.891188,0.901846,,0.957204
5,0.935663,0.857975,,,0.951016
6,0.934795,0.757035,0.92405,,0.945541
7,0.943373,0.845824,0.923518,0.972459,0.950741
8,0.95301,0.828955,0.936359,0.975944,0.952667
9,0.917916,0.899371,0.929943,,0.961632


Unnamed: 0,Insulin,Nucleic acids,Exocrine granules,PP or Ghrelin,Glucagon
count,22.0,23.0,18.0,3.0,15.0
mean,0.913659,0.821629,0.944013,0.969678,0.939218
std,0.063143,0.118243,0.087343,0.008027,0.038173
min,0.711047,0.357698,0.618057,0.96063,0.807111
25%,0.917983,0.799075,0.925523,0.966544,0.945733
50%,0.934303,0.838986,0.978079,0.972459,0.950741
75%,0.940382,0.888994,0.987754,0.974202,0.954675
max,0.95301,0.947056,1.000236,0.975944,0.961632


In [46]:
# row averages
print(df2.mean(axis=1))

count    16.200000
mean      0.917639
std       0.062986
min       0.690909
25%       0.910972
50%       0.934914
75%       0.949201
max       0.967575
dtype: float64
