## This notebook will be instructions how to tune parameters and get the image data into a good state

### Run the imports to get the libraries

In [None]:
%matplotlib inline

In [None]:
%reload_ext autoreload
%autoreload 2
from analysis_stm import MotionAnalyzer,ExpMetaData, DiffusionPlotter
import numpy as np
from scipy.optimize import curve_fit
import trackpy as tp
from PIL import Image
import matplotlib.pyplot as plt
from sxmreader import SXMReader
import matplotlib as mpl
import scipy
from sxmreader import SXMReader
from PIL import Image
from matplotlib.ticker import MultipleLocator
from matplotlib.animation import FuncAnimation, PillowWriter
%matplotlib inline
import ipywidgets as widgets 
import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np
import frame_correct as fc
import pims
import yaml
#import ffmpeg
import matplotlib as mpl

### The next cell loads all the parameters: after tuning them, I suggest editing this initial cell so you don't have to repeat.

In [None]:
## THIS IS TEST DATA, NOT NECCESSARILY ACCURATE

#Set the default parameter for the scan area correction from frame_correct
frame_drift_par={
                   'steps':2, #How many loops it runs: 1 is fastest, 2 is more consistent, 3+ is pretty unneccessary is generally the same as 2.
                   'adjust':[[],[]], #Manual corrections, empty for default. 
                   'output_crop':200, #The output image size with grey padding: Needs to be large enough to account for the frame shifting
} 
emds=[] #Creates list of metadata
frame_drift_pars=[] #Creates list of frame_correct parameters
tears=[] #Creates list of visual tears tears
params_=[] # Creates list of trackpy parameters 


Vg=60 # Gate voltage: Used for labeling an plotting
FOLDER = "test_sets/15.5K" #Folder where the data is 
voltages_temperatures = np.array([15.5,15.6]) #Temperatures or source drain voltage: Can have multiple

sets = [ range(138, 145),range(145, 156)] #The ranges of images (must be formated as Image_123.sxm.)
#Corresponding to each voltage/temeprature
emd = ExpMetaData(sets, Vg, voltages_temperatures, FOLDER) #create object to store this data

#trackpy parameters 
params = {
    'molecule_size': 7, #size of circles to consider a molecule
    'min_mass': 2, #measure of the total brightness in each circle 
    'max_mass': 100, #filtering we added
    'min_size': 1,  #filtering we added
    'max_ecc': 1, #Filtering we aded
    'separation': 1, #for tp.locate
    'search_range': 50, #for tp.link
    'adaptive_stop': 1,#for tp.link
    'diffusion_time': 10,#for time scale of data: time between each frame
    'threshold': 0.12, #lower bound of intensity that trackpy sees: also necessary to ignore grey padding.
} 

#appending everything to the lists
frame_drift_pars.append([frame_drift_par,frame_drift_par])
emds.append(emd)
tears.append([None,None])
#Should be 
#tears.append([[[0,64]],[[3,61]]])

params_.append(params)

### This cell organizes the data 

In [None]:
ms=[] #lists of datasets

for i,emd in enumerate(emds):
    with open('params.yaml', 'w') as f:
        yaml.dump(params_[i], f, default_flow_style=False)
    m = MotionAnalyzer(emd.sets, emd.voltages_temperatures,
                   emd.folder,frame_drift_par=frame_drift_pars[i], heater=True, drift_correction = True,rotation_check=True,correct='lines')
    ms.append(m)


### A very common error that happens is if the data you inputted does not actually exist or is incorrectly named.
The error message will have something like assert os.path.exists(self.filename) followed by an Assertion Error.
The following cell runs a check to see if the files are present.

In [None]:
import os
empty=[]
for m in ms:
    # Path to the folder with your files
    for dataset in m.SXM_PATH:
        for path in dataset:
            if not os.path.exists(path):
                print(f'{path} does not exists')
                empty.append(path)
if len(empty)==0:
    print("All files present")


### The next cell plots all the images currently loaded, sorted by data set and temperature, in order to indentify any visual tears
 In the test dataset given, there are tears in the first temperature on frame 0 at line 64, and on the second temperature at frame 3 at line 61

In [None]:
# Run this, and USE THIS TO FIND TEARS
for i,m in enumerate(ms):
    fig=plt.figure(figsize=(13,0.1))
    plt.title(f'dataset {i}')
    for temp_set, path in enumerate(m.SXM_PATH):
        fig=plt.figure(figsize=(10,0.1))
        plt.title(f'Temperature set {temp_set}')
        frames = SXMReader(path, correct='lines')
        frames_=[]
        for frame in frames:
            frames_.append(frame)
        frames=frames_
        tear=tears[i][temp_set]
        if tear is not None:
            for correction in tear:
                frame_new=np.array(frames[correction[0]])
                frame_new[correction[1],:]=(frame_new[correction[1]+1,:]+frame_new[correction[1]-1,:])/2
                frame_new = (frame_new - frame_new.min()) / (frame_new.max() - frame_new.min())  # scale to [0, 1]
                frame_new = frame_new * 2 - 1    
                frame_new = pims.Frame(frame_new, frame_no=correction[0])
                frames[correction[0]]=frame_new
        for frame_number in range(len(frames)):
            fig,ax=plt.subplots()
            plt.title(f'frame {frame_number}')
            ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(1))
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(5))

            
            ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(1)) 
            plt.imshow(frames[frame_number])
        
plt.show()

### Fixing the Tears
Obviously, we don't want those tears in our data. We record where they are so they can be fixed. 
Once you identify the tears, you can run this cell and then the previous cell to see if they were properly corrected.
I suggest going back to the beggining afterword and replacing tears.append(None,None)
With the correct tears.append([[[0,64]],[[3,61]]])

In [None]:
#OPTIONAL: Fix tears here:
index=0 #The dataset of interest
#The format is like this:
#tears[index]=[tears for first temperature, tears for second temperature...]
#where tears for first temperature= [[frame,row],[frame 2,row 2]...]
tears[0]=[[[0,64]],[[3,61]]]
#rerun above cell to check


### Optimizing search parameters
This is the most important (and also tedious) step. We optimize the search parameters to pick up only real particles.
play with molecule_size and the kwargs (key word arguments) to see what works. You want to find all real particles and not pick up false particles. The parameters at the beggining are the ones I used and found satisfactory.

In [None]:
#IMPORTANT for testing
###### Optimize Parameters
plt.figure()
index=0 #Choose data set
temp_index=0 #Choose temperature set
path = ms[index].SXM_PATH[temp_index]
frames = SXMReader(path, correct='lines')
frames_=[]
for frame in frames:
    frames_.append(frame)
frames=frames_
tear=tears[index][temp_set]
if tear is not None:
    for correction in tear:
        frame_new=np.array(frames[correction[0]])
        frame_new[correction[1],:]=(frame_new[correction[1]+1,:]+frame_new[correction[1]-1,:])/2
        frame_new = (frame_new - frame_new.min()) / (frame_new.max() - frame_new.min())  # scale to [0, 1]
        frame_new = frame_new * 2 - 1    
        frame_new = pims.Frame(frame_new, frame_no=correction[0])
        frames[correction[0]]=frame_new
images=[]
for frame in range(len(frames)): #change len(frames) to a number to see a smaller set of images
    image=fc.add_crop(frames[frame],0,0,frame_drift_pars[0][index]['output_crop'])
    molecule_size=7
    kwaargs={
        "minmass" : 2, #integrated intensity of feature
        "separation" : 1, #How far apart 
        "threshold":0.12,} #lowest intensity the program sees, everything is normalized -1 to 1.
        #Most useful for getting rid of grey padded caused particles is threshold 
    f = tp.locate(image,molecule_size,**kwaargs,engine='python')
    plt.title(frame)
    images.append(image)
    tp.annotate(f,image)

### Once you have good parameters, GO BACK TO THE BEGGINING and change the parameters there.
Run the following cell after to update everything.

In [None]:
#IMPORTANT, loads some paremeters for testing SAME AS EARLIER CELL, HERE FOR CONVINIENCE

ms=[]

for i,emd in enumerate(emds):
    with open('params.yaml', 'w') as f:
        yaml.dump(params_[i], f, default_flow_style=False)
    m = MotionAnalyzer(emd.sets, emd.voltages_temperatures,
                   emd.folder,frame_drift_par=frame_drift_pars[i], heater=True, drift_correction = True,rotation_check=True,correct='lines')
    ms.append(m)


### Addding drift corrections
The next cell uses frame_correct to stabilize the images, and the cell after plots them. Warning, this can be fairly slow.

In [None]:
#RUN IF YOU WANT TO CHECK FRAME CORRECTION BEFOREHAND AND ADD MANUAL CORRECTIONS
frames_all=[]
for i,m in enumerate(ms):
    frames_corrected=[]
    for temp_set,path in enumerate(m.SXM_PATH):
        frames_ = SXMReader(path, correct='lines')
        frames=[]
        for frame in frames_: 
            frames.append(frame)
        tear=tears[i][temp_set]
        if tear is not None:
            for correction in tear:
                plt.figure()
                frame_new=np.array(frames[correction[0]])
                frame_new[correction[1],:]=(frame_new[correction[1]+1,:]+frame_new[correction[1]-1,:])/2
                frame_new = (frame_new - frame_new.min()) / (frame_new.max() - frame_new.min())  # scale to [0, 1]
                frame_new = frame_new * 2 - 1    
                frame_new = pims.Frame(frame_new, frame_no=correction[0])
                frames[correction[0]]=frame_new
                plt.imshow(frame_new)
                plt.show()
        frames_=frames

        frames_corrected.append(fc.Frame_correct_loop(frames_,m.PARAMS[temp_set],**frame_drift_par)) #uses default parameters
    frames_all.append(frames_corrected)

In [None]:
## plot trajectories to check for errors
index=0 #CHANGE THIS TO BE THE INDEX OF THE DATASET YOU'RE INTERSTED IN
temp_set=1
frames_corrected=frames_all[index][temp_set] 
molecule_size, min_mass, max_mass, separation, min_size, max_ecc, adaptive_stop, search_range,threshold, _ = ms[index].PARAMS[temp_set]

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig=plt.figure(figsize=(8, 6), dpi=120)
ax1=plt.axes(xlim=(0, 256), ylim=(0, 256), frameon=False)
plt.axis('off')
ln, = ax1.plot([], [], lw=3)
f=tp.batch(frames_corrected,molecule_size,minmass=min_mass,separation=separation,threshold=threshold,engine='python')
t = tp.link(f, search_range=search_range, adaptive_stop=adaptive_stop,memory=0)
def animate(i):
    plt.cla()
    plt.title(i)
    tp.plot_traj(t[(t['frame'])<=i], superimpose=frames_corrected[i], label=True, ax=ax1, plot_style={'alpha' : 1});
    ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])
line_ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(frames_corrected))
plt.close()
line_ani


### Manual Correction
Sometimes the frame_correct doesnt converge properly, or converges to the wrong conclusion (i.e if we have a periodic background, it might 
be off by a lattice constant). The next cell should let you measure the offset: click a stationary point in each frame. (This might not work if you are on a remote server because of how the popup works). It will open a window where you click a point on the left and on the right that is stationary between the two images, and it will measure the shift.

In [None]:
#Running this cell should open up a window
dataset = 0
temp_set = 1
shift_frames = [] #Frames after shifts
measured_shifts=[]
for frame in shift_frames:
    measured = fc.measure_shift(frames_all[dataset][temp_set], frame)
    measured_shifts.append(measured)


#### This cell confirms your selection

In [None]:
#RUN THIS TO CONFIRM MANUAL CORRECTION
frame_drift_pars[dataset][temp_set]['adjust']=[[],[]] #Resets
frame_drift_pars[dataset][temp_set]['adjust']=[shift_frames,measured_shifts] #This adds the previously calculated shift,


### OPTIONAL:The next two cells add and animate your manual selection. 

In [None]:
frames_manual=[]
for i,m in enumerate(ms):
    frames_corrected=[]
    for temp_set,path in enumerate(m.SXM_PATH):
        frames_ = SXMReader(path, correct='lines')
        frames=[]
        for frame in frames_: 
            frames.append(frame)
        tear=tears[i][temp_set]
        if tear is not None:
            for correction in tear:
                plt.figure()
                frame_new=np.array(frames[correction[0]])
                frame_new[correction[1],:]=(frame_new[correction[1]+1,:]+frame_new[correction[1]-1,:])/2
                frame_new = (frame_new - frame_new.min()) / (frame_new.max() - frame_new.min())  # scale to [0, 1]
                frame_new = frame_new * 2 - 1    
                frame_new = pims.Frame(frame_new, frame_no=correction[0])
                frames[correction[0]]=frame_new
                plt.imshow(frame_new)
                plt.show()
        frames_=frames

        frames_corrected.append(fc.Frame_correct_loop(frames_,m.PARAMS[temp_set],**frame_drift_pars[dataset][temp_set]))
    frames_manual.append(frames_corrected)

In [None]:
## plot trajectories to check for errors
index=0 #CHANGE THIS TO BE THE INDEX OF THE DATASET YOU'RE INTERSTED IN
temp_set=1
frames_corrected=frames_manual[index][temp_set] 
molecule_size, min_mass, max_mass, separation, min_size, max_ecc, adaptive_stop, search_range,threshold, _ = ms[index].PARAMS[temp_set]

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig=plt.figure(figsize=(8, 6), dpi=120)
ax1=plt.axes(xlim=(0, 256), ylim=(0, 256), frameon=False)
plt.axis('off')
ln, = ax1.plot([], [], lw=3)
f=tp.batch(frames_corrected,molecule_size,minmass=min_mass,separation=separation,threshold=threshold,engine='python')
t = tp.link(f, search_range=search_range, adaptive_stop=adaptive_stop,memory=0)
def animate(i):
    plt.cla()
    plt.title(i)
    tp.plot_traj(t[(t['frame'])<=i], superimpose=frames_corrected[i], label=True, ax=ax1, plot_style={'alpha' : 1});
    ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])
line_ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(frames_corrected))
plt.close()
line_ani


### Optional but recommended: Add info back to beggining
In the beggining, we did 

frame_drift_pars.append([frame_drift_par,frame_drift_par])

which puts the default no manual adjustment.
If you want to record the manual adjustment so you don't have to do them again, record use the following format at the beginning instead:

frame_drift_par_0={

                   'steps':2,
                   
                   'adjust':[[1],[(-1.3,-24)]],
                   
                   'output_crop':200, 
} 

frame_drift_par_1={

                   'steps':2,
                   
                   'adjust':[[shift_frame1,shiftframe2,shiftframe3,etc],[(shiftx,shifty),(shiftx,shifty),(shiftx,shifty),etc]],
                   
                   'output_crop':200, 
} 

frame_drift_pars.append(frame_drift_par_0,frame_drift_par_1).



## Now We look at the final product.
Run the next cell to put everything together with some additional filtering. 

In [None]:
ms=[]

for i,emd in enumerate(emds):
    with open('params.yaml', 'w') as f:
        yaml.dump(params_[i], f, default_flow_style=False)
    m = MotionAnalyzer(emd.sets, emd.voltages_temperatures,
                   emd.folder,frame_drift_par=frame_drift_pars[i], heater=True, drift_correction = False,rotation_check=True,correct='lines')
    if tears[i] is not None:
        m.tear_correct=tears[i]
    m.analyze()
    ms.append(m)
dps=[]
for i,m in enumerate(ms):
    dp=DiffusionPlotter(m)
    dps.append(dp)

### This cell animates the final trajectories
The animations are saved to the analysis folder

In [None]:
## plot trajectories to check for errors
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()
fig=plt.figure(figsize=(8, 6), dpi=120)
ax1=plt.axes(xlim=(0, 256), ylim=(0, 256), frameon=False)
plt.axis('off')
ln, = ax1.plot([], [], lw=3)
dataset=ms[0]##CHOOSE DATAFRAME HERE 
temp_index=0
t=dataset.t3s_C[0][temp_index] #DRIFT UNCORRECTED TO PROPERLY MATCH IMAGE
ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])
def animate(i):
    plt.cla()
    plt.title(i)

    tp.plot_traj(t[(t['frame']<=i)], superimpose=dataset.frames[temp_index][i], label=True, ax=ax1, plot_style={'alpha' : 1})
    ax1.set_prop_cycle(color=['g', 'r', 'c', 'm', 'y', 'k'])
#     plt.imshow(frames[i], cmap='gray')
# Set up formatting for the movie files
Writer = matplotlib.animation.writers['ffmpeg']
writer = Writer(fps=5, metadata=dict(artist='Me'), bitrate=1800)
line_ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(dataset.frames[temp_index]))
line_ani.save(f'{dataset.ANALYSIS_FOLDER}_tracking.mp4', writer=writer)
line_ani

#### The following cell is to manually remove any particles. 
Ideally you shouldn't have to.


In [None]:
#REMOVING PARTICLES
index=0

bad_particles_for_first_temp=[2000,3000] #This is just for clarity, you are not limited to two tempuratures. 
bad_particles_for_second_temp=[123123123,345435]
ms[index].removed_particles=[bad_particles_for_first_temp,bad_particles_for_second_temp] 

ms[index].analyze(refilter=True) #Removes them and recalculates 
dps[index]=DiffusionPlotter(ms[index]) #updates the plotter


### The nexts cells shows how the data is stored

In [None]:
dataset=0
temp_set=0
data=ms[dataset]#all data stored in the list ms



In [None]:
#t3s has all the data from trackpy: if you want to make plots or do data analysis, this is where to look.
#The values are in PIXELS. Multiply by data.NM_PER_PIXEL to get physical values
data.t3s[temp_set] 

In [None]:
#t3s_C has both the ensemble drift subtracted and not subtracted, i=0 is raw and i=1 is drift subtracted. Any attribute with _C is structured similarly.
#t3s is just i=0 or 1 depending on what you chose when m was initialized 
i=0
data.t3s_C[i][temp_set] 

#### There are many values already calculated, stored as attributes.
They follow similar organization:

data.attribute[temp_set], 
and 

data.attribute_C[drift_corrected][temp_set].

Look at analysis_stm.py for more detailed description. Below I have printed all the attribute names

In [None]:

for key in m.__dict__.keys():
    print(key)

### DiffusionPlotter
DiffusionPlotter is a class that stores a bunch of plotting functions. We create instances of this class in the list named dps. 
More detail is in analysis_stm.py.

In [None]:
dataset=0
temp_set=0
plotter=dps[0]

In [None]:
# Plots msd vs timestep. scale is either 'linear' or 'log'
#It is fitted using d*(x^a)+c. linearfit=True forces a=1, and intercept=False forces c=0. end takes an integer and cuts off after that point. 
#Note that the points at later time have higher errors associated, because there are less data points for that time scale
plotter.plot_msd(scale='linear',linearfit=False,intercept=False,end=None)
plt.show()

#### All the available functions are printed below. Look at analysis_stm.py for how to use them

In [None]:
import inspect
methods=[
    name for name, obj in DiffusionPlotter.__dict__.items()
    if inspect.isfunction(obj) and not name.startswith('_')
]
for method in methods:
    print(method)