# This Notebook shows how to do AI CTRW classification

### Run the following cells to get all the data 

In [None]:
# IPython magic commands
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Standard library imports
import os
import sys

# Third-party numerical/scientific imports
import numpy as np
import scipy
from scipy.optimize import curve_fit

# Data analysis and tracking
import trackpy as tp
import yaml

# Image processing
from PIL import Image

# Plotting and visualization
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation
from matplotlib.ticker import MultipleLocator
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib.patches import Patch

# Interactive widgets
import ipywidgets as widgets

# TensorFlow/Keras
from tensorflow.keras.models import load_model

# Custom modules
from analysis_stm import MotionAnalyzer, ExpMetaData, DiffusionPlotter
from sxmreader import SXMReader
import frame_correct as fc
from monet.src.classification_net_training_input import create_ctrw_model
import keras_utils as ku

In [None]:
#Sets some parameters
frame_drift_par={
                   'steps':2,
                   'adjust':[[],[]],
                   'output_crop':250,                      
}

### Sets up parameters and file path to be used

In [None]:
emds=[]
frame_drift_pars=[]
tears=[]
params_=[]

Vg=60
FOLDER = "test_sets/15.5K"
voltages_temperatures = np.array([15.5])

sets = [ range(138, 156),]
emd = ExpMetaData(sets, Vg, voltages_temperatures, FOLDER)


params = {
    'molecule_size': 7,
    'min_mass': 2,
    'max_mass': 100,
    'min_size': 1,
    'max_ecc': 1,
    'separation': 1,
    'search_range': 50,
    'adaptive_stop': 1,
    'diffusion_time': 10,
    'threshold': 0.12,
}

frame_drift_pars.append([frame_drift_par])
emds.append(emd)
tears.append([[[0,64],[10,61]]]) 
params_.append(params)



In [None]:
# Loads in and analyzes data

ms=[]

for i,emd in enumerate(emds):
    with open('params.yaml', 'w') as f:
        yaml.dump(params_[i], f, default_flow_style=False)
    m = MotionAnalyzer(emd.sets, emd.voltages_temperatures,
                   emd.folder,frame_drift_par=frame_drift_pars[i], heater=True, drift_correction = False,rotation_check=True,correct='lines')
    if tears[i] is not None:
        m.tear_correct=tears[i]
    m.analyze()
    ms.append(m)
dps=[]
for i,m in enumerate(ms):
    dp=DiffusionPlotter(m)
    dps.append(dp)

### This cell plots trajectories 

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig=plt.figure(figsize=(8, 6), dpi=120)
ax1=plt.axes(xlim=(0, 256), ylim=(0, 256), frameon=False)
plt.axis('off')
ln, = ax1.plot([], [], lw=3)
index=0
dataset=ms[index]##CHOOSE DATAFRAME HERE 

temp_set=0
t=dataset.t3s_C[0][temp_set] #DRIFT UNCORRECTED TO PROPERLY MATCH IMAGE
ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])
def animate(i):
    plt.cla()
    plt.title(i)

    tp.plot_traj(t[(t['frame']<=i)], superimpose=dataset.frames[temp_set][i], label=True, ax=ax1, plot_style={'alpha' : 1})
    ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])

line_ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(dataset.frames[temp_set]))
line_ani


## Machine Learning Part

#### If model doesnt exist run this cell to create one


In [None]:
temp=[14,17] #Chooses the temperature range the model is trained on
step=18 #Chooses the amount of steps for the model input
epochs=50 #The amount of training epochs
net_file=f'models/{step}_{temp[0]}_{temp[1]}_ep{epochs}_class_model.h5' #savefile name.
print(net_file[:-3])
if os.path.exists(net_file):
    print('File already exists')
else:
    create_ctrw_model(step,net_file,temp_range=temp,epochs=epochs,)
model=load_model(net_file)

### Run the next 2 cells to define some functions

In [None]:
# Run this, function to choose certain trajectory lengths

def trim_trajectories(t_, steps, crop_from='start'):
    t = t_.copy()
    particle_counts = t.groupby('particle')['frame'].count()
    good_particles = particle_counts[particle_counts >= steps].index
    t = t[t['particle'].isin(good_particles)]

    if crop_from == 'start':
        first_frames = t.groupby('particle')['frame'].min()
        final_frames = first_frames + steps - 1
        final_frames_dict = final_frames.to_dict()
        t = t[t.apply(lambda row: row['frame'] <= final_frames_dict[row['particle']], axis=1)]

    elif crop_from == 'end':
        final_frames = t.groupby('particle')['frame'].max()
        first_frames = final_frames - steps + 1
        first_frames_dict = first_frames.to_dict()
        t = t[t.apply(lambda row: row['frame'] >= first_frames_dict[row['particle']], axis=1)]

    else:
        raise ValueError("crop_from must be either 'start' or 'end'")

    particle_counts_final = t.groupby('particle')['frame'].count()
    assert (particle_counts_final == steps).all(), "Not all trajectories have exactly 'steps' frames!"
    return t

In [None]:
#A check for a weird trackpy bug, should have not particles with discontinuities
def check_frame_discontinuities(t):
    discontinuous_particles = []
    for particle, group in t.groupby('particle'):
        frames = np.sort(group['frame'].values)
        diffs = np.diff(frames)
        if np.any(diffs > 1):
            discontinuous_particles.append((particle, frames, diffs[diffs > 1]))
    return discontinuous_particles
for m in ms:
    discontinuities = check_frame_discontinuities(m.t3s_C[0][0])
    print(f"Particles with discontinuities: {[d[0] for d in discontinuities]}")


### Run the next cell to use the machine learning model

In [None]:
# Does classification and then makes pie chart, saved as png
#classifications are saved in new collumn in data frame t4 in list t4s[]

# Define color map and label mapping
colors = [np.array([1,0,0]), np.array([0,0,1]), np.array([0,1,0])]
diff_class = {0: 'fbm', 1: 'brownian', 2: 'ctrw'}
labels = ['fbm', 'brownian', 'ctrw']
titles = [m_.SXM_PATH[0][0].strip("/").split("/")[1] for m_ in ms]

# Classification function
def classify(df, steps=step, model=model):
    x = df.x.values
    dx = ku.generate_dx(x)
    value, prediction = ku.classification_on_real(dx, steps=steps, model=model)
    df['class'] = diff_class[prediction]
    df['fbm_prob'], df['brownian_prob'], df['ctrw_prob'] = value[0], value[1], value[2]
    return df

# Prepare storage
results_start = {'fbm': {}, 'brownian': {}, 'ctrw': {}}
results_end = {'fbm': {}, 'brownian': {}, 'ctrw': {}}

print(f"Processing model: {net_file}")
for i in range(len(ms)):
    t4s=[]    
    print(f"  Dataset {i}")
    for t3 in ms[i].t3s:
        t4 = trim_trajectories(t3, steps=step, crop_from='end') ## CHANGE END TO START to choose which side to trim trajectories
        t4 = t4.groupby('particle', group_keys=False).apply(classify, steps=step, model=model)    
        t4s.append(t4)
        if np.shape(t4)[0]==0:
            continue
        # Compute class probabilities
        fbm_mean = t4.groupby('particle')['fbm_prob'].mean().mean()
        brownian_mean = t4.groupby('particle')['brownian_prob'].mean().mean()
        ctrw_mean = t4.groupby('particle')['ctrw_prob'].mean().mean()
        # Pie chart
        sizes = [fbm_mean, brownian_mean, ctrw_mean]
        plt.figure(figsize=(6, 6))
        plt.pie(sizes, labels=labels, colors=colors,
                autopct='%1.1f%%', startangle=140, textprops={'fontsize': 20})
        plt.title(f"{titles[i]}", fontsize=25)
        plt.axis('equal')  # Equal aspect ratio ensures pie is round.
        plt.tight_layout()
        plt.savefig(f"{ms[i].ANALYSIS_FOLDER}/{titles[i]}_ctrw.png") #Saves to analysis folder, can change this if you want 
    ms[i].t4s=t4s
plt.show()

### Run the next cell to create movie

In [None]:
# Creates movies with trajectories colored by classification
line_anis=[]
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()

for index in range(len(ms)):
    dataset = ms[index]
    temp_index = 0
    t = t4s[index]
    if np.shape(t)[0] == 0:
        continue
    frames = dataset.frames[temp_index]

    # ---- Assign colors ----
    color_dict = {
        'ctrw_prob': np.array([0, 1, 0]),
        'fbm_prob': np.array([1, 0, 0]),
        'brownian_prob': np.array([0, 0, 1])
    }

    color_map = {}
    for pid in t['particle'].unique():
        traj_row = t[t['particle'] == pid].iloc[0]
        probs = {
            'ctrw_prob': traj_row['ctrw_prob'],
            'fbm_prob': traj_row['fbm_prob'],
            'brownian_prob': traj_row['brownian_prob']
        }
        dominant_type = max(probs, key=probs.get)
        prob_value = probs[dominant_type]
        intensity = max(0, (prob_value - 0.33) / (1 - 0.33))
        base_color = color_dict[dominant_type]
        final_color = (1 - intensity) * np.array([1, 1, 1]) + intensity * base_color
        color_map[pid] = final_color

    particle_trajectories = {
        pid: df for pid, df in t.groupby('particle')
    }

    fig, ax1 = plt.subplots(figsize=(8, 6), dpi=120)
    ax1.set_xlim(0, 256)
    ax1.set_ylim(0, 256)
    ax1.axis('off')

    legend_elements = [
        Patch(facecolor=color_dict['ctrw_prob'], edgecolor='black', label='CTRW'),
        Patch(facecolor=color_dict['fbm_prob'], edgecolor='black', label='FBM'),
        Patch(facecolor=color_dict['brownian_prob'], edgecolor='black', label='Brownian')
    ]
    ax1.legend(
        handles=legend_elements,
        loc='upper right',
        frameon=True,
        framealpha=0.8,
        edgecolor='black',
        fontsize=9
    )

    img_handle = ax1.imshow(frames[0], cmap='gray')

    # ---- Animation function ----
    def animate(i):
        img_handle.set_data(frames[i])
        ax1.set_title(f"{titles[index]} Frame {i}")
        while ax1.lines:
            ax1.lines[0].remove()
        for pid, traj in particle_trajectories.items():
            traj_up_to_i = traj[traj['frame'] <= i]
            if not traj_up_to_i.empty:
                ax1.plot(traj_up_to_i['x'], traj_up_to_i['y'], lw=1, color=color_map[pid])

    # ---- Create and save animation ----
    line_ani = matplotlib.animation.FuncAnimation(
        fig, animate, frames=len(frames)
    )
    Writer = matplotlib.animation.writers['ffmpeg'] 
    line_ani.save(f"{dataset.ANALYSIS_FOLDER}/{titles[index]}_ctrw.mp4") #Saves to analysis folder, can change this if you want 
    line_ani
    line_anis.append(line_ani)
    

In [None]:
line_anis[0]

In [None]:
import pandas as pd

# Group by particle → take first row for each particle
for index in range(len(ms)):
    t4 = t4s[index]
    if np.shape(t4)[0]==0:
        continue
    particle_groups = t4.groupby('particle').first()
    
    # For each particle, find max probability and its type
    classifications = particle_groups[['ctrw_prob', 'fbm_prob', 'brownian_prob']].idxmax(axis=1)
    max_probs = particle_groups[['ctrw_prob', 'fbm_prob', 'brownian_prob']].max(axis=1)
    
    # Filter particles with max prob > 0.40
    filtered = classifications[max_probs > 0.40]
    
    # Count how many particles in each classification
    counts = filtered.value_counts()
    print(titles[index])
    print(counts)


In [None]:
import pandas as pd

# Group by particle → take first row for each particle
for index in range(len(ms)):
    t5 = t5s[index]
    if np.shape(t5)[0]==0:
        continue
    particle_groups = t5.groupby('particle').first()
    
    # For each particle, find max probability and its type
    classifications = particle_groups[['ctrw_prob', 'fbm_prob', 'brownian_prob']].idxmax(axis=1)
    max_probs = particle_groups[['ctrw_prob', 'fbm_prob', 'brownian_prob']].max(axis=1)
    
    # Filter particles with max prob > 0.40
    filtered = classifications[max_probs > 0.50]
    
    # Count how many particles in each classification
    counts = filtered.value_counts()
    print(titles[index])
    print(counts)
