In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import untangle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import h5py
import tifffile

In [2]:
# time interval of the movie in minutes
time_interval = 2

# Path to the -spots.csv 
path_csv = 'E:/Agnese/20241212_fgfeGFP_h2a_modERK-rnanuclow/20241212_124318_Experiment/BordersTracking-Spot.csv' 

# Path to .xml file generated when creating the .hdf5
path_xml = 'E:/Agnese/20241212_fgfeGFP_h2a_modERK-rnanuclow/20241212_124318_Experiment/pos1.xml' 

# Path to .hs file generated when creating the xml/hdf5 
path_hdf5 = 'E:/Agnese/20241212_fgfeGFP_h2a_modERK-rnanuclow/20241212_124318_Experiment/pos1.h5'

In [3]:
# Obtain all the features that are in the .xml file which has been generated when the data
# is converted to .hdf5 using BigDataViewer/BigStitcher/MultiviewReconstruction in Fiji
import untangle
class xml_features:
    def __init__(self, path_xml):
        # Parse .xml file
        obj = untangle.parse(path_xml)
        # Data Features
        try:
            self.channels = len(obj.SpimData.SequenceDescription.ViewSetups.Attributes[1])
        except:
            self.channels = len(obj.SpimData.SequenceDescription.ViewSetups.Attributes.Channel)
        ch = self.channels
        self.dim = 3
        
        if ch > 1:
            self.width = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].size.cdata.split()[0])
            self.height = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].size.cdata.split()[1])
            self.n_slices = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].size.cdata.split()[2])

            self.x_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].voxelSize.size.cdata.split()[0])
            self.y_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].voxelSize.size.cdata.split()[1])
            self.z_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].voxelSize.size.cdata.split()[2])
            
            self.units = obj.SpimData.SequenceDescription.ViewSetups.ViewSetup[0].voxelSize.unit.cdata
        else:
            self.width = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.size.cdata.split()[0])
            self.height = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.size.cdata.split()[1])
            self.n_slices = int(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.size.cdata.split()[2])

            self.x_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.voxelSize.size.cdata.split()[0])
            self.y_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.voxelSize.size.cdata.split()[1])
            self.z_pixel = float(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.voxelSize.size.cdata.split()[2])
            
            self.units = obj.SpimData.SequenceDescription.ViewSetups.ViewSetup.voxelSize.unit.cdata
            

        #self.channels = len(obj.SpimData.SequenceDescription.ViewSetups.ViewSetup)
      
        while True:
            try:
                self.n_frames = len(obj.SpimData.SequenceDescription.Timepoints.integerpattern.cdata.split())
                break
            except AttributeError:
                pass  # fallback to dict
            try:
                self.n_frames = int(obj.SpimData.SequenceDescription.Timepoints.last.cdata.split()[0])
                break
            except KeyError:
                raise AttributeError("There is something wrong with the .xml file - Did you compute the features?") from None
fts = xml_features(path_xml)

# Features from the image obtained from the .xml generated by BigData Viewer
print('The image has the following dimensions (XYZC): %d, %d, %d, %d'%(fts.width, fts.height, fts.n_slices, fts.channels))
print('There are %d frames in total.'%fts.n_frames)
print('Pixel Size: x = %.3g %s, y = %.3g %s z = %.2g %s'%(fts.x_pixel, fts.units, fts.y_pixel, fts.units, fts.z_pixel, fts.units))
print('There are %d Z-slices in total'%fts.n_slices)

The image has the following dimensions (XYZC): 2304, 2304, 200, 3
There are 149 frames in total.
Pixel Size: x = 0.347 micron, y = 0.347 micron z = 1.5 micron
There are 200 Z-slices in total


#### Import Mastodon Output

In [4]:
# Load the CSV file
df = pd.read_csv(path_csv, skiprows=[1, 2])

# Column names for various features
track_column = 'Spot track ID'       # Track ID column
time_column = 'Spot frame'           # Frame (time step) column
time_column_corrected = 'Time in min'
position_x_column = 'Spot position'  # X position column
position_y_column = 'Spot position.1'  # Y position column
position_z_column = 'Spot position.2'  # Z position column

# Convert relevant columns to numeric types
columns_to_convert = [track_column, time_column, position_x_column, position_y_column, position_z_column]
for col in columns_to_convert:
    df[col] = pd.to_numeric(df[col], errors='coerce')
# Create a new column for corrected time
df[time_column_corrected] = df[time_column] * time_interval

# Drop rows with NaN values in the critical columns
df.dropna(subset=[track_column, time_column, time_column_corrected, position_x_column, position_y_column, position_z_column], inplace=True)

# Ensure the DataFrame is sorted by track ID and corrected time
df = df.sort_values(by=[track_column, time_column_corrected])

### Crop ROI around tracks

#### rearrange the data 

In [5]:
# Count the number of unique tracks (cells)
n_cells = df[track_column].nunique()

# Initialize matrices to store tracked cell data at each time point
tracks_save = np.ones((n_cells, fts.n_frames + 1)) * np.nan
x_save = np.ones((n_cells, fts.n_frames + 1)) * np.nan
y_save = np.ones((n_cells, fts.n_frames + 1)) * np.nan
z_save = np.ones((n_cells, fts.n_frames + 1)) * np.nan

# For each track (unique track_id)
for i, track_id in enumerate(df[track_column].unique()):
    # Find the rows corresponding to the current track_id
    track_data = df[df[track_column] == track_id]
    
    # For each time frame for the current track
    for _, row in track_data.iterrows():
        # Get the time and frame values
        time = row[time_column]  # This will be the time frame 
        
        # Ensure time is an integer (round or cast to int)
        time = int(time)  # Cast the time to integer if necessary
        
        # Adjust time to zero-based indexing by subtracting 1
        time_index = time
        
        # Check if time_index is within the (fts.n_frames + 1)
        if 0 <= time_index < fts.n_frames + 1:
            # Store the tracked data at the corresponding position
            tracks_save[i, time_index] = track_id
            x_save[i, time_index] = row[position_x_column]
            y_save[i, time_index] = row[position_y_column]
            z_save[i, time_index] = row[position_z_column]


#### apply transformation to reverse registration

In [6]:
# Parse the XML file
obj = untangle.parse(path_xml)

x_save_r = []
y_save_r = []
z_save_r = []

for j in range(n_cells):
    
    aux_x = []
    aux_y = []
    aux_z = []
    
    for i in range(0,(fts.n_frames+1)*fts.channels,fts.channels):
        
        # XYZ coordinates from mastodon
        xyz_coord = np.array([x_save[j][int(i/fts.channels)],y_save[j][int(i/fts.channels)],z_save[j][int(i/fts.channels)]]).reshape(-1,1)

        # Registration from .xml
        registration = np.array(obj.SpimData.ViewRegistrations.ViewRegistration[i].ViewTransform[0].affine.cdata.split(),dtype=float).reshape(3,4)
        
        # Rotation matrix
        R = registration[:,:3] 
        
        # Inverse of rotation matrix
        R_inv = np.linalg.inv(R)
        
        # Translation matrix
        T = registration[:,-1].reshape(-1,1)
        
        final = R_inv@(xyz_coord-T)
        
        # Save xyz de-registered coordinates
        aux_x.append(final[0][0])
        aux_y.append(final[1][0])
        aux_z.append(final[2][0])

    x_save_r.append(aux_x)
    y_save_r.append(aux_y)
    z_save_r.append(aux_z)

#### crop the hdf5 file

In [None]:
#TO ADJUST FOR 2D CROPPING


def crop_hdf5_optimized(x_save_r, y_save_r, z_save_r, fts, path_hdf5, data_path, 
                         n_cells, x_crop, y_crop, z_crop, tp, channels, subsampling):
    """
    Efficiently crops multi-channel data from an HDF5 file and saves as merged TIFFs.
    Processes all time points in one go to avoid redundant file access.
    """
    subsampling_key = str(subsampling)

    # Open HDF5 file ONCE
    with h5py.File(path_hdf5, 'r') as f:
        for cell_n in range(n_cells):  # Iterate over first cell only (adjust if needed)
            
            # Ensure the cell has enough time points
            if len(x_save_r[cell_n]) <= max(tp):
                print(f"Warning: Cell {cell_n} has fewer time points than requested. Skipping.")
                continue

            for j in tp:
                if np.isnan(x_save_r[cell_n][j]):
                    continue  # Skip missing data
                
                # Convert coordinates to indices
                x_val = int(np.round(x_save_r[cell_n][j] / fts.x_pixel))
                y_val = int(np.round(y_save_r[cell_n][j] / fts.y_pixel))
                z_val = int(np.round(z_save_r[cell_n][j] / fts.z_pixel))

                # Get the time point group
                group_name = f't{j:05d}'
                if group_name not in f:
                    print(f"Warning: Time point {j} not found. Skipping.")
                    continue
                
                group = f[group_name]
                merged_channels_data = []

                # Loop through channels
                for ch in channels:
                    channel_name = f's0{ch}'
                    if channel_name not in group:
                        print(f"Warning: Channel {channel_name} not found at time point {j}. Skipping.")
                        continue

                    dataset = group[channel_name]


                # fill the data with 0s first to pad the cells at the edges of the image file
                    try:
                        dataset_shape = dataset[subsampling_key]['cells'].shape
                        #z_min, z_max = max(0, z_val - z_crop), min(dataset_shape[0], z_val + z_crop)
                        y_min, y_max = max(0, y_val - y_crop), min(dataset_shape[1], y_val + y_crop)
                        x_min, x_max = max(0, x_val - x_crop), min(dataset_shape[2], x_val + x_crop)

                        # Create an empty crop slice with zeros
                        #crop_slice = dataset[subsampling_key]['cells'][z_val, y_min:y_max, x_min:x_max]
                        #data_array = dataset[subsampling_key]['cells']  # Get the actual dataset
                        #crop_slice = np.zeros((2 * z_crop, 2 * y_crop, 2 * x_crop), dtype=data_array.dtype)


                        # Compute valid ranges in the crop slice
                        #crop_z_start, crop_z_end = z_min - (z_val - z_crop), z_max - (z_val - z_crop)
                        #crop_y_start, crop_y_end = y_min - (y_val - y_crop), y_max - (y_val - y_crop)
                        #crop_x_start, crop_x_end = x_min - (x_val - x_crop), x_max - (x_val - x_crop)
                        crop_y_start, crop_y_end = y_min - (y_val - y_crop), y_max - (y_val - y_crop)
                        crop_x_start, crop_x_end = x_min - (x_val - x_crop), x_max - (x_val - x_crop)

                        # Copy valid data into the zero-padded crop slice
                        crop_slice = dataset[subsampling_key]['cells'][z_val, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
                        #crop_slice[crop_z_start:crop_z_end, crop_y_start:crop_y_end, crop_x_start:crop_x_end] = \
                        #dataset[subsampling_key]['cells'][z_min:z_max, y_min:y_max, x_min:x_max]

                        merged_channels_data.append(crop_slice)

                    except (IndexError, KeyError) as e:
                        print(f"Warning: Issue accessing {channel_name} at time point {j}: {e}. Skipping.")
                        continue



                        # Store channel data
                        #merged_channels_data.append(crop_slice)

                    except (IndexError, KeyError) as e:
                        print(f"Warning: Issue accessing {channel_name} at time point {j}: {e}. Skipping.")
                        continue

                if merged_channels_data:
                    # Stack channels into (C, Z, Y, X)
                    merged_stack = np.stack(merged_channels_data, axis=0)
                    
                    # Transpose to (Z, C, Y, X) for TIFF saving
                    #merged_stack = np.transpose(merged_stack, (1, 0, 2, 3))
                    #merged_stack = np.transpose(merged_stack, (0, 2, 3))

                    # Save as multi-channel TIFF
                    tifffile.imwrite(
                        f'{data_path}/Cell_{cell_n}_T_{j:03d}_merged.tif',
                        merged_stack,
                        imagej=True,
                        metadata={'axes': 'CYX'}
                    )

                    print(f"Saved cell {cell_n}, time point {j}")
                else:
                    print(f"No valid channels for cell {cell_n}, time point {j}. Skipping save.")

# Parameters
x_crop = 35
y_crop = 35
z_crop = 25
subsampling = 0
channels = [0, 1, 2]  # Channels to merge

# Paths
path_hdf5 = Path(path_hdf5)
data_path = path_hdf5.parent / "cropped-data_2d"
data_path.mkdir(parents=True, exist_ok=True)

# Process all time points in one pass
crop_hdf5_optimized(x_save_r, y_save_r, z_save_r, fts, path_hdf5, data_path, 
                     n_cells, x_crop, y_crop, z_crop, range(fts.n_frames), channels, subsampling)


Saved cell 0, time point 50
Saved cell 0, time point 51
Saved cell 0, time point 52
Saved cell 0, time point 53
Saved cell 0, time point 54
Saved cell 0, time point 55
Saved cell 0, time point 56
Saved cell 0, time point 57
Saved cell 0, time point 58
Saved cell 0, time point 59
Saved cell 0, time point 60
Saved cell 0, time point 61
Saved cell 0, time point 62
Saved cell 0, time point 63
Saved cell 0, time point 64
Saved cell 0, time point 65
Saved cell 0, time point 66
Saved cell 0, time point 67
Saved cell 0, time point 68
Saved cell 0, time point 69
Saved cell 0, time point 70
Saved cell 0, time point 71
Saved cell 0, time point 72
Saved cell 0, time point 73
Saved cell 0, time point 74
Saved cell 0, time point 75
Saved cell 0, time point 76
Saved cell 0, time point 77
Saved cell 0, time point 78
Saved cell 0, time point 79
Saved cell 0, time point 80
Saved cell 0, time point 81
Saved cell 0, time point 82
Saved cell 0, time point 83
Saved cell 0, time point 84
Saved cell 0, time p