In [None]:
import glob
import sys

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from sklearn.metrics import pairwise_distances
from tqdm import tqdm

sys.path.insert(1, "../helper_functions")
from create_metacells import *
from helper_functions import gini, lorenz, prepare_counts_df

In [None]:
all_count_files = glob.glob("../../DEFND-seq/readCount_filtered_bam/*100kb*.seg")

In [None]:
len(all_count_files)

In [None]:
counts_df_orig, regions = prepare_counts_df(
    all_count_files, binsize=100000, metacelling=True
)

In [None]:
barcodes = pd.read_csv("../data/CNVs_DEFND_filtered.csv.gz", index_col=0).index
counts_df_orig = counts_df_orig.loc[barcodes].copy()

In [None]:
gini_matac = counts_df_orig.apply(lambda row: gini(row), axis=1).values
lorenz_matac = counts_df_orig.apply(lambda row: lorenz(row), axis=1).values
lorenz_matac_2d = np.array([np.array(x) for x in lorenz_matac])

In [None]:
median_lorenz = np.median(lorenz_matac_2d, axis=0)
# Calculate the 95% confidence interval for each point
lower_bound = np.percentile(lorenz_matac_2d, 2.5, axis=0)
upper_bound = np.percentile(lorenz_matac_2d, 97.5, axis=0)

# X-axis values - normalized to range from 0 to 1
x_values = np.arange(len(median_lorenz)) / (len(median_lorenz) - 1)

plt.plot(x_values, median_lorenz, label="Median Lorenz Curve", lw=2, color="blue")
plt.plot(x_values, x_values, label="uniform", ls="--", color="grey")

plt.fill_between(
    x_values,
    lower_bound,
    upper_bound,
    color="lightblue",
    alpha=0.5,
    label="95% Confidence Interval",
)

plt.xlabel("Fraction of genome")
plt.ylabel("Cumulative Share of reads")
plt.legend()
plt.grid(True)

In [None]:
plt.scatter(gini_matac, counts_df_orig.sum(axis=1))
plt.xlabel("Gini")
plt.ylabel("Total counts")

In [None]:
plt.hist(gini_matac, bins=20)
plt.xlabel("Gini")
plt.ylabel("Count")

In [None]:
counts = counts_df_orig.copy().values

In [None]:
new_step = 1000000
i0 = 0
chrom = "chr1"
start0 = 1
new_counts = []
counts0 = np.zeros(counts_df_orig.shape[0])
new_regions = []
for i, rec in tqdm(enumerate(regions)):
    if rec[0] == chrom and (rec[1] - start0 < new_step):
        counts0 += counts[:, i]
        region = f"{rec[0]}:{start0}-{start0+new_step}"
    else:
        new_counts.append(counts0)
        new_regions.append(region)
        start0 = rec[1]
        counts0 = counts[:, i]
        chrom = rec[0]
        region = f"{rec[0]}:{start0}-{start0+new_step}"

In [None]:
counts_df = pd.DataFrame(new_counts, index=new_regions, columns=counts_df_orig.index).T

In [None]:
counts_df.head()

In [None]:
adata = sc.AnnData(counts_df)

In [None]:
adata.raw = adata

In [None]:
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

In [None]:
sc.pl.violin(adata, ["n_genes_by_counts", "total_counts"], jitter=0.4, multi_panel=True)
adata = adata[adata.obs.total_counts > 10000, :]
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [None]:
%%time
num_comp = 500
sc.tl.pca(adata, svd_solver="arpack", n_comps=num_comp)
print(np.cumsum(adata.uns["pca"]["variance_ratio"])[-1])

In [None]:
sc.pl.pca(adata, color="total_counts")

In [None]:
sc.pp.neighbors(adata, n_pcs=num_comp)

In [None]:
sc.tl.leiden(adata)
sc.tl.paga(adata)
sc.pl.paga(
    adata, plot=False
)  # remove `plot=False` if you want to see the coarse-grained graph
sc.tl.umap(adata, init_pos="paga")

In [None]:
adata.obs["coverage"] = (adata.obs["total_counts"] * 163) / 3099750718

In [None]:
sc.pl.umap(adata, color="coverage", show=False, color_map="Greys", vmax="p99")

In [None]:
max_covergae = 1100
min_coverage = 200

In [None]:
# set the resolution such that the maximum coverag is close to max_covergae
# higher resolution =>  lower maximum coverage
sc.tl.leiden(adata, resolution=190, key_added="leiden_small")
clust_df = []
clusters = adata.obs.leiden_small.unique()
for meta in clusters:
    dat = adata.obs[adata.obs.leiden_small == meta]
    subset = adata[dat.index]
    clust_df.append(
        [
            meta,
            counts_df.loc[dat.index].sum(axis=0).mean(),
            subset.shape[0],
            float(subset.obsm["X_umap"][:, 0].mean()),
            float(subset.obsm["X_umap"][:, 1].mean()),
            # float(subset.obsm['X_pca'][:,0].mean()),float(subset.obsm['X_pca'][:,1].mean()),
            *subset.obsm["X_pca"].mean(axis=0),
        ]
    )
clust_df = pd.DataFrame(
    clust_df,
    columns=[
        "cl",
        "coverage",
        "n_cells",
        "umap_x",
        "umap_y",
        *[f"PC_{i}" for i in range(num_comp)],
    ],
)
print(clust_df.coverage.min(), clust_df.coverage.max())

In [None]:
clust_df["pass_min"] = clust_df.coverage >= min_coverage
clust_df["pass_max"] = clust_df.coverage <= max_covergae

In [None]:
clust_df["counts_per_1MB"] = clust_df["coverage"]

In [None]:
adata.obs["counts_per_1MB"] = adata.obs["total_counts"] / adata.var.shape[0]

In [None]:
sns.set(style="white", font_scale=1)

In [None]:
# f, ax = plt.subplots()
f, ax = plt.subplots(figsize=(8, 6))

p1 = sns.scatterplot(
    x=adata.obsm["X_umap"][:, 0],
    y=adata.obsm["X_umap"][:, 1],
    palette="Greys",
    hue=adata.obs["counts_per_1MB"],
    ax=ax,
)
norm = plt.Normalize(
    adata.obs["counts_per_1MB"].min(), adata.obs["counts_per_1MB"].max()
)
sm = plt.cm.ScalarMappable(cmap="Greys", norm=norm)
sm.set_array([])

p1.get_legend().remove()
p1.figure.colorbar(sm)
ax.set_frame_on(False)
ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis().set_visible(False)
# plt.savefig("final_figures/metacells/cells_umap.png", dpi = 300)
# plt.savefig("final_figures/metacells/cells_umap.pdf", dpi = 300)
plt.show()

In [None]:
palette = sns.color_palette("Purples", as_cmap=True)
custom_colormap = mcolors.LinearSegmentedColormap.from_list(
    "CustomPurple", [(0, palette(0.4)), (0.4, palette(0.8)), (1, palette(0.9))]
)
custom_colormap

In [None]:
# f, ax = plt.subplots()
f, ax = plt.subplots(figsize=(8, 6))

p1 = sns.scatterplot(
    x=adata.obsm["X_umap"][:, 0],
    y=adata.obsm["X_umap"][:, 1],
    alpha=0.1,
    color="grey",
    ax=ax,
)
p2 = sns.scatterplot(
    x=clust_df.umap_x,
    y=clust_df.umap_y,
    hue=clust_df.counts_per_1MB,
    palette=custom_colormap,
    ax=ax,
)

norm = plt.Normalize(clust_df.counts_per_1MB.min(), clust_df.counts_per_1MB.max())
sm = plt.cm.ScalarMappable(cmap=custom_colormap, norm=norm)
sm.set_array([])

p2.get_legend().remove()
p2.figure.colorbar(sm)
ax.set_frame_on(False)
ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis().set_visible(False)
# plt.savefig("final_figures/metacells/pre_metacells.png", dpi = 300)
# plt.savefig("final_figures/metacells/pre_metacells.pdf", dpi = 300)
plt.show()

In [None]:
sns.scatterplot(
    x=adata.obsm["X_pca"][:, 0], y=adata.obsm["X_pca"][:, 1], alpha=0.1, color="grey"
)
sns.scatterplot(
    x=clust_df.PC_0,
    y=clust_df.PC_1,
    hue=clust_df.coverage,
    style=clust_df["pass_min"],
    palette="crest",
)

In [None]:
sns.set(style="white", font_scale=1.5)

In [None]:
plt.hist(counts_df.mean(axis=1), label="cells", alpha=0.6, color="grey")
plt.hist(clust_df.coverage, alpha=0.6, label="pre metacells", color="purple")

plt.axvline(200, label="minimal threshold", ls="--", color="grey")
plt.legend()
# plt.savefig("final_figures/metacells/pre_histo.png", dpi = 300)
# plt.savefig("final_figures/metacells/pre_histo.pdf", dpi = 300)
plt.xlabel("Mean counts per 1MB")
plt.show()

In [None]:
plt.hist(clust_df.n_cells, alpha=0.3, label="pre_metacells")
plt.legend()

In [None]:
clust_df

In [None]:
pcs = [x for x in clust_df.columns if x.startswith("PC")]

dist = pairwise_distances(clust_df[pcs])
plt.hist(dist[np.triu_indices(dist.shape[0], k=1)])

In [None]:
tmp = pd.DataFrame(adata.X, index=adata.obs.index, columns=adata.var.index)

In [None]:
ts = np.quantile(
    dist[np.triu_indices(dist.shape[0], k=1)], np.linspace(0.01, 1, num=10)
)

In [None]:
results = []
for i, d_tresh in enumerate(ts):
    (
        new_df,
        cell_dict,
    ) = merge_cells(
        d_tresh, clust_df, "leiden_small", adata, min_coverage, max_covergae
    )
    sns.scatterplot(
        x=adata.obsm["X_umap"][:, 0],
        y=adata.obsm["X_umap"][:, 1],
        alpha=0.1,
        color="grey",
    )
    sns.scatterplot(
        x=new_df.umap_x,
        y=new_df.umap_y,
        hue=new_df.coverage,
        palette="crest",
        sizes=(1, 10),
    )
    plt.show()
    num_meta = new_df.shape[0]
    size_meta = new_df.n_cells.median()
    excluded = len([cell for cell in cell_dict if cell_dict[cell] == "-1"])
    med_cov = new_df.coverage.median()
    reversed_dict = {
        value: [key for key in cell_dict if cell_dict[key] == value]
        for value in set(cell_dict.values())
    }
    stds = [tmp.loc[reversed_dict[key]].std().mean() for key in reversed_dict]
    results.append((d_tresh, num_meta, size_meta, excluded, med_cov, np.mean(stds)))
    print(
        f"Done with run {i}: th. {d_tresh:.2f}, # metacells {num_meta}, median size {size_meta}, {excluded} cells excluded, median coverage {med_cov:.2f}"
    )

In [None]:
res = pd.DataFrame(
    results,
    columns=[
        "t",
        "# metacells",
        "median_meta_size",
        "cells_excluded",
        "median_coverage",
        "mean_std",
    ],
)

In [None]:
res.cells_excluded = res.cells_excluded / adata.obs.shape[0]

In [None]:
(
    new_df,
    cell_dict,
) = merge_cells(ts[-1], clust_df, "leiden_small", adata, min_coverage, max_covergae)

In [None]:
new_df.shape[0]

In [None]:
new_df.n_cells.median()

In [None]:
len([cell for cell in cell_dict if cell_dict[cell] == "-1"])

In [None]:
new_df[new_df.pass_min == False]

In [None]:
new_df["counts_per_1MB"] = new_df["coverage"]

In [None]:
f, ax = plt.subplots(figsize=(8, 6))

p1 = sns.scatterplot(
    x=adata.obsm["X_umap"][:, 0],
    y=adata.obsm["X_umap"][:, 1],
    alpha=0.1,
    color="grey",
    ax=ax,
)
p2 = sns.scatterplot(
    x=new_df.umap_x,
    y=new_df.umap_y,
    hue=new_df.counts_per_1MB,
    palette=custom_colormap,
    ax=ax,
)

norm = plt.Normalize(clust_df.counts_per_1MB.min(), clust_df.counts_per_1MB.max())
sm = plt.cm.ScalarMappable(cmap=custom_colormap, norm=norm)
sm.set_array([])

p2.get_legend().remove()
p2.figure.colorbar(sm)
ax.set_frame_on(False)
ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis().set_visible(False)


# plt.savefig("final_figures/metacells/post_metacells.png", dpi = 300)
# plt.savefig("final_figures/metacells/post_metacells.pdf", dpi = 300)
plt.show()

In [None]:
sns.scatterplot(
    x=adata.obsm["X_pca"][:, 0], y=adata.obsm["X_pca"][:, 1], alpha=0.1, color="grey"
)
sns.scatterplot(
    x=new_df.PC_1,
    y=new_df.PC_2,
    hue=clust_df.coverage,
    style=new_df["pass_min"],
    palette="crest",
)

In [None]:
plt.hist(counts_df.mean(axis=1), label="cells", alpha=0.6, color="grey")
plt.hist(new_df.coverage, alpha=0.6, label="metacells", color="purple")

plt.axvline(200, label="minimal threshold", ls="--", color="grey")
plt.legend()

plt.xlabel("Mean counts per 1MB")
plt.show()

In [None]:
new_df.sort_values(by="counts_per_1MB")

In [None]:
mc_path = "./metacells/"

In [None]:
for mc in tqdm(new_df.cl):
    cells = [x for x in cell_dict if cell_dict[x] == mc]

    with open(f"{mc_path}{mc}.txt", "w") as the_file:
        for cell in cells:
            full_path = f"readCount_filtered_bam/readcounts.100kb.cell_bc_{cell}.seg"
            the_file.write(f"{full_path}\n")