In [None]:
public_data_path = '/kaggle/input/anomalous-diffusion-challenge/' # make sure the folder has this name or change it

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.stats import norm
from scipy.stats import anderson
from scipy.signal import find_peaks, argrelextrema

np.random.seed(0)

In [None]:
def fill_short_segments_np(labels, minseq=3):
    """
    Fills short segments in the labels array with neighboring values.
    
    Args:
    labels (numpy.array): Input array of labels
    minseq (int): Minimum sequence length to keep unchanged (default: 3)
    
    Returns:
    numpy.array: Modified array with short segments filled
    """
    
    # Recursive call for larger minseq values
    if minseq > 2:
        labels = fill_short_segments_np(labels, minseq=minseq - 1)
    
    # Create a copy of the input array to avoid modifying the original
    labels = labels.copy()
    n = len(labels)
    
    # Find the change points (boundaries between different labels)
    change_points = np.where(np.diff(labels) != 0)[0] + 1
    
    # Split the array into segments based on change points
    segments = np.split(np.arange(n), change_points)
    
    for segment in segments:
        # If segment length is less than minseq, fill it with neighboring values
        if len(segment) < minseq:
            # Determine left and right neighbor indices
            left_neighbor_idx = segment[0] - 1 if segment[0] > 0 else None
            right_neighbor_idx = segment[-1] + 1 if segment[-1] < n - 1 else None
            
            # Choose fill value based on available neighbors
            if left_neighbor_idx is not None:
                fill_value = labels[left_neighbor_idx]
            elif right_neighbor_idx is not None:
                fill_value = labels[right_neighbor_idx]
            else:
                # If no neighbors are available, leave the segment unchanged
                continue
            
            # Fill the segment with the chosen value
            labels[segment] = fill_value
    
    return labels

In [None]:
def fill_short_segments(labels, minseq=3, fill_value=None):
    """
    Fills short segments in the labels array with neighboring values.
    
    Args:
    labels (pd.Series or np.ndarray): Input array of labels.
    minseq (int): Minimum length of segment to keep unchanged.
    fill_value (Any, optional): Value to fill. If None, all short segments are filled.
    
    Returns:
    pd.Series or np.ndarray: Array with filled short segments.
    """
    # Recursive call for larger minseq values
    if minseq > 2:
        labels = fill_short_segments(labels, minseq=minseq - 1, fill_value=fill_value)    
    
    # Create a copy of the input array to avoid modifying the original
    labels = labels.copy()
    n = len(labels)
    
    # Determine if the input is a pandas Series
    is_pandas = isinstance(labels, pd.Series)
    
    # Find the change points (boundaries between different labels)
    change_points = np.where(np.diff(labels) != 0)[0] + 1
    
    # Split the array into segments based on change points
    segments = np.split(np.arange(n), change_points)
    
    for segment in segments:
        # If segment length is less than minseq, check if it needs to be filled
        if len(segment) < minseq:
            current_value = labels.iloc[segment[0]] if is_pandas else labels[segment[0]]
            
            # If fill_value is not specified or current value equals fill_value
            if fill_value is None or current_value == fill_value:
                left_neighbor_idx = segment[0] - 1 if segment[0] > 0 else None
                right_neighbor_idx = segment[-1] + 1 if segment[-1] < n - 1 else None
                
                # Choose fill value based on available neighbors
                if is_pandas:
                    if left_neighbor_idx is not None and left_neighbor_idx in labels.index:
                        new_value = labels.iloc[left_neighbor_idx]
                    elif right_neighbor_idx is not None and right_neighbor_idx in labels.index:
                        new_value = labels.iloc[right_neighbor_idx]
                    else:
                        continue
                else:
                    if left_neighbor_idx is not None:
                        new_value = labels[left_neighbor_idx]
                    elif right_neighbor_idx is not None:
                        new_value = labels[right_neighbor_idx]
                    else:
                        continue
                
                # Fill the segment
                if is_pandas:
                    labels.iloc[segment] = new_value
                else:
                    labels[segment] = new_value
    
    return labels

In [None]:
#################### The KEY FUNCTION of our METHOD ##############################################################
def RowHurst(X1, X2, n):
    """
    Compute the AD exponent for each row of the input matrices.

    Args:
    X1 (numpy.ndarray): SD for 1 time step 
    X2 (numpy.ndarray): SD for n time steps 
    n (int): number of time steps for X2

    Returns:
    numpy.ndarray: Array of AD exponents for each row
    """
    # Computing the AD exponent
    res = np.log(np.nanmedian(X1, axis=1) / np.nanmedian(X2, axis=1)) / np.log(n)
    return res
#################### The end of KEY FUNCTION of our METHOD ##############################################################


In [None]:
def anderson_statistic(group):
    """
    Compute the Anderson-Darling test statistic for a group of data.

    Args:
    group (pandas.Series or numpy.array): Input data group

    Returns:
    float or np.NaN: Anderson-Darling test statistic or NaN if insufficient data
    """
    # Check if there are at least 30 non-NaN values in the group
    if len(group.dropna()) < 30:
        return np.NaN
    
    # Perform Anderson-Darling test for exponential distribution
    result = anderson(group.dropna(), dist='expon')
    
    # Return the test statistic
    return result.statistic

In [None]:
def get_freedom(current_class, alpha):
    """
    Determine the state of freedom based on the current class, alpha value, and model type prediction.

    Args:
    current_class (int): The current class of the particle motion 
    alpha (float): The anomalous diffusion exponent

    Returns:
    int: The state of freedom (0, 1, 2, or 3)

    Note: This function uses a global variable 'model_type_prediction'.
    """
    # Determine the initial state based on alpha value
    if alpha > 1.89:
        st = 3
    else:
        st = 2
    
    # Special case: Immobile traps
    if current_class == 0 and model_type_prediction == 'immobile_traps':   
        return 0
    # Special case: Confinement 
    elif (alpha < 0.6 or exp==5 and current_class==1) and model_type_prediction == 'confinement':      
        return 1
    # Default case: Return the initial state
    else:
        return st

In [None]:
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.signal import argrelextrema

def visualize_rolling_stats(df, big=1, num_bins=30, min_exp=-6, max_exp=5, peak_height=0.05):
    """
    Visualize rolling statistics data with both normal and density histograms, peaks, and minima for all relevant columns.
    
    :param df: DataFrame containing the data
    :param big: Scaling factor for the data (default: 1)
    :param num_bins: Number of bins for the histogram (default: 80)
    :param min_exp: Minimum exponent for log-scale bins (default: -7)
    :param max_exp: Maximum exponent for log-scale bins (default: 6)
    :param peak_height: Minimum height for peak detection (default: 0.05)
    """
    # Filter columns
    columns = [col for col in df.columns if 'rolling' in col and 'alpha' not in col]
    
    # Calculate number of rows needed (1 row per statistic, 2 columns)
    num_plots = len(columns)
    num_rows = num_plots

    # Create logarithmically spaced bins
    bins = np.exp(np.linspace(min_exp, max_exp, num_bins))
    
    # Create subplots
    fig, axs = plt.subplots(num_rows, 2, figsize=(20, 5*num_rows), squeeze=False)
    fig.suptitle('Rolling Statistics Visualization', fontsize=16)

    for idx, column_name in enumerate(columns):
        data = big * df[column_name].dropna()

        # Normal histogram (density=False)
        ax_normal = axs[idx, 0]
        hist_normal, bin_edges = np.histogram(data, bins=bins, density=False)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # Find peaks and minima for normal histogram
        peaks_normal, _ = find_peaks(hist_normal, height=max(hist_normal) * peak_height)
        minima_normal = argrelextrema(hist_normal, np.less, order=2)[0]
        
        ax_normal.set_xscale('log')
        ax_normal.stairs(hist_normal, bin_edges, fill=True, alpha=0.7, label='Histogram')
        ax_normal.plot(bin_centers[peaks_normal], hist_normal[peaks_normal], "x", color='red', label='Peaks')
        ax_normal.plot(bin_centers[minima_normal], hist_normal[minima_normal], "o", color='green', label='Minima')
        ax_normal.set_title(f'{column_name} (Absolute Frequency)  {bin_centers[minima_normal]}', fontsize=10)
        ax_normal.set_xlabel('Value (log scale)', fontsize=8)
        ax_normal.set_ylabel('Frequency', fontsize=8)
        ax_normal.tick_params(axis='both', which='major', labelsize=6)
        ax_normal.tick_params(axis='both', which='minor', labelsize=4)
        ax_normal.grid(True, which="both", ls="-", alpha=0.2)
        if idx == 0:
            ax_normal.legend(fontsize=6)

        # Density histogram (density=True)
        ax_density = axs[idx, 1]
        hist_density, _ = np.histogram(data, bins=bins, density=True)
        
        # Find peaks and minima for density histogram
        peaks_density, _ = find_peaks(hist_density, height=peak_height)
        minima_density = argrelextrema(hist_density, np.less, order=2)[0]
        
        ax_density.set_xscale('log')
        ax_density.stairs(hist_density, bin_edges, fill=True, alpha=0.7, label='Histogram')
        ax_density.plot(bin_centers[peaks_density], hist_density[peaks_density], "x", color='red', label='Peaks')
        ax_density.plot(bin_centers[minima_density], hist_density[minima_density], "o", color='green', label='Minima')
        
        ax_density.set_title(f'{column_name} (Density)  {bin_centers[minima_density]}', fontsize=10)
        ax_density.set_xlabel('Value (log scale)', fontsize=8)
        ax_density.set_ylabel('Density', fontsize=8)
        ax_density.tick_params(axis='both', which='major', labelsize=6)
        ax_density.tick_params(axis='both', which='minor', labelsize=4)
        ax_density.grid(True, which="both", ls="-", alpha=0.2)
        if idx == 0:
            ax_density.legend(fontsize=6)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:
def visualize_traj(big_df1):
    """
    Visualize trajectories from big_df1 on a rectangular field defined by big_df1['x'] and big_df1['y'].
    
    Args:
    big_df1 (pandas.DataFrame): DataFrame containing 'x', 'y', and 'traj_idx' columns
    
    Returns:
    None (displays the plot)
    """
    # Create a new figure and axis with specified size
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create a scatter plot
    scatter = ax.scatter(big_df1['x'], big_df1['y'], 
                         c=big_df1['traj_idx'],  # Color points by trajectory index
                         cmap='viridis',  # Use viridis colormap
                         s=4,  # Set point size
                         alpha=0.6)  # Set point transparency
    
    # Add a colorbar
    cbar = plt.colorbar(scatter)
    cbar.set_label('Trajectory Index', rotation=270, labelpad=15)
    
    # Set axis labels and title
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title('Visualization of Trajectories on X-Y Field')
    
    # Add a grid for better orientation
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()

In [None]:
from scipy.signal import find_peaks, argrelextrema

def divide_single_statistic_1only(df, statistic_type='median', statistic_name='9', statistic_density=False, num_bins=30, min_exp=-6, max_exp=4, peak_height=0.05, visualize=False):
    """
    Visualize a single statistic with both normal and density histograms, peaks, and minima for the entire dataframe.
    
    Args:
    df (pandas.DataFrame): DataFrame containing the data
    statistic_type (str): Type of statistic (default: 'median')
    statistic_name (str): Name of the column to visualize (default: '9')
    statistic_density (bool): Whether to use density for histogram (default: False)
    num_bins (int): Number of bins for the histogram (default: 30)
    min_exp (float): Minimum exponent for log-scale bins (default: -6)
    max_exp (float): Maximum exponent for log-scale bins (default: 4)
    peak_height (float): Minimum height for peak detection (default: 0.05)
    visualize (bool): Whether to visualize the results (default: False)

    Returns:
    numpy.ndarray: Labels array based on the analysis
    """
    # Create logarithmically spaced bins
    bins = np.exp(np.linspace(min_exp, max_exp, num_bins))
    labels = df['rolling_'+statistic_type+'4_'+statistic_name] * 0    
    
    # Create subplots if visualization is enabled
    if visualize:
        fig, (ax_normal, ax_density) = plt.subplots(1, 2, figsize=(20, 5))
        fig.suptitle(f'Visualization of {statistic_name}', fontsize=16)
    
    # Iterate through different rolling window sizes
    for k in [1, 3, 4, 2]:
        data = df['rolling_'+statistic_type+str(k)+'_'+statistic_name]
    
        # Compute histogram
        hist_normal, bin_edges = np.histogram(data, bins=bins, density=statistic_density)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # Find peaks in the histogram
        peaks_normal, _ = find_peaks(hist_normal, height=max(hist_normal) * peak_height)
        
        if len(peaks_normal) == 2:
            # If exactly two peaks are found
            left_peak, right_peak = peaks_normal
            # Find the minimum between the two peaks
            segment = hist_normal[left_peak:right_peak+1]
            min_index = np.argmin(segment) + left_peak 
            min_value = bin_centers[min_index]
            
            print(f"k={k} minimum={min_value} peaks {bin_centers[peaks_normal]}   {'rolling_'+statistic_type+str(k)+'_'+statistic_name}")
            
            if visualize:
                # Visualize the histogram, peaks, and minimum
                ax_normal.set_xscale('log')
                ax_normal.stairs(hist_normal, bin_edges, fill=True, alpha=0.7, label='Histogram')
                ax_normal.plot(bin_centers[peaks_normal], hist_normal[peaks_normal], "x", color='red', label='Peaks')
                ax_normal.plot(bin_centers[min_index], hist_normal[min_index], "o", color='green', label='Minima')
                ax_normal.set_title(f'{statistic_name} min= {min_value}', fontsize=12)
                ax_normal.set_xlabel('Value (log scale)', fontsize=10)
                ax_normal.set_ylabel('Frequency', fontsize=10)
                ax_normal.tick_params(axis='both', which='major', labelsize=8)
                ax_normal.tick_params(axis='both', which='minor', labelsize=6)
                ax_normal.grid(True, which="both", ls="-", alpha=0.2)
                ax_normal.legend(fontsize=8)
                plt.show()
            
            # Create labels based on the minimum value
            labels = fill_short_segments((data > min_value).astype(int), minseq=3, fill_value=0)
            labels = fill_short_segments(labels, minseq=3, fill_value=1)
            
            break
    
    return labels

In [None]:
def visualize_single_statistic(df, statistic_name, big=1, num_bins=30, min_exp=-6, max_exp=4, peak_height=0.05):
    """
    Visualize a single statistic with both normal and density histograms, peaks, and minima for the entire dataframe.
    
    :param df: DataFrame containing the data
    :param statistic_name: Name of the column to visualize
    :param big: Scaling factor for the data (default: 1)
    :param num_bins: Number of bins for the histogram (default: 30)
    :param min_exp: Minimum exponent for log-scale bins (default: -6)
    :param max_exp: Maximum exponent for log-scale bins (default: 4)
    :param peak_height: Minimum height for peak detection (default: 0.05)
    """
    # Create logarithmically spaced bins
    bins = np.exp(np.linspace(min_exp, max_exp, num_bins))
    
    # Create subplots
    fig, (ax_normal, ax_density) = plt.subplots(1, 2, figsize=(20, 5))
    fig.suptitle(f'Visualization of {statistic_name}', fontsize=16)
    
    # Prepare data
    data = big * df[statistic_name].dropna()
    
    # Normal histogram (density=False)
    hist_normal, bin_edges = np.histogram(data, bins=bins, density=False)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Find peaks and minima for normal histogram
    peaks_normal, _ = find_peaks(hist_normal, height=max(hist_normal) * peak_height)
    minima_normal = argrelextrema(hist_normal, np.less, order=2)[0]
    
    ax_normal.set_xscale('log')
    ax_normal.stairs(hist_normal, bin_edges, fill=True, alpha=0.7, label='Histogram')
    ax_normal.plot(bin_centers[peaks_normal], hist_normal[peaks_normal], "x", color='red', label='Peaks')
    ax_normal.plot(bin_centers[minima_normal], hist_normal[minima_normal], "o", color='green', label='Minima')
    ax_normal.set_title(f'{statistic_name} (Absolute Frequency)', fontsize=12)
    ax_normal.set_xlabel('Value (log scale)', fontsize=10)
    ax_normal.set_ylabel('Frequency', fontsize=10)
    ax_normal.tick_params(axis='both', which='major', labelsize=8)
    ax_normal.tick_params(axis='both', which='minor', labelsize=6)
    ax_normal.grid(True, which="both", ls="-", alpha=0.2)
    ax_normal.legend(fontsize=8)
    
    # Density histogram (density=True)
    hist_density, _ = np.histogram(data, bins=bins, density=True)
    
    # Find peaks and minima for density histogram
    peaks_density, _ = find_peaks(hist_density, height=peak_height)
    minima_density = argrelextrema(hist_density, np.less, order=2)[0]
    
    ax_density.set_xscale('log')
    ax_density.stairs(hist_density, bin_edges, fill=True, alpha=0.7, label='Histogram')
    ax_density.plot(bin_centers[peaks_density], hist_density[peaks_density], "x", color='red', label='Peaks')
    ax_density.plot(bin_centers[minima_density], hist_density[minima_density], "o", color='green', label='Minima')
    
    ax_density.set_title(f'{statistic_name} (Density)', fontsize=12)
    ax_density.set_xlabel('Value (log scale)', fontsize=10)
    ax_density.set_ylabel('Density', fontsize=10)
    ax_density.tick_params(axis='both', which='major', labelsize=8)
    ax_density.tick_params(axis='both', which='minor', labelsize=6)
    ax_density.grid(True, which="both", ls="-", alpha=0.2)
    ax_density.legend(fontsize=8)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:
def divide_single_statistic(df, statistic_type='mean', statistic_name='15', statistic_density='True', num_bins=30, min_exp=-6, max_exp=4, peak_height=0.05):
    """
    Analyze a single statistic from the DataFrame, find local minima in its distribution,
    and create labels based on the rightmost minimum.

    Args:
    df (pandas.DataFrame): DataFrame containing the data
    statistic_type (str): Type of statistic (default: 'mean')
    statistic_name (str): Name of the column to analyze (default: '15')
    statistic_density (str): Whether to use density for histogram ('True' or 'False')
    num_bins (int): Number of bins for the histogram (default: 30)
    min_exp (float): Minimum exponent for log-scale bins (default: -6)
    max_exp (float): Maximum exponent for log-scale bins (default: 4)
    peak_height (float): Minimum height for peak detection (unused in this function)

    Returns:
    numpy.ndarray: Labels array based on the analysis
    """
    # Create logarithmically spaced bins
    bins = np.exp(np.linspace(min_exp, max_exp, num_bins))
    
    # Initialize labels array
    labels = df['rolling_'+statistic_type+'4_'+statistic_name] * 0
    
    # Iterate through different rolling window sizes in reverse order
    for k in [4, 3, 2, 1]:
        # Extract data for the current window size
        data = df['rolling_'+statistic_type+str(k)+'_'+statistic_name]
        
        # Compute histogram
        hist, bin_edges = np.histogram(data, bins=bins, density=statistic_density)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # Find local minima in the histogram
        minima = argrelextrema(hist, np.less, order=2)[0]    
        extrems = bin_centers[minima]
        
        if len(extrems):
            # If local minima are found, use the rightmost one as a threshold
            threshold = extrems[-1]
            
            # Create labels based on whether the data is above the threshold
            labels = fill_short_segments((df['rolling_'+statistic_type+str(k)+'_'+statistic_name] > threshold).astype(int), 3)
            
            print(f'k={k} extr={threshold}')
            break
    
    return labels

In [None]:
def visualize_rolling_median(big_df1):
    """
    Visualize the 'class' values from big_df1 on a rectangular field defined by big_df1['x'] and big_df1['y'].
    
    Args:
    big_df1 (pandas.DataFrame): DataFrame containing 'x', 'y', and 'class' columns
    
    Returns:
    None (displays the plot)
    """
    # Create a new figure and axis with specified size
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create a scatter plot
    scatter = ax.scatter(big_df1['x'], big_df1['y'], 
                         c=big_df1['class'],  # Color points by class
                         cmap='plasma',  # Use plasma colormap
                         s=5,  # Set point size
                         alpha=0.6)  # Set point transparency
    
    # Add a colorbar
    cbar = plt.colorbar(scatter)
    cbar.set_label('Class', rotation=270, labelpad=15)
    
    # Set axis labels and title
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title('Visualization of Classes on X-Y Field')
    
    # Add a grid for better orientation
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()

In [None]:
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from scipy.stats import norm

def get_dominant_gaussian_params(data_series, n_components=2, class_name='K', visualize=True):
    """
    Function to determine parameters of the dominant Gaussian from a mixture of Gaussians
    and visualize the data distribution.
    
    Parameters:
    data_series (pandas.Series): Series of 'alpha' data for a specific class.
    n_components (int): Number of components in the Gaussian mixture.
    class_name (str, optional): Class name for the plot title.
    visualize (bool): Flag to enable/disable visualization.
    
    Returns:
    tuple: (mu, sigma) of the dominant Gaussian. Returns (None, None) if an error occurs.
    """
    try:
        # Remove NaN values and convert to numpy array
        data = data_series.dropna().values.reshape(-1, 1)
        
        # Check if there's enough data for GMM
        if len(data) < 2:
            print("Not enough data for GMM.")
            return None, None
        
        # Fit GMM with the specified number of components
        gmm = GaussianMixture(n_components=n_components, random_state=42)
        gmm.fit(data)
        
        # Get parameters of all Gaussians
        means = gmm.means_.flatten()
        sigmas = np.sqrt(gmm.covariances_.flatten())
        weights = gmm.weights_
        
        # Determine which Gaussian has the highest weight
        dominant_index = np.argmax(weights)
        mu, sigma = means[dominant_index], sigmas[dominant_index]
        
        if visualize:
            plt.figure(figsize=(10, 6))
            
            # Data histogram
            plt.hist(data, bins=30, density=True, alpha=0.7, color='skyblue')
            
            # Plot all Gaussians
            x = np.linspace(data.min(), data.max(), 100)
            for i in range(n_components):
                plt.plot(x, weights[i]*norm.pdf(x, means[i], sigmas[i]), 
                         f'C{i}-', lw=2, label=f'Gaussian {i+1}')
            
            # Highlight the dominant Gaussian and median
            plt.axvline(mu, color='k', linestyle='--', lw=2, label='μ of dominant Gaussian')
            plt.axvline(np.median(data), color='g', linestyle='--', lw=2, label='median')
            
            # Configure the plot
            plt.title(f'μ = {mu:.5f}, median = {np.median(data):.5f}, σ = {sigma:.5f}')
            plt.xlabel(class_name)
            plt.ylabel('Density')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
        
        return mu, sigma
    
    except Exception as e:
        print(f"An error occurred while processing the data: {e}")
        return None, None

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
from scipy.optimize import minimize

def truncated_normal_pdf(x, mu, sigma, a=0, b=2):
    """Calculate the probability density function of a truncated normal distribution."""
    return truncnorm.pdf(x, (a - mu) / sigma, (b - mu) / sigma, loc=mu, scale=sigma)

def neg_log_likelihood(params, data, a=0, b=2):
    """Calculate the negative log-likelihood for the truncated normal distribution."""
    mu, sigma = params
    return -np.sum(np.log(truncated_normal_pdf(data, mu, sigma, a, b) + 1e-10))

def get_bounded_gaussian_params(data_series, class_name='Alpha', visualize=True, a=0, b=2):
    """
    Function to determine parameters of a bounded normal distribution
    and visualize the data distribution.
    
    Parameters:
    data_series (pandas.Series): Series of 'alpha' data for a specific class.
    class_name (str, optional): Class name for the plot title.
    visualize (bool): Flag to enable/disable visualization.
    a (float): Lower bound of the distribution.
    b (float): Upper bound of the distribution.
    
    Returns:
    tuple: (mu, sigma) of the bounded normal distribution.
    """
    try:
        # Remove NaN values
        data = data_series.dropna().values
        
        # Check if there's enough data
        if len(data) < 2:
            print("Not enough data to estimate parameters.")
            return None, None
        
        # Initial guess for parameters
        initial_mu = np.mean(data)
        initial_sigma = np.std(data)
        
        # Optimize to find parameters
        result = minimize(neg_log_likelihood, [initial_mu, initial_sigma], 
                          args=(data, a, b), 
                          bounds=[(a, b), (1e-5, None)])
        
        mu, sigma = result.x
        
        if visualize:
            plt.figure(figsize=(10, 6))
            
            # Data histogram
            plt.hist(data, bins=30, density=True, alpha=0.7, color='skyblue')
            
            # Plot fitted distribution
            x = np.linspace(data.min(), data.max(), 100)
            y = truncated_normal_pdf(x, mu, sigma, a, b)
            plt.plot(x, y, 'r-', lw=2, label='Fitted distribution')
            plt.axvline(np.median(data), color='g', linestyle='--', lw=2, label='median')
            plt.axvline(mu, color='k', linestyle='--', lw=2, label='μ')
            
            plt.title(f'μ = {mu:.5f}, median = {np.median(data):.5f}, σ = {sigma:.5f}')
            plt.xlabel(class_name)
            plt.ylabel('Density')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
        
        return mu, sigma
    
    except Exception as e:
        print(f"An error occurred while processing the data: {e}")
        return None, None

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree

def confined_circles5(big_df1, rol_stat='rolling_median3_7', radius=5, step=0.25, visualization=False, min_new_points=5):
    # Classify points for confined exp 5 (Fast)
    big_df1.loc[:, 'class'] = 0
    threshold_low = big_df1[rol_stat].quantile(0.07)
    threshold_high = big_df1[rol_stat].quantile(0.80)
    big_df1.loc[big_df1[rol_stat] <= threshold_low, 'class'] = 1
    big_df1.loc[big_df1[rol_stat] >= threshold_high, 'class'] = 2

    # Filter points of class 1 and class 2
    class_1_points = big_df1[big_df1['class'] == 1]
    class_2_points = big_df1[big_df1['class'] == 2]

    # Create KD-trees for fast neighbor search
    tree_class_1 = cKDTree(class_1_points[['x', 'y']])
    tree_class_2 = cKDTree(class_2_points[['x', 'y']])

    # Create grid points for iteration
    x_min, x_max = big_df1['x'].min(), big_df1['x'].max()
    y_min, y_max = big_df1['y'].min(), big_df1['y'].max()
    x_range = np.arange(x_min, x_max + step, step)
    y_range = np.arange(y_min, y_max + step, step)
    grid_points = np.array(np.meshgrid(x_range, y_range)).T.reshape(-1, 2)

    # Calculate number of points for each potential circle
    circle_data = []
    for cx, cy in grid_points:
        covered_class_1_indices = tree_class_1.query_ball_point([cx, cy], radius)
        covered_class_2_indices = tree_class_2.query_ball_point([cx, cy], radius)
        covered_class_1_points = len(covered_class_1_indices)
        covered_class_2_points = len(covered_class_2_indices)
        score = covered_class_1_points - covered_class_2_points ** 2
        if score >= min_new_points:
            circle_data.append((cx, cy, score, set(covered_class_1_indices)))

    # Sort circles by number of covered points
    circle_data.sort(key=lambda x: x[2], reverse=True)

    # Select circles
    selected_centers = []
    covered_points = set()

    def circles_overlap(c1, c2, threshold=1.5):
        return np.sqrt((c1[0] - c2[0])**2 + (c1[1] - c2[1])**2) < threshold * radius

    for cx, cy, _, indices in circle_data:
        # Check for overlap with existing circles
        if any(circles_overlap((cx, cy), center) for center in selected_centers):
            continue
        
        new_points = indices - covered_points
        if len(new_points) >= min_new_points:
            selected_centers.append((cx, cy))
            covered_points.update(new_points)

    # Visualization of circles (if enabled)
    if visualization:
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(big_df1['x'], big_df1['y'], c=big_df1['class'], cmap='viridis', alpha=0.6)
        plt.colorbar(scatter, label='Class')
        plt.title('Circles Covering Maximum Class 1 Points and Excluding Class 2 Points')
        plt.xlabel('x')
        plt.ylabel('y')
        for center in selected_centers:
            circle = plt.Circle((center[0], center[1]), radius, fill=False, color='red')
            plt.gca().add_artist(circle)
        plt.show()

    # Assign class 1 to points falling within the circles
    big_df1.loc[:, 'class'] = 0
    tree_all_points = cKDTree(big_df1[['x', 'y']])
    for center in selected_centers:
        indices = tree_all_points.query_ball_point(center, radius)
        big_df1.loc[big_df1.index[indices], 'class'] = 1

    print(f"Found {len(selected_centers)} circle centers satisfying conditions.")

    # Apply fill_short_segments to smooth out classification
    labels = fill_short_segments(big_df1['class'], minseq=7, fill_value=1)
    labels = fill_short_segments(labels, minseq=7, fill_value=0)

    return labels

In [None]:
def confined_circles(big_df1, rol_stat='rolling_median3_7', radius=5, visualize=False):
    
    # Classify points for confined exp 6 (very slow. TODO use algorithm confined_circles5)
    big_df1.loc[:,'class'] = 0
    threshold_low = big_df1[rol_stat].quantile(0.07)
    threshold_high = big_df1[rol_stat].quantile(0.93)
    big_df1.loc[big_df1[rol_stat] <= threshold_low, 'class'] = 1
    big_df1.loc[big_df1[rol_stat] >= threshold_high, 'class'] = 2


    # Filter points of class 1 and class 2
    class_1_points = big_df1[big_df1['class'] == 1]
    class_2_points = big_df1[big_df1['class'] == 2]
    class_01_points = big_df1[big_df1['class'] < 2]

    # Iteratively add circles
    selected_centers = []
    selected_radii = []
    while len(selected_centers) < 20 and not class_1_points.empty:
        max_covered_points = -1000
        best_center = None
        print(f'centers {len(selected_centers)}')

        # Iterate through all points of class 0 and 1 to choose the optimal circle center
        for _, point in class_01_points.iterrows():
            cx, cy = point['x'], point['y']
            # Calculate distances to all points of class 1 and class 2
            distances_to_class_1 = np.sqrt((class_1_points['x'] - cx)**2 + (class_1_points['y'] - cy)**2)
            distances_to_class_2 = np.sqrt((class_2_points['x'] - cx)**2 + (class_2_points['y'] - cy)**2)
            # Count class 1 points within the circle
            covered_class_1_points = np.sum(distances_to_class_1 <= radius)
            # Count class 2 points within the circle
            covered_class_2_points = np.sum(distances_to_class_2 <= radius)
            # Greedy algorithm: choose the circle that covers the maximum number of class 1 points and minimum number of class 2 points
            if covered_class_1_points - covered_class_2_points * 5 > max_covered_points:
                max_covered_points = covered_class_1_points - covered_class_2_points * 5
                best_center = (cx, cy)

        if best_center:
            selected_centers.append(best_center)
            selected_radii.append(radius)
            # Remove covered class 1 points
            distances = np.sqrt((class_1_points['x'] - best_center[0])**2 + (class_1_points['y'] - best_center[1])**2)
            class_1_points = class_1_points[distances >= radius]
            if max_covered_points < 3:
                break

    if len(selected_centers) < 20:
        print(f"Error: Could not find 20 circle centers satisfying conditions. Found {len(selected_centers)} centers.")
    else:
        print("Found 20 circle centers satisfying conditions.")


    # Assign class 1 to points falling within the circles
    big_df1.loc[:,'class'] = 0
    for center in selected_centers:
        distances_to_center = np.sqrt((big_df1['x'] - center[0])**2 + (big_df1['y'] - center[1])**2)
        big_df1.loc[distances_to_center <= radius, 'class'] = 1
    
    return fill_short_segments(big_df1['class'], 3)

In [None]:
def div_by_rollstat(big_df1, rol_stat='rolling_mean1_11'):
    """
    Classify points of multy_state model based on rolling statistics and rolling_alpha_15.

    Parameters:
    big_df1 (DataFrame): Input dataframe containing trajectory data.
    rol_stat (str): Column name for the rolling statistic to use. Default is 'rolling_mean1_11'.


    Returns:
    array: Classified and filled array of labels.
    """
    # Define threshold for classification
    extrems = np.array([1.2])

    # Classify points based on rolling statistic and additional condition
    fov_class = ((big_df1[rol_stat] < extrems[0]).astype(int) + 
                 ((big_df1[rol_stat] < extrems[0]) & (big_df1['rolling_alpha_15'] > 1.2)).astype(int))

    # Fill short segments to smooth classification
    return fill_short_segments(fov_class, 6)

In [None]:
def div_by_rollstat9(big_df1):
    """
    Classify points of multy_state model based on rolling statistics and rolling_alpha_15.

    Parameters:
    big_df1 (DataFrame): Input dataframe containing trajectory data.


    Returns:
    array: Classified and filled array of labels.
    """
    # Get initial classification based on a single statistic
    fov_class0 = divide_single_statistic(big_df1)
    
    # Classify points based on rolling alpha value
    fov_class = (big_df1['rolling_alpha_15'] > 1.35).astype(int) + 1 
    
    # Combine classifications if fov_class0 has non-zero values
    if fov_class0.max() > 0:
        fov_class = fov_class * fov_class0
    
    # Fill short segments to smooth classification
    return fill_short_segments(fov_class, 6)

In [None]:
def check_conf(big_df1, rol_stat='rolling_median3_9', radius=9, step=5.0, visualization=False):
    """
    Check for confined motion in trajectory data.

    Parameters:
    big_df1 (DataFrame): Input dataframe containing trajectory data.
    rol_stat (str): Column name for the rolling statistic. Default is 'rolling_median3_15'.
    radius (float): Radius for circular neighborhood search. Default is 9.
    step (float): Step size for grid creation. Default is 5.0.
    visualization (bool): Flag for visualization (unused in this function). Default is False.

    Returns:
    bool: True if confined motion is detected, False otherwise.
    """
    big_df1 = big_df1.copy()
    big_df1.loc[:, 'class'] = 0
    print("Points in fov = ", big_df1['class'].count())

    # Classify points based on rolling statistic
    threshold_low = big_df1[rol_stat].nsmallest(1000).iloc[-1]
    threshold_high = big_df1[rol_stat].nlargest(1000).iloc[-1]
    big_df1.loc[big_df1[rol_stat] <= threshold_low, 'class'] = 1
    big_df1.loc[big_df1[rol_stat] >= threshold_high, 'class'] = 2

    # Filter points of class 1 and class 2
    class_1_points = big_df1[big_df1['class'] == 1]
    class_2_points = big_df1[big_df1['class'] == 2]

    # Create KD-trees for fast neighbor search
    tree_class_1 = cKDTree(class_1_points[['x', 'y']])
    tree_class_2 = cKDTree(class_2_points[['x', 'y']])

    # Create grid points for iteration
    x_min, x_max = big_df1['x'].min() + radius, big_df1['x'].max() - radius
    y_min, y_max = big_df1['y'].min() + radius, big_df1['y'].max() - radius
    x_range = np.arange(x_min, x_max + step, step)
    y_range = np.arange(y_min, y_max + step, step)
    grid_points = np.array(np.meshgrid(x_range, y_range)).T.reshape(-1, 2)

    # Calculate number of points for each potential circle
    circle1, circle2 = [], []
    for cx, cy in grid_points:
        covered_class_1_indices = tree_class_1.query_ball_point([cx, cy], radius)
        covered_class_2_indices = tree_class_2.query_ball_point([cx, cy], radius)
        circle1.append(len(covered_class_1_indices))
        circle2.append(len(covered_class_2_indices))

    # Sort circles by number of covered points
    circle1.sort(reverse=False)
    circle2.sort(reverse=False)

    # Find first non-zero count for each class
    first_non_zero1 = next(i for i, v in enumerate(circle1) if v > 0)
    first_non_zero2 = next(i for i, v in enumerate(circle2) if v > 0)
    print(first_non_zero1,first_non_zero2)
    # Check if the minimum of first non-zero counts is greater than 130 threshold
    if min(first_non_zero1, first_non_zero2) > 90:
        return True
    return False


In [None]:
def check_dim(big_df1):
    """
    Check dimerization.

    :param big_df1: DataFrame containing required columns ('dist', 'rolling_mean1_9', 'rolling_median1_9')
    :return: True if dimerization conditions are met, False otherwise
    """
    # Extract 'dist' and 'rolling_mean1_9' columns from the dataframe
    x = big_df1.dist.values
    y = big_df1.rolling_mean1_9.values
    
    bin_width = 0.5
    bins = np.arange(start=x.min(), stop=x.max() + bin_width, step=bin_width)
    indmax = len(bins) - 2

    # Divide x values into bins and calculate mean y values for each bin
    indices = np.digitize(x, bins)  # Find bin indices for each x value
    y_0 = np.nanmean(y[indices == 1])  # Mean y value for the first bin
    y_m = np.nanmean(y[indices == indmax])  # Mean y value for the last bin
    y_tr = (y_0 + y_m) * 0.5  # Threshold y value

    # Calculate mean y values for all bins
    y_means = [np.nanmean(y[indices == i]) for i in range(1, len(bins))]
    y_means = np.array(y_means)
    y_means = np.nan_to_num(y_means, nan=0.0)  # Replace NaN with 0.0

    if y_0 < y_m:
        # Find r_min where y_means exceeds y_tr
        r_min = bins[np.argmax(y_means[1:] > y_tr)+1] if np.any(np.array(y_means) > y_tr) else None
        
        # Calculate percentage of points with dist < r_min
        dist_part = fill_short_segments((big_df1['dist'] < r_min) + 0, 3).sum() / big_df1['dist'].count()
        print(f"dist% {dist_part}")
        
        # Calculate percentage of points with rolling_median1_9 < y_tr
        y_part = fill_short_segments((big_df1['rolling_median1_9'] < y_tr) + 0, 3).sum() / big_df1['dist'].count()
        print(f"div% {y_part}")
        
        # Calculate percentage of points satisfying both conditions
        dist_and_y_part = fill_short_segments((((big_df1['dist'] < r_min) + 0 + (big_df1['rolling_median1_9'] < y_tr) + 0) > 1) + 0, 3).sum() / big_df1['dist'].count()
        print(f"d+r% {dist_and_y_part}")
        
    else:
        # Similar calculations as above, but with reversed inequality signs
        r_min = bins[np.argmax(y_means[1:] < y_tr)+1] if np.any(np.array(y_means) < y_tr) else None
        dist_part = fill_short_segments((big_df1['dist'] < r_min) + 0, 3).sum() / big_df1['dist'].count()
        print(f"dist% {dist_part}")
        y_part = fill_short_segments((big_df1['rolling_median1_9'] > y_tr) + 0, 3).sum() / big_df1['dist'].count()
        print(f"div% {y_part}")
        dist_and_y_part = fill_short_segments((((big_df1['dist'] < r_min) + 0 + (big_df1['rolling_median1_9'] > y_tr) + 0) > 1) + 0, 3).sum() / big_df1['dist'].count()
        print(f"d+r% {dist_and_y_part}")

    # Return True if conditions are met, otherwise False
    if dist_and_y_part > 0.03 and dist_and_y_part > y_part * 0.33:
        return True
    return False

In [None]:
def check_trap(df, statistic_type='median', statistic_name='15', statistic_density=False, num_bins=30, min_exp=-6, max_exp=4, peak_height=0.05, visualize=False):
    """
    Check if trap conditions are met based on histogram peaks.

    :param df: DataFrame containing the data
    :param statistic_type: Type of statistic to consider ('median', 'mean', etc.)
    :param statistic_name: Name of the column in the DataFrame to analyze
    :param statistic_density: Whether to use histogram density (default: False)
    :param num_bins: Number of bins for the histogram (default: 30)
    :param min_exp: Minimum exponent for log-scale bins (default: -6)
    :param max_exp: Maximum exponent for log-scale bins (default: 4)
    :param peak_height: Minimum height for peak detection (default: 0.05)
    :param visualize: Whether to visualize histograms (default: False)
    :return: True if trap conditions are met, False otherwise
    """
    # Create logarithmically spaced bins
    bins = np.exp(np.linspace(min_exp, max_exp, num_bins))
    
    # Iterate over specific keys
    for k in [1, 3, 4, 2]:
        # Construct column name dynamically
        column_name = 'rolling_' + statistic_type + str(k) + '_' + statistic_name
        data = df[column_name]
        
        # Compute histogram with specified bins and density option
        hist_normal, bin_edges = np.histogram(data, bins=bins, density=statistic_density)
        
        # Find peaks in the histogram
        peaks_normal, _ = find_peaks(hist_normal, height=max(hist_normal) * peak_height)

        # Check conditions based on detected peaks
        if len(peaks_normal) != 2:
            return False
        
        if bins[peaks_normal[0]] > 0.05:
            return False

    return True

In [None]:
def add_min_distance_column(df):
    # Create a copy of the DataFrame for updating
    updated_df = df.copy()
    
    # Add a 'dist' column with default values of np.nan
    updated_df['dist'] = np.nan

    # Iterate through each frame
    for frame in df['frame'].unique():
        frame_df = df[df['frame'] == frame]  # Get all rows for the current frame
        
        # Iterate through each point in the frame
        for index, row in frame_df.iterrows():
            # Calculate distances to all points in the frame
            distances = np.sqrt((frame_df['x'] - row['x'])**2 + (frame_df['y'] - row['y'])**2)
            
            # Set the distance to itself as 8 to exclude it from consideration
            distances[index] = 8
            
            # Find the minimum distance to another point
            min_distance = distances.min()
            
            # Update the 'dist' column with the minimum distance
            updated_df.at[index, 'dist'] = min_distance
    
    return updated_df

In [None]:
# Define the number of experiments, number of FOVs (Field of Views), and number of frames
N_EXP = 12 
N_FOVS = 30
N_FRAMES = 200

# We only track 2 in this example
track = 2

# Predefined model type predictions for each experiment
model_type_predictions = []

import os

# Set up the results directory
path_results = '/kaggle/working/res/'
if not os.path.exists(path_results):
    os.makedirs(path_results)

# Create the folder for the specific track if it doesn't exist
path_track = path_results + f'track_{track}/'
if not os.path.exists(path_track):
    os.makedirs(path_track)


# Loop through each experiment
for exp in range(N_EXP):
    # Create a directory for each experiment
    path_exp = path_track + f'exp_{exp}/'
    if not os.path.exists(path_exp):
        os.makedirs(path_exp)
    
    # Initialize an empty DataFrame to store all FOV data for this experiment
    big_df1 = pd.DataFrame()

    # Loop through each Field of View (FOV)
    for fov in range(N_FOVS):
        # Read the corresponding csv file from the public data
        df = pd.read_csv(public_data_path+f'track_2/exp_{exp}/trajs_fov_{fov}.csv')
        df.sort_values(by=['traj_idx', 'frame'], inplace=True)
        df['fov'] = fov

        # Add minimum distance column
        df = add_min_distance_column(df)

        # Calculate differences in x and y coordinates for various time lags
        df['x_diff'] = df.groupby('traj_idx')['x'].diff(-1).ffill().bfill()
        df['y_diff'] = df.groupby('traj_idx')['y'].diff(-1).ffill().bfill()

        df['x_diff2'] = df.groupby('traj_idx')['x'].diff(-2).shift(-1).ffill().bfill()
        df['y_diff2'] = df.groupby('traj_idx')['y'].diff(-2).shift(-1).ffill().bfill()

        df['x_diff3'] = df.groupby('traj_idx')['x'].diff(-3).shift(-1).ffill().bfill()
        df['y_diff3'] = df.groupby('traj_idx')['y'].diff(-3).shift(-1).ffill().bfill()

        df['x_diff4'] = df.groupby('traj_idx')['x'].diff(-4).shift(-2).ffill().bfill()
        df['y_diff4'] = df.groupby('traj_idx')['y'].diff(-4).shift(-2).ffill().bfill()
        
        # Calculate Mean Squared Displacements (MSD) for different time lags
        df['msd1'] = (df['x_diff'] ** 2 + df['y_diff'] ** 2) / 4
        df['msd2'] = (df['x_diff2'] ** 2 + df['y_diff2'] ** 2) / 4
        df['msd3'] = (df['x_diff3'] ** 2 + df['y_diff3'] ** 2) / 4
        df['msd4'] = (df['x_diff4'] ** 2 + df['y_diff4'] ** 2) / 4
        
        # Calculate rolling medians and means for different window sizes
        df['rolling_median1_7'] = df.groupby('traj_idx')['msd1'].rolling(window=7, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_median1_9'] = df.groupby('traj_idx')['msd1'].rolling(window=9, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean1_9'] = df.groupby('traj_idx')['msd1'].rolling(window=9, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median1_11'] = df.groupby('traj_idx')['msd1'].rolling(window=11, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean1_11'] = df.groupby('traj_idx')['msd1'].rolling(window=11, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median1_15'] = df.groupby('traj_idx')['msd1'].rolling(window=15, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean1_15'] = df.groupby('traj_idx')['msd1'].rolling(window=15, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_mean1'] = df.groupby('traj_idx')['msd1'].rolling(window=5, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median1'] = df.groupby('traj_idx')['msd1'].rolling(window=5, min_periods=1, center=True).median().reset_index(level=0, drop=True)

        # Repeat rolling calculations for msd2
        df['rolling_median2_7'] = df.groupby('traj_idx')['msd2'].rolling(window=7, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_median2_9'] = df.groupby('traj_idx')['msd2'].rolling(window=9, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean2_9'] = df.groupby('traj_idx')['msd2'].rolling(window=9, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median2_11'] = df.groupby('traj_idx')['msd2'].rolling(window=11, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean2_11'] = df.groupby('traj_idx')['msd2'].rolling(window=11, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median2_15'] = df.groupby('traj_idx')['msd2'].rolling(window=15, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean2_15'] = df.groupby('traj_idx')['msd2'].rolling(window=15, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_mean2'] = df.groupby('traj_idx')['msd2'].rolling(window=5, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median2'] = df.groupby('traj_idx')['msd2'].rolling(window=5, min_periods=1, center=True).median().reset_index(level=0, drop=True)

        # Repeat rolling calculations for msd3
        df['rolling_median3_7'] = df.groupby('traj_idx')['msd3'].rolling(window=7, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_median3_9'] = df.groupby('traj_idx')['msd3'].rolling(window=9, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean3_9'] = df.groupby('traj_idx')['msd3'].rolling(window=9, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median3_11'] = df.groupby('traj_idx')['msd3'].rolling(window=11, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean3_11'] = df.groupby('traj_idx')['msd3'].rolling(window=11, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median3_15'] = df.groupby('traj_idx')['msd3'].rolling(window=15, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean3_15'] = df.groupby('traj_idx')['msd3'].rolling(window=15, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_mean3'] = df.groupby('traj_idx')['msd3'].rolling(window=5, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median3'] = df.groupby('traj_idx')['msd3'].rolling(window=5, min_periods=1, center=True).median().reset_index(level=0, drop=True)

        # Repeat rolling calculations for msd4
        df['rolling_median4_7'] = df.groupby('traj_idx')['msd4'].rolling(window=7, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_median4_9'] = df.groupby('traj_idx')['msd4'].rolling(window=9, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean4_9'] = df.groupby('traj_idx')['msd4'].rolling(window=9, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median4_11'] = df.groupby('traj_idx')['msd4'].rolling(window=11, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean4_11'] = df.groupby('traj_idx')['msd4'].rolling(window=11, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median4_15'] = df.groupby('traj_idx')['msd4'].rolling(window=15, min_periods=1, center=True).median().reset_index(level=0, drop=True)
        df['rolling_mean4_15'] = df.groupby('traj_idx')['msd4'].rolling(window=15, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_mean4'] = df.groupby('traj_idx')['msd4'].rolling(window=5, min_periods=1, center=True).mean().reset_index(level=0, drop=True)
        df['rolling_median4'] = df.groupby('traj_idx')['msd4'].rolling(window=5, min_periods=1, center=True).median().reset_index(level=0, drop=True)

        # Calculate rolling alpha
        df['rolling_alpha'] = np.log(df['rolling_median4_15'] / df['rolling_median1_15']) / np.log(4)
        
        # Concatenate the current dataframe to the main dataframe
        big_df1 = pd.concat([big_df1, df], ignore_index=True)

    # Convert 'frame' and 'traj_idx' columns to int32
    big_df1.frame = big_df1.frame.astype(np.int32)
    big_df1.traj_idx = big_df1.traj_idx.astype(np.int32)

    # Initialize 'class' column
    big_df1['class'] = 0

    print(f"exp={exp}")
    
    # Apply Anderson statistic to msd1, msd2, msd3, and msd4 and gather results
    anderson_results1 = big_df1.groupby(['traj_idx', 'fov'])['msd1'].apply(anderson_statistic).dropna().values
    anderson_results2 = big_df1.groupby(['traj_idx', 'fov'])['msd2'].apply(anderson_statistic).dropna().values
    anderson_results3 = big_df1.groupby(['traj_idx', 'fov'])['msd3'].apply(anderson_statistic).dropna().values
    anderson_results4 = big_df1.groupby(['traj_idx', 'fov'])['msd4'].apply(anderson_statistic).dropna().values
    
    # Print Anderson statistic results
    print(f"msd1 anderson_results1 mean {anderson_results1.mean()}  std {anderson_results1.std()}   max {anderson_results1.max()}  q95 {np.quantile(anderson_results1, 0.95)}")
    print(f"msd2 anderson_results2 mean {anderson_results2.mean()}  std {anderson_results2.std()}   max {anderson_results2.max()}  q95 {np.quantile(anderson_results2, 0.95)}")
    print(f"msd3 anderson_results3 mean {anderson_results3.mean()}  std {anderson_results3.std()}   max {anderson_results3.max()}  q95 {np.quantile(anderson_results3, 0.95)}")
    print(f"msd4 anderson_results4 mean {anderson_results4.mean()}  std {anderson_results4.std()}   max {anderson_results4.max()}  q95 {np.quantile(anderson_results4, 0.95)}")

    # Predict the model type
    model_type_prediction = 'multi_state'  # Default model
    if check_conf(big_df1[big_df1.fov == 1]):
        model_type_prediction = 'confinement'
    if np.quantile(anderson_results1, 0.95) < 2.5 and np.quantile(anderson_results2, 0.95) < 2.5 and np.quantile(anderson_results3, 0.95) < 2.5:
        model_type_prediction = 'single_state'
    if check_trap(big_df1):
        model_type_prediction = 'immobile_traps'
    if check_dim(big_df1):
        model_type_prediction = 'dimerization'

    print(f"Model type predicted: {model_type_prediction}")
    model_type_predictions.append(model_type_prediction)

    
    # Re-initialize 'class' column
    big_df1['class'] = 0
    
    # Calculate rolling alpha for msd3
    big_df1['rolling_alpha_15'] = np.log(big_df1['rolling_median3_15'] / big_df1['rolling_median1_15']) / np.log(3)

                                                                      
    if model_type_prediction == 'dimerization':
        x = big_df1.dist.values
        y = big_df1.rolling_mean1_9.values

        bin_width = 0.5
        bins = np.arange(start=x.min(), stop=x.max() + bin_width, step=bin_width)
        indmax = len(bins) - 2
        
        
        indices = np.digitize(x, bins)  # Находит индексы корзин для каждого значения в x
        y_0 = np.nanmean(y[indices == 1])
        y_m = np.nanmean(y[indices == indmax])
        y_tr = (y_0 + y_m) * 0.5
        # Splitting x values into bins and calculating the average y values for each bin
        indices = np.digitize(x, bins)  # Finds bin indices for each value in x
        y_0 = np.nanmean(y[indices == 1])
        y_m = np.nanmean(y[indices == indmax])
        y_tr = (y_0 + y_m) * 0.5
        indices = np.digitize(x, bins)  # Finds bin indices for each value in x        
        y_means = [np.nanmean(y[indices == i]) for i in range(1, len(bins))]
        y_means = np.array(y_means)
        y_means = np.nan_to_num(y_means, nan=0.0)
        plt.bar(bins[:-1], y_means, width=bin_width, align='edge', color='skyblue')

        # Adding labels and title
        plt.xlabel('x values')
        plt.ylabel('Average y values')
        plt.title('Histogram of average y values for x bins')

        # Displaying the histogram
        plt.show()            

        bin_width = 0.2
        bins = np.arange(start=x.min(), stop=x.max() + bin_width, step=bin_width)

        # Splitting x values into bins and calculating the average y values for each bin
        indices = np.digitize(x, bins)  # Finds bin indices for each value in x
        y_means = [np.nanmean(y[indices == i]) for i in range(1, len(bins))]

        # Removing NaN values for bins with no x elements
        y_means = np.array(y_means)
        y_means = np.nan_to_num(y_means, nan=0.0)

        if y_0 < y_m:
            r_min = bins[np.argmax(y_means[1:] > y_tr) + 1] if np.any(np.array(y_means) > y_tr) else None
            new_col = (((x < r_min) + 0 + (y < y_tr) + 0) == 2) + 0
        else:
            r_min = bins[np.argmax(y_means[1:] < y_tr) + 1] if np.any(np.array(y_means) < y_tr) else None
            new_col = (((x < r_min) + 0 + (y > y_tr) + 0) == 2) + 0
        big_df1['class'] = fill_short_segments_np(1 - new_col, 3)

        plt.bar(bins[:-1], y_means, width=bin_width, align='edge', color='skyblue')
        print(f'r_min={r_min}, y_tr={y_tr}')
        # Adding labels and title
        plt.xlabel('x values')
        plt.ylabel('Average y values')
        plt.title('Histogram of average y values for x bins')

        # Displaying the histogram
        plt.show() 

    elif model_type_prediction == 'multi_state':
        for fov in range(N_FOVS):
            print(fov)
            if exp == 10:
                fov_class = div_by_rollstat(big_df1[big_df1['fov']==fov])
            else:
                fov_class = div_by_rollstat9(big_df1[big_df1['fov']==fov])
            big_df1.loc[big_df1['fov'] == fov, 'class'] = fov_class
    #         visualize_rolling_median(big_df1[big_df1['fov'] == fov])       

    elif model_type_prediction == 'immobile_traps':
        for fov in range(N_FOVS):
            print(fov)
            fov_class = divide_single_statistic_1only(big_df1[big_df1['fov'] == fov])
            big_df1.loc[big_df1['fov'] == fov, 'class'] = fov_class

    #         visualize_rolling_median(big_df1[big_df1['fov'] == fov]) 

    elif model_type_prediction == 'confinement' and exp == 5:
    
        for fov in range(N_FOVS):
            print(fov)
            fov_class = confined_circles5(big_df1[big_df1['fov']==fov], rol_stat='rolling_alpha', visualization = True, radius=7)   
            big_df1.loc[big_df1['fov']==fov, 'class'] = fov_class
    
    elif model_type_prediction == 'confinement':
        for fov in range(N_FOVS):
            fov_class = confined_circles(big_df1[big_df1['fov']==fov], rol_stat='rolling_median3_9')
            big_df1.loc[big_df1['fov'] == fov, 'class'] = fov_class

    #         visualize_rolling_median(big_df1[big_df1['fov'] == fov]) 

    print(f"big_df1['class'].unique() {big_df1['class'].unique()} ")
    plt.hist(big_df1['rolling_alpha'], bins=100, range=[-1, 3], density=False)
    plt.show()
    visualize_traj(big_df1[big_df1['fov'] == 0])
    visualize_rolling_median(big_df1[big_df1['fov'] == 0])

    big_df1.reset_index(drop=True, inplace=True)
    xy_norm = np.median([big_df1['x_diff'].abs().mean(), big_df1['y_diff'].abs().mean()])
    big_df1.x /= xy_norm
    big_df1.y /= xy_norm
    big_df1['x_diff'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['x'].diff(-1).ffill().bfill()
    big_df1['y_diff'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['y'].diff(-1).ffill().bfill()
    big_df1['x_diff2'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['x'].diff(-2).shift(-1).ffill().bfill()
    big_df1['y_diff2'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['y'].diff(-2).shift(-1).ffill().bfill()
    big_df1['x_diff3'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['x'].diff(-3).shift(-1).ffill().bfill()
    big_df1['y_diff3'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['y'].diff(-3).shift(-1).ffill().bfill()
    big_df1['x_diff4'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['x'].diff(-4).shift(-2).ffill().bfill()
    big_df1['y_diff4'] = big_df1.groupby(['traj_idx', 'fov', 'class'])['y'].diff(-4).shift(-2).ffill().bfill()

    big_df1['msd1'] = (big_df1['x_diff'] ** 2 + big_df1['y_diff'] ** 2) / 4
    big_df1['msd2'] = (big_df1['x_diff2'] ** 2 + big_df1['y_diff2'] ** 2) / 4
    big_df1['msd3'] = (big_df1['x_diff3'] ** 2 + big_df1['y_diff3'] ** 2) / 4
    big_df1['msd4'] = (big_df1['x_diff4'] ** 2 + big_df1['y_diff4'] ** 2) / 4

    msd1 = pd.pivot_table(big_df1, values='msd1', index=['fov', 'traj_idx', 'class'], columns='frame', dropna=False).reset_index()
    msd2 = pd.pivot_table(big_df1, values='msd2', index=['fov', 'traj_idx', 'class'], columns='frame', dropna=False).reset_index()
    msd3 = pd.pivot_table(big_df1, values='msd3', index=['fov', 'traj_idx', 'class'], columns='frame', dropna=False).reset_index()
    msd4 = pd.pivot_table(big_df1, values='msd4', index=['fov', 'traj_idx', 'class'], columns='frame', dropna=False).reset_index()
    
    
    #################### The KEY POINT of our METHOD ##############################################################
    
    # Calculate K as median for frames range and divide by log(2)
    msd1['K'] = msd1.loc[:, range(1, N_FRAMES)].median(axis=1) / np.log(2)

    # Calculate alpha using RowHurst function
    msd1['alpha'] = RowHurst(msd3.loc[:, range(1, N_FRAMES)].values, msd1.loc[:, range(1, N_FRAMES)].values, 3)

    ################## The End of key point of our method ############################################################

    
    msd1['cnt'] = msd1.loc[:, range(1, N_FRAMES)].count(axis=1)

    
    file_name = path_exp + 'ensemble_labels.txt'

    with open(file_name, 'a') as f:
        # Write model type and number of unique classes
        f.write(f"model: {model_type_prediction}; num_state: {len(big_df1['class'].unique())} \n")
        data = np.random.rand(5, len(big_df1['class'].unique()))

        for col in big_df1['class'].unique():
            print(f'col={col}')
            # Calculate Anderson-Darling statistics for each MSD
            anderson_results1 = big_df1[big_df1['class']==col].groupby(['traj_idx', 'fov'])['msd1'].apply(anderson_statistic).dropna().values
            anderson_results2 = big_df1[big_df1['class']==col].groupby(['traj_idx', 'fov'])['msd2'].apply(anderson_statistic).dropna().values
            anderson_results3 = big_df1[big_df1['class']==col].groupby(['traj_idx', 'fov'])['msd3'].apply(anderson_statistic).dropna().values
            anderson_results4 = big_df1[big_df1['class']==col].groupby(['traj_idx', 'fov'])['msd4'].apply(anderson_statistic).dropna().values

            # Print mean, standard deviation, max, and 95th percentile for each MSD's Anderson-Darling statistics
            print(f"msd1 anderson_results1 mean {anderson_results1.mean()}  std {anderson_results1.std()}   max {anderson_results1.max()}  q95 {np.quantile(anderson_results1, 0.95)}")
            print(f"msd2 anderson_results2 mean {anderson_results2.mean()}  std {anderson_results2.std()}   max {anderson_results2.max()}  q95 {np.quantile(anderson_results2, 0.95)}")
            print(f"msd3 anderson_results1 mean {anderson_results3.mean()}  std {anderson_results3.std()}   max {anderson_results3.max()}  q95 {np.quantile(anderson_results3, 0.95)}")
            print(f"msd4 anderson_results1 mean {anderson_results4.mean()}  std {anderson_results4.std()}   max {anderson_results4.max()}  q95 {np.quantile(anderson_results4, 0.95)}")                

            # Calculate the proportion of the current class
            N2 = big_df1[big_df1['class']==col]['class'].count()            
            N1 = big_df1['class'].count()         
            print(f'col = {col}, % = {100*N2/N1:2f}')   

            # Get dominant Gaussian parameters for K
            K2, K_var2 = get_dominant_gaussian_params(msd1[msd1['class']==col]['K'], 2, 'K')
            K = msd1[msd1['class']==col]['K'].median()
            K3 = big_df1[big_df1['class']==col]['msd1'].median() / np.log(2)
            K_var3 = big_df1[big_df1['class']==col]['rolling_median3_9'].var() / 3
            K_var = msd1[msd1['class']==col]['K'].var() / 2

            # Calculate alpha and its variance
            alpha3 = msd1[msd1['class']==col]['alpha'].median()
            alpha_var3 = np.nanmean((msd1[msd1['class']==col]['alpha']-alpha3) ** 2) / 3
            if alpha3 > 0.6 and alpha3 < 1.4:
                alpha = alpha3
                alpha_var = alpha_var3
            else:

                alpha_med2, alpha1_var2 = get_dominant_gaussian_params(msd1[msd1['class']==col]['alpha'], 2, 'alpha')

                alpha_med, alpha1_var = get_bounded_gaussian_params(msd1[msd1['class']==col]['alpha'])
                alpha_var = (alpha1_var ** 2) / 3
                alpha = (alpha_med + alpha_med2) / 2  



            print(f'alpha3= {alpha3:4f}   alpha_var3={alpha_var3:4f}   K3={K3:4f}  K_var3={K_var3:4f}') 
            print(f'alpha= {alpha:4f}    alpha_var={alpha_var:4f}    K={K:4f}   K_var={K_var:4f}') 

            alpha = np.clip(alpha,0.00,1.999)
            
            # Store calculated data

            data[2,col] = K * xy_norm**2
            data[3,col] = K_var * xy_norm**2
            data[0,col] = alpha
            data[1,col] = alpha_var
            data[4,col] = N2 / N1

            if col == 0 and model_type_prediction == 'immobile_traps':
                data[2,col] = 0
                data[3,col] = 0
                data[0,col] = 0
                data[1,col] = 0


        # Save the data in the corresponding ensemble file
        np.savetxt(f, data, delimiter = ';') 
            
            
    # Process data for each field of view (FOV)        
    for fov in range(N_FOVS):            
            
        dfx = msd1[msd1.fov == fov]
        
        df = big_df1[big_df1.fov == fov]
            
        submission_file = path_exp + f'fov_{fov}.txt'
        traj_idx = df.traj_idx.unique()
        
        with open(submission_file, 'a') as f:
            # Loop over each trajectory index
            for idx in traj_idx:
                
                # Get the lenght of the trajectory
                length_traj = df[df.traj_idx == idx].shape[0]
              
                traj_x = df[df.traj_idx == idx]['class'].values.astype(int)
                # Array to save results
                prediction_traj = [int(idx)]

                # Initialize variables
                current_class = traj_x[0]  # Initial class value
                start_index = 0  # Start index of the current segment

                # Iterate over all elements in traj_x
                for i in range(1, len(traj_x)):
                    # Check if the class has changed compared to the previous value
                    if traj_x[i] != current_class:
                        # Determine the length of the segment
                        segment_length = i - start_index 

                        
                        Kt = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['K'].max()  * xy_norm**2
                        Nt = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['cnt'].max()
                        if not Kt:
                            Kt = data[2,current_class]
                        elif current_class==0 and model_type_prediction == 'immobile_traps':
                            Kt = data[2,current_class]
                        else:
                            Kt = (Kt * Nt + (400-Nt) * data[2,current_class]) / 400
                            
                        alphat = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['alpha'].max()  
                        if not alphat:
                            alphat = data[0,current_class]

                        if current_class==0 and model_type_prediction == 'immobile_traps':
                            alphat = data[0,current_class]
                        else:
                            alphat = (alphat * Nt + (400-Nt) * data[0,current_class]) / 400
                        alphat = np.clip(alphat,0.00,1.99)
                        
                        prediction_traj += [Kt, alphat, get_freedom(current_class, alphat), i+1]

                        # Update variables for the new segment
                        current_class = traj_x[i]
                        start_index = i  # Update the start index of the new segment

                # Process the last segment
                Kt = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['K'].max()  * xy_norm**2
                Nt = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['cnt'].max()
                if not Kt:
                    Kt = data[2,current_class]
                elif current_class==0 and model_type_prediction == 'immobile_traps':
                    Kt = data[2,current_class]
                else:
                    Kt = (Kt * Nt + (400-Nt) * data[2,current_class]) / 400
                    
                alphat = msd1[(msd1.traj_idx == idx)  & (msd1.fov == fov)  & (msd1['class'] == current_class)]['alpha'].max() 
                if not alphat:
                    alphat = data[0,current_class]
                elif current_class==0 and model_type_prediction == 'immobile_traps':
                    alphat = data[0,current_class]
                else:
                    alphat = (alphat * Nt + (400-Nt) * data[0,current_class]) / 400
                alphat = np.clip(alphat,0.00,1.99)
                segment_length = len(traj_x) - start_index
                prediction_traj += [Kt, alphat, get_freedom(current_class, alphat), length_traj]


                # Format and write the results

                formatted_numbers = ','.join(map(str, prediction_traj))
                f.write(formatted_numbers + '\n')



In [None]:
print(f"Model type predicted: {model_type_predictions}")

In [None]:
import shutil
shutil.make_archive("/kaggle/working/res", 'zip', "/kaggle/working/res")
shutil.rmtree('/kaggle/working/res')