In [9]:
import xml.etree.ElementTree as ET
import pandas as pd
import tkinter as tk
from tkinter import filedialog
import os
import numpy as np
from IPython.display import display, clear_output
import ipywidgets as widgets
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.signal import find_peaks
from scipy.signal import find_peaks_cwt

## Process active motion xml

This is the initialization step where we define the starting folder and ask the user to confirm or change it.

### Obtain active motion trajectories

note here if you have issues with kernel: https://stackoverflow.com/questions/53004311/how-to-add-conda-environment-to-jupyter-lab

In [10]:
import os
import xml.etree.ElementTree as ET
import pandas as pd
import tkinter as tk
from tkinter import filedialog

def extract_trajectories_from_xml_to_df(xml_file_path):
    # Parse the XML file
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    
    # Initialize a list to store extracted data
    extracted_data = []
    
    # Extract the ultrasound_burst_frame value
    ultrasound_burst_frame = int(float(root.find('UltrasoundBurstFrame').text))
    
    # Iterate over each <Particle> element and extract x, y, frame, particle, size, displacement, total_distance_traveled, and speed values
    for particle in root.findall('Particles/Particle'):
        x = float(particle.find('x').text)
        y = float(particle.find('y').text)
        frame = int(float(particle.find('frame').text))  # Convert frame to float and then to int
        particle_id = int(particle.find('particle').text)  # Assuming 'particle' is the identifier
        size = float(particle.find('size').text)  # Extract size here
        displacement = float(particle.find('displacement').text)  # Extract displacement
        total_distance_traveled = float(particle.find('total_distance_traveled').text)  # Extract total_distance_traveled
        speed = float(particle.find('speed').text)  # Extract speed
        
        # Append the extracted values to the list as a dictionary
        extracted_data.append({
            'particle': particle_id,
            'frame': frame,
            'x': x,
            'y': y,
            'size': size,
            'displacement': displacement,
            'ultrasound_burst_frame': ultrasound_burst_frame,
            'total_distance_traveled': total_distance_traveled,
            'speed': speed
        })
    
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(extracted_data)
    
    # Sort by frame and then by particle
    df.sort_values(by=['frame', 'particle'], inplace=True)
    
    # Set the frame column as the DataFrame's index
    df.set_index('frame', inplace=True)
    
    # Reset index to make frame a column as well
    df.reset_index(inplace=True)
    
    # Calculate the time for each frame
    fps = 30
    df['time'] = df['frame'] / fps
    
    return df
    
def process_xml_files_recursively(folder_path):
    all_data = []
    
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".xml"):
                xml_file_path = os.path.join(root, file)
                df = extract_trajectories_from_xml_to_df(xml_file_path)
                
                # Extract the particle type, voltage input, and replicate from the file path
                file_parts = os.path.splitext(file)[0].split("-")
                particle_type = file_parts[0].strip()
                voltage_input = file_parts[1].replace("mVpp", "").strip()
                replicate = file_parts[2].split("_")[0].strip()
                
                # Add the particle type, voltage input, and replicate as new columns in the DataFrame
                df['particle_type'] = particle_type
                df['voltage_input'] = voltage_input
                df['replicate'] = replicate
                
                all_data.append(df)
    
    if all_data:
        combined_df = pd.concat(all_data, ignore_index=True)
        return combined_df
    else:
        return None

# Define the starting folder path
start_path = r'D:\Particle tracking'
root = tk.Tk()
root.withdraw()  # Hide the main window
root.lift()
root.attributes('-topmost', True)

try:
    xml_folder_path = filedialog.askdirectory(initialdir=start_path)
    if xml_folder_path:
        print(f"Selected folder path: {xml_folder_path}")
        combined_df = process_xml_files_recursively(xml_folder_path)
        if combined_df is not None:
            print("Combined DataFrame:")
            print(combined_df)
        else:
            print("No XML files found in the selected folder or its subfolders.")
    else:
        print("No folder selected.")
except FileNotFoundError:
    print("Folder selection canceled by the user.")
except Exception as e:
    print(f"An error occurred: {str(e)}")


# Example usage
# filtered_active_traj_filt = extract_trajectories_from_xml_to_df(xml_file_path)

# If you need to scale x, y, or size by a certain factor, you can do so here as shown for x and y
# filtered_active_traj_filt['x'] = filtered_active_traj_filt['x'] * 1.3
# filtered_active_traj_filt['y'] = filtered_active_traj_filt['y'] * 1.3

Selected folder path: E:/Particle tracking/Summer Semester Jun-Aug 2024/28JUN24 - Speed and distance profiles
Combined DataFrame:
       frame  particle          x          y    size  displacement  \
0          0         1  1916.4108   359.8824  2.0705        0.0000   
1          0         2  1340.3515   434.5234  3.2212        0.0000   
2          0         3   183.5619  1110.9484  2.8223        0.0000   
3          0         4  1339.8410  1458.8167  3.6826        0.0000   
4          0         5  1046.3659   898.8039  2.6239        0.0000   
...      ...       ...        ...        ...     ...           ...   
25344     88         1  1449.4597  1527.1134  6.6441        1.7478   
25345     89         1  1448.6724  1525.8729  6.9326        1.9101   
25346     90         1  1448.0645  1525.8544  6.7241        0.7906   
25347     91         1  1448.5866  1526.5003  6.6241        1.0796   
25348     92         1  1447.9776  1525.9566  6.8512        1.0613   

       ultrasound_burst_frame

### Plotting function

In [5]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib.ticker import MultipleLocator, ScalarFormatter

# Assuming combined_df is already created from the previous step

# Define constants
frame_rate = 30
max_time = 10  # Truncate data up to 10 seconds
plot_font_size = 24 # Font size for axis labels and title
tick_font_size = 22 # Font size for ticks
line_thickness = 3 # Thickness of the plot lines

# At the beginning of your script, after defining xml_folder_path:
compiled_plots_folder = os.path.join(xml_folder_path, 'Compiled plots')
os.makedirs(compiled_plots_folder, exist_ok=True)

# Create subfolders for each particle type
particle_types = combined_df['particle_type'].unique()
for particle_type in particle_types:
    particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
    os.makedirs(particle_type_folder, exist_ok=True)
    
    # Create 'Speed' and 'Total Distance Traveled' folders within each particle type folder
    speed_folder = os.path.join(particle_type_folder, 'Speed')
    total_distance_folder = os.path.join(particle_type_folder, 'Total Distance Traveled')
    os.makedirs(speed_folder, exist_ok=True)
    os.makedirs(total_distance_folder, exist_ok=True)
    
    # Create folders for each voltage level
    voltage_levels = combined_df['voltage_input'].unique()
    for voltage_level in voltage_levels:
        os.makedirs(os.path.join(speed_folder, f"{voltage_level} mVpp"), exist_ok=True)
        os.makedirs(os.path.join(total_distance_folder, f"{voltage_level} mVpp"), exist_ok=True)

def calculate_global_maximums(combined_df):
    max_distance = combined_df['total_distance_traveled'].max()
    max_speed = combined_df['speed'].max()
    
    # Add a small buffer (e.g., 10%) to the maximum values
    # max_distance *= 1.1
    max_distance = 3500
    
    #max_speed *= 1.1
    max_speed = 21000
    
    return max_distance, max_speed

# Calculate global maximums
global_max_distance, global_max_speed = calculate_global_maximums(combined_df)

def calculate_fourth_burst_time(burst_time):
    return burst_time + 3.5  # Assuming bursts are 1 second apart

def save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames):
    for voltage_level in voltage_levels:
        burst_time = burst_frames[voltage_level]
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plots without titles for saving
        fig1, ax1 = plt.subplots(figsize=(12, 6))
        fig2, ax2 = plt.subplots(figsize=(12, 6))
        
        # Plot total distance traveled
        for particle in voltage_data['particle'].unique():
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            ax1.plot(truncated_data['time'], truncated_data['total_distance_traveled'], label=f'Particle {particle}', linewidth=line_thickness)
        
        # Add vertical shaded areas for ultrasound bursts
        for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
            ax1.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
        
        ax1.set_xlabel('Time (seconds)', fontsize=plot_font_size)
        ax1.set_ylabel(r'Total Distance Traveled ($\mu m$)', fontsize=plot_font_size)
        ax1.set_xlim(0, fourth_burst_time)
        #ax1.set_ylim(0, global_max_distance * 1.1)
        ax1.set_ylim(0, global_max_distance)
        
        # Set x-axis ticks
        ax1.xaxis.set_major_locator(MultipleLocator(1))
        ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
        
        # Set y-axis ticks
        y_major_ticks = np.arange(0, global_max_distance * 1.1 + 1, 500)
        ax1.set_yticks(y_major_ticks)
        ax1.yaxis.set_minor_locator(MultipleLocator(250))
        
        ax1.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
        ax1.tick_params(axis='both', which='minor', length=4, width=1)
        
        # Use scientific notation for y-axis
        ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax1.yaxis.get_offset_text().set_fontsize(tick_font_size)
        
        # Plot instantaneous velocity
        for particle in voltage_data['particle'].unique():
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            ax2.plot(truncated_data['time'], truncated_data['speed'], label=f'Particle {particle}', linewidth=line_thickness)
        
        for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
            ax2.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
        
        ax2.set_xlabel('Time (seconds)', fontsize=plot_font_size)
        ax2.set_ylabel('Speed ($\mu m/s$)', fontsize=plot_font_size)
        ax2.set_xlim(0, fourth_burst_time)
        #ax2.set_ylim(0, global_max_speed * 1.1)
        ax2.set_ylim(0, global_max_speed)
        
        # Set x-axis ticks (same as for ax1)
        ax2.xaxis.set_major_locator(MultipleLocator(1))
        ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
        
        # Set y-axis ticks
        y_major_ticks = np.arange(0, global_max_speed * 1.1 + 1, 3000)
        ax2.set_yticks(y_major_ticks)
        ax2.yaxis.set_minor_locator(MultipleLocator(1500))
        
        ax2.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
        ax2.tick_params(axis='both', which='minor', width=1, length=4)
        
        # Use scientific notation for y-axis
        ax2.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax2.yaxis.get_offset_text().set_fontsize(tick_font_size)
        
        plt.tight_layout()
        
        # Save total distance traveled plot
        total_distance_path = os.path.join(compiled_plots_folder, particle_type, 'Total Distance Traveled', f"{voltage_level} mVpp")
        fig1.savefig(os.path.join(total_distance_path, f'{particle_type}_{voltage_level}mVpp_total_distance_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        # Save instantaneous velocity plot
        speed_path = os.path.join(compiled_plots_folder, particle_type, 'Speed', f"{voltage_level} mVpp")
        fig2.savefig(os.path.join(speed_path, f'{particle_type}_{voltage_level}mVpp_speed_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig1)
        plt.close(fig2)
    
    print(f"All plots saved for {particle_type}")

# Function to create interactive plot for a particle type
def create_interactive_plot(particle_type, global_max_distance, global_max_speed):
    particle_type_data = combined_df[combined_df['particle_type'] == particle_type]
    voltage_levels = sorted(particle_type_data['voltage_input'].unique())
    
    burst_frames = {}
    
    voltage_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(voltage_levels) - 1,
        step=1,
        description='Voltage:',
        continuous_update=False
    )
    
    burst_frame_slider = widgets.FloatSlider(
        value=0,
        min=0,
        max=max_time,
        step=0.1,
        description='Burst Time (s):',
        continuous_update=False
    )
    
    save_current_button = widgets.Button(description="Save Current Plot")
    save_all_button = widgets.Button(description=f"Save All Plots for {particle_type}")
    output = widgets.Output()

    def initialize_burst_frames():
        for voltage in voltage_levels:
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage]
            burst_frames[voltage] = voltage_data['ultrasound_burst_frame'].mean() / frame_rate

    def update_burst_frame(voltage_level):
        burst_frame_slider.value = burst_frames[voltage_level]

    def on_voltage_change(change):
        voltage_level = voltage_levels[change.new]
        update_burst_frame(voltage_level)
        update_plots(None)

    def on_burst_frame_change(change):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_frames[voltage_level] = change.new
        update_plots(None)

    def update_plots(change, include_title=True):
        with output:
            clear_output(wait=True)
            voltage_level = voltage_levels[voltage_slider.value]
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
            burst_time = burst_frame_slider.value
            
            # Calculate the time of the 4th burst
            fourth_burst_time = calculate_fourth_burst_time(burst_time)
            
            fig1, ax1 = plt.subplots(figsize=(12, 6))
            fig2, ax2 = plt.subplots(figsize=(12, 6))
            
            # Plot total distance traveled
            for particle in voltage_data['particle'].unique():
                particle_data = voltage_data[voltage_data['particle'] == particle]
                # Truncate data to 4th burst
                truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
                ax1.plot(truncated_data['time'], truncated_data['total_distance_traveled'], label=f'Particle {particle}', linewidth=line_thickness)
            
            # Add vertical shaded areas for ultrasound bursts
            for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
                ax1.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
            
            ax1.set_xlabel('Time (seconds)', fontsize=plot_font_size)
            ax1.set_ylabel(r'Total Distance Traveled ($\mu m$)', fontsize=plot_font_size)
            ax1.set_xlim(0, fourth_burst_time)
            #ax1.set_ylim(0, global_max_distance * 1.1)
            ax1.set_ylim(0, global_max_distance)
            
            # Set x-axis ticks
            ax1.xaxis.set_major_locator(MultipleLocator(1))
            ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
            
            # Set y-axis ticks
            y_major_ticks = np.arange(0, global_max_distance * 1.1 + 1, 500)
            ax1.set_yticks(y_major_ticks)
            ax1.yaxis.set_minor_locator(MultipleLocator(250))
            
            ax1.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
            ax1.tick_params(axis='both', which='minor', width=1, length=4)
            
            # Use scientific notation for y-axis
            ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            ax1.yaxis.get_offset_text().set_fontsize(tick_font_size)
            
            if include_title:
                ax1.set_title(f'{particle_type} - {voltage_level} mVpp', fontsize=plot_font_size)
            
            # Plot instantaneous velocity
            for particle in voltage_data['particle'].unique():
                particle_data = voltage_data[voltage_data['particle'] == particle]
                # Truncate data to 4th burst
                truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
                ax2.plot(truncated_data['time'], truncated_data['speed'], label=f'Particle {particle}', linewidth=line_thickness)
            
            for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
                ax2.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
            
            ax2.set_xlabel('Time (seconds)', fontsize=plot_font_size)
            ax2.set_ylabel('Speed ($\mu m/s$)', fontsize=plot_font_size)
            ax2.set_xlim(0, fourth_burst_time)
            #ax2.set_ylim(0, global_max_speed * 1.1)
            ax2.set_ylim(0, global_max_speed)
            
            # Set x-axis ticks (same as for ax1)
            ax2.xaxis.set_major_locator(MultipleLocator(1))
            ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
            
            # Set y-axis ticks
            y_major_ticks = np.arange(0, global_max_speed * 1.1 + 1, 3000)
            ax2.set_yticks(y_major_ticks)
            ax2.yaxis.set_minor_locator(MultipleLocator(1500))
            
            ax2.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
            ax2.tick_params(axis='both', which='minor', width=1, length=4)
            
            # Use scientific notation for y-axis
            ax2.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            ax2.yaxis.get_offset_text().set_fontsize(tick_font_size)
            
            if include_title:
                ax2.set_title(f'{particle_type} - {voltage_level} mVpp', fontsize=plot_font_size)
            
            plt.tight_layout()
            display(fig1)
            display(fig2)
            plt.close(fig1)
            plt.close(fig2)
        
    def save_current_plot(b):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_time = burst_frame_slider.value
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        # Generate plots without titles for saving
        fig1, ax1 = plt.subplots(figsize=(12, 6))
        fig2, ax2 = plt.subplots(figsize=(12, 6))
        
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Plot total distance traveled
        for particle in voltage_data['particle'].unique():
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            ax1.plot(truncated_data['time'], truncated_data['total_distance_traveled'], label=f'Particle {particle}', linewidth=line_thickness)
        
        # Add vertical shaded areas for ultrasound bursts
        for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
            ax1.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
        
        ax1.set_xlabel('Time (seconds)', fontsize=plot_font_size)
        ax1.set_ylabel(r'Total Distance Traveled ($\mu m$)', fontsize=plot_font_size)
        ax1.set_xlim(0, fourth_burst_time)
        #ax1.set_ylim(0, global_max_distance * 1.1)
        ax1.set_ylim(0, global_max_distance)
        
        # Set x-axis ticks
        ax1.xaxis.set_major_locator(MultipleLocator(1))
        ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
        
        # Set y-axis ticks
        y_major_ticks = np.arange(0, global_max_distance * 1.1 + 1, 500)
        ax1.set_yticks(y_major_ticks)
        ax1.yaxis.set_minor_locator(MultipleLocator(250))
        
        ax1.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
        ax1.tick_params(axis='both', which='minor', width=1, length=4)
        
        # Use scientific notation for y-axis
        ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax1.yaxis.get_offset_text().set_fontsize(tick_font_size)
        
        # Plot instantaneous velocity
        for particle in voltage_data['particle'].unique():
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            ax2.plot(truncated_data['time'], truncated_data['speed'], label=f'Particle {particle}', linewidth=line_thickness)
        
        for burst in np.arange(burst_time, fourth_burst_time + 1, 1):
            ax2.axvspan(burst - 0.3, burst + 0.3, color='gray', alpha=0.5)
        
        ax2.set_xlabel('Time (seconds)', fontsize=plot_font_size)
        ax2.set_ylabel('Speed ($\mu m/s$)', fontsize=plot_font_size)
        ax2.set_xlim(0, fourth_burst_time)
        #ax2.set_ylim(0, global_max_speed * 1.1)
        ax2.set_ylim(0, global_max_speed)
        
        # Set x-axis ticks (same as for ax1)
        ax2.xaxis.set_major_locator(MultipleLocator(1))
        ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
        
        # Set y-axis ticks
        y_major_ticks = np.arange(0, global_max_speed * 1.1 + 1, 3000)
        ax2.set_yticks(y_major_ticks)
        ax2.yaxis.set_minor_locator(MultipleLocator(1500))
        
        ax2.tick_params(axis='both', which='major', width=2, length=8, labelsize=tick_font_size)
        ax2.tick_params(axis='both', which='minor', width=1, length=4)
        
        # Use scientific notation for y-axis
        ax2.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax2.yaxis.get_offset_text().set_fontsize(tick_font_size)
        
        plt.tight_layout()
        
        # Save total distance traveled plot
        total_distance_path = os.path.join(compiled_plots_folder, particle_type, 'Total Distance Traveled', f"{voltage_level} mVpp")
        fig1.savefig(os.path.join(total_distance_path, f'{particle_type}_{voltage_level}mVpp_total_distance_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        # Save instantaneous velocity plot
        speed_path = os.path.join(compiled_plots_folder, particle_type, 'Speed', f"{voltage_level} mVpp")
        fig2.savefig(os.path.join(speed_path, f'{particle_type}_{voltage_level}mVpp_speed_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig1)
        plt.close(fig2)
        
        print(f"Plots saved for {particle_type} at {voltage_level} mVpp with burst time {burst_time:.1f}s")
        
        # Update the display with plots including titles
        update_plots(None, include_title=True)

    def save_all_plots_for_this_particle(b):
        save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames)
        print(f"All plots saved for {particle_type}")

    voltage_slider.observe(on_voltage_change, names='value')
    burst_frame_slider.observe(on_burst_frame_change, names='value')
    save_current_button.on_click(save_current_plot)
    save_all_button.on_click(save_all_plots_for_this_particle)

    initialize_burst_frames()
    update_burst_frame(voltage_levels[0])
    
    # Instead of displaying, return the widget layout and output
    widget_layout = widgets.VBox([
        widgets.HTML(f"<h3>Interactive plot for {particle_type}</h3>"),
        voltage_slider, 
        burst_frame_slider, 
        widgets.HBox([save_current_button, save_all_button]),
        output
    ])
    
    # Initial plot
    update_plots(None)

    return widget_layout, particle_type_data, voltage_levels, burst_frames

# New function to display all particle types
def display_all_particle_plots(global_max_distance, global_max_speed):
    all_widgets = []
    all_particle_widgets = {}
    for particle_type in particle_types:
        widget_layout, particle_type_data, voltage_levels, burst_frames = create_interactive_plot(particle_type, global_max_distance, global_max_speed)
        all_widgets.append(widget_layout)
        all_particle_widgets[particle_type] = (widget_layout, particle_type_data, voltage_levels, burst_frames)
    
    # Display all widget layouts
    display(widgets.VBox(all_widgets))
    
    return all_particle_widgets

# Calculate global maximums and call the function to display all plots
global_max_distance, global_max_speed = calculate_global_maximums(combined_df)
all_particle_widgets = display_all_particle_plots(global_max_distance, global_max_speed)


VBox(children=(VBox(children=(HTML(value='<h3>Interactive plot for DOPC HMSM</h3>'), IntSlider(value=0, contin…

In [25]:
combined_df

Unnamed: 0,frame,particle,x,y,size,displacement,ultrasound_burst_frame,total_distance_traveled,speed,time,particle_type,voltage_input,replicate
0,0,1,1916.4108,359.8824,2.0705,0.0000,90,0.0000,0.0000,0.000000,DOPC HMSM,0,1
1,0,2,1340.3515,434.5234,3.2212,0.0000,90,0.0000,0.0000,0.000000,DOPC HMSM,0,1
2,0,3,183.5619,1110.9484,2.8223,0.0000,90,0.0000,0.0000,0.000000,DOPC HMSM,0,1
3,0,4,1339.8410,1458.8167,3.6826,0.0000,90,0.0000,0.0000,0.000000,DOPC HMSM,0,1
4,0,5,1046.3659,898.8039,2.6239,0.0000,90,0.0000,0.0000,0.000000,DOPC HMSM,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
25344,88,1,1449.4597,1527.1134,6.6441,1.7478,90,111.7933,52.4341,2.933333,MSN,800,1
25345,89,1,1448.6724,1525.8729,6.9326,1.9101,90,113.7034,57.3042,2.966667,MSN,800,1
25346,90,1,1448.0645,1525.8544,6.7241,0.7906,90,114.4940,23.7167,3.000000,MSN,800,1
25347,91,1,1448.5866,1526.5003,6.6241,1.0796,90,115.5736,32.3873,3.033333,MSN,800,1


In [30]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
import pandas as pd
import matplotlib.ticker as ticker

# Define constants
frame_rate = 30
max_time = 10  # Truncate data up to 10 seconds
plot_font_size = 24 # Font size for axis labels and title
tick_font_size = 22 # Font size for ticks
line_thickness = 3 # Thickness of the plot lines
microns_per_pixel = 1.3  # Conversion factor

# Define constants for axis limits
X_MIN, X_MAX = 0, 2662.4
Y_MIN, Y_MAX = 0, 2662.4

# Assuming you have a DataFrame called 'combined_df' with the necessary columns
# If you don't have this DataFrame, you need to load or create it here

# Extract unique particle types
particle_types = combined_df['particle_type'].unique()
print(f"Particle types found: {particle_types}")

# Set up the output folder
xml_folder_path = r'E:\Particle tracking\Summer Semester Jun-Aug 2024\28JUN24 - Speed and distance profiles'
compiled_plots_folder = os.path.join(xml_folder_path, 'Compiled plots')

# Check if 'Compiled plots' folder exists, if not create it
if not os.path.exists(compiled_plots_folder):
    os.makedirs(compiled_plots_folder)
    print(f"Created 'Compiled plots' folder at {compiled_plots_folder}")
else:
    print(f"'Compiled plots' folder already exists at {compiled_plots_folder}")

def set_consistent_ax_properties(ax):
    # Set fixed axis limits
    ax.set_xlim(X_MIN, X_MAX)
    ax.set_ylim(Y_MAX, Y_MIN)  # Inverted y-axis

    # Set major ticks every 250 µm
    major_ticks = np.arange(0, 2501, 250)  # 0, 250, 500, ..., 2500
    ax.set_xticks(major_ticks)
    ax.set_yticks(major_ticks)

    # Set minor ticks every 125 µm
    minor_ticks = np.arange(0, X_MAX, 125)
    ax.set_xticks(minor_ticks, minor=True)
    ax.set_yticks(minor_ticks, minor=True)

    # Set tick parameters
    ax.tick_params(axis='both', which='major', labelsize=tick_font_size, length=6)
    ax.tick_params(axis='both', which='minor', length=3)

    # Remove 2662 label
    x_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_xticks()]
    y_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_yticks()]
    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set labels
    ax.set_xlabel('X Position (µm)', fontsize=plot_font_size)
    ax.set_ylabel('Y Position (µm)', fontsize=plot_font_size)

    # Remove gridlines
    ax.grid(False)

def calculate_fourth_burst_time(burst_time):
    return burst_time + 3.5  # Assuming bursts are 1 second apart

def save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames):
    particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
    trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
    os.makedirs(trajectories_folder, exist_ok=True)

    for voltage_level in voltage_levels:
        burst_time = burst_frames[voltage_level]
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
    
    print(f"All plots saved for {particle_type}")

def create_interactive_plot(particle_type):
    particle_type_data = combined_df[combined_df['particle_type'] == particle_type]
    voltage_levels = sorted(particle_type_data['voltage_input'].unique())
    
    burst_frames = {}
    
    voltage_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(voltage_levels) - 1,
        step=1,
        description='Voltage:',
        continuous_update=False
    )
    
    burst_frame_slider = widgets.FloatSlider(
        value=0,
        min=0,
        max=max_time,
        step=0.1,
        description='Burst Time (s):',
        continuous_update=False
    )
    
    save_current_button = widgets.Button(description="Save Current Plot")
    save_all_button = widgets.Button(description=f"Save All Plots for {particle_type}")
    output = widgets.Output()

    def initialize_burst_frames():
        for voltage in voltage_levels:
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage]
            burst_frames[voltage] = voltage_data['ultrasound_burst_frame'].mean() / frame_rate

    def update_burst_frame(voltage_level):
        burst_frame_slider.value = burst_frames[voltage_level]

    def on_voltage_change(change):
        voltage_level = voltage_levels[change.new]
        update_burst_frame(voltage_level)
        update_plots(None)

    def on_burst_frame_change(change):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_frames[voltage_level] = change.new
        update_plots(None)

    def update_plots(change, include_title=True):
        with output:
            clear_output(wait=True)
            voltage_level = voltage_levels[voltage_slider.value]
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
            burst_time = burst_frame_slider.value
            
            # Calculate the time of the 4th burst
            fourth_burst_time = calculate_fourth_burst_time(burst_time)
            
            fig, ax = plt.subplots(figsize=(12, 8))
            
            # Plot trajectories
            unique_particles = np.sort(voltage_data['particle'].unique())
            color_map = plt.get_cmap('viridis')
            colors = color_map(np.linspace(0, 1, len(unique_particles)))
            
            for i, particle in enumerate(unique_particles):
                particle_data = voltage_data[voltage_data['particle'] == particle]
                truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
                
                if not truncated_data.empty:
                    x_positions = truncated_data['x'] * microns_per_pixel
                    y_positions = truncated_data['y'] * microns_per_pixel
                    color = colors[i]
                    
                    ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
                    
                    # Only annotate if there's data
                    if len(x_positions) > 0:
                        ax.annotate(f'{particle}', xy=(x_positions.iloc[-1], y_positions.iloc[-1]), 
                                    xytext=(3, 3), textcoords='offset points', color=color)
            
            # Set consistent axis properties
            set_consistent_ax_properties(ax)
            
            if include_title:
                ax.set_title(f'{particle_type} - {voltage_level} mVpp', fontsize=plot_font_size)
            
            plt.tight_layout()
            display(fig)
            plt.close(fig)
    
    def save_current_plot(b):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_time = burst_frame_slider.value
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
        trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
        os.makedirs(trajectories_folder, exist_ok=True)
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
        
        print(f"Plot saved for {particle_type} at {voltage_level} mVpp with burst time {burst_time:.1f}s")
        
        # Update the display with plot including title and annotations
        update_plots(None, include_title=True)

    def save_all_plots_for_this_particle(b):
        save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames)
        print(f"All plots saved for {particle_type}")

    voltage_slider.observe(on_voltage_change, names='value')
    burst_frame_slider.observe(on_burst_frame_change, names='value')
    save_current_button.on_click(save_current_plot)
    save_all_button.on_click(save_all_plots_for_this_particle)

    initialize_burst_frames()
    update_burst_frame(voltage_levels[0])
    
    # Create the widget layout
    widget_layout = widgets.VBox([
        widgets.HTML(f"<h3>Interactive plot for {particle_type}</h3>"),
        voltage_slider, 
        burst_frame_slider, 
        widgets.HBox([save_current_button, save_all_button]),
        output
    ])
    
    # Initial plot
    update_plots(None)

    return widget_layout, particle_type_data, voltage_levels, burst_frames

def display_all_particle_plots():
    all_widgets = []
    all_particle_widgets = {}
    for particle_type in particle_types:
        widget_layout, particle_type_data, voltage_levels, burst_frames = create_interactive_plot(particle_type)
        all_widgets.append(widget_layout)
        all_particle_widgets[particle_type] = (widget_layout, particle_type_data, voltage_levels, burst_frames)
    
    # Display all widget layouts
    display(widgets.VBox(all_widgets))
    
    return all_particle_widgets

# Call the function to display all plots
all_particle_widgets = display_all_particle_plots()

Particle types found: ['DOPC HMSM' 'DOPC HMSN' 'MSM' 'MSN']
'Compiled plots' folder already exists at E:\Particle tracking\Summer Semester Jun-Aug 2024\28JUN24 - Speed and distance profiles\Compiled plots


VBox(children=(VBox(children=(HTML(value='<h3>Interactive plot for DOPC HMSM</h3>'), IntSlider(value=0, contin…

In [40]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

# Define constants
frame_rate = 30
max_time = 10  # Truncate data up to 10 seconds
plot_font_size = 24 # Font size for axis labels and title
tick_font_size = 22 # Font size for ticks
line_thickness = 3 # Thickness of the plot lines
microns_per_pixel = 1.3  # Conversion factor

# Define constants for axis limits
X_MIN, X_MAX = 0, 2662.4
Y_MIN, Y_MAX = 0, 2662.4

# Assuming you have a DataFrame called 'combined_df' with the necessary columns
# If you don't have this DataFrame, you need to load or create it here


# Extract unique particle types
particle_types = combined_df['particle_type'].unique()
print(f"Particle types found: {particle_types}")

# Set up the output folder
compiled_plots_folder = os.path.join(xml_folder_path, 'Compiled plots')

# Check if 'Compiled plots' folder exists, if not create it
if not os.path.exists(compiled_plots_folder):
    os.makedirs(compiled_plots_folder)
    print(f"Created 'Compiled plots' folder at {compiled_plots_folder}")
else:
    print(f"'Compiled plots' folder already exists at {compiled_plots_folder}")

def set_consistent_ax_properties(ax):
    # Set fixed axis limits
    ax.set_xlim(X_MIN, X_MAX)
    ax.set_ylim(Y_MAX, Y_MIN)  # Inverted y-axis

    # Set major ticks every 250 µm
    major_ticks = np.arange(0, 2501, 250)  # 0, 250, 500, ..., 2500
    ax.set_xticks(major_ticks)
    ax.set_yticks(major_ticks)

    # Set minor ticks every 125 µm
    minor_ticks = np.arange(0, X_MAX, 125)
    ax.set_xticks(minor_ticks, minor=True)
    ax.set_yticks(minor_ticks, minor=True)

    # Set tick parameters
    ax.tick_params(axis='both', which='major', labelsize=tick_font_size, length=6)
    ax.tick_params(axis='both', which='minor', length=3)

    # Remove 2662 label
    x_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_xticks()]
    y_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_yticks()]
    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set labels
    ax.set_xlabel('X Position (µm)', fontsize=plot_font_size)
    ax.set_ylabel('Y Position (µm)', fontsize=plot_font_size)

    # Remove gridlines
    ax.grid(False)

def calculate_fourth_burst_time(burst_time):
    return burst_time + 3.5  # Assuming bursts are 1 second apart

def save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames):
    particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
    trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
    os.makedirs(trajectories_folder, exist_ok=True)

    for voltage_level in voltage_levels:
        burst_time = burst_frames[voltage_level]
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
    
    print(f"All plots saved for {particle_type}")

def create_interactive_plot(particle_type):
    particle_type_data = combined_df[combined_df['particle_type'] == particle_type]
    voltage_levels = sorted(particle_type_data['voltage_input'].unique())
    
    burst_frames = {}
    
    voltage_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(voltage_levels) - 1,
        step=1,
        description='Voltage:',
        continuous_update=False
    )
    
    burst_frame_slider = widgets.FloatSlider(
        value=0,
        min=0,
        max=max_time,
        step=0.1,
        description='Burst Time (s):',
        continuous_update=False
    )
    
    save_current_button = widgets.Button(description="Save Current Plot")
    save_all_button = widgets.Button(description=f"Save All Plots for {particle_type}")
    output = widgets.Output()

    def initialize_burst_frames():
        for voltage in voltage_levels:
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage]
            burst_frames[voltage] = voltage_data['ultrasound_burst_frame'].mean() / frame_rate

    def update_burst_frame(voltage_level):
        burst_frame_slider.value = burst_frames[voltage_level]

    def on_voltage_change(change):
        voltage_level = voltage_levels[change.new]
        update_burst_frame(voltage_level)
        update_plots(None)

    def on_burst_frame_change(change):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_frames[voltage_level] = change.new
        update_plots(None)

    def update_plots(change, include_title=True):
        with output:
            clear_output(wait=True)
            voltage_level = voltage_levels[voltage_slider.value]
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
            burst_time = burst_frame_slider.value
            
            # Calculate the time of the 4th burst
            fourth_burst_time = calculate_fourth_burst_time(burst_time)
            
            fig, ax = plt.subplots(figsize=(12, 8))
            
            # Plot trajectories
            unique_particles = np.sort(voltage_data['particle'].unique())
            color_map = plt.get_cmap('viridis')
            colors = color_map(np.linspace(0, 1, len(unique_particles)))
            
            for i, particle in enumerate(unique_particles):
                particle_data = voltage_data[voltage_data['particle'] == particle]
                truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
                
                if not truncated_data.empty:
                    x_positions = truncated_data['x'] * microns_per_pixel
                    y_positions = truncated_data['y'] * microns_per_pixel
                    color = colors[i]
                    
                    ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
            
            # Set consistent axis properties
            set_consistent_ax_properties(ax)
            
            if include_title:
                ax.set_title(f'{particle_type} - {voltage_level} mVpp', fontsize=plot_font_size)
            
            plt.tight_layout()
            display(fig)
            plt.close(fig)
    
    def save_current_plot(b):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_time = burst_frame_slider.value
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
        trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
        os.makedirs(trajectories_folder, exist_ok=True)
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
        
        print(f"Plot saved for {particle_type} at {voltage_level} mVpp with burst time {burst_time:.1f}s")
        
        # Update the display with plot including title and annotations
        update_plots(None, include_title=True)

    def save_all_plots_for_this_particle(b):
        save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames)
        print(f"All plots saved for {particle_type}")

    voltage_slider.observe(on_voltage_change, names='value')
    burst_frame_slider.observe(on_burst_frame_change, names='value')
    save_current_button.on_click(save_current_plot)
    save_all_button.on_click(save_all_plots_for_this_particle)

    initialize_burst_frames()
    update_burst_frame(voltage_levels[0])
    
    # Create the widget layout
    widget_layout = widgets.VBox([
        widgets.HTML(f"<h3>Interactive plot for {particle_type}</h3>"),
        voltage_slider, 
        burst_frame_slider, 
        widgets.HBox([save_current_button, save_all_button]),
        output
    ])
    
    # Initial plot
    update_plots(None)

    return widget_layout, particle_type_data, voltage_levels, burst_frames

def display_all_particle_plots():
    all_widgets = []
    all_particle_widgets = {}
    for particle_type in particle_types:
        widget_layout, particle_type_data, voltage_levels, burst_frames = create_interactive_plot(particle_type)
        all_widgets.append(widget_layout)
        all_particle_widgets[particle_type] = (widget_layout, particle_type_data, voltage_levels, burst_frames)
    
    # Display all widget layouts
    display(widgets.VBox(all_widgets))
    
    return all_particle_widgets

# Call the function to display all plots
all_particle_widgets = display_all_particle_plots()


Particle types found: ['DOPC HMSM' 'DOPC HMSN' 'MSM' 'MSN']
'Compiled plots' folder already exists at E:\Particle tracking\Summer Semester Jun-Aug 2024\28JUN24 - Speed and distance profiles\Compiled plots


VBox(children=(VBox(children=(HTML(value='<h3>Interactive plot for DOPC HMSM</h3>'), IntSlider(value=0, contin…

In [42]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

# Define constants
frame_rate = 30
max_time = 10  # Truncate data up to 10 seconds
plot_font_size = 24 # Font size for axis labels and title
tick_font_size = 22 # Font size for ticks
line_thickness = 3 # Thickness of the plot lines
microns_per_pixel = 1.3  # Conversion factor

# Define constants for axis limits
X_MIN, X_MAX = 0, 2662.4
Y_MIN, Y_MAX = 0, 2662.4

# Assuming you have a DataFrame called 'combined_df' with the necessary columns
# If you don't have this DataFrame, you need to load or create it here


# Extract unique particle types
particle_types = combined_df['particle_type'].unique()
print(f"Particle types found: {particle_types}")

# Set up the output folder
compiled_plots_folder = os.path.join(xml_folder_path, 'Compiled plots')

# Check if 'Compiled plots' folder exists, if not create it
if not os.path.exists(compiled_plots_folder):
    os.makedirs(compiled_plots_folder)
    print(f"Created 'Compiled plots' folder at {compiled_plots_folder}")
else:
    print(f"'Compiled plots' folder already exists at {compiled_plots_folder}")

def set_consistent_ax_properties(ax):
    # Set fixed axis limits
    ax.set_xlim(X_MIN, X_MAX)
    ax.set_ylim(Y_MAX, Y_MIN)  # Inverted y-axis

    # Set major ticks every 250 µm
    major_ticks = np.arange(0, 2501, 250)  # 0, 250, 500, ..., 2500
    ax.set_xticks(major_ticks)
    ax.set_yticks(major_ticks)

    # Set minor ticks every 125 µm
    minor_ticks = np.arange(0, X_MAX, 125)
    ax.set_xticks(minor_ticks, minor=True)
    ax.set_yticks(minor_ticks, minor=True)

    # Set tick parameters
    ax.tick_params(axis='both', which='major', labelsize=tick_font_size, length=8)
    ax.tick_params(axis='both', which='minor', length=4)

    # Remove 2662 label
    x_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_xticks()]
    y_labels = [str(int(tick)) if tick != 2662.4 else '' for tick in ax.get_yticks()]
    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set labels
    ax.set_xlabel('X Position (µm)', fontsize=plot_font_size)
    ax.set_ylabel('Y Position (µm)', fontsize=plot_font_size)

    # Remove gridlines
    ax.grid(False)

def calculate_fourth_burst_time(burst_time):
    return burst_time + 3.5  # Assuming bursts are 1 second apart

def save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames):
    particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
    trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
    os.makedirs(trajectories_folder, exist_ok=True)

    for voltage_level in voltage_levels:
        burst_time = burst_frames[voltage_level]
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color)
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
    
    print(f"All plots saved for {particle_type}")

def create_interactive_plot(particle_type):
    particle_type_data = combined_df[combined_df['particle_type'] == particle_type]
    voltage_levels = sorted(particle_type_data['voltage_input'].unique())
    
    burst_frames = {}
    
    voltage_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(voltage_levels) - 1,
        step=1,
        description='Voltage:',
        continuous_update=False
    )
    
    burst_frame_slider = widgets.FloatSlider(
        value=0,
        min=0,
        max=max_time,
        step=0.1,
        description='Burst Time (s):',
        continuous_update=False
    )
    
    save_current_button = widgets.Button(description="Save Current Plot")
    save_all_button = widgets.Button(description=f"Save All Plots for {particle_type}")
    output = widgets.Output()

    def initialize_burst_frames():
        for voltage in voltage_levels:
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage]
            burst_frames[voltage] = voltage_data['ultrasound_burst_frame'].mean() / frame_rate

    def update_burst_frame(voltage_level):
        burst_frame_slider.value = burst_frames[voltage_level]

    def on_voltage_change(change):
        voltage_level = voltage_levels[change.new]
        update_burst_frame(voltage_level)
        update_plots(None)

    def on_burst_frame_change(change):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_frames[voltage_level] = change.new
        update_plots(None)

    def update_plots(change, include_title=True):
        with output:
            clear_output(wait=True)
            voltage_level = voltage_levels[voltage_slider.value]
            voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
            burst_time = burst_frame_slider.value
            
            # Calculate the time of the 4th burst
            fourth_burst_time = calculate_fourth_burst_time(burst_time)
            
            fig, ax = plt.subplots(figsize=(12, 8))
            
            # Plot trajectories
            unique_particles = np.sort(voltage_data['particle'].unique())
            color_map = plt.get_cmap('viridis')
            colors = color_map(np.linspace(0, 1, len(unique_particles)))
            
            for i, particle in enumerate(unique_particles):
                particle_data = voltage_data[voltage_data['particle'] == particle]
                truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
                
                if not truncated_data.empty:
                    x_positions = truncated_data['x'] * microns_per_pixel
                    y_positions = truncated_data['y'] * microns_per_pixel
                    color = colors[i]
                    
                    ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color, label=f'Particle {particle}')
                    
                    # Annotate particles
                    if len(x_positions) > 0:
                        ax.annotate(f'{particle}', xy=(x_positions.iloc[-1], y_positions.iloc[-1]), 
                                    xytext=(3, 3), textcoords='offset points', color=color)
            
            # Set consistent axis properties
            set_consistent_ax_properties(ax)
            
            if include_title:
                ax.set_title(f'{particle_type} - {voltage_level} mVpp', fontsize=plot_font_size)
            
            plt.tight_layout()
            display(fig)
            plt.close(fig)
    
    def save_current_plot(b):
        voltage_level = voltage_levels[voltage_slider.value]
        burst_time = burst_frame_slider.value
        
        # Calculate the time of the 4th burst
        fourth_burst_time = calculate_fourth_burst_time(burst_time)
        
        # Generate plot without title for saving
        fig, ax = plt.subplots(figsize=(12, 8))  # Changed to make plot longer
        
        voltage_data = particle_type_data[particle_type_data['voltage_input'] == voltage_level]
        
        # Plot trajectories
        unique_particles = np.sort(voltage_data['particle'].unique())
        color_map = plt.get_cmap('viridis')
        colors = color_map(np.linspace(0, 1, len(unique_particles)))
        
        for i, particle in enumerate(unique_particles):
            particle_data = voltage_data[voltage_data['particle'] == particle]
            truncated_data = particle_data[particle_data['time'] <= fourth_burst_time]
            
            if not truncated_data.empty:
                x_positions = truncated_data['x'] * microns_per_pixel
                y_positions = truncated_data['y'] * microns_per_pixel
                color = colors[i]
                
                ax.plot(x_positions, y_positions, linestyle='-', linewidth=line_thickness, color=color)
        
        # Set consistent axis properties
        set_consistent_ax_properties(ax)
        
        plt.tight_layout()
        
        # Save trajectory plot
        particle_type_folder = os.path.join(compiled_plots_folder, particle_type)
        trajectories_folder = os.path.join(particle_type_folder, 'Trajectories')
        os.makedirs(trajectories_folder, exist_ok=True)
        voltage_folder = os.path.join(trajectories_folder, f"{voltage_level} mVpp")
        os.makedirs(voltage_folder, exist_ok=True)
        fig.savefig(os.path.join(voltage_folder, f'{particle_type}_{voltage_level}mVpp_trajectories_burst{burst_time:.1f}s.png'), dpi=600, bbox_inches='tight')
        
        plt.close(fig)
        
        print(f"Plot saved for {particle_type} at {voltage_level} mVpp with burst time {burst_time:.1f}s")
        
        # Update the display with plot including title and annotations
        update_plots(None, include_title=True)

    def save_all_plots_for_this_particle(b):
        save_all_plots_for_particle(particle_type, particle_type_data, voltage_levels, burst_frames)
        print(f"All plots saved for {particle_type}")

    voltage_slider.observe(on_voltage_change, names='value')
    burst_frame_slider.observe(on_burst_frame_change, names='value')
    save_current_button.on_click(save_current_plot)
    save_all_button.on_click(save_all_plots_for_this_particle)

    initialize_burst_frames()
    update_burst_frame(voltage_levels[0])
    
    # Create the widget layout
    widget_layout = widgets.VBox([
        widgets.HTML(f"<h3>Interactive plot for {particle_type}</h3>"),
        voltage_slider, 
        burst_frame_slider, 
        widgets.HBox([save_current_button, save_all_button]),
        output
    ])
    
    # Initial plot
    update_plots(None)

    return widget_layout, particle_type_data, voltage_levels, burst_frames

def display_all_particle_plots():
    all_widgets = []
    all_particle_widgets = {}
    for particle_type in particle_types:
        widget_layout, particle_type_data, voltage_levels, burst_frames = create_interactive_plot(particle_type)
        all_widgets.append(widget_layout)
        all_particle_widgets[particle_type] = (widget_layout, particle_type_data, voltage_levels, burst_frames)
    
    # Display all widget layouts
    display(widgets.VBox(all_widgets))
    
    return all_particle_widgets

# Call the function to display all plots
all_particle_widgets = display_all_particle_plots()


Particle types found: ['DOPC HMSM' 'DOPC HMSN' 'MSM' 'MSN']
'Compiled plots' folder already exists at E:\Particle tracking\Summer Semester Jun-Aug 2024\28JUN24 - Speed and distance profiles\Compiled plots


VBox(children=(VBox(children=(HTML(value='<h3>Interactive plot for DOPC HMSM</h3>'), IntSlider(value=0, contin…