In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import utils as utl
import matplotlib.pyplot as plt

from matplotlib import cm
from matplotlib import colors
import matplotlib.ticker as mticker
from scipy.interpolate import UnivariateSpline, LSQUnivariateSpline
from scipy.ndimage import gaussian_filter1d   
from functools import partial

import gc

import warnings
import re
import importlib

import data_processing as dp
from data_processing import didx

import common_plot_parameters as cpprm

idx = pd.IndexSlice

In [None]:
# import raw data
# Get data from MV experiment and not MV experiment
EXPERIMENT_PATH="data/20250227_DarkLightTransition_25umol"
FILEPATHS_NOMV = [x for x in Path(EXPERIMENT_PATH).glob("*") if x.is_dir() and not x.name.startswith("_")]

EXPERIMENT_PATH="data/20250314_OJIP_MV_1mM"
FILEPATHS_MV = [x for x in Path(EXPERIMENT_PATH).glob("*") if x.is_dir() and not (x.name.startswith("_") or re.search("250uM", x.name))]

FILEPATHS = FILEPATHS_NOMV + FILEPATHS_MV

_ojip, levels = dp.load_data(FILEPATHS)

# Load the plot parameters
spec=importlib.util.spec_from_file_location("plot_parameters",Path(EXPERIMENT_PATH)/"plot_parameters.py")
pprm = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pprm)
plot_parameters=pprm.plot_parameters

In [None]:
# Replace the treatement with 0 and 1 as the MV treatment concentrations
_ojip_mv = _ojip.loc[:,didx("MV1mM", treatment=[0])]
_ojip_mv.columns = _ojip_mv.columns.set_levels([1], level=3)

# Set the correct treatment levels in the level list
levels["treatments"] = pd.Series({key:[0,1] for key, val in levels["treatments"].to_dict().items()})

# Select the right data and concatenate
ojip = pd.concat([
    _ojip.loc[:,didx("DarkLight", treatment=[0])],
    _ojip_mv
], axis=1)

del _ojip
gc.collect()

## Get the main points of the OJIP curves

In [None]:
ojip_features, ojip_features_meansd = dp.get_ojip_features(ojip)

## Double normalize the data

In [None]:
ojip_norm = dp.normalize_ojip(ojip, ojip_features=ojip_features)

# Plot OJIP for each Density

In [None]:
# Get the ojip pints
with warnings.catch_warnings(record=False) as caught_warnings:
    warnings.simplefilter("ignore") 
    ojip_points_res = utl.determine_OJIP_points(ojip_norm, **dp.feature_finding_options)

ojip_points = ojip_points_res["points"]

In [None]:
# # Plot the raw ojip curves
# dp.plot_ojip(ojip=ojip, levels=levels, treatment_label=plot_parameters["treatment_label"])

In [None]:
# # Plot the normalized OJIP curves with the identified curves
# dp.plot_ojip(
#     ojip=ojip_norm,
#     levels=levels,
#     ojip_points=ojip_points,
#     treatment_label=plot_parameters["treatment_label"],
#     point_finding_span=(dp.feature_finding_options["FJ_time_min"], dp.feature_finding_options["FJ_time_max"])
# )

## Plot the identified VJ timing

In [None]:
# # Get the identified VJ points
# VJ_identified = ojip_points.loc[:, ("inflection", "FJ_value")]

# # Plot
# dp.plot_VJ_per_replicate(VJ_identified, levels=levels, title="Identified VJs", treatment_label=plot_parameters["treatment_label"])

## Plot the VJ value at a common time

In [None]:
# Get VJ at common times
VJ_timing, VJ_timing_range, VJ_values = dp.get_common_time_VJ(
    ojip_points=ojip_points,
    ojip_norm=ojip_norm,
    levels=levels
)

# # Plot
# dp.plot_VJ_per_replicate(VJ_values, levels=levels, title="Common time VJs", treatment_label=plot_parameters["treatment_label"])

## Paper figure

In [None]:
# Calculate the mean ojip curves
ojip_norm_meansd = pd.concat({
    "mean":ojip_norm.T.groupby(ojip_norm.columns.names[:-1]).mean(),
    "sd":ojip_norm.T.groupby(ojip_norm.columns.names[:-1]).std()
    }, names=["Measure"]).T

VJ_values_meansd = pd.concat({
    "mean":VJ_values.T.groupby(VJ_values.index.names[:-1]).mean(),
    "sd":VJ_values.T.groupby(VJ_values.index.names[:-1]).std()
    }, names=["Measure"]).T

In [None]:
plot_parameters = {
    'treatment_format': 'MV treated',
    'treatment_0_label': 'Untreated',
    'treatment_label': 'CO$_2$ concentration',
    'treatments_select': [0, 1],
    'treatment_center': 0.5,
    'strain_map': {
        'Syn': '$\\mathit{\\boldsymbol{Synechocystis}}$ dark-acclimated + 1mM MV',
        'Chlo': '$\\mathit{\\boldsymbol{Chlorella}}$ dark-acclimated + 1mM MV'
    }
  }

In [None]:
from utils import _plot_normalized_OJIP, log_tick_formatter

def get_base_plot_MV(
    ojip_norm,
    ojip_norm_meansd,
    ojip_points,
    VJ_timing,
    VJ_values,
    VJ_values_meansd,
    levels,
    treatment_label="TREATMENT UNIT",
    treatment_format="{treatment}",
    treatment_0_label=None,
    treatments_select=None,
    treatment_center=None,
    plot_replicates = False,
    use_colorbar = False,
    mark_sampled = True,
    cmap = cm.coolwarm,
    plot_strains = ["Syn", "Chlo"],
    row_label_ys = [1, 0.49],
    right_column_y_label="V$_{J}$ (r.u.)",
    left_column_y_label = "Double normalized Fluorescence (r.u.)",
    right_column_mark_zero=False,
    strain_map = {
        "Syn": r'$\mathit{\boldsymbol{Synechocystis}}$',
        "Chlo": r'$\mathit{\boldsymbol{Chlorella}}$'
    },
    light_phases=None,
    broken_logx_firstvalue=False,
    right_column_legend_loc="lower right",
    variance_sleeve_alpha=0.6,
    point_x_selector= ("inflection", "FJ_time"),
    point_y_selector= ("inflection", "FJ_value"),
    point_label="Inflection point",
    return_subplot_arguments=False,
    ojip_ymin=None,
    ojip_ymax=None,
    feature_ymin=None,
    feature_ymax=None,
):

    fig, axes = plt.subplots(
        2, 
        3, 
        # sharex="col", 
        sharey="col",
        figsize=(15,10),
    )

    # cmap = cm.cool

    # if use_colorbar or treatment_center is None:
    #     colornorm = plt.Normalize(
    #         levels["treatments"].apply(np.min).min(),
    #         levels["treatments"].apply(np.max).max()
    #     )
    # else:
    #     colornorm = colors.TwoSlopeNorm(
    #         vmin=levels["treatments"].apply(np.min).min(),
    #         vcenter=treatment_center,
    #         vmax=levels["treatments"].apply(np.max).max()
    #     )
    colornorm = lambda x:x

    # Define the order of the plots
    plot_conditions = ["lowCO2", "highCO2"]
    plot_conditions_colors = {
        "highCO2": np.array((0,176,80,255))/255,
        "lowCO2": np.array((146,208,80,255))/255
    }

    plot_treatments_hatch = {
        0: None,
        1: "//"
    }

    # Map the names of the conditions
    conditions_map = {
        "lowCO2": "Air",
        "highCO2": "High CO$_{2}$",
    }

    markers = {
        "lowCO2": "^",
        "highCO2": "o",
    }

    # If no selected treatments were given, use all
    if treatments_select is None:
        treatments_select = np.sort(np.unique(np.concatenate(levels["treatments"].values)))


    # Plot the double normalized ojip curves
    for i, strain in enumerate(plot_strains):
        conditions = [c for c in plot_conditions if c in levels["conditions"][strain]]
        for j, condition in enumerate(conditions):
            ax = axes[i,j]

            fig, ax = _plot_normalized_OJIP(
                fig=fig,
                ax=ax,
                strain=strain,
                condition=condition,
                levels=levels,
                ojip_norm=ojip_norm,
                ojip_norm_meansd=ojip_norm_meansd,
                ojip_points=ojip_points,
                VJ_timing=VJ_timing,
                colornorm=colornorm,#
                conditions_map=conditions_map,
                treatment_label=treatment_label,
                treatment_format=treatment_format,
                treatment_0_label=treatment_0_label,
                treatments_select=treatments_select,
                treatment_center=treatment_center,
                use_colorbar = use_colorbar,
                cmap = cmap,
                variance_sleeve_alpha=variance_sleeve_alpha,
                point_x_selector= point_x_selector,
                point_y_selector= point_y_selector,
                point_label=point_label,
                plot_replicates=plot_replicates,
            )

            # Set the given ylims
            if ojip_ymin is not None or ojip_ymax is not None:
                ax.set_ylim(ojip_ymin, ojip_ymax)

            # Plot the VJ value
            ax = axes[i,-1]

            if broken_logx_firstvalue:
                treatments = levels["treatments"][strain][1:]
            else:
                treatments = levels["treatments"][strain]

    #         if plot_replicates:
    #             for r, replicate in enumerate(levels["replicates"][(strain,condition)]):
    #                 try:
    #                     dat = VJ_values.loc[didx(strain=strain, condition=condition,replicate=replicate,treatment=treatments)].dropna()
    #                 except KeyError:
    #                     ax.grid(which="both")
    #                     continue
                    
    #                 treatment_levels = dat.index.get_level_values("Treatment").to_numpy()

    #                 ax.plot(
    #                     treatment_levels,
    #                     dat,
    #                     marker=markers[condition],
    #                     ls="-",
    #                     label=condition if r==0 else None,
    #                     c = plot_conditions_colors[condition]
    #                 )
            
    #         else:
            if True:
                try:
                    dat = VJ_values_meansd["mean"].loc[didx(strain=strain, condition=condition,replicate=None,treatment=treatments)].dropna().droplevel([0,1,2])
                except KeyError:
                    ax.grid(which="both")
                    continue
                
                treatment_levels = dat.index.get_level_values("Treatment").to_numpy()

                # ax.plot(
                #     treatment_levels,
                #     dat,
                #     marker="o",
                #     ls="-",
                #     label=condition,
                #     c = plot_conditions_colors[condition]
                # )

                dat_sd = VJ_values_meansd["sd"].loc[didx(strain=strain, condition=condition,replicate=None,treatment=treatments)].dropna().droplevel([0,1,2])

                # ax.errorbar(
                #     dat.index.to_numpy(),
                #     dat,
                #     yerr=dat_sd.to_numpy().flatten(),
                #     marker=markers[condition],
                #     ls="-",
                #     label=conditions_map[condition],
                #     c = plot_conditions_colors[condition],
                #     markeredgecolor="k",
                #     ecolor='k',
                #     capsize=3
                # )

                n_conditions = len(conditions)
                bar_width = 0.35

                # Plot each treatment as a set of bars
                for t, treatment in enumerate(treatments):
                    bar_positions = j + (t - 0.5) * bar_width
                    heights = dat[treatment]
                    errors = dat_sd[treatment]
                    
                    ax.bar(
                        bar_positions,
                        heights,
                        yerr=errors,
                        width=bar_width,
                        color=plot_conditions_colors[condition],
                        edgecolor='k',
                        capsize=3,
                        label=(treatment_format.format(treatment=treatment) if treatment!=0 or treatment_0_label is None else treatment_0_label) if j==0 else None,
                        hatch = plot_treatments_hatch[treatment],
                        zorder=3
                    )

                # X-axis labels and ticks
                ax.set_xticks(np.arange(len(conditions)))
                ax.set_xticklabels([conditions_map[c] for c in conditions])


    #             if not use_colorbar and mark_sampled:
    #                 for k, treatment in enumerate(treatments):

    #                     if treatment not in treatments_select:
    #                         continue
                        
    #                     ax.errorbar(
    #                         treatment,
    #                         dat.loc[treatment],
    #                         yerr=dat_sd.loc[treatment],
    #                         marker=markers[condition],
    #                         ls="",
    #                         c = plot_conditions_colors[condition],
    #                         markeredgecolor="k",
    #                         markeredgewidth=2,
    #                         # markersize=8,
    #                         ecolor='k',
    #                     )

            ax.legend(loc=right_column_legend_loc)# prop={'size': 9})
            ax.grid(which="both")

    #         if right_column_mark_zero:
    #             ax.axhline(0, c="k", ls="--")

    for ax in axes[:,-1]:
        ax.grid(True, zorder=0)
        if not broken_logx_firstvalue:
            ax.set_ylabel(right_column_y_label)
        else:
            ax.xaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))

        ax.set_xlabel(treatment_label)

        # Set the given ylims
        if feature_ymin is not None or feature_ymax is not None:
            ax.set_ylim(feature_ymin, feature_ymax)

    # for ax in axes[:,:-1].flatten():
    for ax in axes[:,:-1].flatten():
        ax.set_xlabel("Time (ms)")
        ax.set_xscale("log")
        ax.xaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))

    # Set the y labels
    for ax in axes[:, :2].flatten():
        ax.set_ylabel(left_column_y_label)

    fig.tight_layout()
    fig.subplots_adjust(hspace=0.3)

    # Add row labels
    for y, strain in zip(row_label_ys ,plot_strains):
        fig.text(0.045,y, strain_map[strain], weight='bold', size=20)

    # Add figure labels
    for i, ax in enumerate(axes.flatten()):
        # y= 0.1 if (i+1)%3==0 else 0.9
        # ax.text(y,0.1,chr(65+i), transform=ax.transAxes, weight='bold', size=20)
        x = 0.07 if not (broken_logx_firstvalue and (i+1)%3==0) else -0.03
        ax.text(x,0.9,chr(65+i), transform=ax.transAxes, weight='bold', size=20)

    # if light_phases is not None:
    #     for ax in axes[:, -1]:
    #         ax.set_xlim(-25)

    #         # Alternatively: Add an axis to the top to draw in
    #         pos = ax.get_position()
    #         # ax.set_position([pos.x0, pos.y0, pos.width, pos.height * 1.1])  # Increase height by 10%
    #         lbax = fig.add_axes([pos.x0, pos.y0+pos.height, pos.width, 0.02])
    #         lbax.sharex(ax)

    #         lbax.xaxis.set_visible(False)
    #         lbax.yaxis.set_visible(False)

    #         # Annotate the light phases
    #         for _, (start, end, light) in light_phases.iterrows():
    #             add_light_annotation_rectangle(lbax,start,end,light)

    return fig, axes


In [None]:
LIGHTPHASES_PATH = Path(EXPERIMENT_PATH) / "light_phases.csv"
if LIGHTPHASES_PATH.is_file():
    light_phases = pd.read_csv(LIGHTPHASES_PATH)
else:
    light_phases = None

# Get the base plot "AL {treatment:d}s"
fig, axes = get_base_plot_MV(
    ojip_norm,
    ojip_norm_meansd,
    ojip_points,
    VJ_timing,
    VJ_values,
    VJ_values_meansd,
    levels,
    plot_replicates = False,
    use_colorbar = False,
    mark_sampled = True,
    cmap = cm.coolwarm,
    light_phases=light_phases,
    **plot_parameters,
    **cpprm.common_plot_parameters_main,
)

# for ax in axes.flatten():
#     ax.grid(which='minor', axis='x', visible=False)

for ext in cpprm.plot_format:
    fig.savefig(Path("figures")/f"{EXPERIMENT_PATH.split("/")[1]}_comparison.{ext}", bbox_inches="tight")

## SI: Raw OJIP and P-timing

In [None]:
# Make the FP-finding range slimmer for Syn
# Remove FP timing from default options
feature_finding_options = dp.feature_finding_options.copy()
feature_finding_options.pop("FP_time_min")

ojip_points_raw_res={}

# Identify the points for both strains and use a slimmer detection range for Syn
for strain in levels["strains"]:
    # Get the ojip pints
    with warnings.catch_warnings(record=False) as caught_warnings:
        warnings.simplefilter("ignore") 
        ojip_points_raw_res[strain] = utl.determine_OJIP_points(ojip.loc[:,didx(strain=strain)],
                                        return_derivatives=True,
                                        return_fits=True,
                                        FP_time_min=40 if strain=="Chlo" else 100,
                                        choose_method="closest",
                                        FJ_time_exp=2,
                                        FI_time_exp=30,
                                        FP_time_exp=100 if strain=="Chlo" else 200,
                                        **feature_finding_options)

ojip_points_raw = pd.concat([ojip_points_raw_res[strain]["points"] for strain in levels["strains"]], axis=0)

In [None]:
ojip_meansd = pd.concat({
    "mean":ojip.T.groupby(ojip.columns.names[:-1]).mean(),
    "sd":ojip.T.groupby(ojip.columns.names[:-1]).std()
    }, names=["Measure"]).T

FP_values = ojip_points_raw[("grad2-min", "FP_value")]
FP_values_meansd = pd.concat({
    "mean":FP_values.T.groupby(FP_values.index.names[:-1]).mean(),
    "sd":FP_values.T.groupby(FP_values.index.names[:-1]).std()
    }, names=["Measure"]).T

FP_timing = ojip_points_raw[("grad2-min", "FP_time")]
FP_timing_meansd = pd.concat({
    "mean":FP_timing.T.groupby(FP_timing.index.names[:-1]).mean(),
    "sd":FP_timing.T.groupby(FP_timing.index.names[:-1]).std()
    }, names=["Measure"]).T

In [None]:
LIGHTPHASES_PATH = Path(EXPERIMENT_PATH) / "light_phases.csv"
if LIGHTPHASES_PATH.is_file():
    light_phases = pd.read_csv(LIGHTPHASES_PATH)
else:
    light_phases = None

# Get the base plot
fig, axes = get_base_plot_MV(
    ojip,
    ojip_meansd,
    ojip_points_raw,
    None,
    FP_values,
    FP_timing_meansd,
    levels,
    plot_replicates = False,
    use_colorbar = False,
    mark_sampled = True,
    cmap = cm.coolwarm,
    light_phases=light_phases,
    right_column_y_label=r"F$_{\mathrm{P}}$ timing (ms)",
    left_column_y_label = "Fluorescence (V)",
    right_column_mark_zero=True,
    point_x_selector=("grad2-min", "FP_time"),
    point_y_selector=("grad2-min", "FP_value"),
    point_label="Identified FP",
    **plot_parameters,
    **cpprm.common_plot_parameters_SI,
)


for ax in axes[:,:-1].flatten():
    ax.set_ylim(0)

# for ax in axes[:,-1]:
#     if ax.get_ylim()[1] > 500:
#         ax.set_ylim(*ax.get_ylim())
#         ax.axhspan(550, 700, color="k", alpha=0.2)
#         ax.text(0.55, 0.78, r"F$_{\mathrm{P}}$ likely not reached", transform=ax.transAxes)

for ext in cpprm.plot_format:
    fig.savefig(Path("figures")/f"{EXPERIMENT_PATH.split("/")[1]}_comparison_SI.{ext}", bbox_inches="tight")

In [None]:
plot_feature = "derivative2"

fig, axes =dp.plot_ojip(
    ojip=pd.concat([ojip_points_raw_res[strain]["derivatives"]["FP"][plot_feature] for strain in levels["strains"]], axis=1),#ojip_points_raw_res["derivatives"]["FP"]["derivative2"],
    levels=levels,
    # ojip_points=ojip_points,
    treatment_label=plot_parameters["treatment_label"],
    point_finding_span=(dp.feature_finding_options["FP_time_min"], dp.feature_finding_options["FP_time_max"])
)

for ax in axes.flatten():
    ax.axhline(0, c="k", ls="--")
    ax.axvline(100, c="k", ls="--")

fig.suptitle(f"{EXPERIMENT_PATH.split("/")[1]} - {plot_feature}", y=1.02)