## DUSP1 Confirmation Notebook
The purpose of this notebook is to:
1. Confirm successful segmentation.
2. Confirm successful BigFISH spot and cluster detection.
3. Refine spots and clusters through additional filtering (SNR) for gating and final dataframe preparation (determine if best before or after total concatenation):  
    a. Find SNR threshold.  
    b. Filter `df_spots`.  
    c. (Optional) Check to see if removed spot was in a cluster (very unlikely due to how clusters are defined).  
    d. Create final dataframes (`df_spots`, `df_clusters`, `df_cellspots`, `df_cellprops`).  
    e. Save the dataframes.

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import dask.array as da
import os
import sys
import logging


logging.getLogger('matplotlib.font_manager').disabled = True
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)

src_path = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
print(src_path)
sys.path.append(src_path)

from src.Analysis import AnalysisManager, Analysis, SpotDetection_SNRConfirmation, Spot_Cluster_Analysis_WeightedSNR, GR_Confirmation

Loads in the data from specified location

In [2]:
loc = None 
log_location = r'/Volumes/share/Users/Eric/AngelFISH_data'  #  r'/Volumes/share/Users/Jack/All_Analysis'
am = AnalysisManager(location=loc, log_location=log_location, mac=True) 

In [None]:
# list all analysis done 
am.list_analysis_names()

In [4]:
# can filter on naDe and dates
am.select_analysis('DUSP1_D_Jan2125')


In [None]:
print(am.analysis_names)
am.location

In [None]:
am.list_datasets()

Does analysis/confirmation

In [7]:
# select DUSP1 spot detection
# SD = SpotDetection_Confirmation(am)
SD = Spot_Cluster_Analysis_WeightedSNR(am)

In [8]:
# this loads the data into memory 
SD.get_data()

In [None]:
# run this multiple times to see a new randomly selected cell
SD.display(newFOV=True, newCell=True) # num_fovs_to_display=2,num_cells_to_display=2, num_spots_to_display=4

In [None]:
df_spots = pd.DataFrame(SD.spots)
df_clusters = pd.DataFrame(SD.clusters)
df_cellspots = pd.DataFrame(SD.cellspots)
df_cellprops = pd.DataFrame(SD.cellprops)

# Print columns for each dataframe
print("df_spots columns:")
print(", ".join(df_spots.columns))

print("\ndf_clusters columns:")
print(", ".join(df_clusters.columns))

print("\ndf_cellspots columns:")
print(", ".join(df_cellspots.columns))

print("\ndf_cellprops columns:")
print(", ".join(df_cellprops.columns))

In [None]:
# Number of timepoints in the dataset
print(df_spots['time'].unique())
print(df_spots['h5_idx'].unique())

In [None]:
# Calculate statistics for all times
stats = df_spots.groupby('time')['snr'].agg(['mean', 'median', 'std'])

# Dynamic SNR thresholding based on mean ± 2*std
thresholds = {}
for time, row in stats.iterrows():
    thresholds[time] = (row['mean'] - 2 * row['std'], row['mean'] + 2 * row['std'])

# Apply dynamic thresholding
df_spots['threshold_pass'] = df_spots.apply(
    lambda row: thresholds[row['time']][0] <= row['snr'] <= thresholds[row['time']][1],
    axis=1
)

# Filtered DataFrame
df_spots_filtered = df_spots[df_spots['threshold_pass']]

# Categorize noise levels based on the paper's description
def categorize_noise(snr):
    if snr < 2:
        return 'very_high_noise'
    if 2 <= snr < 5:
        return 'snr:2-5'
    elif 8 <= snr <= 26:
        return 'snr:8-26'
    elif snr > 26:
        return 'snr>26'
    else:
        return 'did this work?'

df_spots['noise_level'] = df_spots['snr'].apply(categorize_noise)

# Plot histograms of SNR for each time
for time, group in df_spots.groupby('time'):
    mean_snr = group['snr'].mean()
    median_snr = group['snr'].median()
    std_snr = group['snr'].std()
    
    plt.hist(group['snr'], bins=50, alpha=0.7, label=f'time: {time}')
    plt.axvline(mean_snr, color='b', linestyle='dashed', linewidth=1, label=f'Mean: {mean_snr:.2f}')
    plt.axvline(median_snr, color='g', linestyle='dashed', linewidth=1, label=f'Median: {median_snr:.2f}')
    plt.title(f'SNR Histogram for time {time}min')
    plt.xlabel('SNR')
    plt.ylabel('Count')
    plt.legend()
    plt.show()


# Scatter Plot of Intensity vs SNR
plt.scatter(df_spots['signal'], df_spots['snr'], s=1, alpha=0.7)
plt.xlabel('Intensity')
plt.ylabel('SNR')
plt.title('Intensity vs SNR (All Spots)')
plt.xscale('log')
plt.yscale('log')
plt.show()

# Noise Level Distribution
noise_level_counts = df_spots['noise_level'].value_counts()
plt.bar(noise_level_counts.index, noise_level_counts.values, alpha=0.7)
plt.title('Noise Level Distribution')
plt.xlabel('Noise Level')
plt.ylabel('Spot Count')
plt.show()


In [None]:
df_spots_filtered = df_spots[(df_spots['cell_label'] > 0)] # & (df_spots['cluster_index'] > -1)]
print(len(df_spots_filtered))

In [None]:
df_clusters = pd.DataFrame(SD.clusters)
df_clusters.columns
df_clusters_filtered = df_clusters[(df_clusters['is_nuc'] > 0)]
print(len(df_clusters_filtered))

In [None]:
df_clusters['is_nuc'].unique()

In [None]:
SD.cellprops


In [17]:
df = pd.DataFrame(SD.cellprops) 

In [None]:
df.columns

In [None]:
df_cellspots = pd.DataFrame(SD.cellspots)
df_cellspots.columns

In [None]:
num_TS = df_cellspots[(df_cellspots['nb_transcription_site'] > 0)]
print(len(num_TS))

In [None]:
# find cells that have props but arent in the cell spots
allcells = SD.cellprops
cells_wSpots = SD.cellspots
# Find cells that are in allcells but not in cells_wSpots
merged = allcells.merge(cells_wSpots, how='left', left_on=['nuc_label', 'fov', 'NAS_location'], right_on=['cell_id', 'fov', 'NAS_location'], indicator=True)
print(merged.shape)
same_entries = merged[merged['_merge'] == 'both'].drop(columns=['cell_id', '_merge'])
different_entries = merged[merged['_merge'] == 'left_only'].drop(columns=['cell_id', '_merge'])

print("Same entries:")
print(same_entries.shape)
print("\nDifferent entries:")
print(different_entries.shape)

In [None]:
import random
import dask.array as da
print(f'There are {allcells.shape[0]} cells in this data set')
print(f'There are {cells_wSpots.shape[0]} cells with spots')

# how many are have bounded boxes touching the border
print(f'{different_entries['touching_border'].sum()} cells are touching the border and are not counted')

# Select a random row from the different_entries dataframe
for _ in range(2):
    random_row = different_entries[~different_entries['touching_border']].sample(n=1).iloc[0]

    # Read the h5 file
    h5_file = random_row['NAS_location']
    h5_file = os.path.join(r'\\munsky-nas.engr.colostate.edu\share', h5_file) # TODO this will need to be updated so you dont have to find it to get it to work
    with h5py.File(h5_file, 'r') as f:
        # Grab the mask and raw_image
        masks = da.from_array(f['/masks'])
        raw_images = da.from_array(f['/raw_images'])

        # Extract the bounding box coordinates
        bbox = [random_row['cell_bbox-0'], random_row['cell_bbox-1'], random_row['cell_bbox-2'], random_row['cell_bbox-3']]

        img = raw_images[random_row['fov'], random_row['timepoint_x']].squeeze()
        for c in range(img.shape[0]):
            # Display the raw image with the selected cell highlighted
            t = np.max(img[c, :, :,:], axis=0)
            t.compute()
            fig, ax = plt.subplots(1, 1, figsize=(10, 10))
            ax.imshow(t, cmap='gray')
            rect = plt.Rectangle((bbox[1], bbox[0]), bbox[3] - bbox[1], bbox[2] - bbox[0], edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            plt.show()

In [None]:
# Histogram of spots
keys_to_plot = ['signal', 'snr']
for k in SD.spots.keys():
    if k in keys_to_plot:
        # Plot histogram for 'area'
        plt.figure(figsize=(10, 5))
        plt.hist(SD.spots[k], bins=200, density=True)
        plt.ylabel('Frequency')
        plt.title(f'Histogram of {k}')
        plt.legend()
        plt.show()

In [None]:
SD.cellspots.keys()

In [None]:
# spot counts as a function of time and dex
keys_to_plot = ['nb_rna', 'nb_rna_in_nuc']


tp_set = sorted(set(SD.cellspots['time']))
dex_set = sorted(set(SD.cellspots['Dex_Conc']))
for k in keys_to_plot:
    fig, axs = plt.subplots(len(tp_set), len(dex_set), figsize=(15, 15))
    fig.suptitle(f'{k} as a function of time and dex', fontsize=16)
    for i_d, d in enumerate(dex_set):
        data = SD.cellspots[SD.cellspots['Dex_Conc'] == d]
        for i_t, t in enumerate(tp_set):
            temp = data[data['time'] == t]
            mean_val = temp[k].mean()
            std_val = temp[k].std()
            if d == 0 and t == 0:
                for ax in axs[i_t, :]:
                    ax.hist(temp[k], bins=200, density=True)
                    ax.axvline(mean_val, color='r', linestyle='solid', linewidth=2)
                    ax.axvline(mean_val + std_val, color='g', linestyle='dashed', linewidth=1)
                    ax.axvline(mean_val - std_val, color='g', linestyle='dashed', linewidth=1)
                    ax.set_xlim([0, SD.cellspots[k].max()])
                    ax.grid(True)  # Turn on grid lines
                    if i_t != len(tp_set) - 1:
                        axs[i_t, i_d].set_xticks([])
                    ax.set_yticks([])
                axs[i_t, 0].set_ylabel(f'Time: {t}')
            else:
                axs[i_t, i_d].hist(temp[k], bins=200, density=True)
                axs[i_t, i_d].axvline(mean_val, color='r', linestyle='solid', linewidth=2)
                axs[i_t, i_d].axvline(mean_val + std_val, color='g', linestyle='dashed', linewidth=1)
                axs[i_t, i_d].axvline(mean_val - std_val, color='g', linestyle='dashed', linewidth=1)
                axs[i_t, i_d].set_xlim([0, SD.cellspots[k].max()])
                axs[i_t, i_d].grid(True)  # Turn on grid lines
                if i_t != len(tp_set) - 1:
                    axs[i_t, i_d].set_xticks([])
                axs[i_t, i_d].set_yticks([])
                axs[i_t, 0].set_ylabel(f'Time: {t}')
                axs[0, i_d].set_title(f'Dex: {d}')
    plt.show()



In [None]:
am.close()