In [1]:
# Dataset Directory

dataset_folder = '/mnt/Data1/Nick/transcription_pipeline/'

RBSPWM_datasets = [
    "test_data/2024-02-26/Halo-RBSPWM_embryo01",
    "test_data/2024-02-26/Halo-RBSPWM_embryo02",
    "test_data/2024-05-07/Halo552-RBSPWM_embryo01",
    "test_data/2024-05-07/Halo552-RBSPWM_embryo02",
    "test_data/2024-05-09/Halo552-RBSPWM_embryo01",
]

RBSVar2_datasets = [
    "test_data/2024-07-23/Halo673_RBSVar2_embryo01",
    "test_data/2024-07-25/Halo673_RBSVar2_embryo01",
    "test_data/2024-10-10/Halo673_RBSVar2_embryo01",
    "test_data/2024-10-10/Halo673_RBSVar2_embryo02",
]
MCP_mSG_datasets = [
    "test_data/2024-10-31/MCP-mSG_ParB-mScar_RBSPWM_embryo01",
    "test_data/2024-10-31/MCP-mSG_ParB-mScar_RBSPWM_embryo02",
    "test_data/2025-03-18/MCP-mSG_His-RFP_RBSPWM(003)_embryo01",
    "test_data/2025-03-18/MCP-mSG_His-RFP_RBSPWM(003)_embryo02",
    ]

NSPARC_datasets = [
    'test_data/NSPARC/2025-03-17/MCP-Halo552_His-GFP_Var2(001)_embryo01',
    'test_data/NSPARC/2025-03-31/MCP-mSG_His-RFP_Var2(001)_embryo01',
    'test_data/NSPARC/2025-03-31/MCP-mSG_His-RFP_Var2(001)_embryo02',
    'test_data/NSPARC/2025-04-01/MCP-mSG_His-RFP_Var2(001)_embryo20',
    'test_data/NSPARC/2025-04-01/MCP-mSG_His-RFP_Var2(001)_embryo38',
    'test_data/NSPARC/2025-04-14/MCP-mSG_His-RFP_Var2(001)_embryo28',
    'test_data/NSPARC/2025-04-15/MCP-mSG_His-RFP_Var2(001)_embryo01',
]
test_dataset_name = dataset_folder + NSPARC_datasets[5]
print('Dataset Path: ' + test_dataset_name)

Dataset Path: /mnt/Data1/Nick/transcription_pipeline/test_data/NSPARC/2025-04-14/MCP-mSG_His-RFP_Var2(001)_embryo28


In [2]:
# Import pipeline
from transcription_pipeline import nuclear_pipeline
from transcription_pipeline import preprocessing_pipeline

from transcription_pipeline import spot_pipeline
from transcription_pipeline import fullEmbryo_pipeline

from transcription_pipeline.spot_analysis import compile_data
from transcription_pipeline.utils import plottable

import os
import matplotlib.pyplot as plt
import matplotlib as mpl

`JAVA_HOME` environment variable set to /home/nickgravina/miniforge3/envs/transcription_pipeline


In [3]:
# Specify how you would want the plots to be shown: Use TkAgg if you use PyCharm, or widget if you use a browser

mpl.use('TkAgg')
# %matplotlib widget

In [4]:
ms2_import_previous = os.path.isdir(test_dataset_name + '/collated_dataset')
ms2_import_previous

True

In [5]:
dataset = preprocessing_pipeline.DataImport(
    name_folder=test_dataset_name,
    trim_series=True,
    working_storage_mode='zarr',
    import_previous=ms2_import_previous, 
)
if not ms2_import_previous:
    dataset.save()

In [None]:
dataset.export_frame_metadata[nuclear_channel]['t_s'].shape

### Import FullEmbryo Dataset

In [None]:
FullEmbryo_dataset = preprocessing_pipeline.FullEmbryoImport(
    name_folder=test_dataset_name,
    import_previous=True
)
# Loading FullEmbryo dataset is not working currently, but reported to Yovan where it only reads in the last channel
# FullEmbryo_dataset.save()

## Starting a DASK Client for parallel processing

In [None]:
from dask.distributed import LocalCluster, Client

try:
    cluster = LocalCluster(
        host="localhost",
        #scheduler_port=37763,
        threads_per_worker=1,
        n_workers=14,
        memory_limit="6GB",
    )
    
    client = Client(cluster)
except:
    print("Cluster already running")
    client = Client('localhost:37763')

print(client)

In [None]:
client.restart()

In [None]:
client.dashboard_link

## Nuclear Tracking

Detect whether the nuclear tracking has been done "previously." If so, load the previous results.

In [None]:
# Plot dataset
plt.figure(figsize=(12,6))

plt.subplot(1, 2, 1)
plt.imshow(dataset.channels_full_dataset[1][49,5, :, :], cmap='gray')
plt.show()

In [None]:
nuclear_channel = 1
spot_channel = 0

In [None]:
nuclear_tracking_previous = os.path.isdir(test_dataset_name + '/nuclear_analysis_results')
nuclear_tracking_previous

In [None]:
if nuclear_tracking_previous:
    # Load nuclear tracking results
    print('Load from previous nuclear tracking results')
    
    nuclear_tracking = nuclear_pipeline.Nuclear()
    nuclear_tracking.read_results(name_folder=test_dataset_name)
    
else:
    # Do nuclear tracking and save the results
    print('Do nuclear tracking for the dataset')
    
    nuclear_tracking = nuclear_pipeline.Nuclear(
        data=dataset.channels_full_dataset[nuclear_channel],
        global_metadata=dataset.export_global_metadata[nuclear_channel],
        frame_metadata=dataset.export_frame_metadata[nuclear_channel],
        series_splits=dataset.series_splits,
        series_shifts=dataset.series_shifts,
        search_range_um=1.5,
        stitch=False,
        stitch_max_distance=4,
        stitch_max_frame_distance=2,
        client=client,
        keep_futures=False,
    )
    
    nuclear_tracking.track_nuclei(
            working_memory_mode="zarr",
            working_memory_folder=test_dataset_name,
            trackpy_log_path="".join([test_dataset_name, "trackpy_log"]),
        )
        # Saves tracked nuclear mask as a zarr, and pickles dataframes with segmentation and
        # tracking information.
    nuclear_tracking.save_results(
            name_folder=test_dataset_name, save_array_as=None
        )

## Spot Tracking

Detect whether the spot tracking has been done "previously." If so, load the previous results.

In [None]:
spot_tracking_previous = os.path.isdir(test_dataset_name + '/spot_analysis_results')
spot_tracking_previous

In [None]:
%%time

if spot_tracking_previous:
    # Load spot tracking results
    print('Load from spot tracking results')
    
    spot_tracking = spot_pipeline.Spot()
    spot_tracking.read_results(name_folder=test_dataset_name)
    
else:
    # Do spot tracking and save the results
    print('Do spot tracking for the dataset')
    
    spot_tracking = spot_pipeline.Spot(
        data=dataset.channels_full_dataset[spot_channel],
        global_metadata=dataset.export_global_metadata[spot_channel],
        frame_metadata=dataset.export_frame_metadata[spot_channel],
        labels=nuclear_tracking.reordered_labels,
        expand_distance=3,
        search_range_um=4.2,
        retrack_search_range_um=4.5,
        threshold_factor=1.3,
        memory=3,
        retrack_after_filter=False,
        stitch=True,
        min_track_length=0,
        series_splits=dataset.series_splits,
        series_shifts=dataset.series_shifts,
        keep_bandpass=False,
        keep_futures=False,
        keep_spot_labels=False,
        evaluate=True,
        retrack_by_intensity=True,
        client=client,
    )
    
    spot_tracking.extract_spot_traces(
        working_memory_folder=test_dataset_name, 
        stitch=True,
        retrack_after_filter=True,
        trackpy_log_path = test_dataset_name+'/trackpy_log'
    )
    
    # Saves tracked spot mask as a zarr, and pickles dataframes with spot fitting and
    # quantification information.
    spot_tracking.save_results(name_folder=test_dataset_name, save_array_as=None)

### Make Compiled Dataframe

In [None]:
spot_tracking.reordered_spot_labels

In [None]:
# Load spot tracking dataframe
spot_df = spot_tracking.spot_dataframe

# Remove spots that were not detected
detected_spots = spot_df[spot_df["particle"] != 0]

# Compile traces
compiled_dataframe = compile_data.compile_traces(
    detected_spots,
    compile_columns_spot=[
        "frame",
        "t_s",
        "intensity_from_neighborhood",
        "intensity_std_error_from_neighborhood",
        "x",
        "y"
    ],
    nuclear_tracking_dataframe=None,
)

compiled_dataframe.head()

In [None]:
from transcription_pipeline.gui import check_spots

In [None]:
check_spots.CheckSpotsGUI(
    spot_channel=dataset.channels_full_dataset[spot_channel],
    labels=spot_tracking.reordered_spot_labels,
    dataset_name=test_dataset_name,
    spot_channel_index=spot_channel,
    compiled_dataframe=compiled_dataframe,
)

## Full Embryo Analysis

In [None]:
plt.figure(figsize=(12,6))

plt.subplot(1, 2, 1)
plt.imshow(FullEmbryo_dataset.channels_full_dataset_surf[0][0, :, :], cmap='gray')
plt.title('Full Embryo Surf')

plt.subplot(1, 2, 2)
plt.imshow(FullEmbryo_dataset.channels_full_dataset_mid[0][0, :, :], cmap='gray')
plt.title('Full Embryo Mid')

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib
from skimage import color, filters, morphology, util, measure, transform#, exposure, segmentation, io
from scipy.spatial import ConvexHull
from skimage.draw import line
import feret
import cv2 as cv
from skimage.transform import AffineTransform

def contour_mask(binary_mask):
    """
    Generates a mask by flood filling the largest contour within the input binary_mask.
    """
    contours = measure.find_contours(binary_mask)

    # Identify the desired contour (e.g., the largest)
    largest_contour = max(contours, key=len)

    # Fit a convex hull to the contour
    hull = ConvexHull(largest_contour)

    # Initialize the FullEmbryo mask
    mask = np.zeros(binary_mask.shape)

    # Extract points from the contour
    pts0 = [(largest_contour[simplex, 1][0], largest_contour[simplex, 0][0]) for simplex in hull.simplices]
    pts1 = [(largest_contour[simplex, 1][1], largest_contour[simplex, 0][1]) for simplex in hull.simplices]
    pts = pts0 + pts1
    pts = np.array(pts)

    # Calculate reference point for determining polar angle
    reference_point = np.mean(pts, axis=0)

    # Function to calculate the polar angle relative to a reference point
    def polar_angle(point):
        x, y = point[0] - reference_point[0], point[1] - reference_point[1]
        return np.arctan2(y, x)

    # Sort points based on polar angle
    sorted_pts = sorted(pts, key=polar_angle)

    # Draw contour connecting sorted points
    for i in range(len(sorted_pts)):
        if i == len(sorted_pts) - 1:
            x1, y1 = np.round(sorted_pts[i])
            x2, y2 = np.round(sorted_pts[0])
        else:
            x1, y1 = np.round(sorted_pts[i])
            x2, y2 = np.round(sorted_pts[i + 1])

        x1 = int(x1)
        y1 = int(y1)
        x2 = int(x2)
        y2 = int(y2)

        rr, cc = line(y1, x1, y2, x2)
        mask[rr, cc] = 1

    # Save contour mask
    contour_mask = mask

    # Flood fill to generate the FullEmbryo mask
    mask = morphology.flood_fill(mask, (0, 0), 1, connectivity=1)
    mask = util.invert(mask)

    return mask, contour_mask

def gen_full_embryo_mask(tif_array, sigma=10, radius=5):
    """
    Creates a FullEmbryo mask by detecting the embryo edge through a Gaussian blur, thresholding, and a closing operation.
    """
    # Convert the image to grayscale if it's not already
    # if tif_array.shape[-1] == 3:
    #     grayscale_image = color.rgb2gray(tif_array)
    # else:
    #     grayscale_image = tif_array

    # Gaussian blur the image with given sigma
    tif_array = filters.gaussian(tif_array, sigma)

    # Otsu thresholding
    threshold_value = filters.threshold_otsu(tif_array)
    tif_array = tif_array > 1 * threshold_value

    # Closing with disk of given radius
    tif_array = morphology.closing(tif_array, morphology.disk(radius))

    mask, contour = contour_mask(tif_array)
    return mask, contour

In [None]:
image = FullEmbryo_dataset.channels_full_dataset_mid[nuclear_channel][1, :, :]

# Threshold the image
otsu_threshold = filters.threshold_otsu(image)
print(otsu_threshold)
binary_image =image < otsu_threshold
final_image = np.where(binary_image, image, 0)


mask, _ = gen_full_embryo_mask(tif_array=image, sigma=10, radius=5)

plt.figure(figsize=(12,6))
plt.subplot(1, 2, 1)
plt.imshow(final_image, cmap='gray')
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap='gray')
plt.show()


In [None]:
fullEmbryo = fullEmbryo_pipeline.FullEmbryo(FullEmbryo_dataset, dataset, his_channel=nuclear_channel)

In [None]:
fullEmbryo.find_ap_axis(make_plots=True, remove_small_objects=False, ap_method='minf90', sigma=10, radius=5)

In [None]:
fullEmbryo.swap_ap_points(make_plots=True)

In [None]:
compiled_dataframe = fullEmbryo.xy_to_ap(compiled_dataframe)
compiled_dataframe.head()

In [None]:
# Save compiled_dataframe as pickle
compiled_dataframe.to_pickle(test_dataset_name + '/compiled_dataframe.pkl')

## RateExtraction Analysis


### Fit and Average

In [None]:
from transcription_pipeline.RateExtraction import FitAndAverage

In [None]:
# Specify here at what frame NC14 starts
nc14_start_frame = 0

# Any trace with frame number smaller than min_frames will be filtered out
min_frames = 40

# Number of bins you want to split the full embryo into
num_bins = 42

In [None]:
faadata = FitAndAverage(compiled_dataframe, nc14_start_frame, min_frames, num_bins, test_dataset_name)

In [None]:
faadata.check_particle_fits()

In [None]:
faadata.save_checked_particle_fits()

In [None]:
faadata.average_particle_fits(plot_results=True, show_slopes=True);

In [None]:
faax, faay, faay_err, _, _ = faadata.average_particle_fits();

### Average and Fit (using approved particle fits)

In [None]:
import pandas as pd

faadatapoints = pd.read_pickle(faadata.checked_particle_fits_file_path)

In [None]:
from transcription_pipeline.RateExtraction import AverageAndFit
time_bin_width = dataset.export_frame_metadata[0]['t_s'][1, 0]
aafdata_sp = AverageAndFit(faadatapoints, nc14_start_frame-3, time_bin_width, num_bins, test_dataset_name)

In [None]:
aafdata_sp.check_bin_fits()

In [None]:
aafdata_sp.bin_average_fit_dataframe;

In [None]:
aafdata_sp.save_checked_bin_fits()

In [None]:
aafspx, aafspy, aafspy_err = aafdata_sp.plot_bin_fits()

### Average and Fit

In [None]:
from transcription_pipeline.RateExtraction import AverageAndFit

In [None]:
time_bin_width = dataset.export_frame_metadata[0]['t_s'][1, 0]
aafdata = AverageAndFit(compiled_dataframe, nc14_start_frame, time_bin_width, num_bins, test_dataset_name);

In [None]:
aafdata.check_bin_fits()

In [None]:
aafdata.bin_average_fit_dataframe

In [None]:
aafdata.save_checked_bin_fits()

In [None]:
aafx, aafy, aafy_err = aafdata.plot_bin_fits()

In [None]:
plt.errorbar(faax, faay, yerr=faay_err, capsize=2, fmt='o', label='faa')
plt.errorbar(aafspx, aafspy, yerr=aafspy_err, capsize=2, fmt='o', label='aaf_sp')
plt.errorbar(aafx, aafy, yerr=aafy_err, capsize=2, fmt='o', label='aaf')


plt.show()