In [None]:
%load fast_spt_analysis.py

### Test masking parameters for SPT trajectories

In [None]:
tracks_path = "" # provide path to tracks
bulk_test_masking_params(tracks_path)

### Batch apply masking parameters to trajectories

In [None]:
masking_file = "example.csv" # csv file indicating how to perform masking for trajectories
batch_apply_mask(masking_file)

# fastSPT analysis using Spot-On

In [1]:
# camera parameters
frame_interval = 0.00608576 #s
pixel_size = 0.11 #um

### Convert trajectory formats

In [None]:
quot_path = "" # path to masked quot output

output_path = "" # output path for new format

for quot_dir in os.listdir(quot_path):
    print(quot_dir)
    out_path = os.path.join(output_path, quot_dir)
    if not os.path.exists(out_path):
        os.makedirs(out_path) 
        
    in_path = os.path.join(quot_path, quot_dir)
    batch_prepare_for_spoton(in_path, out_path, pixel_size, frame_interval)

### Convert trajectories to Anders format (see spoton website)

In [None]:
datasets = {"condition_1": [],
           "condition_2": []}

for spoton_dir in os.listdir(output_path):
    for key in datasets.keys():
        print(spoton_dir, key)
        key_dataset = convert_spoton_to_anders_format(os.path.join(output_path, spoton_dir), contains=key)
        if key_dataset is not None:
            datasets.get(key).append(key_dataset)

### Example analysis for a condition
#### Per cell analysis

In [None]:
# fit condition_1 dataaset
h1s, fits, ys = [], [], []
dZ = 0.7
for rep_dataset in datasets.get("condition_1"):
    # perform a two-state fit for each cell
    h1, fit, y = fit_spoton_2_state(rep_dataset, frame_interval, cdf=True, use_entire_traj=True, 
                                    loc_error=None, fit_sigma=True, dZ=dZ)
    h1s.append(h1)
    fits.append(fit)
    ys.append(y)

In [None]:
## And perform the plot
for h1, fit, y in zip(h1s, fits, ys):
    HistVecJumps = h1[2]
    JumpProb = h1[3]
    HistVecJumpsCDF = h1[0]
    JumpProbCDF = h1[1]
    plt.figure(figsize=(10,6)) # Initialize the plot
    fastspt_plot_histogram(HistVecJumps, JumpProb, HistVecJumpsCDF, y) 

#### Average analysis

In [None]:
# fit condition_1 dataaset
avg_dataset = np.concatenate(datasets.get("condition_1"))
avg_h1, avg_fit, avg_y = fit_spoton_2_state(avg_dataset, frame_interval, cdf=True, use_entire_traj=True, 
                                loc_error=None, fit_sigma=True, dZ=dZ)

In [None]:
HistVecJumps = avg_h1[2]
JumpProb = avg_h1[3]
HistVecJumpsCDF = avg_h1[0]
JumpProbCDF = avg_h1[1]
plt.figure(figsize=(10,6)) # Initialize the plot
fastspt_plot_histogram(HistVecJumps, JumpProb, HistVecJumpsCDF, avg_y) 
plt.savefig("condition_1_average.svg", format="svg")

In [None]:
cmap = plt.get_cmap('viridis')
colors = [cmap(i) for i in np.linspace(0, 1, avg_h1[3].shape[0])]

fig, axs = plt.subplots(7, 1, sharex=True)
bar_width = np.diff(avg_h1[2])[0]
for i, ax in enumerate(reversed(axs)):
    ax.plot(avg_h1[0], avg_h1[1][i, :], 'k-', linewidth=1)
    ax.bar(avg_h1[2], np.cumsum(avg_h1[3][i, :]), align='edge', width=bar_width, color=colors[i])
    ax.set_ylim([0, 1])
    ax.patch.set_alpha(0)

sns.despine()
plt.subplots_adjust(hspace=-0.25)

# fastSPT analysis with saSPT
#### Read trajectories

In [None]:
quot_dataframes = [""] # list of directories to replicates

directories = []
conditions = []
replicate = []
for i, input_dir in enumerate(quot_dataframes):
    for file in tqdm(os.listdir(input_dir)):
            if file.split('.')[-1] == "csv":
                df = pd.read_csv(os.path.join(input_dir, file))
                if not df.empty:
                    directories.append(os.path.join(input_dir, file))
                    replicate.append(i+1)
                    if "condition_1" in file:
                        conditions.append("condition_1")
                    elif "condition_2" in file:
                        conditions.append("condition_2")
                    else:
                        conditions.append("else")

print(len(directories), len(conditions), len(replicate)) 
expt_conditions = pd.DataFrame({'filepath': directories, 'condition': conditions, 'replicate': replicate} )

#### Perform state array analysis

In [None]:
settings = dict(
    likelihood_type = RBME,
    pixel_size_um = pixel_size,
    frame_interval = frame_interval,
    focal_depth = 0.7,
    path_col = 'filepath',
    condition_col = 'condition',
    progress_bar = True,
    num_workers = 64,
)
SAD = StateArrayDataset.from_kwargs(expt_conditions, **settings)

In [None]:
marginal_naive_occs = SAD.marginal_naive_occs
marginal_posterior_occs = SAD.marginal_posterior_occs
print(marginal_naive_occs.shape)
print(marginal_posterior_occs.shape)

In [None]:
modified_df = SAD.marginal_posterior_occs_dataframe.copy()
modified_df = modified_df[modified_df["condition"] != "else"]

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(10, 10), 
                       dpi=600, sharex=True, height_ratios=[3, 1, 1])
sns.lineplot(modified_df, x="diff_coef", y="posterior_occupation", hue="condition", ax=ax[0])
ax[0].set_xscale('log')
# ax[0].set_xlim([0.01, 100])
axis_lim = 0.075
ax[0].set_ylim([0, axis_lim])
converted_df = modified_df[modified_df["condition"]=="condition_1"].set_index("diff_coef").groupby(level=0).agg(list)["posterior_occupation"].apply(pd.Series)
converted_df.sort_values(by=list(converted_df.index), axis=1, inplace=True)
x = converted_df.index
y = np.arange(converted_df.shape[1])
X,Y = np.meshgrid(x,y)

ax[1].pcolormesh(X,Y,converted_df.to_numpy().T, vmax=axis_lim)
ax[1].set_aspect('auto')

converted_df = modified_df[modified_df["condition"]=="condition_2"].set_index("diff_coef").groupby(level=0).agg(list)["posterior_occupation"].apply(pd.Series)
converted_df.sort_values(by=list(converted_df.index), axis=1, inplace=True)
x = converted_df.index
y = np.arange(converted_df.shape[1])
X,Y = np.meshgrid(x,y)

ax[2].pcolormesh(X,Y,converted_df.to_numpy().T, vmax=axis_lim)
ax[2].set_aspect('auto')

plt.savefig("conditions_saspt_plot.svg", format="svg")