In [None]:
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
base_dir = Path("/Users/sylvi/topo_data/shelterin")
assert base_dir.exists()
all_stats_file = base_dir / "output-data-redo" / "all_statistics.csv"
assert all_stats_file.exists()
all_stats_df = pd.read_csv(all_stats_file)
all_disordered_segments_file = base_dir / "output-data-redo" / "all_disordered_segment_statistics.csv"
assert all_disordered_segments_file.exists()
all_disordered_segments_df = pd.read_csv(all_disordered_segments_file)
all_molstats_file = base_dir / "output-data-redo" / "all_mol_statistics.csv"
assert all_molstats_file.exists()
all_molstats_df = pd.read_csv(all_molstats_file)

images_to_delete = [
    "20250502_5nMTRF2_1ng_tel12picozEE_nicl.0_00010",
    "20250508_5nMTRF1cTRFH_1ngTel12picozEE_3mMNiCl.0_00005",
]
# remove images to delete
for image_name in images_to_delete:
    len_before = len(all_stats_df)
    all_stats_df = all_stats_df[~all_stats_df["image"].str.contains(image_name)]
    len_after = len(all_stats_df)
    assert len_before - len_after > 0
    len_before = len(all_disordered_segments_df)
    all_disordered_segments_df = all_disordered_segments_df[
        ~all_disordered_segments_df["image"].str.contains(image_name)
    ]
    len_after = len(all_disordered_segments_df)
    assert len_before - len_after > 0
    len_before = len(all_molstats_df)
    all_molstats_df = all_molstats_df[~all_molstats_df["image"].str.contains(image_name)]
    len_after = len(all_molstats_df)
    print(f"removed {image_name} from all stats and all disordered segments")

# create a display name column
all_stats_df["display_name"] = all_stats_df["basename"].str.split("/").str[-1]
print(f"all stats display names: {all_stats_df['display_name'].unique()}")

# convert units to be in nm
all_stats_df["total_branch_lengths"] /= 1e-9
all_stats_df["smallest_bounding_area"] /= 1e-18
all_stats_df["area"] /= 1e-18
all_stats_df["height_median"] /= 1e-9
all_stats_df["total_contour_length"] /= 1e-9

boxplotcolour = "lightgrey"

print(f"all stats: {len(all_stats_df)}, cols: {len(all_stats_df.columns)}")
for col in all_stats_df.columns:
    print(f"  {col}")

print(f"all disordered segments: {len(all_disordered_segments_df)}, cols: {len(all_disordered_segments_df.columns)}")
for col in all_disordered_segments_df.columns:
    print(f"  {col}")

#### pretty plotting

In [None]:
def boxplot(x, y, data):
    sns.boxplot(x=x, y=y, data=data, color=boxplotcolour)
    # print the median and IQRs for each x's y values
    for i, group in data.groupby(x):
        # median
        median = group[y].median()
        iqr = group[y].quantile(0.75) - group[y].quantile(0.25)
        print(f"{i}: median={median}, IQR={iqr}")


def simpleticks():
    num_ticks = len(plt.gca().get_xticklabels())
    new_ticks = []
    for i, tick in enumerate(plt.gca().get_xticklabels()):
        # get rid of "TEL12" or "TEL80" from the tick label
        # get rid of TEL12 if applicable
        if "TEL12" in tick.get_text():
            if "control" in tick.get_text():
                tick.set_text("TEL12")
            else:
                tick.set_text(tick.get_text().replace("TEL12", ""))
        # get rid of TEL80 if applicable
        if "TEL80" in tick.get_text():
            if "control" in tick.get_text():
                tick.set_text("TEL80")
            else:
                tick.set_text(tick.get_text().replace("TEL80", ""))
        new_ticks.append(tick.get_text())
    # set the ticks
    tick_indexes = np.linspace(0, num_ticks - 1, num_ticks)
    plt.xticks(tick_indexes, new_ticks)


def simpleticks_x_y():
    # replaces the x and y ticks with the same labels
    num_ticks = len(plt.gca().get_xticklabels())
    new_ticks = []
    for i, tick in enumerate(plt.gca().get_xticklabels()):
        # get rid of "TEL12" or "TEL80" from the tick label
        # get rid of TEL12 if applicable
        if "TEL12" in tick.get_text():
            if "control" in tick.get_text():
                tick.set_text("TEL12")
            else:
                tick.set_text(tick.get_text().replace("TEL12", ""))
        # get rid of TEL80 if applicable
        if "TEL80" in tick.get_text():
            if "control" in tick.get_text():
                tick.set_text("TEL80")
            else:
                tick.set_text(tick.get_text().replace("TEL80", ""))
        new_ticks.append(tick.get_text())
    # set the ticks
    tick_indexes = np.linspace(0, num_ticks - 1, num_ticks)
    plt.xticks(tick_indexes, new_ticks)
    plt.yticks(tick_indexes, new_ticks)


# set default font size for axes labels
plt.rcParams.update({"axes.labelsize": 16})

# order of samples for plotting with violin or strip plots
sample_order = [
    "TEL12Shelterin",
    "TEL12TRF1",
    "TEL12TRF2",
    "TEL12TRF1cTRFH",
    "TEL12TRF2cTRFH",
    "TEL12TRF2dTRFH",
]
# check all samples are present in the sample_order list
for sample in all_stats_df["display_name"].unique():
    if sample not in sample_order:
        print(f"sample {sample} not in sample_order")
        # add the sample to the sample_order list
        sample_order.append(sample)

### analysis functions

In [None]:
def percentage_shift(
    x: np.ndarray,
    y: np.ndarray,
) -> float:
    """
    Calculate the percentage shift between two values.

    Parameters
    ----------
    x : np.ndarray
        The first value.
    y : np.ndarray
        The second value.

    Returns
    -------
    float
        The percentage shift between the two values.
    """
    return (y - x) / x * 100


def percentage_shift_between_samples(
    df: pd.DataFrame, sample1: str, sample2: str, column: str, average_type: str = "median"
) -> None:
    """
    Calculate the percentage shift between two samples for a given column and average type (mean or median).

    Parameters
    ----------
    df : pd.DataFrame
        The DataFrame containing the data.
    sample1 : str
        The first sample to compare.
    sample2 : str
        The second sample to compare.
    column : str
        The column to compare.
    average_type : str
        The type of average to use (mean or median). Default is "median".
    """
    sample_1_data = df[df["display_name"] == sample1]
    sample_2_data = df[df["display_name"] == sample2]

    if average_type == "mean":
        sample_1_value = sample_1_data[column].mean()
        sample_2_value = sample_2_data[column].mean()
        sample_1_std = sample_1_data[column].std()
        sample_2_std = sample_2_data[column].std()
        avg_str = f"{sample1}: mean={sample_1_value:.2f}±{sample_1_std:.2f}, {sample2}: mean={sample_2_value:.2f}±{sample_2_std:.2f}"
    elif average_type == "median":
        sample_1_value = sample_1_data[column].median()
        sample_2_value = sample_2_data[column].median()
        sample_1_iqr = sample_1_data[column].quantile(0.75) - sample_1_data[column].quantile(0.25)
        sample_2_iqr = sample_2_data[column].quantile(0.75) - sample_2_data[column].quantile(0.25)
        avg_str = f"{sample1}: median={sample_1_value:.2f}±{sample_1_iqr:.2f}, {sample2}: median={sample_2_value:.2f}±{sample_2_iqr:.2f}"
    else:
        raise ValueError("average_type must be 'mean' or 'median'")
    percentage_shift_value = percentage_shift(sample_1_value, sample_2_value)
    print(
        f"percentage shift in {column} between {sample1} and {sample2}: {percentage_shift_value:.2f}%. Average type: {average_type}.\n"
        f"Distribution info: ({avg_str})"
    )

### plotting

In [None]:
sns.stripplot(x="display_name", y="area", data=all_stats_df, order=sample_order)
plt.xticks(rotation=90)
plt.ylabel("area (nm^2)")
plt.xlabel("sample type")
plt.title("molecule area")
simpleticks()
plt.show()

sns.stripplot(x="display_name", y="smallest_bounding_area", data=all_stats_df, order=sample_order)
plt.xticks(rotation=90)
plt.ylabel("smallest bounding box area (nm^2)")
plt.xlabel("sample type")
plt.title("molecule bounding box area")
plt.hlines(1000, xmin=-1, xmax=10, colors="red", linestyles="dashed")
simpleticks()
plt.show()

In [None]:
# create a DF of just grain data (no subgrains, so no double-counting of total branch length stats)
grains_list = []
unique_images = all_stats_df["image"].unique()
print(f"len all stats: {len(all_stats_df)}")
print(f"unique images: {len(unique_images)}")
for image in unique_images:
    all_grains_data = all_stats_df[all_stats_df["image"] == image]
    grain_numbers = all_grains_data["grain_number"].unique()
    print(f"  image: {image}, grain numbers: {len(grain_numbers)}")
    for grain_number in grain_numbers:
        # grab just this grain's data
        grain_data = all_grains_data[all_grains_data["grain_number"] == grain_number]
        classes = grain_data["class_number"].unique()
        num_dna_segments = len(grain_data[grain_data["class_number"] == 1])
        num_protein_segments = len(grain_data[grain_data["class_number"] == 2])
        protein_present = 1 if 1 in classes and 2 in classes else 0
        print(f"    grain number: {grain_number}, classes: {classes}")

        # get the segment data for this grain from the disordered segments df
        disordered_segment_data = all_disordered_segments_df[all_disordered_segments_df["image"] == image]
        disordered_segment_data = disordered_segment_data[disordered_segment_data["grain_number"] == grain_number]

        molstats_data = all_molstats_df[all_molstats_df["image"] == image]
        molstats_data = molstats_data[molstats_data["grain_number"] == grain_number]

        # get important stats
        branch_distances = disordered_segment_data["branch_distance"]
        mean_branch_distance = branch_distances.mean()
        # tbl is the same for all subgrains since it's the total, and just repeated for each row of the grain
        tbl = grain_data["total_branch_lengths"].values[0]
        tcl = grain_data["total_contour_length"].values[0]
        # total protein volume is just the total of all the class 2 volumes for the grain
        total_protein_volume = grain_data[grain_data["class_number"] == 2]["volume"].sum()
        total_protein_area = grain_data[grain_data["class_number"] == 2]["area"].sum()
        mean_protein_area = total_protein_area / num_protein_segments if num_protein_segments > 0 else 0
        smallest_bounding_area = grain_data["smallest_bounding_area"].values[0]
        total_dna_volume = grain_data[grain_data["class_number"] == 1]["volume"].sum()
        displayname = grain_data["display_name"].values[0]
        basename = grain_data["basename"].values[0]

        # add to the list
        grains_list.append(
            {
                "image": image,
                "grain_number": grain_number,
                "protein_present": protein_present,
                "num_dna_segments": num_dna_segments,
                "num_protein_segments": num_protein_segments,
                "mean_branch_distance": mean_branch_distance,
                "total_branch_lengths": tbl,
                "total_contour_length": tcl,
                "total_protein_volume": total_protein_volume,
                "total_dna_volume": total_dna_volume,
                "total_grain_volume": total_protein_volume + total_dna_volume,
                "mean_protein_area": mean_protein_area,
                "total_protein_area": total_protein_area,
                "smallest_bounding_area": smallest_bounding_area,
                "display_name": displayname,
                "basename": basename,
            }
        )

grains_df = pd.DataFrame(grains_list)

# print display name quantities
display_names = grains_df["display_name"].unique()
for display_name in display_names:
    display_name_data = grains_df[grains_df["display_name"] == display_name]
    num_grains = len(display_name_data)
    num_proteins = len(display_name_data[display_name_data["protein_present"] == 1])
    num_dna = len(display_name_data[display_name_data["num_dna_segments"] > 0])
    num_proteins_and_dna = len(
        display_name_data[(display_name_data["num_dna_segments"] > 0) & (display_name_data["protein_present"] == 1)]
    )
    print(
        f"{display_name:<15} {num_grains:>4} grains, {num_proteins:>4} proteins, {num_dna:>4} dna, {num_proteins_and_dna:>4} proteins and dna"
    )

In [None]:
# Plot TBL vs display_name for only grains with protein
plt.figure()
# stripplot with violins
sns.violinplot(
    x="display_name",
    y="total_branch_lengths",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_branch_lengths",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total branch lengths (nm)")
plt.xlabel("sample type")
plt.title("total branch lengths for grains with protein")
simpleticks()
plt.show()

# Plot TCL vs display_name for only grains with protein
plt.figure()
# stripplot with violins
sns.violinplot(
    x="display_name",
    y="total_contour_length",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_contour_length",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total contour length (nm)")
plt.xlabel("sample type")
plt.title("total contour length for grains with protein")
simpleticks()
plt.show()

# Plot total grain volume vs display_name as violins
plt.figure()
sns.violinplot(
    x="display_name",
    y="total_grain_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_grain_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total grain volume (nm^3)")
plt.xlabel("sample type")
plt.title("total grain volume for grains with protein")
simpleticks()
plt.show()


# plot protein volume against TBL for grains with protein, with volume on y and TBL on x, using jointplot
plt.figure()
# scatterplot
# plot TRF1 in orange and TRF2 in blue
sns.jointplot(
    x="total_branch_lengths",
    y="total_protein_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    hue="display_name",
    alpha=0.5,
)
plt.figure()
# instead do jointplot with histograms instead of kde
sns.jointplot(
    x="total_branch_lengths",
    y="total_protein_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    hue="display_name",
    alpha=1,
    kind="hist",
)
plt.xlabel("total branch lengths (nm)")
plt.ylabel("total protein volume (nm^3)")
# plt.title("total protein volume vs total branch lengths for grains with protein")
plt.tight_layout()
plt.show()

# plot TBL vs protein volume but as separate joint plots in a grid
plot_cols = 3
plot_rows = int(np.ceil(len(display_names) / plot_cols))
plt.figure(figsize=(plot_cols * 5, plot_rows * 5))
for i, display_name in enumerate(display_names):
    display_name_data = grains_df[grains_df["display_name"] == display_name]
    ax = plt.subplot(plot_rows, plot_cols, i + 1)
    # scatterplot
    sns.scatterplot(
        x="total_branch_lengths",
        y="total_protein_volume",
        data=display_name_data[display_name_data["protein_present"] == 1],
        color="orange",
        alpha=1,
        ax=ax,
    )
    ax.set_title(display_name)
    ax.set_xlabel("total branch lengths (nm)")
    ax.set_ylabel("total protein volume (nm^3)")
    ax.set_xlim(0, grains_df["total_branch_lengths"].max() * 1.1)
    ax.set_ylim(0, grains_df["total_protein_volume"].max() * 1.1)
    ax.grid(True)
plt.tight_layout()
plt.show()

# smallest bounding area vs protein volume
plot_cols = 3
plot_rows = int(np.ceil(len(display_names) / plot_cols))
plt.figure(figsize=(plot_cols * 5, plot_rows * 5))
for i, display_name in enumerate(display_names):
    display_name_data = grains_df[grains_df["display_name"] == display_name]
    ax = plt.subplot(plot_rows, plot_cols, i + 1)
    # scatterplot
    sns.scatterplot(
        x="smallest_bounding_area",
        y="total_protein_volume",
        data=display_name_data[display_name_data["protein_present"] == 1],
        color="orange",
        alpha=1,
        ax=ax,
    )
    ax.set_title(display_name)
    ax.set_xlabel("smallest bounding area (nm^2)")
    ax.set_ylabel("total protein volume (nm^3)")
    ax.set_xlim(0, grains_df["smallest_bounding_area"].max() * 1.1)
    ax.set_ylim(0, grains_df["total_protein_volume"].max() * 1.1)
    ax.grid(True)
plt.tight_layout()
plt.show()


# number of DNA segments vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="num_dna_segments",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="num_dna_segments",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("number of DNA segments")
plt.xlabel("sample type")
plt.title("number of DNA segments for grains with protein")
simpleticks()
plt.show()

# number of protein segments vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="num_protein_segments",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="num_protein_segments",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("number of protein segments")
plt.xlabel("sample type")
plt.title("number of protein segments for grains with protein")
simpleticks()
plt.show()

# mean branch distance vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="mean_branch_distance",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="mean_branch_distance",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("mean branch distance (nm)")
plt.xlabel("sample type")
plt.title("mean branch distance for grains with protein")
simpleticks()
plt.show()

# total protein volume vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="total_protein_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_protein_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total protein volume (nm^3)")
plt.xlabel("sample type")
plt.title("total protein volume for grains with protein")
simpleticks()
plt.show()
# total DNA volume vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="total_dna_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_dna_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total DNA volume (nm^3)")
plt.xlabel("sample type")
plt.title("total DNA volume for grains with protein")
simpleticks()
plt.show()
# total protein voluem vs total DNA volume
plt.figure()
sns.jointplot(
    x="total_dna_volume",
    y="total_protein_volume",
    data=grains_df[grains_df["protein_present"] == 1],
    hue="display_name",
    alpha=0.5,
)
plt.xlabel("total DNA volume (nm^3)")
plt.ylabel("total protein volume (nm^3)")
plt.title("total protein volume vs total DNA volume for grains with protein")
plt.tight_layout()
plt.show()

# total protein area vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="total_protein_area",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="total_protein_area",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("total protein area (nm^2)")
plt.xlabel("sample type")
plt.title("total protein area for grains with protein")
simpleticks()
plt.show()
# mean protein area vs display name
plt.figure()
sns.violinplot(
    x="display_name",
    y="mean_protein_area",
    data=grains_df[grains_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="mean_protein_area",
    data=grains_df[grains_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("mean protein area (nm^2)")
plt.xlabel("sample type")
plt.title("mean protein area for grains with protein")
simpleticks()
plt.show()

# kruskal wallis test for all samples for total grain volume
from scipy.stats import kruskal

result = kruskal(*[grains_df[grains_df["display_name"] == sample]["total_grain_volume"] for sample in sample_order])
print(f"kruskal wallis test for total grain volume: {result.statistic:.2f}, p-value: {result.pvalue:.2e}")
threshold = 0.05
if result.pvalue < threshold:
    print(f"results are statistically significant (p-value < {threshold})")
else:
    print(f"results are not statistically significant (p-value > {threshold})")

In [None]:
# post-hoc test
from scikit_posthocs import posthoc_dunn

posthoc_total_grain_volume_results = posthoc_dunn(
    grains_df,
    val_col="total_grain_volume",
    group_col="display_name",
    p_adjust="holm",
)

print("post-hoc test results:")
print(posthoc_total_grain_volume_results)

# create a plot of a matrix of the results
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_total_grain_volume_results,
    annot=True,
    fmt=".2e",
    cmap="binary_r",
    cbar_kws={"label": "p-value"},
)
plt.title("Post-hoc Dunn's test results for total grain volume")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()

# same matrix plot but binary with threshold 0.05
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_total_grain_volume_results < threshold,
    annot=posthoc_total_grain_volume_results < threshold,
    fmt="",
    cmap="binary",
    cbar=False,
)
plt.title("Post-hoc Dunn's test results for total grain volume (p-value < 0.05)")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()

# post-hoc for total branch lengths
posthoc_total_branch_lengths_results = posthoc_dunn(
    grains_df,
    val_col="total_branch_lengths",
    group_col="display_name",
    p_adjust="holm",
)
print("post-hoc test results:")
print(posthoc_total_branch_lengths_results)
# create a plot of a matrix of the results
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_total_branch_lengths_results,
    annot=True,
    fmt=".2e",
    cmap="binary_r",
    cbar_kws={"label": "p-value"},
)
plt.title("Post-hoc Dunn's test results for total branch lengths")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()
# same matrix plot but binary with threshold 0.05
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_total_branch_lengths_results < threshold,
    annot=posthoc_total_branch_lengths_results < threshold,
    fmt="",
    cmap="binary",
    cbar=False,
)
plt.title("Post-hoc Dunn's test results for total branch lengths (p-value < 0.05)")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()

In [None]:
# stat test to see if there is a difference in TBL between TRF1 and TRF2
from scipy import stats

# get the TBL for TRF1 and TRF2
trf1_tbl = grains_df[grains_df["display_name"].str.contains("TRF1")]["total_branch_lengths"]
trf2_tbl = grains_df[grains_df["display_name"].str.contains("TRF2")]["total_branch_lengths"]
significance_level = 0.05
print("stats tests for TBL TRF1 vs TRF2")
# t-test
t_stat, p_val = stats.ttest_ind(trf1_tbl, trf2_tbl)
print(f"t-test: t-statistic: {t_stat:.4f}, p-value: {p_val:.4f} Significant: {p_val < significance_level}")
# Mann-Whitney U test
u_stat, p_val = stats.mannwhitneyu(trf1_tbl, trf2_tbl)
print(
    f"Mann-Whitney u-test: u-statistic: {u_stat:.4f}, p-value: {p_val:.4f} Significant: {p_val < significance_level}"
)

In [None]:
grains_centroid_diffs_protein_only_df = grains_df[grains_df["protein_present"] == 1].copy()

# normalise the TBL and TPV values across all samples to keep relative differences
tbl_max = grains_df["total_branch_lengths"].max()
tbl_min = grains_df["total_branch_lengths"].min()
tpv_max = grains_df["total_protein_volume"].max()
tpv_min = grains_df["total_protein_volume"].min()

for display_name in display_names:
    display_name_data = grains_centroid_diffs_protein_only_df[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name
    ]

    # normalise the TBL and TPV values and add to new columns
    grains_centroid_diffs_protein_only_df.loc[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name,
        "total_branch_lengths_normalised",
    ] = (grains_centroid_diffs_protein_only_df["total_branch_lengths"] - tbl_min) / (tbl_max - tbl_min)
    grains_centroid_diffs_protein_only_df.loc[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name,
        "total_protein_volume_normalised",
    ] = (grains_centroid_diffs_protein_only_df["total_protein_volume"] - tpv_min) / (tpv_max - tpv_min)


# now for each sample calculate the centroid of the TBL and TPV values and then create a new column with the euclidean distance for each grain
def euclidean_distance(x: np.ndarray, y: np.ndarray) -> float:
    """
    Calculate the Euclidean distance between two points.

    Parameters
    ----------
    x : np.ndarray
        The first point
    y : np.ndarray
        The second point.

    Returns
    -------
    float
        The Euclidean distance between the two points.
    """
    return np.sqrt(np.sum((x - y) ** 2))


# plot scatter plots of post-nomalisation TBL vs TPV for each sample, in a grid
plot_cols = 3
plot_rows = int(np.ceil(len(display_names) / plot_cols))
plt.figure(figsize=(plot_cols * 5, plot_rows * 5))
for i, display_name in enumerate(display_names):
    display_name_data = grains_centroid_diffs_protein_only_df[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name
    ]
    ax = plt.subplot(plot_rows, plot_cols, i + 1)
    # scatterplot
    sns.scatterplot(
        x="total_branch_lengths_normalised",
        y="total_protein_volume_normalised",
        data=display_name_data[display_name_data["protein_present"] == 1],
        color="orange",
        alpha=1,
        ax=ax,
    )
    ax.set_title(display_name)
    ax.set_xlabel("normalised total branch lengths")
    ax.set_ylabel("normalised total protein volume")
    ax.set_xlim(0, 1.1)
    ax.set_ylim(0, 1.1)
    ax.grid(True)
plt.tight_layout()
plt.show()

# create a new column for TBL_TPV euclidean distance
grains_centroid_diffs_protein_only_df["tbl_tpv_euclidean_distance"] = np.nan
# for each display name, calculate the centroid of the TBL and TPV values
for display_name in display_names:
    display_name_data = grains_centroid_diffs_protein_only_df[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name
    ]
    # get the centroid of the TBL and TPV values
    tbl_centroid = display_name_data["total_branch_lengths_normalised"].mean()
    tpv_centroid = display_name_data["total_protein_volume_normalised"].mean()
    # calculate the euclidean distance for each grain
    euclidean_distances = np.sqrt(
        (display_name_data["total_branch_lengths_normalised"] - tbl_centroid) ** 2
        + (display_name_data["total_protein_volume_normalised"] - tpv_centroid) ** 2
    )
    # add to the new column
    grains_centroid_diffs_protein_only_df.loc[
        grains_centroid_diffs_protein_only_df["display_name"] == display_name, "tbl_tpv_euclidean_distance"
    ] = euclidean_distances


# turn this to a function
def euclidean_distance_from_centroid(df: pd.DataFrame, col1: str, col2: str) -> pd.DataFrame:
    """
    Calculate the Euclidean distance from the centroid for each row in the DataFrame.
    """

    # Normalise the columns values across all samples
    col1_max = df[col1].max()
    col1_min = df[col1].min()
    col2_max = df[col2].max()
    col2_min = df[col2].min()

    df[f"{col1}_normalised"] = (df[col1] - col1_min) / (col1_max - col1_min)
    df[f"{col2}_normalised"] = (df[col2] - col2_min) / (col2_max - col2_min)

    # iterate over each display name
    for display_name in df["display_name"].unique():
        display_name_data = df[df["display_name"] == display_name]
        # get the centroid of the TBL and TPV values
        col1_centroid = display_name_data[f"{col1}_normalised"].mean()
        col2_centroid = display_name_data[f"{col2}_normalised"].mean()
        # calculate the euclidean distance for each grain
        euclidean_distances = np.sqrt(
            (display_name_data[f"{col1}_normalised"] - col1_centroid) ** 2
            + (display_name_data[f"{col2}_normalised"] - col2_centroid) ** 2
        )
        # add to the new column
        df.loc[df["display_name"] == display_name, f"{col1}_{col2}_euclidean_distance"] = euclidean_distances
    return df


# plot the euclidean distance vs display name as a strip plot
plt.figure()
sns.violinplot(
    x="display_name",
    y="tbl_tpv_euclidean_distance",
    data=grains_centroid_diffs_protein_only_df[grains_centroid_diffs_protein_only_df["protein_present"] == 1],
    color=boxplotcolour,
    order=sample_order,
)
sns.stripplot(
    x="display_name",
    y="tbl_tpv_euclidean_distance",
    data=grains_centroid_diffs_protein_only_df[grains_centroid_diffs_protein_only_df["protein_present"] == 1],
    color="black",
    alpha=0.5,
)
plt.xticks(rotation=90)
plt.ylabel("euclidean distance from\ncentroid for normalised\ntotal branch lengths\nvs total protein volume")
plt.xlabel("sample type")
simpleticks()
plt.tight_layout()
plt.show()

# post-hoc test for tbl tpv euclidean distance
posthoc_tbl_tpv_euclidean_distance_results = posthoc_dunn(
    grains_centroid_diffs_protein_only_df,
    val_col="tbl_tpv_euclidean_distance",
    group_col="display_name",
    p_adjust="holm",
)
print("post-hoc test results:")
print(posthoc_tbl_tpv_euclidean_distance_results)
# create a plot of a matrix of the results
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_tbl_tpv_euclidean_distance_results,
    annot=True,
    fmt=".2e",
    cmap="binary_r",
    cbar_kws={"label": "p-value"},
)
plt.title("Post-hoc Dunn's test results for tbl tpv euclidean distance")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()
# same matrix plot but binary with threshold 0.05
plt.figure(figsize=(10, 10))
sns.heatmap(
    posthoc_tbl_tpv_euclidean_distance_results < threshold,
    annot=posthoc_tbl_tpv_euclidean_distance_results < threshold,
    fmt="",
    cmap="binary",
    cbar=False,
)
plt.title("Post-hoc Dunn's test results for tbl tpv euclidean distance (p-value < 0.05)")
plt.yticks(rotation=0)
plt.xticks(rotation=90)
simpleticks_x_y()
plt.show()