In [None]:
import os
########## Dependencies
##### define neural net weights
autoencoder_path = "/mnt/local/data2/Bootsma/2D_CTC/src/analysis/publication_code/weights/02_SEE_TC_encoder.hdf5"
classifier_path = "/mnt/dho-nas06/zhaolab/long_term_storage/pipeline/digital_pathology/cellClass_2Channel_GBoost.hdf5"
##### define samples to process
img_dir_qa = "/mnt/local/data2/Bootsma/2D_CTC/src/analysis/publication_code/test_data/"

##### define GPU indices to use
gpu_id = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# define the channels for exclusion and pCK, just the first integer will be needed
pCK_name = "7"
exclusion_name = "6"

# ##### define suffix for input tiff
# input_tiff_suffix = ".bias_corrected.tiff"

# ##### define suffix for output table 
# # must match the suffix of your physical feature table
# # results will be appended into said table
# feature_suffix = ".region_features.tsv"


In [None]:
##### Dependencies
import sys
sys.path.append('../src/') 
import SEE_TC as ctc
# import preprocessing.CTC_2d_preprocessing as ctc_pp
import umap
import tifffile
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model

def count_labels(y_pred):
    unique_vals, counts = np.unique(y_pred, return_counts=True)
    freq_dict = dict(zip(unique_vals, counts))
    print(freq_dict)

##### NORMALIZE DATA - method A; min-max scale across channels to maintain relative intensity
def normalize_CNN_input_ch6(z_stack_in):
        
    z_stack_c = z_stack_in.copy()
    z_stack_c = z_stack_c[:,:,:,[0,1,2,3,4]]

    z_stack_c = z_stack_c - z_stack_c.min(axis=(1,2,3), keepdims=True) # min-max scale across channels to maintain relative intensity
    z_stack_c = z_stack_c / z_stack_c.max(axis=(1,2,3), keepdims=True)

    z_stack_BF = z_stack_in.copy() # handle BF appart from fluorescence data as it's not fluorescence
    z_stack_BF = z_stack_BF[:,:,:,5]
    z_stack_BF = np.expand_dims(z_stack_BF, axis=-1)
    z_stack_out = np.concatenate((z_stack_c,z_stack_BF), axis = -1)
    
    z_stack_UNET = z_stack_in[:,:,:,6] # append binary just in case
    z_stack_UNET = np.expand_dims(z_stack_UNET, axis=-1)
    z_stack_out = np.concatenate((z_stack_c,z_stack_UNET), axis = -1)
    return(z_stack_out)

def visualize_embeddings(embeddings, labels, method='umap', manual_colors = ['#FF0000', '#000000', '#990066', '#999999', '#006600', '#6600CC']):
    labels = np.array(labels)
    unique_labels = np.unique(labels)

    # Check if there are enough colors
    if len(unique_labels) > len(manual_colors):
        raise ValueError(f"Not enough manual colors defined for {len(unique_labels)} labels.")

    # Create mapping using only as many colors as needed
    label_to_color = {label: manual_colors[i] for i, label in enumerate(unique_labels)}

    # Dimensionality reduction
    if method == 'umap':
        reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=710)
    elif method == 'tsne':
        from sklearn.manifold import TSNE
        reducer = TSNE(n_components=2, perplexity=30, learning_rate='auto', init='pca', random_state=42)
    else:
        raise ValueError("method must be 'umap' or 'tsne'")

    reduced = reducer.fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(7, 7))
    for label in unique_labels:
        idx = labels == label
        plt.scatter(reduced[idx, 0], reduced[idx, 1], color=label_to_color[label],
                    label=f'Class {label}', s=10, alpha=0.8)

    plt.title(f"Embedding Visualization ({method.upper()})")
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")
    plt.legend(markerscale=2, loc='best', fontsize=8)
    plt.tight_layout()
    plt.show()
    
    return(reduced) # return the coordinates so we can inspect cells

from tensorflow.keras.models import load_model
autoencoder = load_model(autoencoder_path)

encoder = Model(inputs=autoencoder.input,
                outputs=autoencoder.layers[6].output)

import joblib
clf = joblib.load(classifier_path)



In [None]:
# Define file names
file_names_qa = ctc.parse_nd2_paths(img_dir_qa, ".cells.tiff", recursive=False)
file_names_qa = [s.replace(".cells.tiff", "") for s in file_names_qa]
file_names_qa = file_names_qa
print(len(file_names_qa)), file_names_qa[0]

In [None]:
res_dict = { # initialize data object
"sample_ID": [],
"cells_raw": [],
"embeddings": [],
"labels_pred": [],
"PofZ": [],
"latent_features": []
}

for r in range(len(file_names_qa)):
        file_qa = file_names_qa[r]
        print("Reading input...")
        print(file_qa)
        
        features_qa = pd.read_csv(img_dir_qa+file_qa+".region_features.tsv", sep = "\t")
        features_qa = features_qa[features_qa['class_cleanSeg'] == "Cell"]
        features_qa['idx_0'] = range(len(features_qa))
        features_cNames = features_qa.columns
        target_cells_qa = features_qa
        target_idx_qa = target_cells_qa['idx_0']

        if target_cells_qa.shape[0] == 0:
            continue

        ##### read image, normalize, predict
        img_qa_raw = tifffile.imread(img_dir_qa+file_qa+".cells.tiff") # get stack of cells
        if len(img_qa_raw.shape) == 3: # handle one cell inputs
            img_qa_raw = img_qa_raw[np.newaxis,:]
        img_qa_raw = img_qa_raw.transpose(0,2,3,1) # c last
        print("########")
        print(len(features_qa))
        print(img_qa_raw.shape)
        
        img_qa_raw = img_qa_raw[target_cells_qa['idx_0']]

        c_names = ctc.get_channel_names_tiff(img_dir_qa+file_qa+".cells.tiff") # normalize CNN input, protein data is normalize apart from BF and UNET channels
        print(c_names)
        
        c_exclusion = next((i for i, s in enumerate(c_names) if s.startswith(exclusion_name)), -1)
        c_pCK = next((i for i, s in enumerate(c_names) if s.startswith(pCK_name)), -1)
        channel_indices_pred = [c_exclusion,c_pCK]

        if np.min(img_qa_raw[:,:,:,channel_indices_pred].max(axis=(1,2,3))) == 0:
             # if a predictive channel was not recorded, the sample cannot be used and is skipped       
             # the channel is, however, present as an array of 0's as a result of pre-processing
             continue 
                      
        img_qa_norm = ctc.normalize_CNN_input_chN(img_qa_raw, c_names)      
        img_qa_CNN_in = img_qa_norm[:,:,:,channel_indices_pred]

        print("embedding...")
        with tf.device(f'/gpu:{gpu_id}'):  # Use the specified GPU
            embeddings_qa = encoder.predict(img_qa_CNN_in, verbose=0)

        embeddings_qa = embeddings_qa.reshape(embeddings_qa.shape[0], -1) # Flatten embeddings

        y_pred_qa = clf.predict(embeddings_qa) # predict label (0: CTC, 1: non-CTC)
        y_pred_qa = np.expand_dims(y_pred_qa, axis=-1)
        y_pred_prob_qa = clf.predict_proba(embeddings_qa)
        y_pred_res_qa = np.concatenate([y_pred_qa,y_pred_prob_qa],axis = 1)

        # encapsulate results
        print("Input sample: "+file_qa)
        print("N input cells: "+str(img_qa_raw.shape[0]))

        res_dict["sample_ID"].append([file_qa]*img_qa_raw.shape[0])
        res_dict["cells_raw"].append(img_qa_raw)
        res_dict["embeddings"].append(embeddings_qa)
        res_dict["labels_pred"].append(y_pred_qa)
        res_dict["PofZ"].append(y_pred_prob_qa)
        res_dict["latent_features"].append(embeddings_qa)
        
        print("#####\n")


In [None]:
##### threshold for review purposes
th_PofZ_0 = 0.75

all_cells_raw_qa = res_dict["cells_raw"]
all_cells_raw_qa = np.concatenate(all_cells_raw_qa)
all_y_pred_qa = res_dict["labels_pred"].copy()
all_y_pred_qa = np.concatenate(all_y_pred_qa)
all_PofZ_qa = res_dict["PofZ"].copy()
all_PofZ_qa = np.concatenate(all_PofZ_qa)
all_embeddings_qa = res_dict["embeddings"].copy()
all_embeddings_qa = np.concatenate(all_embeddings_qa)


all_y_pred_qa[all_y_pred_qa==0]=10 # PofZ = 0.5
all_y_pred_qa[all_y_pred_qa==1]=11
count_labels(all_y_pred_qa)

low_prob_0 = np.where(all_PofZ_qa[:,0] < th_PofZ_0) # PofZ = user definition
all_y_pred_qa[low_prob_0]=11 
count_labels(all_y_pred_qa)

y_pred_qa_vis = y_pred_qa.copy() # format for visualization
y_pred_qa_vis = y_pred_qa_vis.squeeze()
all_y_pred_qa=all_y_pred_qa.squeeze()

##### visualize to assess how distinct CTCs are at first pass
idx_pred_11 = np.asarray(np.where(all_y_pred_qa == 11))
idx_pred_11= idx_pred_11.squeeze()
vis_11_embeddings_qa = all_embeddings_qa[idx_pred_11]
vis_n_11 = np.min([2000,len(idx_pred_11)])

idx_pred_11 = np.random.choice(range(idx_pred_11.shape[0]), size=vis_n_11, replace=False)
vis_11_embeddings_qa = vis_11_embeddings_qa[idx_pred_11]
vis_11_labs = [11]*vis_11_embeddings_qa.shape[0]

idx_pred_10 = np.where(all_y_pred_qa == 10)
vis_10_embeddings_qa = all_embeddings_qa[idx_pred_10]
vis_10_labs = all_y_pred_qa[idx_pred_10]

vis_n_10 = np.min([500,len(idx_pred_10[0])])
idx_pred_10 = np.random.choice(range(idx_pred_10[0].shape[0]), size=vis_n_10, replace=False)
vis_10_embeddings_qa = vis_10_embeddings_qa[idx_pred_10]
vis_10_labs = [10]*vis_10_embeddings_qa.shape[0]

embeddings_plot_qa = np.concatenate([vis_10_embeddings_qa,vis_11_embeddings_qa], axis = 0)
labels_plot_qa = np.concatenate([vis_10_labs,vis_11_labs])
umap_coords = visualize_embeddings(embeddings_plot_qa, labels_plot_qa, method='umap', manual_colors = ['#FF0000', '#000000', '#990066', '#CC0033', '#999999'])

In [None]:
########## RUN INFERENCE ON INDEPENDENT SAMPLES
# PRINT OUT FEATURE TABLES WITH THE PREDICTIONS INCLUDED FOR ANALYSIS WITH CLINICAL DATA
all_sample_ID_qa = res_dict["sample_ID"]
all_sample_ID_qa = np.concatenate(all_sample_ID_qa)
all_sample_ID_qa_unique = list(set(all_sample_ID_qa))

for sample_ID_qa in all_sample_ID_qa_unique:
        cells_qa = all_cells_raw_qa[np.where(all_sample_ID_qa == sample_ID_qa)]
        print(sample_ID_qa)

        features_qa = pd.read_csv(img_dir_qa+sample_ID_qa+".region_features.tsv", sep = "\t")
        features_qa = features_qa[features_qa['class_cleanSeg'] == "Cell"]
        features_qa['idx_0'] = range(len(features_qa))

        review_idx_qa = np.asarray(np.where(all_sample_ID_qa == sample_ID_qa)).squeeze()
        if review_idx_qa.shape == ():
            continue
        
        print("N cells predicted: "+str(len(review_idx_qa))+" (at PofZ = "+str(th_PofZ_0)+")")

        features_cNames = features_qa.columns
        target_cells_qa = features_qa

        target_cells_qa['y_pred'] = all_y_pred_qa[review_idx_qa].ravel()
        target_cells_qa['PofZ_0'] = all_PofZ_qa[review_idx_qa,0].ravel()
        target_cells_qa['PofZ_1'] = all_PofZ_qa[review_idx_qa,1].ravel()
        target_cells_qa['latent_features'] = list(all_embeddings_qa[review_idx_qa,:])#.ravel()
        
        target_cells_qa['idx_pred'] = np.arange(len(target_cells_qa))
        print("N CTCs predicted: "+str(len(target_cells_qa[target_cells_qa['y_pred']==10])))
        cell_features_qa = target_cells_qa.copy()        

        df_p = cell_features_qa
        cell_features_qa['y_pred'] = cell_features_qa['y_pred'].fillna(9)

        CTC_features_qa = cell_features_qa[cell_features_qa['y_pred']!=9] # print all examined cells so we can access thresholding
        CTC_features_qa.to_csv(img_dir_qa+sample_ID_qa+".cell_features.csv")
        print("---\n")