# Experimental
Notebook to generate * associated supplementary results

## Reqs

In order to run this notebook you have to download / get following file in addition to the files already present in the repository :
- `data/cmc/cmc_mutation_context.csv` available from Zenodo
- `data/clinvar/clinvar_mutation_context.csv` available from Zenodo
- `hg38/kmer_counts_3N.csv` available from github and Zenodo
- `hg38/kmer_counts_5N.csv` available from github and Zenodo

In [1]:
import polars as pl
from pathlib import Path
from utils import load_data, utils
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

GENERATE_RESULTS = (
    True  # if set to true it will generate results in outputs/ directory.
)
PROJECT_DIR = Path().absolute().parent

In [2]:
def compute_correlation(
    aggregated, category_col, categories, stat_method=stats.pearsonr
):
    """Correlation between F and vip for each category."""
    return utils.get_table_of_corr(
        aggregated.with_columns(pl.lit("_").alias("db")),
        categories=("db", ["_"]),
        sub_categories=(category_col, categories),
        corr_method=stat_method,
    )


def compute_corrected_correlation(
    aggregated, category_col, categories, x, y, z, stat_method=stats.pearsonr
):
    """Partial correlation between F and vip, controlling for GC content of context."""
    return pl.DataFrame(
        [
            {
                category_col: cat,
                **dict(
                    zip(
                        ["partial_r", "p_value"],
                        utils.partial_correlation(
                            aggregated.filter(pl.col(category_col) == cat),
                            x,
                            y,
                            z,
                            stat_method=stat_method,
                        ),
                    )
                ),
            }
            for cat in categories
        ]
    )


def plot_scatter(aggregated, hue):
    """Scatter plot of vip vs F coloured by category."""
    fig, ax = plt.subplots()
    sns.scatterplot(data=aggregated.to_pandas(), x="vip", y="F", hue=hue, ax=ax)
    plt.show()
    return fig


def _partial_r(x, y, z, stat_method=stats.pearsonr):
    """Partial correlation r(x, y | z) from arrays."""
    r_xy, _ = stat_method(x, y)
    r_xz, _ = stat_method(x, z)
    r_yz, _ = stat_method(y, z)
    denom = np.sqrt((1 - r_xz**2) * (1 - r_yz**2))
    return (r_xy - r_xz * r_yz) / denom if denom != 0 else np.nan


def permutation_test(
    aggregated_data,
    x: str,
    y: str,
    categories_col: str,
    n_permutations: int = 10000,
    z: str = None,
    stat_method=stats.pearsonr,
):
    """
    Permutation test for the difference in correlation between two groups.

    If z is provided, uses partial correlation r(x, y | z) instead of r(x, y).
    stat_method is applied for all correlation computations.
    """
    categories = aggregated_data[categories_col].unique().to_list()
    if len(categories) != 2:
        raise Exception(f"Number of categories is not 2! {categories}")
    for col in [x, y] + ([z] if z else []):
        if col not in aggregated_data.columns:
            raise Exception(
                f"{col} not in aggregated_data columns: {aggregated_data.columns}"
            )

    cat1 = aggregated_data.filter(pl.col(categories_col) == categories[0])
    cat2 = aggregated_data.filter(pl.col(categories_col) == categories[1])

    x1, x2 = cat1[x].to_numpy(), cat2[x].to_numpy()
    y1, y2 = cat1[y].to_numpy(), cat2[y].to_numpy()

    if z is not None:
        z1, z2 = cat1[z].to_numpy(), cat2[z].to_numpy()
        r1 = _partial_r(x1, y1, z1, stat_method=stat_method)
        r2 = _partial_r(x2, y2, z2, stat_method=stat_method)
    else:
        r1, _ = stat_method(x1, y1)
        r2, _ = stat_method(x2, y2)

    observed_diff = r1 - r2
    print(f"r {categories[0]}: {r1:.4f}")
    print(f"r {categories[1]}: {r2:.4f}")
    print(f"Observed diff: {observed_diff:.4f}")

    combined_x = np.concatenate([x1, x2])
    combined_y = np.concatenate([y1, y2])
    combined_z = np.concatenate([z1, z2]) if z is not None else None
    n1 = len(cat1)

    rng = np.random.default_rng(42)
    perm_diffs = np.zeros(n_permutations)

    for i in range(n_permutations):
        idx = rng.permutation(len(combined_x))
        xp = combined_x[idx]
        if z is not None:
            zp = combined_z[idx]
            r1p = _partial_r(xp[:n1], combined_y[:n1], zp[:n1], stat_method=stat_method)
            r2p = _partial_r(xp[n1:], combined_y[n1:], zp[n1:], stat_method=stat_method)
        else:
            r1p, _ = stat_method(xp[:n1], combined_y[:n1])
            r2p, _ = stat_method(xp[n1:], combined_y[n1:])
        perm_diffs[i] = r1p - r2p

    p_value = np.mean(np.abs(perm_diffs) >= np.abs(observed_diff))
    print(
        f"p-value ({'partial, z=' + z if z else stat_method.__name__}): {p_value:.4f}"
    )


def run_analysis(
    df, kmer, context_col, category_col, categories, stat_method=stats.pearsonr, prefix:str = None
):
    """Aggregate counts, print correlations and scatter for one (df, kmer) combination."""
    aggregated = utils.aggregate_data(
        df.group_by([category_col, context_col]).len(),
        kmer,
        on=[category_col, "context"],
        over=[category_col],
        context_col=context_col,
    ).with_columns(
        pl.col("context")
        .map_elements(
            lambda ctx: sum(1 for c in ctx if c in "GC"), return_dtype=pl.Int32
        )
        .alias("gc_count")
    )
    corr = compute_correlation(
            aggregated, category_col, categories, stat_method=stat_method
        )
    print("Correlation")
    utils.print_markdown(
        corr
    )
    if GENERATE_RESULTS and prefix is not None:
        output_path = PROJECT_DIR / "outputs" / "diff_pathogenicity" / f"{prefix}_{stat_method.__name__}.tsv"
        output_path.parent.mkdir(parents=True, exist_ok=True)
        corr.write_csv(output_path, separator="\t")
    print("Corrected correlation")
    corrected_corr = compute_corrected_correlation(
            aggregated,
            category_col,
            categories,
            x="vip",
            y="F",
            z="gc_count",
            stat_method=stat_method,
        )
    utils.print_markdown(
        corrected_corr
    )
    if GENERATE_RESULTS and prefix is not None:
        output_path = PROJECT_DIR / "outputs" / "diff_pathogenicity" / f"{prefix}_{stat_method.__name__}_corrected.tsv"
        corrected_corr.write_csv(output_path, separator="\t")
    #plot_scatter(aggregated, category_col)
    print("Permutation test")
    permutation_test(
        aggregated, x="vip", y="F", categories_col=category_col, stat_method=stat_method
    )
    return aggregated

## Load data

In [3]:
cmc = (
    pl.read_csv(
        PROJECT_DIR / "data" / "cmc" / "cmc_mutation_context.csv",
        separator=",",
        schema_overrides={"chromosome": pl.Utf8},
    )
    .group_by(["chromosome", "position", "ref", "alt"])
    .first()
)
cmc_df = cmc.with_columns(
    pl.col("context").str.slice(4, 3).alias("3N"),
    pl.col("context").str.slice(3, 5).alias("5N"),
    pl.lit("cmc").alias("db"),
    pl.when(pl.col.type == pl.lit("Other"))
    .then(pl.lit("passenger"))
    .otherwise(pl.lit("driver"))
    .alias("driver"),
)
pathogenicity_map = {
    "Benign": "Benign",
    "Likely_benign": "Benign",
    "Likely_pathogenic": "Pathogenic",
    "Pathogenic": "Pathogenic",
}

clinvar = pl.read_csv(
    PROJECT_DIR / "data" / "clinvar" / "clinvar_mutation_context.csv", separator=","
)
clinvar_df = (
    clinvar.filter(pl.col("impact") != "Uncertain_significance")
    .with_columns(
        pl.col("context").str.slice(4, 3).alias("3N"),
        pl.col("context").str.slice(3, 5).alias("5N"),
        pl.col("impact").replace_strict(pathogenicity_map, default=pl.lit(None)),
        pl.lit("clinvar").alias("db"),
    )
    .drop_nulls()
)

kmer3 = load_data.get_kmer_df(PROJECT_DIR / "hg38", 3)["context", "frequencies"]
kmer5 = load_data.get_kmer_df(PROJECT_DIR / "hg38", 5)["context", "frequencies"]

## Triplicate

In [4]:
methods = [stats.pearsonr, stats.spearmanr]

for method in methods :
    print(f"\n\n ## {method.__name__} correlation ## \n\n" )
    agg_df = run_analysis(cmc_df, kmer3, "3N", "driver", ["driver", "passenger"], prefix="triplicate_cmc")



 ## pearsonr correlation ## 


Correlation
| driver    |     _ |
|-----------|-------|
| driver    | -0.69 |
| passenger | -0.8  |
Corrected correlation
| driver    |   partial_r |   p_value |
|-----------|-------------|-----------|
| driver    |       -0.52 |         0 |
| passenger |       -0.58 |         0 |
Permutation test
r passenger: -0.8004
r driver: -0.6851
Observed diff: -0.1153
p-value (pearsonr): 0.6508


 ## spearmanr correlation ## 


Correlation
| driver    |     _ |
|-----------|-------|
| driver    | -0.69 |
| passenger | -0.8  |
Corrected correlation
| driver    |   partial_r |   p_value |
|-----------|-------------|-----------|
| driver    |       -0.52 |         0 |
| passenger |       -0.58 |         0 |
Permutation test
r driver: -0.6851
r passenger: -0.8004
Observed diff: 0.1153
p-value (pearsonr): 0.6612


In [5]:
for method in methods :
    print(f"\n\n ## {method.__name__} correlation ## \n\n" )
    agg_df = run_analysis(clinvar_df, kmer3, "3N", "impact", ["Pathogenic", "Benign"], prefix="triplicate_clinvar")



 ## pearsonr correlation ## 


Correlation
| impact     |     _ |
|------------|-------|
| Pathogenic | -0.73 |
| Benign     | -0.83 |
Corrected correlation
| impact     |   partial_r |   p_value |
|------------|-------------|-----------|
| Pathogenic |       -0.49 |      0.01 |
| Benign     |       -0.66 |      0    |
Permutation test
r Benign: -0.8265
r Pathogenic: -0.7345
Observed diff: -0.0920
p-value (pearsonr): 0.7190


 ## spearmanr correlation ## 


Correlation
| impact     |     _ |
|------------|-------|
| Pathogenic | -0.73 |
| Benign     | -0.83 |
Corrected correlation
| impact     |   partial_r |   p_value |
|------------|-------------|-----------|
| Pathogenic |       -0.49 |      0.01 |
| Benign     |       -0.66 |      0    |
Permutation test
r Pathogenic: -0.7345
r Benign: -0.8265
Observed diff: 0.0920
p-value (pearsonr): 0.7181


## Quintuplets

In [6]:
for method in methods:
    print(f"\n\n ## {method.__name__} correlation ## \n\n")
    agg_df = run_analysis(cmc_df, kmer5, "5N", "driver", ["driver", "passenger"], prefix="quintuplet_cmc")



 ## pearsonr correlation ## 


Correlation
| driver    |     _ |
|-----------|-------|
| driver    | -0.54 |
| passenger | -0.76 |
Corrected correlation
| driver    |   partial_r |   p_value |
|-----------|-------------|-----------|
| driver    |       -0.42 |         0 |
| passenger |       -0.54 |         0 |
Permutation test
r passenger: -0.7557
r driver: -0.5369
Observed diff: -0.2188
p-value (pearsonr): 0.0007


 ## spearmanr correlation ## 


Correlation
| driver    |     _ |
|-----------|-------|
| driver    | -0.54 |
| passenger | -0.76 |
Corrected correlation
| driver    |   partial_r |   p_value |
|-----------|-------------|-----------|
| driver    |       -0.42 |         0 |
| passenger |       -0.54 |         0 |
Permutation test
r passenger: -0.7557
r driver: -0.5369
Observed diff: -0.2188
p-value (pearsonr): 0.0007


In [7]:
for method in methods:
    print(f"\n\n ## {method.__name__} correlation ## \n\n")
    agg_df = run_analysis(clinvar_df, kmer5, "5N", "impact", ["Pathogenic", "Benign"], prefix="quintuplet_clinvar")



 ## pearsonr correlation ## 


Correlation
| impact     |     _ |
|------------|-------|
| Pathogenic | -0.69 |
| Benign     | -0.77 |
Corrected correlation
| impact     |   partial_r |   p_value |
|------------|-------------|-----------|
| Pathogenic |       -0.5  |         0 |
| Benign     |       -0.58 |         0 |
Permutation test
r Benign: -0.7734
r Pathogenic: -0.6881
Observed diff: -0.0853
p-value (pearsonr): 0.1677


 ## spearmanr correlation ## 


Correlation
| impact     |     _ |
|------------|-------|
| Pathogenic | -0.69 |
| Benign     | -0.77 |
Corrected correlation
| impact     |   partial_r |   p_value |
|------------|-------------|-----------|
| Pathogenic |       -0.5  |         0 |
| Benign     |       -0.58 |         0 |
Permutation test
r Benign: -0.7734
r Pathogenic: -0.6881
Observed diff: -0.0853
p-value (pearsonr): 0.1677
