In [None]:
# Image analysis pipeline for DNA vertex jump under electric field reversal
# Original : DNA stretched rightward before reversal and leftward after; Opposite: DNA stretched leftward before reversal and rightward after
# The code directly uses raw TIFF files that have not been rotated. 
# The rightward direction in the figure corresponds to the downward direction in the imported TIFF files, and leftward in the figure corresponds to upward in the TIFF.

import tifffile
import numpy as np
from scipy.spatial import KDTree
import cv2
import matplotlib.pyplot as plt
from skimage.filters import frangi
from scipy.ndimage import generic_filter, gaussian_filter1d
from scipy.signal import find_peaks
import math
import os
import csv
import uuid
from collections import defaultdict

# ===== Parameters =====
window_size = 9
buffer_frame = 1
max_shift_distance = 40.0 # Distance constraint when matching vertices
max_angle = 15  # Angular constraint when matching vertices (within 15° of the streamwise direction)

tiff_path_list = [r"D:\DNA data\20250721\REV\20250731_1PVP_REV(4).tif"]
# Manually identified central frame of each reversal event 
frame_list_original = [[91, 291, 504]]
frame_list_opposite = [[184, 378, 621]]

all_matches_original = []
all_matches_opposite = []

params = {
    'min_length': 40,        # Minimum DNA length in pixels (vertical)
    'max_width': 8,         # Maximum DNA width in pixels
}

# ===== Functions =====
def find_ends_by_morphology(image, direction=0):
    """Find DNA ends using morphology"""
    # Edge detection
    mean_val = np.mean(image)
    std_val = np.std(image)
    image_edges = cv2.Canny(image, int(mean_val), int(mean_val + 3 * std_val))
    
    potential_pinpoints = []
    potential_ends = []
    height, width = image.shape
    
    # Scan through image to find edge points
    for py in range(1, height-1):
        for px in range(1, width-1):
            if image_edges[py, px] == 255 and image[py, px] > mean_val:
                neighborhood = image_edges[py-1:py+2, px-1:px+2]

                if direction == 0:
                    # Top edge detection (for downward DNA)
                    if np.all(neighborhood[0, :] == 0):  # Top edge
                        potential_pinpoints.append((px, py))
                    # Bottom edge detection (for upward DNA)
                    elif np.all(neighborhood[2, :] == 0):  # Bottom edge
                        potential_ends.append((px, py))
                else:
                    # Top edge detection (for downward DNA)
                    if np.all(neighborhood[0, :] == 0):  # Top edge
                        potential_ends.append((px, py))
                    # Bottom edge detection (for upward DNA)
                    elif np.all(neighborhood[2, :] == 0):  # Bottom edge
                        potential_pinpoints.append((px, py))
    
    # Select pinpoints and ends
    pinpoint = None
    if potential_pinpoints:
        # For downward DNA, pinpoint is topmost point
        if direction == 0:
            pinpoint = min(potential_pinpoints, key=lambda p: p[1])
        else:
            pinpoint = max(potential_pinpoints, key=lambda p: p[1])
            
    longend = None
    if potential_ends:
        if direction == 0:
            # For downward DNA, longend is bottommost point
            longend = max(potential_ends, key=lambda p: p[1])
        else:
            longend = min(potential_ends, key=lambda p: p[1])
            
    return np.array(pinpoint) if pinpoint else None, [], np.array(longend) if longend else None

def check_overlap(image):
    """Check for overlapping DNA arms (vertical orientation)"""
    count = 0
    distance = 0
    max_length = 0
    cropped_image = image[:, :image.shape[1] - image.shape[1] % 3]
    
    for i in range(0, cropped_image.shape[1], 3):
        col = np.sum(cropped_image[:, i:i+3], axis=1)
        smoothed = gaussian_filter1d(col, sigma=1)
        peaks, _ = find_peaks(smoothed, distance=3, height=np.mean(smoothed)*1.1)
        
        if len(peaks) == 2:
            distance += abs(peaks[0] - peaks[1])
            count += 1
        else:
            if count >= 3:
                return True, distance/count, i
            count = 0
            distance = 0
    
    return False, 0, 0

def find_ends(image, direction=0):
    """Find DNA ends in vertically oriented images"""
    # Morphology-based detection
    pinpoint_morph, _, longend_morph = find_ends_by_morphology(image, direction)
    
    # Intensity-based detection
    mean_intensity = np.mean(image)
    vertical_gradient = generic_filter(image, np.std, size=(1, 3))
    vertical_profile = np.sum(vertical_gradient, axis=0)
    vertical_smoothed = gaussian_filter1d(vertical_profile, sigma=1)
    peaks, _ = find_peaks(vertical_smoothed)
    
    candidates = []
    for px in peaks:
        col = image[:, px]
        py = np.argmax(col)
        if col[py] > mean_intensity:
            candidates.append((px, py))

    if direction == 0:
        pinpoint_intensity = min(candidates, key=lambda p: p[1]) if candidates else None
        longend_intensity = max(candidates, key=lambda p: p[1]) if candidates else None
    else:
        pinpoint_intensity = max(candidates, key=lambda p: p[1]) if candidates else None
        longend_intensity = min(candidates, key=lambda p: p[1]) if candidates else None
        
    # Combine results
    potential_pinpoints = []
    if pinpoint_morph is not None: 
        potential_pinpoints.append(pinpoint_morph)
    if pinpoint_intensity is not None: 
        potential_pinpoints.append(pinpoint_intensity)
    
    potential_longends = []
    if longend_morph is not None: 
        potential_longends.append(longend_morph)
    if longend_intensity is not None: 
        potential_longends.append(longend_intensity)
    
    # Final selection 
    if direction == 0:
        pinpoint = min(potential_pinpoints, key=lambda p: p[1]) if potential_pinpoints else None
        longend = max(potential_longends, key=lambda p: p[1]) if potential_longends else None
    else:
        pinpoint = max(potential_pinpoints, key=lambda p: p[1]) if potential_pinpoints else None
        longend = min(potential_longends, key=lambda p: p[1]) if potential_longends else None
        
    return pinpoint, None, longend

def detect_pinpoints(median_image, params, direction=0):
    """Detect pinpoints in vertically oriented DNA images"""
    # Preprocessing
    processed = cv2.convertScaleAbs(median_image, alpha=2.0, beta=-200)
    
    # Processing pipeline
    blurred = cv2.GaussianBlur(processed, (1, 15), 0)  # Vertical blur for vertical DNA
    background = cv2.GaussianBlur(blurred, (25, 25), 0)
    dog = cv2.subtract(blurred, background)
    
    # Frangi filtering for vertical structures
    frangi_img = frangi(dog, np.linspace(1.0, 6.0, 12), alpha=0.3, beta=1.2, gamma=7.5, black_ridges=False)
    thresholded = (frangi_img > 0.01).astype(np.uint8) * 255
    
    # Connected component analysis
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(thresholded)

    print("Cropped Segmentation:", num_labels)
    
    pinpoints = []
    for i in range(1, num_labels):
        x, y, w, h = stats[i, 0:4]
        
        # Filter based on vertical orientation
        if (h > params['min_length'] and  # Height is now the length
            w < params['max_width']):
            
            roi = dog[y:y+h, x:x+w]
            pinpoint, _, longend = find_ends(roi, direction)
            
            if pinpoint is not None:
                # Convert to global coordinates
                global_coord = (x + pinpoint[0], y + pinpoint[1])
                pinpoints.append(global_coord)
    
    return pinpoints

def compute_window_median(image_stack, center_frame, position='before'):
    """Compute median image for before/after field change windows"""
    if position == 'before':
        start = max(0, center_frame - window_size - buffer_frame)
        end = center_frame - buffer_frame
    else:  # 'after'
        start = center_frame + 1 + buffer_frame
        end = min(center_frame + window_size + 1 + buffer_frame, len(image_stack))
    
    print(f"Computing {position} median: frames {start} to {end-1}")
    return np.median(image_stack[start:end], axis=0)

def match_pinpoints(points_bef, points_aft, max_distance, scenario, max_angle=15):
    """Match points with strict angular/directional constraints for high density"""
    if not points_aft or not points_bef:
        return [], points_bef, points_aft

    # Convert to numpy arrays for efficient operations
    bef_array = np.array(points_bef)
    aft_array = np.array(points_aft)
    
    # Build spatial trees for fast queries
    bef_tree = KDTree(bef_array)
    aft_tree = KDTree(aft_array)
    
    # Generate candidate matches with directional validation
    candidate_matches = []
    for j, p_aft in enumerate(aft_array):
        # Find all points within max distance
        candidates_idx = bef_tree.query_ball_point(p_aft, max_distance)
        
        for i in candidates_idx:
            p_bef = bef_array[i]
            dx = p_aft[0] - p_bef[0]
            dy = p_aft[1] - p_bef[1]
            distance = math.sqrt(dx**2 + dy**2)
            
            # Skip near-zero displacements
            if distance < 0.1:
                continue
                
            # Calculate angular deviation from vertical streamwise direction (0° = vertical)
            angle = math.degrees(math.acos(abs(dy) / distance))
            
            # Scenario-specific directional validation
            if scenario == 'original':
                valid_direction = (dy < 0)  # y decreases (bef below aft)
            else:  # opposite
                valid_direction = (dy > 0)  # y increases (bef above aft)
                
            # Include candidate if constraints met
            if angle <= max_angle and valid_direction:
                candidate_matches.append((i, j, angle, distance))
    
    # Group candidates by after point and select best (smallest angle)
    best_candidates = {}
    for i, j, angle, dist in candidate_matches:
        if j not in best_candidates or angle < best_candidates[j][2]:
            best_candidates[j] = (i, j, angle, dist)
    
    # Convert to list and sort by quality (angle then distance)
    sorted_candidates = sorted(best_candidates.values(), key=lambda x: (x[2], x[3]))
    
    # Final matching with one-to-one constraint
    matches = []
    matched_bef = set()
    matched_aft = set()
    
    for candidate in sorted_candidates:
        i, j, angle, dist = candidate
        if i not in matched_bef and j not in matched_aft:
            matches.append((points_bef[i], points_aft[j]))
            matched_bef.add(i)
            matched_aft.add(j)
    
    # Identify unmatched points
    unmatched_bef = [p for i, p in enumerate(points_bef) if i not in matched_bef]
    unmatched_aft = [p for i, p in enumerate(points_aft) if i not in matched_aft]
    
    return matches, unmatched_bef, unmatched_aft

# ===== MAIN PROCESSING LOOP =====
for idx in range(len(tiff_path_list)):
    tiff_path = tiff_path_list[idx]
    tiff_filename = os.path.basename(tiff_path)
    
    print(f"\n{'#'*50}")
    print(f"Processing: {tiff_filename}")
    print(f"{'#'*50}")
    
    # Load TIFF stack
    with tifffile.TiffFile(tiff_path) as tif:
        image_stack = tif.asarray()
        print(f"Loaded stack: {image_stack.shape} (frames, height, width)")
    
    # Initialize per-file collections
    file_matches_original = []
    file_matches_opposite = []
    
    # ===== PROCESS ORIGINAL SCENARIO =====
    print("\nProcessing ORIGINAL scenario (down before, up after)")
    for cf in frame_list_original[idx]:
        print(f"\n{'='*40}")
        print(f"Processing frame {cf} (original scenario)")
        print(f"{'='*40}")
        
        # Compute medians
        median_bef = compute_window_median(image_stack, cf, 'before')
        median_aft = compute_window_median(image_stack, cf, 'after')
        
        # Detect pinpoints - original scenario
        pinpoints_bef = detect_pinpoints(median_bef, params, direction=0)  # Down before (direction=0)
        pinpoints_aft = detect_pinpoints(median_aft, params, direction=1)  # Up after (direction=1)
        
        # Match pinpoints with enhanced constraints
        matched_pairs, _, _ = match_pinpoints(
            pinpoints_bef, pinpoints_aft, max_shift_distance, 
            scenario='original', max_angle=max_angle
        )
        
        # Store results
        for pair in matched_pairs:
            p_bef, p_aft = pair
            point_id = str(uuid.uuid4())[:8]
            match_data = {
                'id': point_id,
                'file': tiff_filename,
                'frame': cf,
                'x_before': p_bef[0],
                'y_before': p_bef[1],
                'x_after': p_aft[0],
                'y_after': p_aft[1],
                'status': 'matched',
                'scenario': 'original',
                'displacement': math.hypot(p_aft[0]-p_bef[0], p_aft[1]-p_bef[1])
            }
            file_matches_original.append(match_data)
            all_matches_original.append(match_data)
        
        print(f"Persistent molecules: {len(matched_pairs)}")
    
    # ===== PROCESS OPPOSITE SCENARIO =====
    print("\nProcessing OPPOSITE scenario (up before, down after)")
    for cf in frame_list_opposite[idx]:
        print(f"\n{'='*40}")
        print(f"Processing frame {cf} (opposite scenario)")
        print(f"{'='*40}")
        
        # Compute medians
        median_bef = compute_window_median(image_stack, cf, 'before')
        median_aft = compute_window_median(image_stack, cf, 'after')
        
        # Detect pinpoints - opposite scenario
        pinpoints_bef = detect_pinpoints(median_bef, params, direction=1)  # Up before (direction=1)
        pinpoints_aft = detect_pinpoints(median_aft, params, direction=0)  # Down after (direction=0)
        
        # Match pinpoints with enhanced constraints
        matched_pairs, _, _ = match_pinpoints(
            pinpoints_bef, pinpoints_aft, max_shift_distance, 
            scenario='opposite', max_angle=max_angle
        )
        
        # Store results
        for pair in matched_pairs:
            p_bef, p_aft = pair
            point_id = str(uuid.uuid4())[:8]
            match_data = {
                'id': point_id,
                'file': tiff_filename,
                'frame': cf,
                'x_before': p_bef[0],
                'y_before': p_bef[1],
                'x_after': p_aft[0],
                'y_after': p_aft[1],
                'status': 'matched',
                'scenario': 'opposite',
                'displacement': math.hypot(p_aft[0]-p_bef[0], p_aft[1]-p_bef[1])
            }
            file_matches_opposite.append(match_data)
            all_matches_opposite.append(match_data)
        
        print(f"Persistent molecules: {len(matched_pairs)}")
    
    # ===== SAVE PER-FILE RESULTS =====
    if file_matches_original:
        output_file = f'matched_pairs_ORIGINAL_{os.path.splitext(tiff_filename)[0]}.csv'
        with open(output_file, 'w', newline='') as csvfile:
            fieldnames = ['id', 'file', 'frame', 'x_before', 'y_before', 
                          'x_after', 'y_after', 'status', 'scenario', 'displacement']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(file_matches_original)
            print(f"Saved ORIGINAL scenario matches to {output_file}")
    
    if file_matches_opposite:
        output_file = f'matched_pairs_OPPOSITE_{os.path.splitext(tiff_filename)[0]}.csv'
        with open(output_file, 'w', newline='') as csvfile:
            fieldnames = ['id', 'file', 'frame', 'x_before', 'y_before', 
                          'x_after', 'y_after', 'status', 'scenario', 'displacement']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(file_matches_opposite)
            print(f"Saved OPPOSITE scenario matches to {output_file}")

# ===== FINAL OUTPUT =====
if all_matches_original:
    csv_path = 'all_matched_pairs_ORIGINAL.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        fieldnames = ['id', 'file', 'frame', 'x_before', 'y_before', 
                      'x_after', 'y_after', 'status', 'scenario', 'displacement']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(all_matches_original)
    print(f"\nSaved all ORIGINAL scenario matches to: {csv_path}")

if all_matches_opposite:
    csv_path = 'all_matched_pairs_OPPOSITE.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        fieldnames = ['id', 'file', 'frame', 'x_before', 'y_before', 
                      'x_after', 'y_after', 'status', 'scenario', 'displacement']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(all_matches_opposite)
    print(f"\nSaved all OPPOSITE scenario matches to: {csv_path}")

print("\nProcessing complete!")