In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import matplotlib.colors as clr
import pandas as pd
import numpy as np

warnings.filterwarnings("ignore")

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
# Define the Zissou palette
zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]
colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
# Set the Zissou palette
sns.set_palette(zissou)

In [None]:
# Update the list of samples to be plotted
samples_to_plot = [
    "day6_SI",
    "day6_SI_r2",
    "day8_SI_Ctrl",
    "day8_SI_r2",
    "day30_SI",
    "day30_SI_r2",
    "day90_SI",
    "day90_SI_r2",
]


# Create a DataFrame to store the counts for each segment and time point
counts_df = pd.DataFrame(index=range(1, 9), columns=samples_to_plot)

# Iterate over time points
for time_point in samples_to_plot:
    # Filter the data for the current time point and subtype 'Cd8_T-Cell_P14'
    subset_data = adata[
        (adata.obs["batch"] == time_point) & (adata.obs["Subtype"] == "Cd8_T-Cell_P14")
    ]

    # Example data range
    data_min = subset_data.obs["predicted_longitudinal"].min()
    data_max = subset_data.obs["predicted_longitudinal"].max()

    # Specify the number of bins
    num_bins = (
        8  # For example, if you want to divide your data into 8 equal-length segments
    )

    # Calculate bin edges
    bin_edges = np.linspace(data_min, data_max, num_bins + 1)

    # Calculate counts for each segment
    counts, _ = np.histogram(subset_data.obs["predicted_longitudinal"], bins=bin_edges)

    # Update the counts DataFrame
    counts_df[time_point] = counts / sum(
        ~subset_data.obs["predicted_longitudinal"].isna()
    )

# Print the counts for each segment and time point
print("Frequencies for Each Segment and Time Point:")
print(counts_df)

In [None]:
# QC that it add up to 1
counts_df.sum(axis=0)

In [None]:
df_long = counts_df.melt(var_name="sample", value_name="frequency", ignore_index=False)
df_long["day"] = df_long["sample"].str.extract("(\d+)")
df_long["segment"] = df_long.index
df_long

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

ax.axhline(0.125, dashes=(4, 2), color="#8A8A8A")
sns.barplot(data=df_long, x="day", y="frequency", hue="segment", ax=ax, palette=zissou)
sns.stripplot(
    data=df_long,
    x="day",
    y="frequency",
    hue="segment",
    dodge=True,
    ax=ax,
    jitter=False,
    palette=zissou,
    linewidth=1,
)
ax.set_title("Precentage of P14 cells for each Segment and Time Point")
ax.set_xlabel("day")
ax.set_ylabel("Percentaage of P14 cells in segment")
ax.set_ylim(0, 0.25)
ax.legend(loc="upper center", ncol=8)

fig.savefig("out/longitudinal.pdf")