# Photometry Analysis

The goal of this script is to:

1. import photometry data and crop
1. deinterleave data
1. normalize photometry data for photobleaching
1. chop up photometry data around keystrokes


## import and definition


In [None]:
import pandas as pd
import numpy as np

import plotly.express as px
from scipy.optimize import curve_fit
from sklearn.linear_model import HuberRegressor


IN_DATA = "./data/data.csv"
IN_TS = "./data/timestamp.csv"
PARAM_ROIS = ["Region0G"]
PARAM_NFM_DISCARD = 100
PARAM_LED_DICT = {7: "initial", 1: "415nm", 2: "470nm", 4: "560nm"}
PARAM_EVT_RANGE = (-200, 400)


def cut_df(df, nrow, sortby="Timestamp"):
    return df.sort_values(sortby).iloc[:nrow]


def exp2(x, a, b, c, d, e):
    return a * np.exp(b * x) + c * np.exp(d * x) + e


def plot_signals(data):
    dat_long = data[["Timestamp", "signal"] + PARAM_ROIS].melt(
        id_vars=["Timestamp", "signal"], var_name="roi", value_name="raw"
    )
    return px.line(dat_long, x="Timestamp", y="raw", facet_row="roi", color="signal")


## load data


ensure equal number of frames for each channel


In [None]:
data = pd.read_csv(IN_DATA)
ts = pd.read_csv(IN_TS, names=["Timestamp", "Key", "Time"])
data = data[data["FrameCounter"] > PARAM_NFM_DISCARD].copy()
data["signal"] = data["LedState"].map(PARAM_LED_DICT)
nfm = data.groupby("signal").size().min()
data = (
    data.groupby("signal", group_keys=False)
    .apply(cut_df, nrow=nfm)
    .reset_index(drop=True)
)
channels = np.unique(data["signal"])


## visualize raw signal


In [None]:
plot_signals(data)


## correcting for photobleaching


order of operations:

1. deinterleave data by flag (LED) and save into a matrix
1. fit isosbestic signal with a biexponential decay -- the shape of this decay is a good approximation of the CONCENTRATION of GCaMP molecules underneath your fiber.
   it decreases as the photobleach.
   the amplitude, however, is tiny.
   to adjust for this, we:
1. linearly scale the fitted decay to the 470 data using robust fit.
1. divide the raw 470 data by this scale fit to get a corrected signal

note: this isn't dF/F but it is INTERNALLY reliable -- that is, you can compare the beginning of the recording to the end of the recording.
dF/F requires a good approximation of baseline.
you can use the `FP.no_led` as an underestimation of this -- or determine it empirically.
it often isn't critical to a sound analysis.


In [None]:
dat_415 = data[data["signal"] == "415nm"].copy()
x = np.linspace(0, 1, len(dat_415))
dat_fit = dat_415.copy()
dat_fit["signal"] = "415nm-fit"
sig_df_ls = [
    data[data["signal"] == sig].copy() for sig in set(channels) - set(["415nm"])
]
for roi in PARAM_ROIS:
    popt, pcov = curve_fit(
        exp2, x, dat_415[roi], p0=(1.0, -1.0, 1.0, -1.0, dat_415[roi].mean())
    )
    fit_415 = exp2(x, *popt)
    dat_fit[roi] = fit_415
    for sig_df in sig_df_ls:
        sig_df["signal"] = sig_df["signal"] + "-norm"
        model = HuberRegressor()
        model.fit(fit_415.reshape((-1, 1)), sig_df[roi])
        sig_df[roi] = sig_df[roi] - model.predict(fit_415.reshape((-1, 1)))
data_norm = pd.concat([data, dat_fit] + sig_df_ls, ignore_index=True)


## visualize correction result


In [None]:
plot_signals(data_norm)


## pool signals around `'Key'` events


In [None]:
ts["event"] = ts["Key"].astype(str)
ts["evt_id"] = ts["Timestamp"].astype(str) + "-" + ts["event"]
data_join = data_norm.merge(ts, on="Timestamp", how="left")
max_fm = data_join["FrameCounter"].max()
evt_df = []
for _, dat_sig in data_join.groupby("signal"):
    for idx, row in dat_sig[dat_sig["evt_id"].notnull()].iterrows():
        fm = row["FrameCounter"]
        fm_range = tuple((np.array(PARAM_EVT_RANGE) + fm).clip(0, max_fm))
        dat_sub = dat_sig[dat_sig["FrameCounter"].between(*fm_range)].copy()
        dat_sub["evt_fm"] = dat_sub["FrameCounter"] - fm
        dat_sub["event"] = row["event"]
        dat_sub["evt_id"] = row["evt_id"]
        for roi in PARAM_ROIS:
            mean = dat_sub.loc[dat_sub["evt_fm"] < 0, roi].mean()
            std = dat_sub.loc[dat_sub["evt_fm"] < 0, roi].std()
            if std > 0:
                dat_sub[roi] = (dat_sub[roi] - mean) / std
            else:
                dat_sub[roi] = 0
        evt_df.append(dat_sub)
evt_df = pd.concat(evt_df, ignore_index=True)


## visualize signals around events


In [None]:
px.line(
    evt_df,
    x="evt_fm",
    y="Region0G",
    color="evt_id",
    facet_row="event",
    facet_col="signal",
)
