In [1]:
import cv2
import numpy as np
import pandas as pd
from pandas import ExcelWriter
import os
import json
import plotly.graph_objs as go
import math

# Function to get integer input with error handling
def get_int_input(prompt):
    while True:
        try:
            input_value = input(prompt).strip()
            return int(input_value) if input_value else None
        except ValueError:
            print("Please enter a valid integer.")

# Function to get float input with error handling
def get_float_input(prompt):
    while True:
        try:
            return float(input(prompt))
        except ValueError:
            print("Please enter a valid float.")

# Class to handle segment processing and analysis
class Segments:
    def __init__(self, segment_boxes, base_values):
        # Initialize the class with segment boxes and base values
        self.flags = []  # List to store indices of flagged segments
        self.segments = segment_boxes  # Store the bounding boxes for each segment
        self.base_values = base_values  # Store base values associated with each segment
        self.baseline_counts = []  # List to store baseline counts for background subtraction

    # Function to analyze a given frame, detect segments, and flag the highest segment
    def digest(self, number, segment_threshold, debug_file, frame_number, binary_threshold, write_to_debug=True, calculate_baseline=False):
        self.flags = []  # Reset flags for each new frame
        h, w = number.shape[:2]  # Get image dimensions (height and width)
        all_segments = []  # List to store all segment details
        segment_counts = []  # List to store non-zero pixel counts within each segment

        if write_to_debug:
            debug_file.write(f"\nFrame Number: {frame_number}\n")
            debug_file.write(f"Binary Threshold: {binary_threshold}\n")

        for a in range(len(self.segments)):  # Loop through each segment
            seg = self.segments[a]  # Get the current segment
            rect = np.array([[seg[0], seg[4]], [seg[1], seg[5]], [seg[2], seg[6]], [seg[3], seg[7]]])  # Define the rectangle coordinates
            rect[:, 0] = (rect[:, 0] * w).astype(int)  # Scale x-coordinates to image size
            rect[:, 1] = (rect[:, 1] * h).astype(int)  # Scale y-coordinates to image size
            mask = np.zeros_like(number)  # Create a mask of the same size as the image
            cv2.fillPoly(mask, [rect.astype(int)], 255)  # Fill the mask with the segment shape
            count = np.count_nonzero(mask & number)  # Count non-zero pixels in the segment
            segment_counts.append((count, a))  # Append count and segment index to the list
            all_segments.append((rect, a, count, np.count_nonzero(mask)))  # Store segment details

        if write_to_debug:
            debug_file.write("Raw Counts:\n")
            for rect, a, count, area in all_segments:
                debug_file.write(f"Segment {a}: Count = {count}\n")

        if calculate_baseline:  # If calculating baseline counts
            self.baseline_counts = segment_counts.copy()  # Save the current counts as baseline

        if segment_counts:  # If there are segments
            highest_segment = max(segment_counts, key=lambda x: x[0])  # Find the segment with the highest count
            self.flags.append(highest_segment[1])  # Flag the highest segment

        # Loop through all segments and draw them on the image, highlighting flagged segments
        for rect, a, count, area in all_segments:
            color = (0, 0, 255) if a in self.flags else (255, 255, 255)  # Set color based on whether the segment is flagged
            cv2.polylines(number, [rect.astype(int)], isClosed=True, color=color, thickness=2)  # Draw the segment on the image
            value = self.base_values.get(a, 0)  # Get the base value for the segment
            if a % 5 == 0:  # Display the value for every 5th segment
                cv2.putText(number, f"{value:.2f}", (int(rect[0][0] + 50), int(rect[0][1])), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)  # Draw the value
            if write_to_debug:
                debug_file.write(f"Segment {a}: Coordinates = {rect.tolist()}, Count = {count}, Area = {area}, Ratio = {count / area:.2f}\n")

        if write_to_debug:  # Write final segment details to debug file
            final_segment = self.flags[0] if self.flags else None
            detected_number = self.get_num()  # Get the detected number
            debug_file.write(f"Flagged Segment: {final_segment}, Detected Number: {detected_number}\n")

    # Function to get the detected number based on flagged segment
    def get_num(self):
        value = 0
        if self.flags:
            value = self.base_values.get(self.flags[0], 0)
        return value

# Function to remove red color from a frame
def remove_red_color(frame):
    hsv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)  # Convert the frame from BGR to HSV color space

    # Define the HSV range for purple to orange
    # Part 1: Purple to red (130 to 180)
    lower_purple_to_red = np.array([130, 50, 50])  # Lower bound for purple to red
    upper_purple_to_red = np.array([180, 255, 255])  # Upper bound for purple to red

    # Part 2: Red to orange (0 to 25)
    lower_red_to_orange = np.array([0, 50, 50])  # Lower bound for red to orange
    upper_red_to_orange = np.array([25, 255, 255])  # Upper bound for red to orange

    # Create masks for both parts
    mask_purple_to_red = cv2.inRange(hsv_frame, lower_purple_to_red, upper_purple_to_red)
    mask_red_to_orange = cv2.inRange(hsv_frame, lower_red_to_orange, upper_red_to_orange)

    # Combine both masks into a single mask
    mask_combined = mask_purple_to_red | mask_red_to_orange

    # Create a white image the same size as the original frame
    white_frame = np.full_like(frame, 255)

    # Replace the detected colors with white
    frame_without_colors = cv2.bitwise_and(frame, frame, mask=cv2.bitwise_not(mask_combined))
    frame_with_white = cv2.add(frame_without_colors, cv2.bitwise_and(white_frame, white_frame, mask=mask_combined))

    return frame_with_white

# Function to resize an image while maintaining aspect ratio
def resize_image(image, width=None, height=None):
    (h, w) = image.shape[:2]  # Get image dimensions (height and width)
    if width is None and height is None:
        return image  # Return original image if no dimensions are provided
    if width is None:
        ratio = height / float(h)  # Calculate the ratio for the given height
        dim = (int(w * ratio), height)  # Calculate new dimensions
    else:
        ratio = width / float(w)  # Calculate the ratio for the given width
        dim = (width, int(h * ratio))  # Calculate new dimensions
    return cv2.resize(image, dim, interpolation=cv2.INTER_AREA)  # Resize the image with calculated dimensions

# Function to handle mouse clicks on an image
def get_mouse_click(image, window_name):
    points = []

    def click_event(event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            points.append((x, y))
            cv2.circle(image, (x, y), 3, (255, 0, 0), -1)
            cv2.imshow(window_name, image)
            cv2.destroyAllWindows()  # Close the window after the first click

    while True:
        resized_image = resize_image(image, width=window_size)
        resize_ratio_x = image.shape[1] / resized_image.shape[1]
        resize_ratio_y = image.shape[0] / resized_image.shape[0]

        cv2.imshow(window_name, resized_image)
        cv2.setMouseCallback(window_name, click_event)
        cv2.waitKey(0)

        if points:
            # Convert the coordinates back to the original image size
            original_point = (
                int(points[0][0] * resize_ratio_x),
                int(points[0][1] * resize_ratio_y)
            )
            return [original_point]  # Return a list containing the tuple

        retry = input("No point selected. Would you like to select the point again? (y/n): ").strip().lower()
        if retry != 'y':
            cv2.destroyAllWindows()
            return None

# Function to select the center point of a dial
def get_center_point(image, dial_name):
    window_name = f'Select Center - Dial {dial_name}'
    print(f"Click to select the center of the dial {dial_name}.")
    while True:
        resized_image = resize_image(image, width=window_size)
        points = get_mouse_click(resized_image, window_name)
        if points:
            center = (int(points[0][0] * image.shape[1] / resized_image.shape[1]), int(points[0][1] * image.shape[0] / resized_image.shape[0]))
            print(f"Center selected: {center}")
            return center

        retry = input("No point selected. Would you like to select the point again? (y/n): ").strip().lower()
        if retry != 'y':
            return None

# Function to select the radius of a dial
def get_radius(image, center, dial_name):
    window_name = f'Select Radius - Dial {dial_name}'
    print(f"Click to select a point on the edge of the dial {dial_name} to define the radius.")

    while True:
        # Get the edge point from the user click
        edge_point = get_mouse_click(image.copy(), window_name)
        if edge_point:
            # Extract the actual point from the returned list
            edge_point = edge_point[0]  # Unwrap the first point from the list
            
            # Calculate the radius based on the center and edge point
            radius = int(np.sqrt((edge_point[0] - center[0])**2 + (edge_point[1] - center[1])**2))
            print(f"Radius selected: {radius}")
            return radius

        retry = input("No point selected. Would you like to select the point again? (y/n): ").strip().lower()
        if retry != 'y':
            return None

# Function to confirm the selected circle
def confirm_circle(image, center, radius, dial_name):
    while True:
        preview_image = image.copy()
        cv2.circle(preview_image, center, 5, (255, 0, 0), -1)
        cv2.circle(preview_image, center, radius, (255, 0, 0), 2)
        resized_preview_image = resize_image(preview_image, width=window_size)
        cv2.imshow(f'Confirm Circle - Dial {dial_name}', resized_preview_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

        if input(f"Is the circle correct for dial {dial_name}? (y/n): ").strip().lower() != 'n':
            return True
        else:
            return False

# Function to get angle from a click on an image
def get_angle_from_click(image, center, window_name="Select Point"):
    # Display the image and get the click point from the user
    print("Click on the point on the dial to select an angle.")
    
    # Resize the image for display
    resized_image = resize_image(image, width=window_size)
    scale_x = image.shape[1] / resized_image.shape[1]
    scale_y = image.shape[0] / resized_image.shape[0]

    # Get the clicked point
    points = get_mouse_click(resized_image, window_name)
    
    if not points:
        print("No point selected. Please try again.")
        return None
    
    # Scale the point back to the original image size
    point = (int(points[0][0] * scale_x), int(points[0][1] * scale_y))
    
    # Calculate the angle based on the center and the point
    angle = np.degrees(np.arctan2(center[1] - point[1], point[0] - center[0]))
    angle = (angle + 360) % 360
    print(f"Point selected: {point}, Angle: {angle:.2f}")
    return angle

# Function to rotate a box around a center point
def rotate_box(center, angle, box_size, image_shape):
    angle_rad = np.deg2rad(angle)
    rect = np.array([
        [-box_size // 2, -box_size // 2],
        [box_size // 2, -box_size // 2],
        [box_size // 2, box_size // 2],
        [-box_size // 2, box_size // 2]
    ])
    rotation_matrix = np.array([
        [np.cos(angle_rad), -np.sin(angle_rad)],
        [np.sin(angle_rad), np.cos(angle_rad)]
    ])
    rotated_rect = np.dot(rect, rotation_matrix) + center
    rotated_rect[:, 0] = np.clip(rotated_rect[:, 0], 0, image_shape[1] - 1)
    rotated_rect[:, 1] = np.clip(rotated_rect[:, 1], 0, image_shape[0] - 1)
    return rotated_rect

# Function to generate segment boxes
def get_segment_boxes(image, center, radius_x, radius_y, origin_angle, max_angle, value_at_origin, value_at_max, clockwise=True):
    # Normalize angles for the calculation
    origin_angle = origin_angle % 360
    max_angle = max_angle % 360
    
    # Determine angle range based on direction (clockwise or counterclockwise)
    if clockwise:
        if max_angle <= origin_angle:
            max_angle += 360
        angles = np.arange(origin_angle, max_angle, step_angle)
    else:
        if origin_angle <= max_angle:
            origin_angle += 360
        angles = np.arange(origin_angle, max_angle, -step_angle)

    segment_boxes = []
    box_size = int((min(radius_x, radius_y) * 2 * np.pi) / (360 / step_angle))

    # Iterate through each angle to create the segment boxes
    for angle in angles:
        angle_rad = np.deg2rad(angle % 360)
        x = center[0] + int(radius_x * np.cos(angle_rad))
        y = center[1] - int(radius_y * np.sin(angle_rad))
        rotated_rect = rotate_box([x, y], angle + 90, box_size, image.shape)
        segment_boxes.append([
            rotated_rect[0, 0] / image.shape[1], rotated_rect[1, 0] / image.shape[1], 
            rotated_rect[2, 0] / image.shape[1], rotated_rect[3, 0] / image.shape[1],
            rotated_rect[0, 1] / image.shape[0], rotated_rect[1, 1] / image.shape[0], 
            rotated_rect[2, 1] / image.shape[0], rotated_rect[3, 1] / image.shape[0]
        ])
        cv2.polylines(image, [rotated_rect.astype(int)], isClosed=True, color=(255, 255, 255), thickness=3)

    # Mapping values to segments based on angles
    base_values = map_values_to_segments(segment_boxes, origin_angle, max_angle, value_at_origin, value_at_max, clockwise)
    
    # Annotate the image with segment values
    for idx, segment in enumerate(segment_boxes):
        if idx % 5 == 0:
            rect = np.array([
                [segment[0] * image.shape[1], segment[4] * image.shape[0]],
                [segment[1] * image.shape[1], segment[5] * image.shape[0]],
                [segment[2] * image.shape[1], segment[6] * image.shape[0]],
                [segment[3] * image.shape[1], segment[7] * image.shape[0]],
            ]).astype(int)
            value = base_values.get(idx, 'N/A')
            center_x = int((rect[0, 0] + rect[2, 0]) / 2)
            center_y = int((rect[0, 1] + rect[2, 1]) / 2)
            cv2.putText(image, f"{value:.2f}", (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)

    resized_image = resize_image(image, width=window_size)
    cv2.imshow('Segment Boxes', resized_image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    return segment_boxes

# Function to map values to segments
def map_values_to_segments(segment_boxes, origin_angle, max_angle, value_at_origin, value_at_max, clockwise=True):
    base_values = {}
    num_segments = len(segment_boxes)
    
    # Calculate total angle span
    if clockwise:
        angle_span = (max_angle - origin_angle) % 360
    else:
        angle_span = (origin_angle - max_angle) % 360
    
    value_step = (value_at_max - value_at_origin) / (num_segments - 1) if num_segments > 1 else 0

    current_value = value_at_origin
    for j in range(num_segments):
        base_values[j] = current_value
        current_value += value_step

    return base_values

# Function to save configuration to a JSON file
def save_config(segment_info_dict, config_path):
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    dict_to_save = {}
    for dial_letter, segment_info in segment_info_dict.items():
        (center, radius, radius_x, radius_y, origin_angle, max_angle, clockwise, segment_boxes, base_values) = segment_info
        dict_to_save[dial_letter] = {
            'center': center,
            'radius': radius,
            'radius_x': radius_x,
            'radius_y': radius_y,
            'origin_angle': origin_angle,
            'max_angle': max_angle,
            'clockwise': clockwise,
            'segment_boxes': segment_boxes,
            'base_values': base_values,
        }

    with open(config_path, 'w') as config_file:
        json.dump(dict_to_save, config_file, indent=4)
    print(f"Configuration saved to {config_path}")

# Function to load configuration from a JSON file
def load_config(config_path, video_paths):
    config_path = normalize_path(config_path)
    try:
        with open(config_path, 'r') as config_file:
            segment_info_dict = json.load(config_file)
        print("Loaded configuration data successfully.")
    except Exception as e:
        print(f"Failed to load or parse the configuration file: {e}")
        return None

    if not video_paths:
        print("No video paths provided.")
        return None

    video_path = normalize_path(video_paths[0])
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Failed to open video file {video_path}")
        return None

    ret, frame = cap.read()
    cap.release()

    if not ret:
        print("Failed to read the first frame from the video.")
        return None

    for dial_letter, segment_info in segment_info_dict.items():
        print(f"Processing dial {dial_letter} with segment info: {segment_info}")
        if isinstance(segment_info, dict):
            try:
                center = tuple(int(x) for x in segment_info['center'])
                radius = int(segment_info['radius'])
                radius_x = int(segment_info['radius_x'])
                radius_y = int(segment_info['radius_y'])
                origin_angle = int(segment_info['origin_angle'])
                max_angle = int(segment_info['max_angle'])
                clockwise = segment_info['clockwise']
                segment_boxes = [list(map(float, box)) for box in segment_info['segment_boxes']]

                base_values = {int(k): round(float(v), 2) for k, v in segment_info['base_values'].items()}

                segment_info_dict[dial_letter] = (center, radius, radius_x, radius_y, origin_angle, max_angle, clockwise, segment_boxes, base_values)

                for idx, segment in enumerate(segment_boxes):
                    rect = np.array([
                        [segment[0] * frame.shape[1], segment[4] * frame.shape[0]],
                        [segment[1] * frame.shape[1], segment[5] * frame.shape[0]],
                        [segment[2] * frame.shape[1], segment[6] * frame.shape[0]],
                        [segment[3] * frame.shape[1], segment[7] * frame.shape[0]],
                    ]).astype(int)
                    cv2.polylines(frame, [rect], isClosed=True, color=(255, 255, 255), thickness=2)
                    if idx % 5 == 0:
                        value = base_values.get(idx, 'N/A')
                        cv2.putText(frame, f"{value:.2f}", (rect[0, 0] + 10, rect[0, 1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

                resized_frame = resize_image(frame, width=window_size)
                cv2.imshow(f'Segment Preview for Dial {dial_letter}', resized_frame)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

            except KeyError as e:
                print(f"Missing key in segment info for {dial_letter}: {e}")
            except TypeError as e:
                print(f"Type error in segment info for {dial_letter}: {e}")
        else:
            print(f"Error: Expected segment info for {dial_letter} to be a dictionary but got {type(segment_info)}")

    if input("Is the loaded config data correct? (y/n): ").strip().lower() != 'y':
        print("Exiting due to incorrect config data.")
        return None

    return segment_info_dict

# Function to normalize file paths to ensure compatibility across different operating systems
def normalize_path(path):
    return os.path.normpath(path.strip('"'))

def get_replicate_info(dial_letters):
    replicate_info = {}
    has_replicates = input("Are there any replicates? (y/n): ").strip().lower() == 'y'

    if not has_replicates:
        return replicate_info

    num_groups = get_int_input("How many groups of replicates are there? ")

    for group_num in range(1, num_groups + 1):
        num_replicates = get_int_input(f"How many replicates are in group {group_num}? ")
        while True:
            dials = input(f"Enter the dials for group {group_num} separated by spaces (e.g., A B C): ").strip().split()
            dials = [dial.strip().upper() for dial in dials]

            if all(dial in dial_letters for dial in dials) and len(dials) == num_replicates:
                replicate_info[f'Group {group_num}'] = dials
                break
            else:
                print("Invalid input. Please ensure all dials are listed and the number matches the replicate count.")

    return replicate_info

def process_multiple_videos(video_paths, output_dir_base, frame_skip, segment_threshold, frame_start, frame_end):
    output_dir_base = normalize_path(output_dir_base)  # Normalize the base output directory path

    combined_data = []  # List to store combined data from all videos

    segment_info_dict = {}

    # Ask the user how many dials are being used
    num_dials = get_int_input("How many dials are being used? (1, 2, 3, ...): ")
    dial_letters = [chr(65 + i) for i in range(num_dials)]  # Generate dial letters ['A', 'B', 'C', ...]

    # Ask the user for the configuration file path
    config_path = input("Enter the path of the config file (leave blank if not using a previous configuration): ").strip()
    use_existing_config = bool(config_path)

    if use_existing_config:
        segment_info_dict = load_config(config_path, video_paths)
        if segment_info_dict is None:
            return
    else:
        # Manually set up the dials if not using an existing configuration
        cap = cv2.VideoCapture(video_paths[0])
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start or 0)
        ret, unprocessed_frame = cap.read()
        cap.release()

        if not ret:
            print("Failed to read the first frame for segment configuration.")
            return

        # Use the unprocessed first frame for configuration
        for dial_letter in dial_letters:
            print(f"\nSetting up for Dial {dial_letter}")

            while True:
                # Select the center of the dial
                center = get_center_point(unprocessed_frame.copy(), dial_letter)
                if center is None:
                    print("Failed to select center. Exiting.")
                    return
                # Select the radius of the dial
                radius = get_radius(unprocessed_frame.copy(), center, dial_letter)
                if radius is None:
                    print("Failed to select radius. Exiting.")
                    return
                if confirm_circle(unprocessed_frame.copy(), center, radius, dial_letter):
                    break

            # Select the origin and maximum points on the dial
            print(f"Click on the point representing the minimum value for dial {dial_letter}.")
            origin_angle = get_angle_from_click(unprocessed_frame.copy(), center, window_name=f"Select Minimum Point for Dial {dial_letter}")
            if origin_angle is None:
                print("Failed to select minimum point. Exiting.")
                return
            print(f"Selected origin angle for dial {dial_letter}: {origin_angle:.2f} degrees")  # Debugging statement
            value_at_origin = get_float_input("Enter the value at the origin: ")

            print(f"Click on the point representing the maximum value for dial {dial_letter}.")
            max_angle = get_angle_from_click(unprocessed_frame.copy(), center, window_name=f"Select Maximum Point for Dial {dial_letter}")
            if max_angle is None:
                print("Failed to select maximum value point. Exiting.")
                return
            print(f"Selected maximum angle for dial {dial_letter}: {max_angle:.2f} degrees")  # Debugging statement
            value_at_max = get_float_input("Enter the value at the maximum position: ")

            # Assume clockwise direction
            clockwise = True

            # Generate segment boxes and map values to each segment
            segment_boxes = get_segment_boxes(unprocessed_frame.copy(), center, radius, radius, origin_angle, max_angle, value_at_origin, value_at_max, clockwise)
            base_values = map_values_to_segments(segment_boxes, origin_angle, max_angle, value_at_origin, value_at_max, clockwise)

            # Store the segment information in the dictionary
            segment_info_dict[dial_letter] = (center, radius, radius, radius, origin_angle, max_angle, clockwise, segment_boxes, base_values)

        # Save the segment configuration to a JSON file
        save_config(segment_info_dict, os.path.join(output_dir_base, "dial_segment_config.json"))

    # Ask for replicate information
    replicate_info = get_replicate_info(dial_letters)

    # Process each video and save the results
    combined_time = 1  # Initialize the combined time for the entire dataset
    for idx, video_path in enumerate(sorted(video_paths)):  # Ensure files are processed in alphanumeric order
        video_path = normalize_path(video_path)
        base_filename = os.path.splitext(os.path.basename(video_path))[0]
        output_excel = os.path.join(output_dir_base, f"output_data_{base_filename}.xlsx")
        output_dir = os.path.join(output_dir_base, f"processed_frames_{base_filename}")

        # Process the video and get the data
        video_data = process_video(
            video_path,
            output_excel,
            output_dir,
            frame_skip,
            segment_threshold,
            frame_start,
            frame_end,
            segment_info_dict,
            idx,
            len(video_paths)
        )

        # Combine data from all dials for the same frame
        frame_data = {}
        for dial_letter in dial_letters:
            for entry in video_data[dial_letter]:
                frame_number = entry[0]
                if frame_number not in frame_data:
                    frame_data[frame_number] = {
                        "individual_time_s": frame_number,
                        "video_file": base_filename,
                    }
                # Add the data for the specific dial to the current frame
                frame_data[frame_number][f"dial_{dial_letter}_primary"] = entry[1]

        # Append the combined frame data to the list
        combined_data.extend(frame_data.values())

    # Convert to DataFrame
    combined_df = pd.DataFrame(combined_data)

    # Calculate mean and standard deviation for each group of replicates
    for group_name, dials in replicate_info.items():
        dial_columns = [f'dial_{dial}_primary' for dial in dials]
        combined_df[f'{group_name}_mean'] = combined_df[dial_columns].mean(axis=1)
        combined_df[f'{group_name}_std'] = combined_df[dial_columns].std(axis=1)

    # Sort combined data by video file name and then by individual time
    combined_df = combined_df.sort_values(by=['video_file', 'individual_time_s']).reset_index(drop=True)

    # Add combined_time_s, combined_time_m, and combined_time_h
    combined_df['combined_time_s'] = range(1, 1 + frame_skip * len(combined_df), frame_skip)
    combined_df['combined_time_m'] = combined_df['combined_time_s'] / 60
    combined_df['combined_time_h'] = combined_df['combined_time_m'] / 60

    # Reorder columns for better readability
    columns_order = ['combined_time_s', 'combined_time_m', 'combined_time_h', 'individual_time_s', 'video_file']
    for dial_letter in dial_letters:
        columns_order.extend([f'dial_{dial_letter}_primary'])
    for group_name in replicate_info.keys():
        columns_order.extend([f'{group_name}_mean', f'{group_name}_std'])
    combined_df = combined_df[columns_order]

    # Save combined data to Excel
    combined_output_path = os.path.join(output_dir_base, f"{os.path.basename(output_dir_base)}_combined_output.xlsx")
    with ExcelWriter(combined_output_path) as writer:
        combined_df.to_excel(writer, sheet_name='Combined Data', index=False)
        print(f"Combined data saved to {combined_output_path}")

    # Plot combined data for all dials individually
    plot_combined_data(combined_df, output_dir_base)

    # Plot the mean data for each replicate with 1 SD error bars
    plot_replicate_means_with_error(combined_df, replicate_info, output_dir_base)


def process_video(video_path, output_excel, output_dir, frame_skip, segment_threshold, frame_start, frame_end, segment_info_dict, idx=0, total_files=1):
    video_path = normalize_path(video_path.strip('"'))  # Normalize the video path to ensure compatibility across different operating systems
    output_excel = normalize_path(output_excel.strip('"'))  # Normalize the output Excel file path
    output_dir = normalize_path(output_dir.strip('"'))  # Normalize the output directory path

    cap = cv2.VideoCapture(video_path)  # Open the video file using OpenCV
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # Get the total number of frames in the video
    video_data = {dial_letter: [] for dial_letter in segment_info_dict.keys()}  # Initialize a dictionary to store data for each dial
    summary_data = []  # List to store summary data

    if not os.path.exists(output_dir):  # Check if the output directory exists
        os.makedirs(output_dir)  # Create the output directory if it doesn't exist

    # Define the debug file path for writing debug information
    debug_file_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(video_path))[0]}_debug.txt")
    
    with open(debug_file_path, 'w') as debug_file:  # Open the debug file for writing
        if frame_start is None:
            frame_start = 0  # If no start frame is specified, start from the first frame
        if frame_end is None:
            frame_end = frame_count  # If no end frame is specified, process until the last frame

        frame_start = int(frame_start)
        frame_end = int(frame_end)

        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start)  # Set the starting frame position
        ret, frame = cap.read()  # Read the first frame
        if not ret:  # If the frame cannot be read, exit the function
            print("Failed to read the first frame.")
            cap.release()
            return None, []
        
        # Process frame 0
        if frame_start == 0:
            processed_frame = remove_red_color(frame)  # Remove purple to orange colors
            # Save the processed frame before any binary threshold
            pre_threshold_image_path = os.path.join(output_dir, "frame_0000_pre_threshold.png")
            cv2.imwrite(pre_threshold_image_path, processed_frame)  # Save the image
            print(f"Frame 0 after color removal saved to {pre_threshold_image_path}")


        total_frames_to_process = (frame_end - frame_start) // frame_skip + 1  # Calculate the total number of frames to process

        starting_threshold = 255  # Initialize the starting threshold for the first frame

        # Loop through frames from start to end with specified skip
        for i in range(frame_start, frame_end, frame_skip):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)  # Set the frame position to the current frame
            ret, frame = cap.read()  # Read the current frame

            if not ret:  # If the frame cannot be read, skip to the next iteration
                print(f"Failed to read frame {i}.")
                continue

            frame = remove_red_color(frame)  # Remove red color from the frame before processing
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)  # Convert the frame to grayscale

            # Process each dial
            for dial_letter, segment_info in segment_info_dict.items():
                center, radius, radius_x, radius_y, origin_angle, max_angle, clockwise, segment_boxes, base_values = segment_info
                model_primary = Segments(segment_boxes, base_values)  # Initialize the Segments model

                current_threshold = starting_threshold  # Start with the last threshold from the previous frame
                ratio_on_off = 0  # Initialize the on/off pixel ratio
                previous_threshold = None  # Initialize the previous threshold

                best_threshold = None  # Initialize the best threshold
                closest_ratio = None  # Initialize the closest ratio

                recent_thresholds = []  # List to store recent thresholds for oscillation detection

                # Threshold adjustment loop to find the best threshold for segment detection
                while True:
                    try:
                        # Apply binary thresholding to the grayscale frame
                        _, binary = cv2.threshold(gray, current_threshold, 255, cv2.THRESH_BINARY_INV)
                        mask = np.zeros_like(gray)  # Create a mask of the same size as the grayscale image
                        cv2.circle(mask, center, radius, 255, -1)  # Draw a circle on the mask

                        # Count pixels that are 'on' and 'off' within the mask
                        on_pixels = np.sum(binary[mask == 255] == 255)
                        off_pixels = np.sum(binary[mask == 255] == 0)
                        ratio_on_off = on_pixels / off_pixels if off_pixels > 0 else np.nan  # Calculate the on/off ratio

                    except Exception as e:  # Handle exceptions
                        print(f"Error calculating ratio: {e}")
                        ratio_on_off = np.nan

                    # Update the best threshold if this ratio is valid and closer to the target range
                    if not np.isnan(ratio_on_off):
                        if closest_ratio is None or abs(ratio_on_off - 0.015) < abs(closest_ratio - 0.015):
                            best_threshold = current_threshold
                            closest_ratio = ratio_on_off

                        recent_thresholds.append(current_threshold)  # Add the current threshold to the recent thresholds list
                        if len(recent_thresholds) > 5:
                            recent_thresholds.pop(0)  # Keep the last 5 thresholds for oscillation detection

                        # Check for threshold oscillation and pick the best threshold if detected
                        if len(recent_thresholds) > 3:
                            last_three_thresholds = recent_thresholds[-3:]
                            if len(set(last_three_thresholds)) == 1:
                                print(f"Oscillation detected. Picking the best threshold: {best_threshold} with closest ratio: {closest_ratio}")
                                current_threshold = best_threshold
                                break

                    # Adjust the threshold based on the ratio
                    if np.isnan(ratio_on_off) or ratio_on_off > 0.13:
                        current_threshold = math.floor(current_threshold * 0.995)  # Decrease threshold slightly
                    elif ratio_on_off < 0.1:
                        current_threshold = math.ceil(current_threshold * 1.01)  # Increase threshold slightly
                    else:
                        break  # Exit the loop if the ratio is within an acceptable range

                    previous_threshold = current_threshold  # Update the previous threshold

                    # If no valid ratio was found, stop adjusting and use the last valid best_threshold
                    if best_threshold is None:
                        best_threshold = starting_threshold  # Fallback to the last known good threshold or initial threshold

                # Handle cases where no valid ratio was found
                if closest_ratio is None:
                    closest_ratio = "N/A"  # If no valid ratio was found, set to "N/A"
                    print(f"No valid ratio found, using fallback threshold: {best_threshold}")

                binary_copy = binary.copy()  # Create a copy of the binary image
                # Analyze the binary image to detect segments
                model_primary.digest(binary_copy, segment_threshold, debug_file, i, current_threshold, write_to_debug=True, calculate_baseline=(i == 0))
                number_primary = model_primary.get_num()  # Get the detected number from the model

                # Annotate the binary image with detected information
                text_ratio = f"Ratio: {closest_ratio:.4f}"
                text_threshold = f"Threshold: {current_threshold}"
                font_scale = 3.0
                thickness = 6
                text_color = (255, 255, 255)

                text_position_ratio = (binary_copy.shape[1] - 800, 110)
                text_position_threshold = (binary_copy.shape[1] - 800, 220)

                # Draw text annotations on the binary image
                cv2.putText(binary_copy, text_ratio, text_position_ratio, cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, thickness)
                cv2.putText(binary_copy, text_threshold, text_position_threshold, cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, thickness)
                cv2.putText(binary_copy, f"{number_primary:.2f}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, thickness)

                # Conditionally save the annotated frame as a PNG file based on user input
                if save_images:
                    frame_output_path = os.path.join(output_dir, f"frame_{i:04d}_{dial_letter}.png")
                    cv2.imwrite(frame_output_path, binary_copy)  # Save the annotated frame as a PNG file
                    #print(f"Processed frame saved to {frame_output_path}")

                # Append frame data to video data
                video_data[dial_letter].append([
                    i,  # Frame Number
                    number_primary,  # Detected Number
                    current_threshold,  # Binary Threshold
                    closest_ratio  # On:Off Ratio
                ])

                print(f"Frame {i} | Dial {dial_letter} | Detected Number: {number_primary} (Threshold: {current_threshold})")

                # Store the final threshold used in this iteration to start the next frame
                starting_threshold = current_threshold

            # Calculate the number of processed frames
            frames_processed = (i - frame_start) // frame_skip
            # Calculate the percentage of completion
            percent_complete = (frames_processed / total_frames_to_process) * 100
            print(f"Now processing {os.path.basename(video_path)}, {idx + 1} of {total_files}. Frame {i} of {frame_count}, {percent_complete:.2f}% complete.")

    cap.release()  # Release the video capture object
    cv2.destroyAllWindows()  # Close all OpenCV windows

    # Save results to Excel
    with ExcelWriter(output_excel) as writer:
        for dial_letter in segment_info_dict.keys():
            # Specify only the required columns
            columns = ['Frame Number', 'Detected Number', 'Binary Threshold', 'On:Off Ratio']
            df = pd.DataFrame(video_data[dial_letter], columns=columns)  # Create a DataFrame from video data

            # Write DataFrame to Excel
            df.to_excel(writer, sheet_name=f'Dial {dial_letter}', index=False)
            print(f"Results saved to {output_excel} in sheet 'Dial {dial_letter}'")

            # Plotting all individual dials for the video on one plot
            fig_individual = go.Figure()
            for dial_letter in segment_info_dict.keys():
                fig_individual.add_trace(go.Scatter(
                    x=df['Frame Number'],
                    y=df['Detected Number'],
                    mode='lines+markers',
                    name=f'Dial {dial_letter}'
                ))

            fig_individual.update_layout(
                title=f'Individual Detected Numbers for {os.path.basename(video_path)} - All Dials',
                xaxis_title='Frame Number',
                yaxis_title='Detected Number',
                legend_title='Dials',
                template='plotly_white'
            )

            plot_output_path_individual = os.path.join(output_dir, f"output_plot_individual_{os.path.basename(video_path)}.html")
            fig_individual.write_html(plot_output_path_individual)  # Save plot to HTML file
            print(f"Individual plot saved to {plot_output_path_individual}")

        # Create a summary DataFrame for any additional summary data (if needed)
        summary_df = pd.DataFrame(summary_data)
        summary_df.to_excel(writer, sheet_name='Summary', index=False)  # Write summary data to Excel
        print(f"Summary sheet saved to {output_excel}")

    return video_data

def plot_combined_data(combined_df, output_dir_base):
    """
    Function to plot combined data for all dials individually.
    """
    # Plotting the combined data using Plotly
    fig_combined = go.Figure()

    # Loop over each column that matches the dial data pattern
    for dial_letter in combined_df.columns:
        # Check if the column name starts with 'dial_' and ends with '_primary'
        if dial_letter.startswith('dial_') and dial_letter.endswith('_primary'):
            fig_combined.add_trace(go.Scatter(
                x=combined_df['combined_time_s'],  # X-axis is the combined time in seconds
                y=combined_df[dial_letter],  # Y-axis is the data for this dial
                mode='lines+markers',  # Plot both lines and markers
                name=dial_letter  # Label for this trace
            ))

    # Update the layout of the plot to include titles and labels
    fig_combined.update_layout(
        title='Combined Detected Numbers - All Dials',  # Title of the plot
        xaxis_title='Combined Time (s)',  # X-axis label
        yaxis_title='Detected Number',  # Y-axis label
        legend_title='Legend',  # Legend title
        template='plotly_white'  # Use a clean white template for the plot
    )

    # Show the combined plot
    fig_combined.show()

    # Define the path where the combined plot will be saved as an HTML file
    combined_plot_path = os.path.join(output_dir_base, "combined_plot_all_dials.html")
    # Save the plot to an HTML file
    fig_combined.write_html(combined_plot_path)
    print(f"Combined plot saved to {combined_plot_path}")


def plot_replicate_means_with_error(combined_df, replicate_info, output_dir_base):
    """
    Function to plot replicate means with 1 SD error bars.
    """
    # Create a figure object for the plot
    fig = go.Figure()

    # Loop over each group of replicates
    for group_name in replicate_info.keys():
        fig.add_trace(go.Scatter(
            x=combined_df['combined_time_h'],  # X-axis is the combined time in hours
            y=combined_df[f'{group_name}_mean'],  # Y-axis is the mean for this group
            mode='lines+markers',  # Plot both lines and markers
            name=f'{group_name} Mean',  # Label for this trace
            error_y=dict(
                type='data',  # Error type is data-based
                array=combined_df[f'{group_name}_std'],  # Error bars are the standard deviation for this group
                visible=True  # Show error bars
            )
        ))

    # Update the layout of the plot to include titles and labels
    fig.update_layout(
        title='Combined Detected Numbers - All Dials with Replicate Means and Error Bars',  # Title of the plot
        xaxis_title='Time (Hours)',  # X-axis label
        yaxis_title='Vacuum drawn (inHg)',  # Y-axis label
        legend_title='Legend',  # Legend title
        template='plotly_white'  # Use a clean white template for the plot
    )

    # Show the plot
    fig.show()

    # Define the path where the plot will be saved as an HTML file
    plot_output_path = os.path.join(output_dir_base, "combined_plot_all_dials_with_replicates.html")
    # Save the plot to an HTML file
    fig.write_html(plot_output_path)
    print(f"Combined plot saved to {plot_output_path}")

# Inputs
video_paths = input("Enter video paths separated by spaces: ").split()
video_paths = [os.path.normpath(path.replace('"', '')) for path in video_paths]
output_dir_base = os.path.normpath(input("Enter the base output directory: ").replace('"', ''))

# Add input to ask the user if they want to save images
save_images = input("Do you want to save processed frame images? (y/n): ").strip().lower() == 'y'

frame_skip = 60
segment_threshold = 0.2
window_size = 400
step_angle = 2  # Updated to draw segment boxes every x degrees

frame_start_input = input("Enter the starting frame number (leave blank to start from the first frame): ")
frame_start = int(frame_start_input) if frame_start_input else None
frame_end_input = input("Enter the ending frame number (leave blank to process until the end): ")
frame_end = int(frame_end_input) if frame_end_input else None

print("Video Paths:", video_paths)
print("Output Directory Base:", output_dir_base)
print("Frame Skip:", frame_skip)
print("Segment Threshold:", segment_threshold)
print("Window Size:", window_size)
print("Step Angle:", step_angle)
print("Frame Start:", frame_start)
print("Frame End:", frame_end)
print()

process_multiple_videos(video_paths, output_dir_base, frame_skip, segment_threshold, frame_start, frame_end)


Video Paths: []
Output Directory Base: .
Frame Skip: 60
Segment Threshold: 0.2
Window Size: 400
Step Angle: 2
Frame Start: None
Frame End: None



TypeError: 'NoneType' object cannot be interpreted as an integer