In [None]:
import os
import simba as si
from scipy import sparse
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mc
import seaborn as sns


si.__version__

In [None]:
si.settings.set_figure_params(dpi=80,
                              style='white',
                              fig_size=[5,5],
                              rc={'image.cmap': 'viridis'})

# make plots prettier
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('retina')

In [None]:
workdir = '/mnt/d/JorritvU/SIMBA/tests/0205-test/'
si.settings.set_workdir(workdir)

In [None]:
adata_s143 = si.read_h5ad('/mnt/d/JorritvU/Tripolar/scRNA-seq/s143/old/SNV/s143.germline.updated.h5ad')

In [None]:
adata_s145 = si.read_h5ad('/mnt/d/JorritvU/Tripolar/scRNA-seq/s145/old/SNV/s145.germline.updated.h5ad')

In [None]:
adata_chi006 = si.read_h5ad('/mnt/d/JorritvU/Tripolar/scDNA-seq/CHI-006/processed/SNV/CHI-006.germline_v2.h5ad')

adata_chi007 = si.read_h5ad('/mnt/d/JorritvU/Tripolar/scDNA-seq/CHI-007/processed/SNV/CHI-007.germline_v2.h5ad')

In [None]:
adata_s143.X.A

In [None]:
adata_s145.X.A

In [None]:
# DNA samples are already sparse matrices.
adata_chi006.X.A

In [None]:
adata_chi007.X

Now we filter on Allele Frequency (AF). <br/>
For now arbitrary number (0.1). 

In [None]:
adata_s143.var['pass'] = adata_s143.var['AF'] > 0.3

In [None]:
adata_s145.var['pass'] = adata_s145.var['AF'] > 0.3

In [None]:
adata_chi006.var['pass'] = adata_chi006.var['AF'] > 0.3

In [None]:
adata_chi007.var['pass'] = adata_chi007.var['AF'] > 0.3

## Merge the RNA runs into 1, and merge the DNA runs into 1

In [None]:
import anndata as ad

def merge_datasets(adata1, adata2):
    common_vars = list(set(adata1.var_names).intersection(set(adata2.var_names)))
    print(f"Number of intersecting SNVs: {len(common_vars)}")
    adata1 = adata1[:, common_vars]
    adata2 = adata2[:, common_vars]
    adata = ad.concat([adata1, adata2], merge='first', join='inner')
    return adata

In [None]:
adata_dna = merge_datasets(adata_chi006, adata_chi007)  
print(adata_dna)
adata_rna = merge_datasets(adata_s143, adata_s145)  
print(adata_rna)


# DNA and RNA

Filter the datasets on NAs, filter out AF > 0.05 (above we set the threshold to 0.1, so the one here does nothing).

Intersect the common vars (i.e. the common SNVs). 

This results in *5933* SNVs in each dataset.

In [None]:
"""
Filter the NaN samples and based on AF > 0.05.
If other sample types should be excluded, change code here.
"""

data = {'rna': adata_rna, 'dna': adata_dna}

for k in ['rna', 'dna']:
    data[f"{k}_filtered"] = data[k][~data[k].obs_names.str.contains('nan|Control', na=False), data[k].var['AF'] > 0.05].copy()
    
    print(f"{k}_filtered: {data[f'{k}_filtered'].shape}")


filtered_datasets = [d for d in data.keys() if 'filtered' in d]
common_vars = set(data[filtered_datasets[1]].var_names).intersection(set(data[filtered_datasets[0]].var_names))

filtered_datasets = [d for d in data.keys() if 'filtered' in d]
# Filtered datasets 1 = DNA
for key in filtered_datasets:
    data[key] = data[key][:, list(common_vars)]

data

## Add proportions to metadata

Here we add the proportions of each variant across the cells and across the SNVs.

Resulting in 3 lists in the variable metadata, and 3 lists in the observable metadata.

Proportions are calculated as the sum of the variant divided by the total count.

In [None]:
# Access the .X attribute where your SNV data is stored
# Convert it to a dense matrix if it's stored as a sparse matrix
for key in filtered_datasets:
    adata = data[key].copy()
    X_dense = data[key].X.toarray()
    
    # Count occurrences of each variant type (1, 2, 3) per cell
    variant_counts = np.apply_along_axis(lambda x: np.bincount(x, minlength=4)[1:], axis=1, arr=X_dense)
    
    # Calculate proportions
    variant_proportions = variant_counts / variant_counts.sum(axis=1, keepdims=True)
    
    # Add proportions back to adata as layers or as part of obs (depending on your preference)
    # Add variant proportions as separate columns in adata.obs
    data[key].obs['variant_1_proportion_0/0'] = variant_proportions[:, 0]
    data[key].obs['variant_2_proportion_0/1'] = variant_proportions[:, 1]
    data[key].obs['variant_3_proportion_1/1'] = variant_proportions[:, 2]


    snv_counts = np.zeros((data[key].n_vars, 3), dtype=int)
    
    # Iterate over each variant type and count occurrences per SNV
    for variant_type in range(1, 4):
        snv_counts[:, variant_type-1] = np.sum(X_dense == variant_type, axis=0)

    # Calculate the total counts per SNV to use for proportion calculation
    total_snv_counts = snv_counts.sum(axis=1, keepdims=True)
    
    # Calculate proportions of each variant type per SNV
    snv_proportions = snv_counts / total_snv_counts

    # Add SNV proportions to the .var DataFrame
    data[key].var['variant_1_proportion_0/0'] = snv_proportions[:, 0]
    data[key].var['variant_2_proportion_0/1'] = snv_proportions[:, 1]
    data[key].var['variant_3_proportion_1/1'] = snv_proportions[:, 2]

## Functions for the plotting.

In [None]:
def heatmap(adata, phenotype=True, batches=True, name='heatmap', format="pdf", workdir=""):
    current_data = adata
    obs_data = current_data.obs.reset_index(drop=True)
    plt.figure(figsize=(22, 8))

    
        
    batch_palette = sns.color_palette("hls", len(obs_data['Batch'].unique()))
    batch_color_map = {batch: color for batch, color in zip(obs_data['Batch'].unique(), batch_palette)}
    batch_colors = [batch_color_map[batch] for batch in obs_data['Batch']]
    
    batch_hex_colors = {key: mc.to_hex(value) for key, value in batch_color_map.items()}
    
    phenotype_palette  = sns.color_palette("bright", len(obs_data['Phenotype'].unique()))
    phenotype_color_map = {phenotype: color for phenotype, color in zip(obs_data['Phenotype'].unique(), phenotype_palette)}
    phenotype_colors  = [phenotype_color_map[phenotype] for phenotype in obs_data['Phenotype']]
    
    phenotype_hex_colors = {key: mc.to_hex(value) for key, value in phenotype_color_map.items()}
    
    # Create a heatmap

    columns = []
    legend_colors = {}
    if batches:
        columns.append(batch_colors)
        legend_colors.update(**batch_hex_colors)
    if phenotype:
        columns.append(phenotype_colors)
        legend_colors.update(**phenotype_hex_colors)

    heatmap_legend = {"label": "SNV variant", "ticks": [1,2,3]}
    g = sns.clustermap(current_data.X.A.T, col_colors = columns, cmap="viridis", yticklabels=False, row_cluster=True, cbar_kws=heatmap_legend)
    g.cax.set_yticklabels(['0/0', '0/1', '1/1'])
    # Define legend patches
    legend_patches = [
        mpatches.Patch(color=color, label=batch) for batch, color in legend_colors.items()
    ]    
    
    plt.legend(handles=legend_patches, title="Metadata", bbox_to_anchor=(2.00, -0.4), loc='upper right')
    plt.savefig(f"{workdir}/{name}.{format}", format=format, dpi=300)
    plt.show()


def stacked_barplot_variants(current_data, axis):
    variants = ['variant_1_proportion_0/0', 'variant_2_proportion_0/1', 'variant_3_proportion_1/1']
    
    batches = current_data.obs['Batch'].unique()
    stacked_data = current_data.obs.groupby('Batch')[variants].mean()
    
    # Example: Visualizing the mean proportion of variant 2 per Batch
    # Set up subplots
    

    # Stacked bar plot for each variant proportion per batch on the first subplot
    bottom = np.zeros(len(batches))
    for i, variant in enumerate(variants):
        axis.bar(batches, stacked_data[variant], bottom=bottom, label=f'{variant.split("_")[-1]}')
        bottom += stacked_data[variant].values

    axis.set_title('Stacked Mean Proportion of Variants per Batch')
    axis.set_ylabel('Mean Proportion')
    axis.set_xlabel('Batch')
    axis.tick_params(axis='x', rotation=45)
    axis.legend(loc='upper right')

    return axis

def variant_linegraph(current_data, axis, max_n=100):
    variants = ['variant_1_proportion_0/0', 'variant_2_proportion_0/1', 'variant_3_proportion_1/1']
    
    snv_locations = range(max_n)
    for i, variant in enumerate(variants):
        axis.plot(snv_locations, current_data.var[variant][:len(snv_locations)], label=f'{variant.split("_")[-1]}')

    axis.set_xlabel('SNV Location')
    axis.set_ylabel('Proportion')
    axis.set_title('Variant Proportions Across SNV Locations')
    axis.legend(loc='upper right')

    return axis

In [None]:
workdir = "/mnt/d/JorritvU/Tripolar/SNV_Profile"

## Stacked barplot + linegraph for variants

In [None]:
for key in filtered_datasets:
    current_data = data[key]
    fig, axes = plt.subplots(1, 2, figsize=(20, 6))
    fig.suptitle(f'Overview of SNVs for {key}')
    
    axes[0] = stacked_barplot_variants(current_data, axes[0])    
    axes[1] = variant_linegraph(current_data, axes[1])

    # Show the plot for the current dataset
    plt.tight_layout(rect=[0, 0, 1, 0.95]) 
    plt.show()

    heatmap(data[key], workdir=workdir, name=f"{key}_SNV_profile", format="png")
    

## Heatmap for the combined dataset

In [None]:
adata = merge_datasets(data['rna_filtered'], data['dna_filtered'])
print(adata)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(adata, axes[0])    
axes[1] = variant_linegraph(adata, axes[1], 100)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

heatmap(adata, phenotype=False, workdir=workdir, name=f"combined_SNV_profile", format="png")

# SNV filtering
We already added the proportions of each variant per SNV, as seen in the two cells below.

<br/>

Next, we compare the proportions per SNV between DNA and RNA datasets. 
For this we divide each proportion of RNA with the respective DNA proportion, like:

RNA_variant_1_proportion_0/0 / DNA_variant_1_proportion_0/0 <br/>
which is; 0.995595 / 0.873508 <br/>
r = 0.8773734344309039 <br/>
<br/>

If the ratio deviates too much from 1, decided by `window` and the corresponding proportions are above threshold `t`, the SNV is considered to be a _bad_ SNV. <br/>

In the example above, it passes.


In [None]:
data['rna_filtered'].var

In [None]:
data['dna_filtered'].var

In [None]:
def check_snv(adata, adata1, var_name, t=0.01, window=0.2, debug=False):
    print("\nChecking", var_name) if debug else print("", end="")
    variants = ['variant_1_proportion_0/0', 'variant_2_proportion_0/1', 'variant_3_proportion_1/1']
    p1 = [adata.var[k][var_name] for k in variants]
    p2 = [adata1.var[k][var_name] for k in variants]
    ratios = [(p2[i]/p1[i]) for i in range(len(variants))]

    if debug:
        print(f"p1: {p1}")
        print(f"p2: {p2}")
        print(f"ratio: {ratios}")
    
    if max(p1) >= 0.999 or max(p2) >= 0.999:
        print(f"BAD: {var_name} solely 1 variant") if debug else print("", end="")
        return False
        
    for i, r in enumerate(ratios):
        if r < 1-window or r > 1+window:
            if p1[i] > t and p2[i] > t:
                print("BAD RATIO") if debug else print("", end="")
                return False

    return True

In [None]:
snvs = list(data['rna_filtered'].var_names)

len_before = len(snvs)
good_snvs = []
bad_snvs = []

for s in snvs:
    keep = check_snv(data['rna_filtered'], data['dna_filtered'], s, window=0.3, t=0.1, debug=False)
    if keep:
        good_snvs.append(s)
    else:
        bad_snvs.append(s)
len_after = len(good_snvs)

print(f"No. SNVs before: {len_before}")
print(f"No. SNVs after: {len_after}")
print(f"Percentage thrown out: {100-round(len_after/len_before*100, 2)}%")



### Proportion check

The results: <br/>
p1: `[0.9955947136563876, 0.004405286343612335, 0.0]`<br/>
p2: `[0.8735083532219571, 0.12171837708830549, 0.00477326968973747]`<br/>
ratio: `[0.8773734344309039, 27.630071599045344, inf]`<br/>

Eventhough the second ratio is 27, which is outside the `window=0.2`, the corresponding proportion of p1, is below the treshold `t=0.1`


In [None]:
check_snv(data['rna_filtered'], data['dna_filtered'], "chr22:46475197_G/A", debug=True)

In [None]:
adata_filtered = adata[:, list(good_snvs)]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(adata_filtered, axes[0])    
axes[1] = variant_linegraph(adata_filtered, axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

heatmap(adata_filtered, workdir=workdir, name=f"Filtered_combined_SNV_profile", format="png")

## The SNVs that are filtered out

In [None]:
adata_filtered = adata[:, bad_snvs]

fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(adata_filtered, axes[0])    
axes[1] = variant_linegraph(adata_filtered, axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

heatmap(adata_filtered, workdir=workdir, name=f"rejected_SNV_profile", format="png")

## Strict filtering

In [None]:
def strict_filter(adata, adata1, var_name, t=0.01, window=0.03, debug=False):
    print("\nChecking", var_name) if debug else print("", end="")
    variants = ['variant_1_proportion_0/0', 'variant_2_proportion_0/1', 'variant_3_proportion_1/1']
    p1 = [adata.var[k][var_name] for k in variants]
    p2 = [adata1.var[k][var_name] for k in variants]

    if max(p1) >= 0.999 or max(p2) >= 0.999:
        print(f"BAD: {var_name} solely 1 variant") if debug else print("", end="")
        return False
    
    for i in range(len(variants)):
        rna_p = p1[i]
        dna_p = p2[i]
        if dna_p + window > rna_p and dna_p - window < rna_p:
            print(f"{dna_p} is very close to being equal to {rna_p}") if debug else print("", end="")
        else:
            return False
    print("All three are good") if debug else print("", end="")
    return True

In [None]:
snvs = list(adata.var_names)

len_before = len(snvs)
good_snvs = []
bad_snvs = []

for s in snvs:
    keep = strict_filter(data['rna_filtered'], data['dna_filtered'], s)
    if keep:
        good_snvs.append(s)
    else:
        bad_snvs.append(s)
len_after = len(good_snvs)

print(f"No. SNVs before: {len_before}")
print(f"No. SNVs after: {len_after}")
print(f"Percentage thrown out: {100-round(len_after/len_before*100, 2)}%")


In [None]:
adata_filtered = adata[:, list(good_snvs)]

fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(adata_filtered, axes[0])    
axes[1] = variant_linegraph(adata_filtered, axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

heatmap(adata_filtered)

In [None]:
adata_filtered = adata[:, list(bad_snvs)]

fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(adata_filtered, axes[0])    
axes[1] = variant_linegraph(adata_filtered, axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

heatmap(adata_filtered)

In [None]:
data['rna_filtered'][:, good_snvs].var

In [None]:
data['dna_filtered'][:, good_snvs].var

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(data['dna_filtered'][:, good_snvs], axes[0])    
axes[1] = variant_linegraph(data['dna_filtered'][:, good_snvs], axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
fig.suptitle(f'Overview of SNVs for merged adata')
    
axes[0] = stacked_barplot_variants(data['rna_filtered'][:, good_snvs], axes[0])    
axes[1] = variant_linegraph(data['rna_filtered'][:, good_snvs], axes[1], 250)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

In [None]:
def variant_linegraph_residuals(adata, adata1, max_n=100):
    fig, axes = plt.subplots(1, 3, figsize=(25, 6))
    fig.suptitle(f'Difference between RNA - DNA proportions')
    variants = ['variant_1_proportion_0/0', 'variant_2_proportion_0/1', 'variant_3_proportion_1/1']
    
    snv_locations = range(max_n)
    for i, variant in enumerate(variants):
        axes[i].plot(snv_locations, adata.var[variant][:len(snv_locations)] - adata1.var[variant][:len(snv_locations)], label=f'{variant.split("_")[-1]}')
        axes[i].set_xlabel('SNV Location')
        axes[i].set_ylabel('Proportion')
        axes[i].set_title(f'{variant}')
        #axes[i].legend(loc='upper right')


    plt.show()

In [None]:
variant_linegraph_residuals(data['rna_filtered'][:, good_snvs], data['dna_filtered'][:, good_snvs], max_n = 600)

In [None]:
import session_info
session_info.show(dependencies=True)

In [None]:
phenotype_count = {}

ps = list(adata_s143.obs['Phenotype'])

for p in ps:
    if p not in phenotype_count.keys():
        phenotype_count[p] = 1
    else:
        phenotype_count[p] += 1

print(phenotype_count)
