In [None]:
import io
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output
from scipy.stats import linregress

## Quantification

This notebook provides functionality to perform quantification using external standards. It is tailored towards LA-ICP-TOFMS data but can be adapted to accomodate other data types.  
The notebook expects 3 CSV files: 1) single-cell CSV (e.g. obtained from the data_extraction notebook); 2) 'standards summary' CSV (e.g. exported ROIs from the *image_reconstruction* notebook); 3) 'amounts' CSV.

1. **Single-cell CSV file(s)**  
   Measured integrated isotope counts in single-cells (e.g. output of *data_extraction* notebook).

   **Example:**
   <table>
     <tr><th>cell_label</th><th>56Fe</th><th>63Cu</th><th>...</th></tr>
     <tr><td>1</t><td>12345</td><td>6789</td><td>...</td></tr>
     <tr><td>2</t><td>23456</td><td>7890</td><td>...</td></tr>
     <tr><td>...</td><td>...</td><td>...</td><td>...</td></tr>
   </table>

2. **Standards Summary CSV**  
   Contains summary values for each ROI of the standards measurement. Must contain headers including a `ROI` column and one column per isotope ending with `sum (counts_pixel)`. Each row represents a single ROI from the standards acquisition.

   **Example:**
   <table>
     <tr><th>ROI</th><th>56Fe sum (counts_pixel)</th><th>63Cu sum (counts_pixel)</th><th>...</th></tr>
     <tr><td>1</td><td>1000</td><td>2000</td><td>...</td></tr>
     <tr><td>...</td><td>...</td><td>...</td><td>...</td></tr>
     <tr><td>2</td><td>1500</td><td>2500</td><td>...</td></tr>
     <tr><td>...</td><td>...</td><td>...</td><td>...</td></tr>
   </table>
  
3. **Amounts CSV**  
   Contains known amounts for each element in the standards, used to calculate calibration slopes. Format: two columns without headers.  
   First column is the element symbol (or `general`), second column is a comma-separated list of numeric amount values.  
   If included, the 'general'column will by applied to any unmatched isotopes.

   **Example:**
   <table>
     <tr><td>Fe</td><td>0.1, 0.2 ,0.3, 0.4, 0.5, 0.6, 0.7</td></tr>
     <tr><td>Cu</td><td>0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35</td></tr>
     <tr><td>...</td><td>...</td></tr>
     <tr><td>general</td><td>0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7</td></tr>
   </table>
  
**Workflow**  
Upon uploading the files the notebook will:
1. Match each isotope in the standards file to its corresponding element amounts.
2. Combine all ROI values for each isotope and match them to the known amounts.
3. Perform a linear regression (amounts vs. measured counts) for each isotope.
4. Output regression statistics.
5. Construct a summary DataFrame containing calibration slopes for each element symbol.

In [None]:
# widgets
singlecell_out = widgets.Output()
standard_out = widgets.Output()
amounts_out = widgets.Output()

singlecell_upload = widgets.FileUpload(
    accept=".csv",
    multiple=True,
    description="Upload single-cell CSVs",
    layout=widgets.Layout(width="auto"),
)
standard_upload = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload standards summary CSV",
    layout=widgets.Layout(width="auto"),
)
amounts_upload = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload amounts CSV",
    layout=widgets.Layout(width="auto"),
)
save_path_widget = widgets.Text(
    value="",
    placeholder="Enter working directory/desired save path",
    description="Path:",
    style={"description_width": "initial"},
)


def load_singlecell(change):
    with singlecell_out:
        clear_output(wait=True)
        global singlecell_dfs
        singlecell_dfs = []
        for file_info in singlecell_upload.value:
            content = file_info["content"]
            df = pd.read_csv(io.BytesIO(content))
            singlecell_dfs.append(df)
        print(f"Loaded {len(singlecell_dfs)} single-cell file(s).")


singlecell_upload.observe(load_singlecell, names="value")


# laod standards
def load_standards(change):
    with standard_out:
        clear_output(wait=True)
        global standards_dict
        standards_dict = {}
        if standard_upload.value:
            upload = standard_upload.value[0]
            df_std = pd.read_csv(io.BytesIO(upload["content"]))
            name = upload["name"].rsplit(".", 1)[0]
            # split by ROI
            for roi in df_std["ROI"].unique():
                sub = (
                    df_std[df_std["ROI"] == roi]
                    .drop(columns=["ROI"])
                    .reset_index(drop=True)
                )
                key = f"{name}_ROI_{roi}"
                standards_dict[key] = sub
            display(df_std)
            print("Standards loaded from ROIs:")
            for k in standards_dict:
                print(f" - {k}:")


standard_upload.observe(load_standards, names="value")


# load amounts
def load_amounts(change):
    with amounts_out:
        clear_output(wait=True)
        global element_x_values
        element_x_values = {}
        if amounts_upload.value:
            upload = amounts_upload.value[0]
            df_amt = pd.read_csv(io.BytesIO(upload["content"]), header=None)
            display(df_amt.head())
            # parse comma-separated lists
            element_x_values = {
                el: np.fromstring(vals, sep=",") for el, vals in df_amt.values
            }
            if "general" not in element_x_values:
                print(
                    "Warning: 'general' element not found. Unmatched elements may cause errors."
                )
            print("Amounts parsed successfully.")


amounts_upload.observe(load_amounts, names="value")

# Display widgets and outputs
display(
    widgets.VBox(
        [
            widgets.HBox([singlecell_upload]),
            singlecell_out,
            widgets.HBox([standard_upload]),
            standard_out,
            widgets.HBox([amounts_upload]),
            amounts_out,
            widgets.HBox([save_path_widget]),
        ]
    )
)

In [None]:
# Perform regression using ROI vectors per isotope with truncation on mismatch
if not standards_dict or not element_x_values:
    print("Please load both standards summary and amounts CSVs before running this cell.")
else:
    results = {}
    # Determine all isotope columns from any ROI
    isotope_cols = set()
    for df in standards_dict.values():
        isotope_cols.update([col for col in df.columns if col.endswith('sum (counts_pixel)')])

    for col in sorted(isotope_cols):
        # Extract element symbol
        m = re.search(r"\d{1,3}([A-Z][a-z]?)", col)
        if not m:
            print(f"{col}: pattern mismatch; skipping.")
            continue
        sym = m.group(1)
        # Retrieve amounts array
        x_vals = element_x_values.get(sym)
        if x_vals is None:
            x_vals = element_x_values.get('general')
            if x_vals is None:
                print(f"{col}: no amounts for '{sym}' or 'general'; skipping.")
                continue
            print(f"{col}: no entry for '{sym}'; using 'general' amounts of length {len(x_vals)}.")
        else:
            print(f"{col}: matched to '{sym}', amounts length {len(x_vals)}.")

        # Collect ROI-specific values for this isotope
        y_vals = []
        for key, df in standards_dict.items():
            if col in df.columns:
                # One summary value per ROI
                y_vals.append(df[col].iloc[0])
        print(f"Collected {len(y_vals)} ROI values for '{col}'.")

        # Handle length mismatch by truncation
        n_x, n_y = len(x_vals), len(y_vals)
        if n_x != n_y:
            n = min(n_x, n_y)
            print(f"Warning: length mismatch for '{col}' ({n_y} ROI vs {n_x} amounts), truncating to {n} points.")
            x = x_vals[:n]
            y = y_vals[:n]
        else:
            x = x_vals
            y = y_vals

        # Run regression: amounts vs ROI-sums
        slope, intercept, r_value, p_value, std_err = linregress(x, y)
        results[col] = {
            'symbol': sym,
            'slope': slope,
            'intercept': intercept,
            'r2': r_value**2,
            'p_value': p_value,
            'std_err': std_err
        }
        print(f"Regression for '{col}': slope={slope:.4f}, intercept={intercept:.4f}, r2={r_value**2:.4f}\n")

    # Construct DataFrame of slopes by element symbol
    slopes_dict = {}
    for stats in results.values():
        symbol = stats['symbol']
        slopes_dict[symbol] = stats['slope']
    slope_df = pd.DataFrame([slopes_dict], index=['slope'])
    print("All regressions complete across ROIs.")
    print("\nQuantification factor DataFrame:")
    display(slope_df)



## Pair standards with image channels

The cell below provides widgets to match image channels to corresponding standard meassurements.  
Upon uploading a standard file above, one first selects all required image channels from a list and subsequently matches them with the individual standards.  
Optionally, a channel-specific factor can be applied.  

In [None]:
quant_channels = widgets.SelectMultiple(
    options=singlecell_dfs[0].columns,
    description="Select features for quantification",
    disabled=False,
    rows=10,
    style={"description_width": "initial"},
    layout=widgets.Layout(width="auto", height="auto"),
)
quant_channels_selected = []

def on_quant_channels_change(change):
    if change["type"] == "change" and change["name"] == "value":
        quant_channels_selected.clear()
        quant_channels_selected.extend(change["new"])
        update_group_options()

quant_channels.observe(on_quant_channels_change, names="value")
display(quant_channels)

# Assigns groups of channels to standard isotopes
assignment_title = widgets.HTML("<b>Assign (groups of) cell features to a standard isotope</b>")
assignment_groups = {}
group_widgets = []
for std in list(slope_df.columns):
    std_widget = widgets.SelectMultiple(
         options=[],
         description=f"{std}:",
         rows=5,
         style={"description_width": "initial"},
         layout=widgets.Layout(width="auto")
    )
    assignment_groups[std] = std_widget
    group_widgets.append(std_widget)
assignment_box = widgets.VBox(group_widgets)

def update_group_options():
    for std, widget in assignment_groups.items():
        widget.options = quant_channels_selected

display(assignment_title, assignment_box)

### Quantify and plot channels of all datasets

In [None]:
quant_df = [df.copy() for df in singlecell_dfs]

# Build mapping (channel -> standard)
assignment_mapping = {}
for std, widget in assignment_groups.items():
    for channel in widget.value:
        assignment_mapping[channel] = std

# Build mapping of individual multiplication factors
channel_mult_factors = {
    channel: widget.value for channel, widget in assignment_groups.items()
}

# quantify each channel by applying corresponding standard slope
for df in quant_df:
    for channel, standard in assignment_mapping.items():
        div_factor = slope_df.iloc[0][standard]
        mult_factor = channel_mult_factors.get(channel, 1.0)
        div_factor = float(slope_df.iloc[0, slope_df.columns.get_loc(standard)])
        df[channel] = df[channel].astype(float) / div_factor * mult_factor

In [None]:
# plot quantified data
fig, axs = plt.subplots(
    len(quant_channels_selected), len(quant_df), figsize=(16, 9), sharey=False
)


# Ensure axs is a 2D array
def ensure_2d_axs(axs):
    if axs.ndim == 1:
        if len(quant_channels_selected) == 1 or len(quant_df) == 1:
            return axs.reshape(len(quant_channels_selected), len(quant_df))
    return axs


axs = ensure_2d_axs(axs)

# Extracting single cell file names
uploaded_filenames = [file_dict["name"] for file_dict in singlecell_upload.value]

for i, c in enumerate(quant_channels_selected):
    for j, df in enumerate(quant_df):
        ax = sns.histplot(df, x=c, ax=axs[i, j])
        ax.set(xlabel=f"{c} content in fg/cell")

        if j == 0:
            ax.set(ylabel="Cell Count")

        ax.set_title(f"Quantified {c} distribution of {uploaded_filenames[j][:-4]}")
        fig.tight_layout()

plt.subplots_adjust(hspace=1.0, wspace=0.2)
plt.show()

### Uncomment the cell below to save individual histograms

In [None]:
# for idx, df in enumerate(quant_df):
#     for c in quant_channels.value:
#         ax = sns.displot(df, x=c)

#         file_name = list(singlecell_upload.value[idx].values())[0]
#         file_base_name = os.path.splitext(file_name)[0]

#         plt.title(f"Quantified {c} distribution of {file_base_name}")
#         ax.set(xlabel=f"{c} content in fg/cell", ylabel="Cell Count")
#         ax.figure.tight_layout()

#         save_path = os.path.join(
#             save_path_widget.value, f"{file_base_name}_quantified_{c}_hist.png"
#         )
#         plt.savefig(fname=save_path, dpi=600)
#         plt.close()

### Uncomment the cell below to save quantified data as .csv files

In [None]:
# for idx, df in enumerate(quant_df):
#     df_copy = df.copy()

#     rename_dict = {c: c + "_quantified" for c in quant_channels.value}
#     df_copy.rename(columns=rename_dict, inplace=True)

#     file_name = list(singlecell_upload.value[idx].values())[0]
#     file_base_name = os.path.splitext(file_name)[0]

#     save_path = os.path.join(save_path_widget.value, f"{file_base_name}_quantified.csv")
#     df_copy.to_csv(save_path, index=False)