# 1. CNV analysis of Visium data

Before running the code below, you need to first run CopyKAT.
The code to run CopyKAT is `../../scripts/run_copykat.R`.

In [None]:
import sys
import warnings

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from natsort import natsort_keygen, natsorted
from plotting_settings import PLOTTING_PARAMS
from scipy.stats import mannwhitneyu
from sklearn.metrics import calinski_harabasz_score
from tqdm import tqdm

warnings.filterwarnings("ignore")
sys.path.insert(1, "../../helper_functions")

In [None]:
samples = ["B22", "B24", "B60", "B154", "B156", "B175", "B178", "B4", "B42", "B123"]

In [None]:
colors_clones = {
    "0": "#0173b2",
    "1": "#de8f05",
    "2": "#029e73",
    "3": "#d55e00",
    "4": "#949494",
    "5": "#ca9161",
    "6": "#fbafe4",
    "diploid": "#97ffff",
}

## Map inferred CNAs from CopyKAT to ONT fixed genomic bins

In [None]:
def assign_bins(chromosome, start, end, gene, binsize):
    bins = []
    for pos in range(start, end, binsize):
        bin_start = (pos // binsize) * binsize
        bin_end = bin_start + binsize
        bins.append((f"{chromosome}:{bin_start}-{bin_end}", gene))
    return bins

In [None]:
def prepare_annotations(gene_pos, binsize):
    binned_data = []
    for _, row in tqdm(gene_pos.iterrows()):
        gene = row["gene"]
        chromosome = row["chrom"]
        start = row["start"]
        end = row["end"]
        binned_data.extend(
            assign_bins(chromosome, start, end, gene, binsize=binsize * 1000)
        )
    binned_df = pd.DataFrame(binned_data, columns=["bin", "gene_name"])

    return binned_df

In [None]:
def gene2bin(scrna_mtx, gene_pos):

    gene_pos_f = gene_pos[gene_pos.gene_name.isin(scrna_mtx.columns)]
    gene_pos_f = gene_pos_f.sort_values(by="bin", key=natsort_keygen(), ascending=True)
    scrna = scrna_mtx[gene_pos_f.gene_name]
    gene_to_bin = gene_pos_f.set_index("gene_name")["bin"]
    expression_with_bins = scrna.T
    expression_with_bins["bin"] = gene_to_bin
    avg_expression_per_bin = expression_with_bins.groupby("bin").mean().T
    avg_expression_per_bin = avg_expression_per_bin[
        natsorted(avg_expression_per_bin.columns)
    ]

    return avg_expression_per_bin

In [None]:
scrna_copykat = pd.read_csv(
    f"copykat/B123_filtered_estimate/B123_copykat_CNA_raw_results_gene_by_cell.txt",
    sep="\t",
)

In [None]:
gene_pos = scrna_copykat[
    ["chromosome_name", "start_position", "end_position", "hgnc_symbol"]
]
gene_pos.rename(
    columns={
        "chromosome_name": "chrom",
        "start_position": "start",
        "end_position": "end",
        "hgnc_symbol": "gene",
    },
    inplace=True,
)
gene_pos["length"] = gene_pos["end"] - gene_pos["start"]
gene_pos["chrom"] = "chr" + gene_pos["chrom"].astype(str)
gene_pos["chrom"] = gene_pos["chrom"].replace("chr23", "chrX")

In [None]:
for sample in samples:
    print(f"Working on sample {sample}")
    ont = pd.read_csv(
        f"chromothripsis/bladder_cancer/ont/cnvs/coverage/{sample}tumor.cov.gz",
        sep="\t",
    )
    for version in ["_filtered"]:
        ss = sample

        if sample == "B154":
            version = "_filtered_estimate"
        if sample == "B178":
            version = "_filtered_estimate2"
        scrna_copykat = pd.read_csv(
            f"copykat/{sample}{version}/{ss}_copykat_CNA_raw_results_gene_by_cell.txt",
            sep="\t",
        )
        scrna_copykat = scrna_copykat.set_index("hgnc_symbol").iloc[:, 6:]
        scrna_copykat = scrna_copykat.T

        gene_pos_bins = prepare_annotations(gene_pos, binsize=25)
        g2b = gene2bin(scrna_copykat, gene_pos_bins)
        ont["segment"] = ont.apply(
            lambda x: f"{x['chr']}:{x['start']}-{x['end']}", axis=1
        )
        interect_segs = g2b.columns.intersection(ont["segment"])
        ont = ont.set_index("segment")
        ont = ont.loc[interect_segs, :]
        g2b = g2b[interect_segs]
        cors = g2b.corrwith(ont[ont.columns[-1]], axis=1)

        copykatpred = pd.read_csv(
            f"copykat/{sample}{version}/{sample}_copykat_prediction.txt",
            sep="\t",
            index_col=0,
        )
        estimate = pd.read_csv(
            f"chromothripsis/j462r/spatial_transcriptomics/ESTIMATE_scores/{sample}_ESTIMATE.csv",
            index_col=0,
        )
        patho = pd.read_csv(
            f"chromothripsis/j462r/spatial_transcriptomics/CNVs/healthy_annotations/{sample}_healthy.csv",
            index_col=0,
        )

        metadata = pd.concat(
            [pd.DataFrame(cors, columns=["correlation"]), copykatpred, estimate, patho],
            axis=1,
        ).rename(columns={"Healthy": "Pathologist"})
        metadata["Pathologist"] = metadata["Pathologist"].fillna("Tumor")

        g2b.to_csv(f"copykat/{sample}{version}/{sample}_copykat_bins.csv")
        ont.to_csv(f"copykat/{sample}{version}/{sample}_ont_bins.csv")
        metadata.to_csv(f"copykat/{sample}{version}/{sample}_metadata.csv")

### CNV score

In [None]:
master_cnv = pd.DataFrame()
master_meta = pd.DataFrame()
for sample in samples:

    if sample == "B154":
        version = "_filtered_estimate"
    elif sample == "B178":
        version = "_filtered_estimate2"
    else:
        version = "_filtered"

    print(f"Working on sample {sample}")
    metadata = pd.read_csv(
        f"copykat/{sample}{version}/{sample}_metadata.csv", index_col=0
    )
    df = pd.read_csv(
        f"copykat/{sample}{version}/{sample}_copykat_bins.csv", index_col=0
    )
    cnv_score = pd.DataFrame(np.abs(df).mean(axis=1), columns=["cnv_score"])
    cnv_score.index = cnv_score.index + f"_{sample}"
    metadata.index = metadata.index + f"_{sample}"
    cnv_score = cnv_score.loc[metadata["copykat.pred"] == "aneuploid", :]
    master_cnv = pd.concat([master_cnv, cnv_score])
    master_meta = pd.concat([master_meta, metadata])

In [None]:
low_grade = ["B22", "B24", "B60", "B154", "B156"]
high_grade = ["B4", "B42", "B123", "B175", "B178"]
ecdna = ["B4", "B42", "B123"]
nonecdna = ["B22", "B24", "B60", "B154", "B156", "B175", "B178"]
nmibc = ["B22", "B24", "B60", "B154", "B156", "B175", "B178", "B4"]
mibc = ["B42", "B123"]

In [None]:
master_cnv["sample"] = master_cnv.index.str.split("_").str[-1]

In [None]:
master_cnv["grade"] = master_cnv["sample"].apply(
    lambda x: "High grade" if x in high_grade else "Low grade"
)
master_cnv["type"] = master_cnv["sample"].apply(
    lambda x: "MIBC" if x in mibc else "NMIBC"
)
master_cnv["ecdna"] = master_cnv["sample"].apply(
    lambda x: "ecDNA" if x in ecdna else "non-ecDNA"
)

In [None]:
master_meta["sample"] = master_meta.index.str.split("_").str[-1]

In [None]:
master_meta["grade"] = master_meta["sample"].apply(
    lambda x: "High grade" if x in high_grade else "Low grade"
)
master_meta["type"] = master_meta["sample"].apply(
    lambda x: "MIBC" if x in mibc else "NMIBC"
)
master_meta["ecdna"] = master_meta["sample"].apply(
    lambda x: "ecDNA" if x in ecdna else "non-ecDNA"
)

In [None]:
sns.set_theme(style="white", rc=PLOTTING_PARAMS)

fig, axes = plt.subplots(1, 1, figsize=(6, 4))
sns.violinplot(
    master_cnv,
    x="ecdna",
    y="cnv_score",
    ax=axes,
    palette=["#949494", "#78D3D3"],
    inner="quart",
)
sns.despine(top=True, right=True, left=False, bottom=False)

axes.set_ylabel("CNV score", size=14)
axes.set_xlabel("ecDNA status", size=14)
plt.tight_layout()
plt.savefig("suppfig_6B_cnv_score.svg", dpi=300)
plt.close()
sns.reset_defaults()

In [None]:
print(
    mannwhitneyu(
        master_cnv.cnv_score[master_cnv.ecdna == "ecDNA"],
        master_cnv.cnv_score[master_cnv.ecdna == "non-ecDNA"],
    )
)

## Defining the number of subclones

In [None]:
def optimal_leiden(adata, n_pcs=10):
    """Perform sweep of Leiden resolutions and report the optimal resolution."""
    ones = []
    max_calinski = 0
    optimal_leiden = ""
    for resolution in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:
        sc.tl.leiden(
            adata, resolution=resolution, key_added=f"leiden_{str(resolution)}"
        )
        try:
            calinski_avg = calinski_harabasz_score(
                adata.obsm["X_pca"][:, :n_pcs],
                adata.obs[f"leiden_{str(resolution)}"].astype(int),
            )
            res_cluster = f"leiden_{str(resolution)}"
            print(
                f"Resolution: {resolution}, N clusters: {adata.obs[res_cluster].nunique()},  Calinski Score: {calinski_avg}"
            )

            if calinski_avg > max_calinski:
                max_calinski = calinski_avg
                optimal_leiden = f"leiden_{str(resolution)}"
        except ValueError:
            print(f"Only 1 cluster for leiden_{str(resolution)}")
            ones.append(f"leiden_{str(resolution)}")

    if len(ones) > 1:
        return ones[-1]

    return optimal_leiden


def compute_pseudobulk_correlations(acnvs, groupby="subclones", threshold=0.95):
    """Calculate pairwise correlations between all the pseudobulk subclones."""
    cluster_score = {
        cluster: acnvs.X[acnvs.obs[groupby].values == cluster, :].mean(axis=0)
        for cluster in acnvs.obs[groupby].unique()
    }
    pseudobulks = pd.DataFrame.from_dict(cluster_score, orient="index")
    pseudobulks = pseudobulks.T.corr()
    corr_matrix = pseudobulks.to_numpy()
    print(pseudobulks)

    # Get the indices of the elements above the threshold
    indices = np.where((corr_matrix > threshold) & (corr_matrix < 1))
    print(indices)

    # Extract the pairs of columns and rows
    high_corr_pairs = [
        (pseudobulks.columns[i], pseudobulks.columns[j])
        for i, j in zip(indices[0], indices[1])
    ]
    print(high_corr_pairs)
    sorted_high_corr_pairs = {
        tuple(sorted((pair[0], pair[1]))) for pair in high_corr_pairs
    }
    print(sorted_high_corr_pairs)

    return dict((y, x) for x, y in sorted_high_corr_pairs)

In [None]:
for sample in samples:

    print(f"Working on sample {sample}")

    if sample == "B154":
        version = "_filtered_estimate/"
    elif sample == "B178":
        version = "_filtered_estimate2/"
    else:
        version = "_filtered/"

    cnvs = pd.read_csv(
        f"chromothripsis/j462r/spatial_transcriptomics/CNVs/copykat/{sample}{version}{sample}_copykat_bins.csv",
        index_col=0,
    )
    acnvs = ad.AnnData(cnvs)
    adata = sc.read_h5ad(
        f"chromothripsis/j462r/spatial_transcriptomics/scripts/h5ad_objects/{sample}.h5ad"
    )

    adata.obs.index = adata.obs.index.str.split("_").str[0]

    assert all(acnvs.obs_names == adata.obs_names)

    acnvs.uns["spatial"] = adata.uns["spatial"]
    acnvs.obsm["spatial"] = adata.obsm["spatial"]

    copykat = pd.read_csv(
        f"chromothripsis/j462r/spatial_transcriptomics/CNVs/copykat/{sample}{version}{sample}_copykat_prediction.txt",
        sep="\t",
        index_col=0,
    )
    acnvs.obs = acnvs.obs.join(copykat)

    adata_subset = acnvs[acnvs.obs["copykat.pred"] == "aneuploid", :].copy()
    adata_diploid = acnvs[acnvs.obs["copykat.pred"] == "diploid", :].copy()

    sc.pp.pca(adata_subset)
    print(
        f"Variance explained in the first {npcs}: {adata_subset.uns['pca']['variance_ratio'][:npcs].sum()}"
    )
    sc.pp.neighbors(adata_subset, n_pcs=npcs)
    sc.tl.umap(adata_subset)

    # sweep of different Leiden resolutions
    leiden = optimal_leiden(adata_subset, n_pcs=npcs)

    # The second best solution for the sample B4
    if sample == "B4":
        leiden = f"leiden_0.3"

    adata_together = ad.concat([adata_subset, adata_diploid], join="outer")
    adata_together.obs[leiden] = (
        adata_together.obs[leiden].astype(str).replace("nan", "diploid")
    )
    adata_together.uns["spatial"] = acnvs.uns["spatial"]

    acnvs = acnvs[adata_together.obs_names, :]
    acnvs.obs[leiden] = adata_together.obs[leiden]

    current_groups = {}
    for group in adata_together.obs[leiden].unique():
        current_groups[group] = group

    print(current_groups)
    print(acnvs.obs)

    to_merge = compute_pseudobulk_correlations(acnvs, groupby=leiden)
    print(to_merge)

    if len(to_merge) > 0:
        print("Merging clones")
        for group in adata_together.obs[leiden].unique():
            if group in to_merge:
                print(f"Changing from {current_groups[group]} to {to_merge[group]}")
                current_groups[group] = to_merge[group]
        adata_together.obs[f"{leiden}_merged"] = adata_together.obs[leiden].map(
            current_groups
        )
        adata_subset.obs[f"{leiden}_merged"] = adata_subset.obs[leiden].map(
            current_groups
        )
    else:
        adata_together.obs[f"{leiden}_merged"] = adata_together.obs[leiden]
        adata_subset.obs[f"{leiden}_merged"] = adata_subset.obs[leiden]
    adata_together.obs[[f"{leiden}_merged"]].rename(
        columns={f"{leiden}_merged": "subclones"}
    ).to_csv(f"{sample}{version}{sample}_leiden_subclones_final.csv")

    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    sc.pl.umap(
        adata_subset,
        color=f"{leiden}_merged",
        title="Subclones",
        show=False,
        ax=ax[0],
        palette=colors_clones,
    )
    sc.pl.spatial(
        adata_together,
        color=f"{leiden}_merged",
        title="Subclones",
        size=1.5,
        show=False,
        ax=ax[1],
        palette=colors_clones,
    )

    fig.tight_layout()
    plt.savefig(f"{sample}{version}{sample}_leiden_subclones_final.svg", dpi=300)
    plt.close()

### Number of clones per sample

In [None]:
clones_per_sample = {}
master = pd.DataFrame()
for sample in samples:

    print(f"Working on sample {sample}")

    if sample == "B154":
        version = "_filtered_estimate/"
    elif sample == "B178":
        version = "_filtered_estimate2/"
    else:
        version = "_filtered/"

    subclones = pd.read_csv(
        f"chromothripsis/j462r/spatial_transcriptomics/CNVs/copykat/{sample}{version}{sample}_leiden_subclones_final.csv",
        index_col=0,
    )
    # Normalize the number of subclones to the total number of aneuploid spots
    clones_per_sample[sample] = (
        subclones.loc[subclones["subclones"] != "diploid"]["subclones"].nunique()
        / subclones.loc[subclones["subclones"] != "diploid"].shape[0]
    )
    sub = subclones["subclones"].value_counts(normalize=True).reset_index()
    sub["sample"] = sample
    master = pd.concat([master, sub])

In [None]:
adata = sc.read_h5ad(
    f"chromothripsis/j462r/spatial_transcriptomics/scripts/h5ad_objects/merged_samples.h5ad"
)
metadata = pd.read_csv(
    "chromothripsis/j462r/spatial_transcriptomics/CNVs/copykat/metadata_all.csv",
    index_col=0,
)

adata.obs = adata.obs.join(metadata)
adata = adata[adata.obs["copykat.pred"] == "aneuploid", :].copy()

In [None]:
no_spots = adata.obs.groupby("sample").count().reset_index().iloc[:, :2]
no_spots.rename(columns={"sample": "index"}, inplace=True)

In [None]:
clones = pd.DataFrame.from_dict(
    clones_per_sample, orient="index", columns=["subclones"]
).reset_index()

In [None]:
low_grade = ["B22", "B24", "B60", "B154", "B156"]
clones["grade"] = clones["index"].apply(
    lambda x: "Low grade" if x in low_grade else "High grade"
)
clones["ecDNA_status"] = clones["index"].apply(
    lambda x: "ecDNA" if x in ["B123", "B4", "B42"] else "non-ecDNA"
)
clones = clones.set_index("index")

In [None]:
sns.set_theme(style="white", rc=PLOTTING_PARAMS)
plt.figure(figsize=(5, 4))
sns.barplot(
    clones,
    x="ecDNA_status",
    y="subclones",
    hue="ecDNA_status",
    palette=["#949494", "#78D3D3"],
    width=0.5,
    errorbar="sd",
)
sns.stripplot(clones, x="ecDNA_status", y="subclones", color="black")
sns.despine(top=True, right=True, left=False, bottom=False)
plt.xlabel("ecDNA status", size=14)
plt.ylabel("Normalized number of subclones", size=14)
plt.xticks(rotation=0, size=12)
plt.yticks(size=12)
plt.tight_layout()
plt.savefig("suppfig_6C_subclones.svg", dpi=300)
plt.show()

sns.reset_defaults()

In [None]:
print(
    mannwhitneyu(
        clones.subclones[clones.ecDNA_status == "ecDNA"],
        clones.subclones[clones.ecDNA_status == "non-ecDNA"],
    )
)

### Generate pseudobulks per subclone

In [None]:
def _get_group_mean(adata, group):
    group_mean = np.mean(adata.X[adata.obs["subclones"].values == group, :], axis=0)
    if len(group_mean.shape) == 1:
        # derived from an array instead of sparse matrix -> 1 dim instead of 2
        group_mean = group_mean[np.newaxis, :]
    return group_mean

In [None]:
dfs = []

for sample in samples:

    print(f"Working on sample {sample}")

    if sample == "B154":
        version = "_filtered_estimate/"
    elif sample == "B178":
        version = "_filtered_estimate2/"
    else:
        version = "_filtered/"

    cnvs = pd.read_csv(f"{sample}{version}{sample}_copykat_bins.csv", index_col=0)
    acnvs = ad.AnnData(cnvs)

    clones = pd.read_csv(
        f"{sample}{version}{sample}_leiden_subclones_final.csv", index_col=0
    )
    acnvs.obs = acnvs.obs.join(clones)

    X = np.vstack(
        [
            np.repeat(_get_group_mean(acnvs, group), 1, axis=0)
            for group in acnvs.obs["subclones"].unique()
        ]
    )
    df = pd.DataFrame(X, index=acnvs.obs["subclones"].unique(), columns=acnvs.var.index)

    df["sample"] = sample
    df.reset_index(inplace=True)

    dfs.append(df)

In [None]:
merged = pd.concat(dfs).dropna(axis=1).reset_index()
merged.rename(columns={"index": "subclone"}, inplace=True)
merged["subclone"] = merged["subclone"].apply(
    lambda x: x if x == "diploid" else "aneuploid" if isinstance(x, str) else x
)

In [None]:
merged.to_csv("./subclones_pseudobulks.csv")

## Prepare plots for figure 2

In [None]:
adata = sc.read_h5ad(f"../../scripts/h5ad_objects/B123.h5ad")

In [None]:
cnv_scores = pd.read_csv("../../CNVs/copykat/CNV_scores.csv", index_col=0)
cnv_scores

In [None]:
adata.obs = adata.obs.join(cnv_scores)

In [None]:
clones = pd.read_csv("B123_filtered/B123_leiden_subclones_final.csv", index_col=0)
clones

In [None]:
clones.index = clones.index + "_B123"
adata.obs = adata.obs.join(clones)

In [None]:
sc.pl.spatial(
    adata,
    color=[None, "cnv_score", "subclones"],
    title=["H&E", "CNV score", "Subclones"],
    palette=colors_clones,
    cmap="viridis",
    size=1.5,
    wspace=0.05,
    show=False,
)
plt.savefig("fig2F_B123_space.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

### To generate the heatmap from the figure 2E and supplementary figure 6A run `../scripts/Visium_plot_heatmap_pseudobulk_subclones.R` 