# LDA post-processing

In [None]:
import os
import numpy as np
import pandas as pd
import skimage.io as io
from skimage.segmentation import find_boundaries
import spatial_lda.model
import spatial_lda.visualization
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
from matplotlib import colors

import ark.utils.spatial_lda_utils as spu

## Make LDA cell table

In [None]:
base_dir = "../data"
processed_dir = "spatial_analysis/spatial_lda/preprocessed"
complete_name = "complete_spatial_lda_model_num_topics=6_diff=250"
cell_table_name = "formatted_cell_table"
seg_cell_tab_path = os.path.join(base_dir, "tables", "cell_table_size_normalized.csv")
save_path = os.path.join(base_dir, "tables", "cell_table_spatial_lda.csv")

processed_path = os.path.join(base_dir, processed_dir)

In [None]:
# Get data
complete_model = spu.read_spatial_lda_file(dir=processed_path, file_name=complete_name, format="pkl")
cell_table = spu.read_spatial_lda_file(dir=processed_path, file_name=cell_table_name, format="pkl")

In [None]:
# Look at some examples to make sure model loaded propperly
fov_set = ['sample1_fov1','sample10_fov10']
_plot_fn = spu.make_plot_fn(
    plot="topic_assignment", topic_weights=complete_model.topic_weights, cell_table=cell_table)
spatial_lda.visualization.plot_samples_in_a_row(
    complete_model.topic_weights, _plot_fn, cell_table, fov_set)

In [None]:
# Make one large cell table from spatial LDA object
dfs_list = []

for key, df in cell_table.items():
    if 'sample' in key:
        temp_df = pd.DataFrame(df)
        temp_df['fov'] = key
        temp_df = temp_df.reset_index(names="index")
        dfs_list.append(temp_df[['fov','index','x','y','cluster']])

combined_cell_tab = pd.concat(dfs_list, ignore_index=True)

In [None]:
# Get LDA weights
lda_tab = complete_model.topic_weights
lda_tab['fov'] = [index[0] for index in lda_tab.index]
lda_tab['index'] = [index[1] for index in lda_tab.index]
lda_tab = lda_tab.reset_index(drop=True)

combined_cell_tab = pd.merge(combined_cell_tab, lda_tab, on=['fov','index'])
combined_cell_tab = combined_cell_tab.rename(columns={'x':'centroid-0',
                                                      'y':'centroid-1',
                                                      'cluster':'cell_meta_cluster'})

In [None]:
# Get segmentation labels
seg_cell_table = pd.read_csv(seg_cell_tab_path)
combined_cell_tab = pd.merge(seg_cell_table[['fov','label','centroid-0','centroid-1','cell_meta_cluster']], combined_cell_tab, on=['fov','centroid-0','centroid-1','cell_meta_cluster'])
combined_cell_tab = combined_cell_tab.drop('index', axis=1)

# Get highest probability topic
topic_cols = [x for x in combined_cell_tab if 'Topic' in x]
combined_cell_tab['lda_me'] = combined_cell_tab[topic_cols].idxmax(axis=1)

combined_cell_tab.to_csv(save_path, index=False)

## Make LDA masks

In [None]:
lda_tab_path = os.path.join("../data", "tables", "cell_table_spatial_lda.csv")
seg_dir = "../data/segmentation_masks"

combined_cell_tab = pd.read_csv(lda_tab_path)
all_fovs = np.unique(combined_cell_tab['fov'])
all_topics = [x for x in combined_cell_tab.columns.values if 'Topic' in x]

# Extract topic with maximum score
combined_cell_tab['lda_me_int'] = combined_cell_tab['lda_me'].str.extract('(\d+)').astype(int) + 1
combined_cell_tab['lda_me_score'] = [
    combined_cell_tab.loc[idx, f'Topic-{me-1}'] 
    for idx, me in zip(combined_cell_tab.index, combined_cell_tab['lda_me_int'])
]

output_dir_individual = "../data/colored_lda_masks_individual"
output_dir_composite = "../data/colored_lda_masks_scores"
output_dir_composite_int = "../data/colored_lda_masks"
if not os.path.exists(output_dir_individual):
    os.makedirs(output_dir_individual)
if not os.path.exists(output_dir_composite):
    os.makedirs(output_dir_composite)
if not os.path.exists(output_dir_composite_int):
    os.makedirs(output_dir_composite_int)

In [None]:
# Colors
maxnum = 100
bounds_individual = [i-0.5 for i in np.linspace(0, maxnum+1, maxnum+2)]

num_topics = len(all_topics)
bounds_composite_int = [i-0.5 for i in np.linspace(0,num_topics+1,num_topics+2)]

white = np.array([1, 1, 1, 1])  # RGBA for white
topic_colors = [[0.243, 0.349, 0.224, 1],
                [0.639, 0.784, 0.745, 1],
                [0.596, 0.420, 0.502, 1],
                [0.678, 0.804, 0.906, 1],
                [0.482, 0.537, 0.643, 1],
                [0.6, 0.573, 0.475, 1]]

topic_colors_with_black = [[0, 0, 0, 0]] + topic_colors
colmap_composite_int = colors.ListedColormap(topic_colors_with_black)
norm_composite_int = colors.BoundaryNorm(bounds_composite_int, colmap_composite_int.N)

In [None]:
for fov in all_fovs:

    print(fov)
    seg_array = io.imread(os.path.join(seg_dir, fov+".tiff"))
    predicted_contour_mask = find_boundaries(seg_array, connectivity=1, mode='inner').astype(np.uint8)
    
    one_fov_cell_table = combined_cell_tab.loc[combined_cell_tab['fov'] == fov]
    
    ## One image for each topic
    for i,topic in enumerate(all_topics):
        fov_cell_dict = dict(zip(one_fov_cell_table['label'], one_fov_cell_table[topic]))
        fov_cell_dict[0] = 0
        
        dark_col = topic_colors[i]
        mycols = [white + (dark_col - white) * (i / (maxnum+1)) for i in range(maxnum+2)]
        colmap = colors.ListedColormap(mycols)
        norm = colors.BoundaryNorm(bounds_individual, colmap.N)

        lda_array = np.vectorize(fov_cell_dict.get, otypes=[float])(seg_array)
        lda_array = lda_array*100

        image = colmap(norm(lda_array))
        # Add cell borders
        image[predicted_contour_mask > 0] = [0, 0, 0, 1]
        # Change empty slide to black
        image[seg_array == 0] = [0, 0, 0, 1]

        plt.imsave(os.path.join(output_dir_individual, fov+"_"+topic+".tiff"), image)
    
    
    ## One with image with gradient for ME score
    fov_cell_dict = dict(zip(one_fov_cell_table['label'], one_fov_cell_table['lda_me_int']))
    fov_cell_dict_score = dict(zip(one_fov_cell_table['label'], one_fov_cell_table['lda_me_score']))
    
    fov_cell_dict[0] = 0
    fov_cell_dict_score[0] = 0
    
    me_array = np.vectorize(fov_cell_dict.get)(seg_array)
    score_array = np.vectorize(fov_cell_dict_score.get, otypes=[float])(seg_array)
    
    image = np.zeros((seg_array.shape[0], seg_array.shape[1], 4))
    for i in range(len(topic_colors)):
        mask = (me_array == i+1)
        if np.any(mask):
            # Get the values in this region (0 to 1)
            values = score_array[mask]
            # Create colors that blend from white to the ME color based on value
            color_array = np.zeros((np.sum(mask), 4))
            for j in range(3):  # RGB channels
                color_array[:, j] = white[j] + (topic_colors[i][j] - white[j]) * values
            # Set alpha channel
            color_array[:, 3] = 1.0
            # Assign these colors to the image
            image[mask] = color_array

    # Add cell borders
    image[predicted_contour_mask > 0] = [0, 0, 0, 1]
    # Change empty slide to black
    image[seg_array == 0] = [0, 0, 0, 1]
    # Save
    plt.imsave(os.path.join(output_dir_composite, fov+".tiff"), image)
    
    
    ## One with image with assigned ME for each cell
    me_array[predicted_contour_mask > 0] = 0
    image = colmap_composite_int(norm_composite_int(me_array))
    plt.imsave(os.path.join(output_dir_composite_int, fov+".tiff"), image)